GithubHelp home page GithubHelp logo

deq's Introduction

Deep Equilibrium Models

(Version 2.0 released now! ๐Ÿ˜€)

News

๐Ÿ’ฅ2021/6: Repo updated with the multiscale DEQ (MDEQ) code, Jacobian-related analysis & regularization support, and the new, faster and simpler implicit differentiation implementation through PyTorch's backward hook! (See here.)

  • For those who would like to start with a toy version of the DEQ, the NeurIPS 2020 tutorial on "Deep Implicit Layers" has a detailed step-by-step introduction: tutorial video & colab notebooks here.

  • A JAX version of the DEQ, including JAX implementation of Broyden's method, etc. is available here.


This repository contains the code for the deep equilibrium (DEQ) model, an implicit-depth architecture that directly solves for and backpropagtes through the (fixed-point) equilibrium state of an (effectively) infinitely deep network. Importantly, compared to prior implicit-depth approaches (e.g., ODE-based methods), in this work we also demonstrate the potential power and compatibility of this implicit model with modern, structured layers like Transformers, which enable the DEQ networks to achieve results on par with the SOTA deep networks (in NLP and vision) without using a "deep" stacking (and thus O(1) memory). Moreover, we also provide tools for regularizing the stability of these implicit models.

Specifically, this repo contains the code from the following papers (see bibtex at the end of this README):

Prerequisite

Python >= 3.6 and PyTorch >= 1.10. 4 GPUs strongly recommended for computational efficiency.

Data

We provide more detailed instructions for downloading/processing the datasets (WikiText-103, ImageNet, Cityscapes, etc.) in the DEQ-Sequence/ and MDEQ-Vision/ subfolders.

How to build/train a DEQ model?

Starting in 2021/6, we partition the repo into two sections, containing the sequence-model DEQ (i.e., DEQ-Sequence/) and the vision-model DEQ (i.e., MDEQ-Vision/) networks, respectively. As these two tasks require different input processing and loss objectives, they do not directly share the training framework.

However, both frameworks share the same utility code, such as:

  • lib/solvers.py: Advanced fixed-point solvers (e.g., Anderson acceleration and Broyden's method)
  • lib/jacobian.py: Jacobian-related estimations (e.g., Hutchinson estimator and the Power method)
  • lib/optimization.py: Regularizations (e.g., weight normalization and variational dropout)
  • lib/layer_utils.py: Layer utilities

Moreover, the repo is significantly simplified from the previous version for users to extend on it. In particular,

Theorem 2 (Universality of "single-layer" DEQs, very informal): Stacking multiple DEQs (with potentially different classes of transformations) does not create extra representational power over a single DEQ.

(See the paper for a formal statement.) By the theorem above, designing a better DEQ model boils down to designing a better stable transformation f_\theta. Creating and playing with a DEQ is easy, and we recommend following 3 steps (which we adopt in this repo):

Step 1: Defining a layer f=f_\theta that we'd like to iterate until equilibrium.

Typically, this is just like any deep network layer, and should be a subclass of torch.nn.Module. Evaluating this layer requires the hidden unit z and the input injection x; e.g.:

class Layer(nn.Module):
    def __init__(self, ...):
	...
    def forward(self, z, x, **kwargs):
        return new_z

Step 2: Prepare the fixed point solver to use for the DEQ model.

As a DEQ model can use any black-box root solver. We provide PyTorch fixed-point solver implementations anderson(...) and broyden(...) in lib/solvers.py that output a dictionary containing the basic information of the optimization process. By default, we use the relative residual difference (i.e., |f(z)-z|/|z|) as the criterion for stopping the iterative process.

The forward pass can then be reduced to 2 lines:

with torch.no_grad():
    # x is the input injection; z0 is the initial estimate of the fixed point.
    z_star = self.solver(lambda z: f(z, x, *args), z0, threshold=f_thres)['result']

where we note that the forward pass does not need to store any intermediate state, so we put it in a torch.no_grad() block.

Step 3: Engage with the autodiff tape to use implicit differentiation

Finally, we need to ensure there is a way to compute the backward pass of a DEQ, which relies on implicit function theorem. To do this, we can use the register_hook function in PyTorch that registers a backward hook function to be executed in the backward pass. As we noted in the paper, the backward pass is simply solving for the fixed point of a linear system involving the Jacobian at the equilibrium:

new_z_star = self.f(z_star.requires_grad_(), x, *args)

def backward_hook(grad):
    if self.hook is not None:
        self.hook.remove()
        torch.cuda.synchronize()   # To avoid infinite recursion
    # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
    new_grad = self.solver(lambda y: autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, \
                           torch.zeros_like(grad), threshold=b_thres)['result']
    return new_grad

self.hook = new_z_star.register_hook(backward_hook)

(Optional) Additional Step: Jacobian Regularization.

The fixed-point formulation of DEQ models means their stability are directly characterized by the Jacobian matrix J_f at the equilibrium point. Therefore, we provide code for analyzing and regularizing the Jacobian properties (based on the ICML'21 paper Stabilizing Equilibrium Models by Jacobian Regularization). Specifically, we added the following flags to the training script:

  • jac_loss_weight: The strength of Jacobian regularization, where we regularize ||J_f||_F.
  • jac_loss_freq: The frequency p of the stochastic Jacobian regularization (i.e., we only apply this loss with probaility p during training).
  • jac_incremental: If >0, then we increase the jac_loss_weight by 0.1 after every jac_incremental training steps.
  • spectral_radius_mode: If True, estimate the DEQ models' spectral radius when evaluating on the validation set.

A full DEQ model implementation is therefore as simple as follows:

from lib.solvers import anderson, broyden
from lib.jacobian import jac_loss_estimate

class DEQModel(nn.Module):
    def __init__(self, ...):
        ...
        self.f = Layer(...)
        self.solver = broyden
        ...
    
    def forward(self, x, ..., **kwargs):
        z0 = torch.zeros(...)

        # Forward pass
        with torch.no_grad():
            z_star = self.solver(lambda z: self.f(z, x, *args), z0, threshold=f_thres)['result']   # See step 2 above
            new_z_star = z_star

        # (Prepare for) Backward pass, see step 3 above
        if self.training:
            new_z_star = self.f(z_star.requires_grad_(), x, *args)
            
            # Jacobian-related computations, see additional step above. For instance:
            jac_loss = jac_loss_estimate(new_z_star, z_star, vecs=1)

            def backward_hook(grad):
                if self.hook is not None:
                    self.hook.remove()
                    torch.cuda.synchronize()   # To avoid infinite recursion
                # Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
                new_grad = self.solver(lambda y: autograd.grad(new_z_star, z_star, y, retain_graph=True)[0] + grad, \
                                       torch.zeros_like(grad), threshold=b_thres)['result']
                return new_grad

            self.hook = new_z_star.register_hook(backward_hook)
        return new_z_star, ...

Fixed-point Solvers

We provide PyTorch implementation of two generic solvers, broyden(...) (based on Broyden's method) and anderson(...) (based on Anderson acceleration) in lib/solvers.py. Both functions take in the transformation f whose fixed point we would like to solve for, and returns a dictionary of the following format:

{
 "result": ... (The closest estimate to the fixed point),
 "nstep": ... (The step that gives us this closest estimate),
 "abs_trace": ... (Absolute residuals along the trajectory),
 "rel_trace": ... (Relative residuals along the trajectory),
 ...
}

Pretrained Models

See DEQ-Sequence/ and MDEQ-Vision/ sub-directories for the links.

Credits

  • The transformer implementation as well as the extra modules (e.g., adaptive embeddings) were based on the Transformer-XL repo.

  • Some utilization code (e.g., model summary and yaml processing) of this repo were modified from the HRNet repo.

  • We also added the RAdam optimizer as an option to the training (but didn't set it to default). The RAdam implementation is from the RAdam repo.

Bibtex

If you find this repository useful for your research, please consider citing our work(s):

  1. Deep Equilibrium Models
@inproceedings{bai2019deep,
  author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
  title     = {Deep Equilibrium Models},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year      = {2019},
}
  1. Multiscale Deep Equilibrium Models
@inproceedings{bai2020multiscale,
  author    = {Shaojie Bai and Vladlen Koltun and J. Zico Kolter},
  title     = {Multiscale Deep Equilibrium Models},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year      = {2020},
}
  1. Stabilizing Equilibrium Models by Jacobian Regularization
@inproceedings{bai2021stabilizing,
  title     = {Stabilizing Equilibrium Models by Jacobian Regularization},
  author    = {Shaojie Bai and Vladlen Koltun and J. Zico Kolter},
  booktitle = {International Conference on Machine Learning (ICML)},
  year      = {2021}
}

deq's People

Contributors

jerrybai1995 avatar tesfaldet avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deq's Issues

RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given

Dear authors,

I am writing to raise an issue about the broyden method. My environment is python=3.8 + torch=1.9.0 (which I fear might be too high). I attach an example code and the error below. Any help would be appreciated!

from scipy.optimize import root
import time
from termcolor import colored
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# toy network
class ResNetLayer(nn.Module):
    def __init__(self):
        super(ResNetLayer, self).__init__()
        self.conv1 = nn.Conv2d(3, 9, 3, padding=3//2)
        self.conv2 = nn.Conv2d(9, 3, 5, padding=5//2)
        
    def forward(self, z, x):
        y = F.relu(self.conv1(z))
        return F.relu(z + F.relu(x + self.conv2(y)))



def _safe_norm(v):
    if not torch.isfinite(v).all():
        return np.inf
    return torch.norm(v)


def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.
    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [torch.norm(g0)**2]
    s_norm = torch.norm(x0) / torch.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite

def rmatvec(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if part_Us.nelement() == 0:
        return -x
    xTU = torch.einsum('bij, bijd -> bd', x, part_Us)   # (N, threshold)
    return -x + torch.einsum('bd, bdij -> bij', xTU, part_VTs)    # (N, 2d, L'), but should really be (N, 1, (2d*L'))


def matvec(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if part_Us.nelement() == 0:
        return -x
    VTx = torch.einsum('bdij, bij -> bd', part_VTs, x)  # (N, threshold)
    return -x + torch.einsum('bijd, bd -> bij', part_Us, VTx)     # (N, 2d, L'), but should really be (N, (2d*L'), 1)


def broyden(f, x0, threshold, eps=1e-3, stop_mode="rel", ls=False, name="unknown"):
    bsz, total_hsize, H, W = x0.size()
    seq_len = H * W
    g = lambda y: f(y) - y
    dev = x0.device
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    x_est = x0           # (bsz, 2d, L')
    gx = g(x_est)        # (bsz, 2d, L')
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = torch.zeros(bsz, total_hsize, seq_len, threshold).to(dev)     # One can also use an L-BFGS scheme to further reduce memory
    VTs = torch.zeros(bsz, threshold, total_hsize, seq_len).to(dev)
    update = -matvec(Us[:,:,:,:nstep], VTs[:,:nstep], gx)      # Formally should be -torch.matmul(inv_jacobian (-I), gx)
    prot_break = False
    
    # To be used in protective breaks
    protect_thres = (1e6 if stop_mode == "abs" else 1e3) * seq_len
    new_objective = 1e8

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}
    nstep, lowest_xest, lowest_gx = 0, x_est, gx

    while nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        abs_diff = torch.norm(gx).item()
        rel_diff = abs_diff / (torch.norm(gx + x_est).item() + 1e-9)
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode: 
                    lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach()
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = nstep

        new_objective = diff_dict[stop_mode]
        if new_objective < eps: break
        if new_objective < 3*eps and nstep > 30 and np.max(trace_dict[stop_mode][-30:]) / np.min(trace_dict[stop_mode][-30:]) < 1.3:
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > trace_dict[stop_mode][0] * protect_thres:
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:,:nstep-1], VTs[:,:nstep-1]
        vT = rmatvec(part_Us, part_VTs, delta_x)
        u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
        vT[vT != vT] = 0
        u[u != u] = 0
        VTs[:,nstep-1] = vT
        Us[:,:,:,nstep-1] = u
        update = -matvec(Us[:,:,:,:nstep], VTs[:,:nstep], gx)

    # Fill everything up to the threshold length
    for _ in range(threshold+1-len(trace_dict[stop_mode])):
        trace_dict[stop_mode].append(lowest_dict[stop_mode])
        trace_dict[alternative_mode].append(lowest_dict[alternative_mode])

    return {"result": lowest_xest,
            "lowest": lowest_dict[stop_mode],
            "nstep": lowest_step_dict[stop_mode],
            "prot_break": prot_break,
            "abs_trace": trace_dict['abs'],
            "rel_trace": trace_dict['rel'],
            "eps": eps,
            "threshold": threshold}

test = torch.rand(128, 3, 32, 32).to(device)
f = ResNetLayer().to(device)
zz = broyden(lambda Z : f(Z,test), torch.zeros_like(test), threshold=100, eps=1e-1)['result']
Traceback (most recent call last):
  File "mwe.py", line 251, in <module>
    zz = broyden(lambda Z : f(Z,test), torch.zeros_like(test), threshold=100, eps=1e-1)['result']
  File "mwe.py", line 188, in broyden
    u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
  File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given

Mismatch between a pretrained ImageNet model and a config file

Hi!

I am trying to run some initial experiments using the pretrained MDEQ model on ImageNet dataset.

As instructed in the README, I download the (small) model from here and then run the evaluation script via the following command:

python tools/cls_valid.py --testModel pretrained_models/MDEQ_Small_Cls.pkl --cfg experiments/imagenet/cls_mdeq_SMALL.yaml

However, this results in the following error

RuntimeError: Error(s) in loading state_dict for MDEQClsNet: Unexpected key(s) in state_dict: "fullstage_copy.branches.0.blocks.0.conv1.weight", ...

indicating that there is a mismatch between the saved model and the model specs in the config file.

Anyone else experiencing this issue when using pretrained ImageNet models in this repo? If so, could the authors update the pretrained models? Thanks in advance!

Hyperparameters for MDEQ-XL on ImageNet

Hi,

I've been trying to reproduce the results reported in the paper, and noticed that Table 4 in Appendix A does not incorporate the hyperparameters used for training MDEQ-XL on ImageNet. In particular, I'm curious about the following:

  • In general, is the stop mode "rel" or "abs"?
  • What epsilon is used as the threshold in the Broyden solver? Should I assume it was 1e-3 as is the default value?
  • What were the forward and backward quasi-Newton thresholds $T_f, T_b$?

Thanks so much!

Higher order derivatives

Greetings!

For my application, I need higher order (>= 2) derivatives of the equilibrium point. I realize that doing this with the current implementation is quite inefficient, as it requires to backprop through the custom backward hook, i.e., backprop through the vjp solver. I believe I know how to compute the second derivative, but I am not sure how to do it using backward hooks (only using the functional approach).

It was mentioned in the NeurIPS'20 tutorial that "the function above will not work with double backprop, though again this can be adressed with some additional effect if needed". So is there some solution that I'm not seeing?

Also, I couldn't find any implementations of such higher order derivatives/double backprop of DEQs, so it would be very helpful if you could share some references (if any).

Memory consumption on CIFAR-10

Hello, When I trained the tiny MDEQ model on cifar10(The batch size was set to 32), the memory usage was 2099MB on a single RTX3090. But the paper says it is 0.7GB. How could I reproduce it?

Below are the details of the configuration file I use.

GPUS: (0,)
LOG_DIR: 'log/'
DATA_DIR: ''
OUTPUT_DIR: 'output/'
WORKERS: 2
PRINT_FREQ: 100

MODEL: 
  NAME: mdeq
  NUM_LAYERS: 10
  NUM_CLASSES: 10
  NUM_GROUPS: 8
  DROPOUT: 0.25
  WNORM: true
  DOWNSAMPLE_TIMES: 0
  EXPANSION_FACTOR: 5
  POST_GN_AFFINE: false
  IMAGE_SIZE: 
    - 32
    - 32
  EXTRA:
    FULL_STAGE:
      NUM_MODULES: 1
      NUM_BRANCHES: 2
      BLOCK: BASIC
      BIG_KERNELS:
      - 0
      - 0
      HEAD_CHANNELS:
      - 8
      - 16
      FINAL_CHANSIZE: 200
      NUM_BLOCKS:
      - 1
      - 1
      NUM_CHANNELS:
      - 24
      - 24
      FUSE_METHOD: SUM
DEQ:
  F_SOLVER: 'broyden'
  B_SOLVER: 'broyden'
  STOP_MODE: 'rel'
  F_THRES: 18
  B_THRES: 20
  SPECTRAL_RADIUS_MODE: false
CUDNN:
  BENCHMARK: true
  DETERMINISTIC: false
  ENABLED: true
LOSS:
  JAC_LOSS_WEIGHT: 0.0
DATASET:
  DATASET: 'cifar10'
  DATA_FORMAT: 'jpg'
  ROOT: 'data/cifar10/'
  TEST_SET: 'val'
  TRAIN_SET: 'train'
  AUGMENT: False
TEST:
  BATCH_SIZE_PER_GPU: 32
  MODEL_FILE: ''
TRAIN:
  BATCH_SIZE_PER_GPU: 32
  BEGIN_EPOCH: 0
  END_EPOCH: 50
  RESUME: false
  LR_SCHEDULER: 'cosine'
  PRETRAIN_STEPS: 3000
  LR_FACTOR: 0.1
  LR_STEP:
  - 30
  - 60
  - 90
  OPTIMIZER: adam
  LR: 0.001
  WD: 0.0000025
  MOMENTUM: 0.95
  NESTEROV: true
  SHUFFLE: true
DEBUG:
  DEBUG: false

License?

This is very nice work. Can you please add a license file?

UserWarning: resource_tracker: There appear to be 14 leaked semaphore objects to clean up at shutdown

Hello,

I am trying to train a MDEQ on the image classification task. Here is the command I used to train the image classifier
python tools/cls_train.py --cfg experiments/cifar/cls_mdeq_TINY.yaml.
Everything works fine during the pretraining stage, but when actual training starts, I get an error
UserWarning: resource_tracker: There appear to be 14 leaked semaphore objects to clean up at shutdown and the training terminates. I have tried decreasing the BATCH_SIZE_PER_GPU to 16 but still cannot solve the issue. Can anyone help me with this problem? Thanks!

Simple Model Training Issue

Hi, so I've been trying to create a simple adder model using some of the code in the deq.py file but my model doesn't seem to be learning at all. I've even tried to use a linear layer as the function for the deq forward pass with the inputs being the same as the outputs for the training and it still couldn't learn it. I've attached the script for my model ("add.py") with my versions of the broyden.py and deq.py scripts which have some minor modifications made for debugging purposes (found under the comments "#Diego: debugging purposes"). I was hoping you could help me understand why this is happening. Thank you for your time!

Scripts.zip

CIFAR-10 Reproduction

Hi Shaojie,
I could not reproduce the result for MDEQ on CIFAR-10 image classification.

I only obtained 91.56% using MDEQ_large.
I'm using the same parameters in cls_mdeq_LARGE_reg.yaml, except the batch size and the number of GPUs.
Batch size per gpu are 512 with 2 GPUs.
I'm using 2 RTX 3090 graphic cards.

Hope that you can give me some advice.

Thanks Shaojie.

TrellisNetDEQModule class inheritence

Hi, I noticed that in deq_trellisnet_module.py, you wrote

super(TransformerDEQModule, self).__init__(func, func_copy)

for class TrellisNetDEQModule. That not quite makes sense to me. Should that actually be following line instead?

super(TrellisNetDEQModule, self).__init__(func, func_copy)`

Thanks for any clarification!

Broyden defeats the purpose of DEQs?

Heya,

Thanks for your continued work in building better DEQs.

The main selling point of DEQs is that the solver can take as many steps as required to converge without increasing the memory. This isn't true for your implementation of broyden, which starts off with:

Us = torch.zeros(bsz, total_hsize, seq_len, max_iters).to(dev)
VTs = torch.zeros(bsz, max_iters, total_hsize, seq_len).to(dev)

and therefore has a memory cost linear with max_iters, even though the ops aren't tracked. Anderson also keeps the previous m states in memory, where m is usually larger than the number of solver iterations needed anyways. Don't those solvers contradict the claim of constant memory cost?

On a related note, I've found it quite hard to modify these solvers even after going over the theory. Is there any notes or resources you could point to to help people understand your implementation? Thanks!

Language modeling inference complexity and training objective?

Dear Shaojie, I am trying to understand the time complexity of generating a sentence from the DEQ model and the complexity of finding the perplexity of a sentence.

  1. The DEQ model takes a sequence of observations, and generates a sequence of latent states of equal size. If I want to generate a sequence of length T, then I will need to first generate a sequence of length 1, then length 2 and so on till length T. So the generation time will be quadratic in the length of the sentence. Is this correct?

  2. During training, we can generate an complete sequence of latent states for the full input sequence in one shot. But then ith hidden state z_i can depend on x_{i+1}. Does that mean that even at training time, we need to split a sentence into all of its prefixes? Also what is the training objective? Is it log-likelihood of predicting x_{i+1} given z_i through a linear+softmax layer? The paper does not go into these details, and I couldn't easily figure it out from the code.

Updated code with load_state_dict

Hi, I wanted to know more about the codebase. Why does it have cloning and copying of data instead of just having one Transformer + Residual layer that is iteratively queried?

Also, do you have a more up to date codebase that uses load_state_dict instead of cloning and copying?

Thank you so much!

Does MDEQ have different inference results for different batch sizes?

I'm running some experiments with MDEQ on ImageNet validation set and I get different activations (variable new_z1 in the mdeq_core.py) for the DEQ layer for different batch sizes. I can see in the broyden function that there's no loop over the batch but since I'm not familiar with Broyden's method and its implementation, I do not know if different images (within a batch) can interfere with each other or not directly or indirectly (by having an effect on the number of iterations in the solver for instance). Should I run inference on 1 image at a time?

Two slightly different process for Deq

Hi Shaojie,

I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

And, the second one was in this repo.

with torch.no_grad():
result = self.f_solver(lambda z: self.func(z, *func_args), z1s, threshold=f_thres, stop_mode=self.stop_mode)
z1s = result['result']
new_z1s = z1s
if (not self.training) and spectral_radius_mode:
with torch.enable_grad():
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
_, sradius = power_method(new_z1s, z1s, n_iters=150)
if self.training:
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
if compute_jac_loss:
jac_loss = jac_loss_estimate(new_z1s, z1s, vecs=1)
def backward_hook(grad):
if self.hook is not None:
# To avoid infinite loop
self.hook.remove()
torch.cuda.synchronize()
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s, z1s, y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
self.hook = new_z1s.register_hook(backward_hook)

I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)

Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3

I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error.
I think I triggered an infinite loop in backward solver although I already called torch.cuda.synchronize() in backward_hook function.

In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.

Thanks :-)

My experiment environment:
workstation:

  • python: 3.8
  • pytorch: 1.9
  • cuda: 11.1
  • GPU: Nvidia3090

Google colab:
default environment with GPU.

Test ImageNet Pre-trained Model

Hi, I tried to test the pretrained models MDEQ_XL_Cls.pkl. However, I got size mismatch errors between the weights of the checkpoint model and the model in the code.

I download and run the command: python tools/cls_valid.py --testModel pretrained_models/MDEQ_XL_Cls.pkl --cfg experiments/imagenet/cls_mdeq_XL.yaml

Parts of error log:
size mismatch for downsample.0.weight: copying a param with shape torch.Size([88, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 3, 3, 3]). size mismatch for downsample.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.3.weight: copying a param with shape torch.Size([88, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]). size mismatch for downsample.4.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.0.weight: copying a param with shape torch.Size([88, 88, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]). size mismatch for stage0.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_g: copying a param with shape torch.Size([528, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 1, 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_v: copying a param with shape torch.Size([528, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([400 , 80, 3, 3]). size mismatch for fullstage.branches.0.blocks.0.gn1.weight: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.gn1.bias: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_g: copying a param with shape torch.Size([88, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 1 , 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_v: copying a param with shape torch.Size([88, 528, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 400, 3, 3]).

Question about Remove Hook

Dear Shaojie,

Hi there! This is Qiyao, a huge fan of your works! I am writing to ask a question about the lines. I notice that if I remove these lines the training does not work, but I am having a hard time figuring out why? In my understanding, the program should never be creating more than one hook in a single forward pass, so I don't see the purpose of having this check here? For example, this tutorial does not check for the hook, so I am confused as to what is happening here?

Expected a 'cuda' device type for generator (related to speed issues?)

Heya, thanks for the great paper(s) :)

Initially I've had to fix a few things to make your code run, but now I find it very slow and I'm wondering if I broke anything.
The cls_mdeq_LARGE_reg.yaml experiment runs at 130 samples/s post pretraining on a GTX 2080, which means that it takes hours to reach ~90% test acc (while a WideResNet will take 10min for that perf).

The main error I had to fix was this:

Traceback (most recent call last):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 257, in <module>
    main()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 220, in main
    final_output_dir, tb_log_dir, writer_dict, topk=topk)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/../lib/core/cls_function.py", line 42, in train
    for i, (input, target) in enumerate(train_loader):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
 return self._get_iterator()                                                                                                                                                                                                                                                                                    [8/202]
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 944, in __init__
    self._reset(loader, first_iter=True)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 975, in _reset
    self._try_put_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1209, in _try_put_index
    index = self._next_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 512, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 229, in __iter__
    for idx in self.sampler:
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 126, in __iter__
    yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

which according to this issue seems to be caused by this line in your code: torch.set_default_tensor_type('torch.cuda.FloatTensor') which I removed. After setting all the needed things on .cuda() manually I get the performance mentionned above. Is this normal or did I break something? Thanks!

Specs
Pytorch 1.10
Windows (RTX3070) and ubuntu 20 (GTX 2080) both tried

DEQ for Vision Transformer

Since the DEQ tech has achieved performance competitive with the state-of-the-art deep networks on Transformer based LM and CNN-based Image Recognition tasks, do the authors have plans to adapt DEQ for vision transformer architecture?

Segmentation Fault when Loss Backward CIFAR cls_mdeq_LARGE_reg

Hi, I encounter Segmentation Fault (core dump) when training cls_mdeq_LARGE_reg.
The bug happens at epoch 61, iteration 79, right before the code: (loss + factor*jac_loss).backward().

I'm following this suggestion to trace back the errors using gdb and here is the error:
#4 0x000055555568e989 in PyObject_GetAttrString () at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:846 #5 0x00005555555ce5ab in PyObject_HasAttrString (v=<optimised out>, name=<optimised out>) at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:854 #6 0x00007ffff4ebb42b in hook_name(_object*) () from /home/hieu/anaconda3/envs/deq/lib/python3.8/site-packages/torch/lib/libt orch_python.so #7 0x00007ffff4ebb84e in check_single_result(_object*, _object*, _object*) ()

I guess the hook function does not like something here?

I'm using:
Python v3.8.11
Pytorch v1.7.1+cu110
CUDA v11.1
RTX 3090 graphic cards.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.