GithubHelp home page GithubHelp logo

deq's Issues

RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given

Dear authors,

I am writing to raise an issue about the broyden method. My environment is python=3.8 + torch=1.9.0 (which I fear might be too high). I attach an example code and the error below. Any help would be appreciated!

from scipy.optimize import root
import time
from termcolor import colored
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# toy network
class ResNetLayer(nn.Module):
    def __init__(self):
        super(ResNetLayer, self).__init__()
        self.conv1 = nn.Conv2d(3, 9, 3, padding=3//2)
        self.conv2 = nn.Conv2d(9, 3, 5, padding=5//2)
        
    def forward(self, z, x):
        y = F.relu(self.conv1(z))
        return F.relu(z + F.relu(x + self.conv2(y)))



def _safe_norm(v):
    if not torch.isfinite(v).all():
        return np.inf
    return torch.norm(v)


def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.
    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [torch.norm(g0)**2]
    s_norm = torch.norm(x0) / torch.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite

def rmatvec(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if part_Us.nelement() == 0:
        return -x
    xTU = torch.einsum('bij, bijd -> bd', x, part_Us)   # (N, threshold)
    return -x + torch.einsum('bd, bdij -> bij', xTU, part_VTs)    # (N, 2d, L'), but should really be (N, 1, (2d*L'))


def matvec(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if part_Us.nelement() == 0:
        return -x
    VTx = torch.einsum('bdij, bij -> bd', part_VTs, x)  # (N, threshold)
    return -x + torch.einsum('bijd, bd -> bij', part_Us, VTx)     # (N, 2d, L'), but should really be (N, (2d*L'), 1)


def broyden(f, x0, threshold, eps=1e-3, stop_mode="rel", ls=False, name="unknown"):
    bsz, total_hsize, H, W = x0.size()
    seq_len = H * W
    g = lambda y: f(y) - y
    dev = x0.device
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    x_est = x0           # (bsz, 2d, L')
    gx = g(x_est)        # (bsz, 2d, L')
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = torch.zeros(bsz, total_hsize, seq_len, threshold).to(dev)     # One can also use an L-BFGS scheme to further reduce memory
    VTs = torch.zeros(bsz, threshold, total_hsize, seq_len).to(dev)
    update = -matvec(Us[:,:,:,:nstep], VTs[:,:nstep], gx)      # Formally should be -torch.matmul(inv_jacobian (-I), gx)
    prot_break = False
    
    # To be used in protective breaks
    protect_thres = (1e6 if stop_mode == "abs" else 1e3) * seq_len
    new_objective = 1e8

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}
    nstep, lowest_xest, lowest_gx = 0, x_est, gx

    while nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        abs_diff = torch.norm(gx).item()
        rel_diff = abs_diff / (torch.norm(gx + x_est).item() + 1e-9)
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode: 
                    lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach()
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = nstep

        new_objective = diff_dict[stop_mode]
        if new_objective < eps: break
        if new_objective < 3*eps and nstep > 30 and np.max(trace_dict[stop_mode][-30:]) / np.min(trace_dict[stop_mode][-30:]) < 1.3:
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > trace_dict[stop_mode][0] * protect_thres:
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:,:nstep-1], VTs[:,:nstep-1]
        vT = rmatvec(part_Us, part_VTs, delta_x)
        u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
        vT[vT != vT] = 0
        u[u != u] = 0
        VTs[:,nstep-1] = vT
        Us[:,:,:,nstep-1] = u
        update = -matvec(Us[:,:,:,:nstep], VTs[:,:nstep], gx)

    # Fill everything up to the threshold length
    for _ in range(threshold+1-len(trace_dict[stop_mode])):
        trace_dict[stop_mode].append(lowest_dict[stop_mode])
        trace_dict[alternative_mode].append(lowest_dict[alternative_mode])

    return {"result": lowest_xest,
            "lowest": lowest_dict[stop_mode],
            "nstep": lowest_step_dict[stop_mode],
            "prot_break": prot_break,
            "abs_trace": trace_dict['abs'],
            "rel_trace": trace_dict['rel'],
            "eps": eps,
            "threshold": threshold}

test = torch.rand(128, 3, 32, 32).to(device)
f = ResNetLayer().to(device)
zz = broyden(lambda Z : f(Z,test), torch.zeros_like(test), threshold=100, eps=1e-1)['result']
Traceback (most recent call last):
  File "mwe.py", line 251, in <module>
    zz = broyden(lambda Z : f(Z,test), torch.zeros_like(test), threshold=100, eps=1e-1)['result']
  File "mwe.py", line 188, in broyden
    u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
  File "/opt/conda/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (3) does not match the number of dimensions (4) for operand 0 and no ellipsis was given

License?

This is very nice work. Can you please add a license file?

DEQ for Vision Transformer

Since the DEQ tech has achieved performance competitive with the state-of-the-art deep networks on Transformer based LM and CNN-based Image Recognition tasks, do the authors have plans to adapt DEQ for vision transformer architecture?

Memory consumption on CIFAR-10

Hello, When I trained the tiny MDEQ model on cifar10(The batch size was set to 32), the memory usage was 2099MB on a single RTX3090. But the paper says it is 0.7GB. How could I reproduce it?

Below are the details of the configuration file I use.

GPUS: (0,)
LOG_DIR: 'log/'
DATA_DIR: ''
OUTPUT_DIR: 'output/'
WORKERS: 2
PRINT_FREQ: 100

MODEL: 
  NAME: mdeq
  NUM_LAYERS: 10
  NUM_CLASSES: 10
  NUM_GROUPS: 8
  DROPOUT: 0.25
  WNORM: true
  DOWNSAMPLE_TIMES: 0
  EXPANSION_FACTOR: 5
  POST_GN_AFFINE: false
  IMAGE_SIZE: 
    - 32
    - 32
  EXTRA:
    FULL_STAGE:
      NUM_MODULES: 1
      NUM_BRANCHES: 2
      BLOCK: BASIC
      BIG_KERNELS:
      - 0
      - 0
      HEAD_CHANNELS:
      - 8
      - 16
      FINAL_CHANSIZE: 200
      NUM_BLOCKS:
      - 1
      - 1
      NUM_CHANNELS:
      - 24
      - 24
      FUSE_METHOD: SUM
DEQ:
  F_SOLVER: 'broyden'
  B_SOLVER: 'broyden'
  STOP_MODE: 'rel'
  F_THRES: 18
  B_THRES: 20
  SPECTRAL_RADIUS_MODE: false
CUDNN:
  BENCHMARK: true
  DETERMINISTIC: false
  ENABLED: true
LOSS:
  JAC_LOSS_WEIGHT: 0.0
DATASET:
  DATASET: 'cifar10'
  DATA_FORMAT: 'jpg'
  ROOT: 'data/cifar10/'
  TEST_SET: 'val'
  TRAIN_SET: 'train'
  AUGMENT: False
TEST:
  BATCH_SIZE_PER_GPU: 32
  MODEL_FILE: ''
TRAIN:
  BATCH_SIZE_PER_GPU: 32
  BEGIN_EPOCH: 0
  END_EPOCH: 50
  RESUME: false
  LR_SCHEDULER: 'cosine'
  PRETRAIN_STEPS: 3000
  LR_FACTOR: 0.1
  LR_STEP:
  - 30
  - 60
  - 90
  OPTIMIZER: adam
  LR: 0.001
  WD: 0.0000025
  MOMENTUM: 0.95
  NESTEROV: true
  SHUFFLE: true
DEBUG:
  DEBUG: false

UserWarning: resource_tracker: There appear to be 14 leaked semaphore objects to clean up at shutdown

Hello,

I am trying to train a MDEQ on the image classification task. Here is the command I used to train the image classifier
python tools/cls_train.py --cfg experiments/cifar/cls_mdeq_TINY.yaml.
Everything works fine during the pretraining stage, but when actual training starts, I get an error
UserWarning: resource_tracker: There appear to be 14 leaked semaphore objects to clean up at shutdown and the training terminates. I have tried decreasing the BATCH_SIZE_PER_GPU to 16 but still cannot solve the issue. Can anyone help me with this problem? Thanks!

Simple Model Training Issue

Hi, so I've been trying to create a simple adder model using some of the code in the deq.py file but my model doesn't seem to be learning at all. I've even tried to use a linear layer as the function for the deq forward pass with the inputs being the same as the outputs for the training and it still couldn't learn it. I've attached the script for my model ("add.py") with my versions of the broyden.py and deq.py scripts which have some minor modifications made for debugging purposes (found under the comments "#Diego: debugging purposes"). I was hoping you could help me understand why this is happening. Thank you for your time!

Scripts.zip

Hyperparameters for MDEQ-XL on ImageNet

Hi,

I've been trying to reproduce the results reported in the paper, and noticed that Table 4 in Appendix A does not incorporate the hyperparameters used for training MDEQ-XL on ImageNet. In particular, I'm curious about the following:

  • In general, is the stop mode "rel" or "abs"?
  • What epsilon is used as the threshold in the Broyden solver? Should I assume it was 1e-3 as is the default value?
  • What were the forward and backward quasi-Newton thresholds $T_f, T_b$?

Thanks so much!

Test ImageNet Pre-trained Model

Hi, I tried to test the pretrained models MDEQ_XL_Cls.pkl. However, I got size mismatch errors between the weights of the checkpoint model and the model in the code.

I download and run the command: python tools/cls_valid.py --testModel pretrained_models/MDEQ_XL_Cls.pkl --cfg experiments/imagenet/cls_mdeq_XL.yaml

Parts of error log:
size mismatch for downsample.0.weight: copying a param with shape torch.Size([88, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 3, 3, 3]). size mismatch for downsample.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.3.weight: copying a param with shape torch.Size([88, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 80, 3, 3]). size mismatch for downsample.4.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for downsample.4.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.0.weight: copying a param with shape torch.Size([88, 88, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 80, 1, 1]). size mismatch for stage0.1.weight: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_mean: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for stage0.1.running_var: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([80]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_g: copying a param with shape torch.Size([528, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([400, 1, 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv1.weight_v: copying a param with shape torch.Size([528, 88, 3, 3]) from checkpoint, the shape in current model is torch.Size([400 , 80, 3, 3]). size mismatch for fullstage.branches.0.blocks.0.gn1.weight: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.gn1.bias: copying a param with shape torch.Size([528]) from checkpoint, the shape in current model is torch.Size([400]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_g: copying a param with shape torch.Size([88, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([80, 1 , 1, 1]). size mismatch for fullstage.branches.0.blocks.0.conv2.weight_v: copying a param with shape torch.Size([88, 528, 3, 3]) from checkpoint, the shape in current model is torch.Size([80, 400, 3, 3]).

Updated code with load_state_dict

Hi, I wanted to know more about the codebase. Why does it have cloning and copying of data instead of just having one Transformer + Residual layer that is iteratively queried?

Also, do you have a more up to date codebase that uses load_state_dict instead of cloning and copying?

Thank you so much!

Does MDEQ have different inference results for different batch sizes?

I'm running some experiments with MDEQ on ImageNet validation set and I get different activations (variable new_z1 in the mdeq_core.py) for the DEQ layer for different batch sizes. I can see in the broyden function that there's no loop over the batch but since I'm not familiar with Broyden's method and its implementation, I do not know if different images (within a batch) can interfere with each other or not directly or indirectly (by having an effect on the number of iterations in the solver for instance). Should I run inference on 1 image at a time?

Language modeling inference complexity and training objective?

Dear Shaojie, I am trying to understand the time complexity of generating a sentence from the DEQ model and the complexity of finding the perplexity of a sentence.

  1. The DEQ model takes a sequence of observations, and generates a sequence of latent states of equal size. If I want to generate a sequence of length T, then I will need to first generate a sequence of length 1, then length 2 and so on till length T. So the generation time will be quadratic in the length of the sentence. Is this correct?

  2. During training, we can generate an complete sequence of latent states for the full input sequence in one shot. But then ith hidden state z_i can depend on x_{i+1}. Does that mean that even at training time, we need to split a sentence into all of its prefixes? Also what is the training objective? Is it log-likelihood of predicting x_{i+1} given z_i through a linear+softmax layer? The paper does not go into these details, and I couldn't easily figure it out from the code.

Broyden defeats the purpose of DEQs?

Heya,

Thanks for your continued work in building better DEQs.

The main selling point of DEQs is that the solver can take as many steps as required to converge without increasing the memory. This isn't true for your implementation of broyden, which starts off with:

Us = torch.zeros(bsz, total_hsize, seq_len, max_iters).to(dev)
VTs = torch.zeros(bsz, max_iters, total_hsize, seq_len).to(dev)

and therefore has a memory cost linear with max_iters, even though the ops aren't tracked. Anderson also keeps the previous m states in memory, where m is usually larger than the number of solver iterations needed anyways. Don't those solvers contradict the claim of constant memory cost?

On a related note, I've found it quite hard to modify these solvers even after going over the theory. Is there any notes or resources you could point to to help people understand your implementation? Thanks!

Segmentation Fault when Loss Backward CIFAR cls_mdeq_LARGE_reg

Hi, I encounter Segmentation Fault (core dump) when training cls_mdeq_LARGE_reg.
The bug happens at epoch 61, iteration 79, right before the code: (loss + factor*jac_loss).backward().

I'm following this suggestion to trace back the errors using gdb and here is the error:
#4 0x000055555568e989 in PyObject_GetAttrString () at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:846 #5 0x00005555555ce5ab in PyObject_HasAttrString (v=<optimised out>, name=<optimised out>) at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:854 #6 0x00007ffff4ebb42b in hook_name(_object*) () from /home/hieu/anaconda3/envs/deq/lib/python3.8/site-packages/torch/lib/libt orch_python.so #7 0x00007ffff4ebb84e in check_single_result(_object*, _object*, _object*) ()

I guess the hook function does not like something here?

I'm using:
Python v3.8.11
Pytorch v1.7.1+cu110
CUDA v11.1
RTX 3090 graphic cards.

Two slightly different process for Deq

Hi Shaojie,

I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

And, the second one was in this repo.

with torch.no_grad():
result = self.f_solver(lambda z: self.func(z, *func_args), z1s, threshold=f_thres, stop_mode=self.stop_mode)
z1s = result['result']
new_z1s = z1s
if (not self.training) and spectral_radius_mode:
with torch.enable_grad():
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
_, sradius = power_method(new_z1s, z1s, n_iters=150)
if self.training:
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
if compute_jac_loss:
jac_loss = jac_loss_estimate(new_z1s, z1s, vecs=1)
def backward_hook(grad):
if self.hook is not None:
# To avoid infinite loop
self.hook.remove()
torch.cuda.synchronize()
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s, z1s, y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
self.hook = new_z1s.register_hook(backward_hook)

I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)

Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3

I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error.
I think I triggered an infinite loop in backward solver although I already called torch.cuda.synchronize() in backward_hook function.

In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.

Thanks :-)

My experiment environment:
workstation:

  • python: 3.8
  • pytorch: 1.9
  • cuda: 11.1
  • GPU: Nvidia3090

Google colab:
default environment with GPU.

Mismatch between a pretrained ImageNet model and a config file

Hi!

I am trying to run some initial experiments using the pretrained MDEQ model on ImageNet dataset.

As instructed in the README, I download the (small) model from here and then run the evaluation script via the following command:

python tools/cls_valid.py --testModel pretrained_models/MDEQ_Small_Cls.pkl --cfg experiments/imagenet/cls_mdeq_SMALL.yaml

However, this results in the following error

RuntimeError: Error(s) in loading state_dict for MDEQClsNet: Unexpected key(s) in state_dict: "fullstage_copy.branches.0.blocks.0.conv1.weight", ...

indicating that there is a mismatch between the saved model and the model specs in the config file.

Anyone else experiencing this issue when using pretrained ImageNet models in this repo? If so, could the authors update the pretrained models? Thanks in advance!

CIFAR-10 Reproduction

Hi Shaojie,
I could not reproduce the result for MDEQ on CIFAR-10 image classification.

I only obtained 91.56% using MDEQ_large.
I'm using the same parameters in cls_mdeq_LARGE_reg.yaml, except the batch size and the number of GPUs.
Batch size per gpu are 512 with 2 GPUs.
I'm using 2 RTX 3090 graphic cards.

Hope that you can give me some advice.

Thanks Shaojie.

Expected a 'cuda' device type for generator (related to speed issues?)

Heya, thanks for the great paper(s) :)

Initially I've had to fix a few things to make your code run, but now I find it very slow and I'm wondering if I broke anything.
The cls_mdeq_LARGE_reg.yaml experiment runs at 130 samples/s post pretraining on a GTX 2080, which means that it takes hours to reach ~90% test acc (while a WideResNet will take 10min for that perf).

The main error I had to fix was this:

Traceback (most recent call last):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 257, in <module>
    main()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/cls_train.py", line 220, in main
    final_output_dir, tb_log_dir, writer_dict, topk=topk)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/git/official_untouched/MDEQ-Vision/tools/../lib/core/cls_function.py", line 42, in train
    for i, (input, target) in enumerate(train_loader):
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 359, in __iter__
 return self._get_iterator()                                                                                                                                                                                                                                                                                    [8/202]
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 305, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 944, in __init__
    self._reset(loader, first_iter=True)
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 975, in _reset
    self._try_put_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1209, in _try_put_index
    index = self._next_index()
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 512, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 229, in __iter__
    for idx in self.sampler:
  File "/afs/inf.ed.ac.uk/user/s17/s1771851/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 126, in __iter__
    yield from torch.randperm(n, generator=generator).tolist()
RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

which according to this issue seems to be caused by this line in your code: torch.set_default_tensor_type('torch.cuda.FloatTensor') which I removed. After setting all the needed things on .cuda() manually I get the performance mentionned above. Is this normal or did I break something? Thanks!

Specs
Pytorch 1.10
Windows (RTX3070) and ubuntu 20 (GTX 2080) both tried

TrellisNetDEQModule class inheritence

Hi, I noticed that in deq_trellisnet_module.py, you wrote

super(TransformerDEQModule, self).__init__(func, func_copy)

for class TrellisNetDEQModule. That not quite makes sense to me. Should that actually be following line instead?

super(TrellisNetDEQModule, self).__init__(func, func_copy)`

Thanks for any clarification!

Higher order derivatives

Greetings!

For my application, I need higher order (>= 2) derivatives of the equilibrium point. I realize that doing this with the current implementation is quite inefficient, as it requires to backprop through the custom backward hook, i.e., backprop through the vjp solver. I believe I know how to compute the second derivative, but I am not sure how to do it using backward hooks (only using the functional approach).

It was mentioned in the NeurIPS'20 tutorial that "the function above will not work with double backprop, though again this can be adressed with some additional effect if needed". So is there some solution that I'm not seeing?

Also, I couldn't find any implementations of such higher order derivatives/double backprop of DEQs, so it would be very helpful if you could share some references (if any).

Question about Remove Hook

Dear Shaojie,

Hi there! This is Qiyao, a huge fan of your works! I am writing to ask a question about the lines. I notice that if I remove these lines the training does not work, but I am having a hard time figuring out why? In my understanding, the program should never be creating more than one hook in a single forward pass, so I don't see the purpose of having this check here? For example, this tutorial does not check for the hook, so I am confused as to what is happening here?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.