GithubHelp home page GithubHelp logo

nvlabs / a-vit Goto Github PK

View Code? Open in Web Editor NEW
133.0 4.0 12.0 2.72 MB

Official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer (CVPR 2022)

License: Apache License 2.0

Python 100.00%

a-vit's Introduction

Python 3.6

A-ViT: Adaptive Tokens for Efficient Vision Transformer

This repository is the official PyTorch implementation of A-ViT: Adaptive Tokens for Efficient Vision Transformer presented at CVPR 2022.

The code enables adaptive inference of tokens for vision transformers. Current version includes training, evaluation, and visualization of models on ImageNet-1K. We plan to update repository and release run-time token zipping for inference, distillation models, and base models in coming versions.

Useful links:
[Camera ready]
[ArXiv Full PDF]
[Project page]

Teaser

Requirements

Code was tested in virtual environment with Python 3.6. Install requirements as in the requirements.txt file.

Commands

Data Preparation

Please prepare the ImageNet dataset into the following structure:

  imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...

Training an A-ViT on ImageNet-1K

This snippet will support the training of A-ViT on ImageNet-1K. Please place the ImageNet dataset accordingly as above. For starting point of pretrained DEIT weights the code shall automatically load as below with --pretrained key. Please refer to here or in the models_act.py code URLs if facing loading or downloading issues.

Code is tested on NVIDIA 4-V100 GPUs cluster of 32GB memory. Training takes 100 epochs.

For tiny model:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main_act.py --model avit_tiny_patch16_224 --data-path <data to imagenet folder> --output_dir ./results/<name your exp for tensorboard files and ckpt> --pretrained --batch-size 128 --lr 0.0005 --tensorboard --epochs 100 --gate_scale 10.0 --gate_center 30 --warmup-epochs 5 --ponder_token_scale 0.0005 --distr_prior_alpha 0.001

For small model:

python -m torch.distributed.launch --nproc_per_node=4 --use_env main_act.py --model avit_small_patch16_224 --data-path <data to imagenet folder> --output_dir ./results/<name your exp name> --pretrained --batch-size 96 --lr 0.0003 --tensorboard --epochs 100 --gate_scale 10.0 --gate_center 75 --warmup-epochs 5 --ponder_token_scale 0.0005 --distr_prior_alpha 0.001

Arguments:

  • batch-size - batch size.
  • tensorboard - write training and testing stats for tracking purposes.
  • lr - learning rate.
  • epochs - epochs to fine-tune. If --pretrained is not raised, this is epochs for total training.
  • pretrained - flag to load starting full pretrained static models to start learning for adaptive inference capability.
  • ponder_scale_token - regularization constant for token ponder loss scaling.
  • gate_scale - scaling factor for H gate, constant shared across all tokens at all layers.
  • gate_center - absolute value of the negative bias for H gate, constant shared across all tokens at all layers.
  • distr_prior_alpha - regularization constant for kl divergence distributional prior.
  • finetune - loading avit weights.
  • demo - raise to save and visualize token depth distribution after loading avit weights using the finetune command. This saves (i) original image, (ii) token depth distribution, and (iii) their left-right concatenation in jpg files.

Pretrained Weights on ImageNet-1K

Pretrained A-ViT model weights can be downloaded here.

Use the following command to unzip it into the main repository into folder ./a-vit-weights such that loading is supported directly in the training file.

tar -xzf <path to your downloaded folder>/a-vit-weights.tar.gz ./a-vit-weights
Name Acc@1 Acc@5 Resolution #Params (M) FLOPs (G) Path
A-ViT-T 71.4 90.4 224x224 5 0.8 a-vit-weights/tiny-10-30.pth
A-ViT-S 78.8 93.9 224x224 22 3.6 a-vit-weights/small-10-75.pth

For evaluation, simply append the following snippet to the training script with nproc_per_node set as 1:

--eval --finetune <path to ckpt folder>/{name of chekpoint}

Visualization of Token Depths

To give a quickly visualization of the learnt adaptive inference, kindly download pretrained weights as instructed, and raise the --demo.

This will plot the token depth distribution in a new folder named token_act_visualization, with one demo image per 1K validation classes, saving (i) original, (ii) token depth, and (iii) concatenated left-to-right comparison images. A full run of code generates 3K images.

Batch size is set as 50 for per-class analysis. Note that this function is not fully optimized and can be slow.

python -m torch.distributed.launch --nproc_per_node=1 --use_env main_act.py --model avit_tiny_patch16_224 --data-path <data to imagenet folder> --finetune <path to ckpt folder>/visualization-tiny.pth --demo

Some examples (Left - ImageNet validation image 224x224, unseen during training |Right - token depth, whiter is deeper | Legend - depth, 12 is full depth.) More examples in example_images folder.

Please feel free to generate more examples using the --demo key, or more via visualize.py function in engine_act.py. Note that the snippet is a very quick demo from an intermediate checkpoint of tiny a-vit. We observe the distribution will continue to slightly converge upon higher accuracy. Analyzing other checkpoint is very easily supported by changing the loading of .pth files, and distribution semantic meaning holds.

License

Copyright (C) 2022 NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License (1-Way Commercial). To view a copy of this license, visit the LICENSE file in this repository.

The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

For license information regarding the timm repository, please refer to the official website.

For license information regarding the DEIT repository, please refer to the official website.

For license information regarding the ImageNet dataset, please refer to the official website.

Citation

@InProceedings{Yin_2022_CVPR,
    author    = {Yin, Hongxu and Vahdat, Arash and Alvarez, Jose M. and Mallya, Arun and Kautz, Jan and Molchanov, Pavlo},
    title     = {{A}-{V}i{T}: {A}daptive Tokens for Efficient Vision Transformer},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022},
    pages     = {10809-10818}
}

a-vit's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

a-vit's Issues

Training accuracy

Thanks for your interesting and excellent work.
I rerun the training codes using avit-tiny but only get 68.26% top-1 accuracy on ImageNet, would different training processes cause that much difference?
Additionally, how to actually 'remove' those stopped tokens in the inference stage to reduce inference time when batchsize>1?

token nums

Hi, Thanks for your excellent work.
The paper mentions that the number of tokens will be changed in different Transformer layers, but when I debugged the code, I found that the number of tokens is the same in each layer. What is the reason for this?

A software supply-chain vulnerability detected

Hi,

I'm a Cybersecurity researcher developing Packj [1]. Our tool has detected a supply-chain vulnerability in this repository. In order for me to disclose it, kindly enable GitHub Private vulnerability reporting, which allows security research to responsibly disclose a security vulnerability.

Thanks!

Packj detects malicious/"risky" NPM/PyPI/Ruby dependencies: https://github.com/ossillate-inc/packj

The inference time of A-Vit same as the Deit.

Thanks for this interesting work, and I believe it would be valuable for people in this area.
Here, I have some problems. Could the authors provide some explanation?
(1) Why the inference time of A-Vit is same as Deit? According to the paper, the A-VIT is faster than Deit. But I find that the inference time is the same regardless of whether the pre-trained model is loaded or not.

(2) According to the paper, different samples should have different tokens, i.e., the tensor shape should not be the same in the testing phase. So, how to evaluate the validation set under 64 batch size? To the best of my knowledge, it is almost impossible to combine into a batch.

FLOPs Reduction & Calculation

I am particularly interested in understanding the mechanisms behind the reduction in GFLOPS achieved by the adaptive computation framework, and I have some questions I was not able to clarify myself:

My understanding is that GFLOPS are primarily tied to the number of computations performed, including matrix multiplications. However, I am unclear about how the reduction in the number of tokens affects the GFLOPS calculation, especially when the dimensions of the input tensors remain the same and the unused tokens are assigned a value of 0.
in act_vision_transformer.py file:

class VisionTransformer(nn.Module):
      def forward_features_act_token(self, x):
            #assigned a value of 0
            out.data = out.data * mask_token.float().view(bs, self.total_token_cnt, 1)
class Block_ACT(nn.Module):
        def forward_act(self, x, mask=None):
              # original dimensions multiplication without reduction of dimensions 
              x = x + self.drop_path(self.attn(self.norm1(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1), mask=mask))
              x = x + self.drop_path(self.mlp(self.norm2(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1)))

Could you kindly provide insights into how the adaptive computation framework in A-Vit leads to GFLOPS reduction, even when the token dimensions are unchanged?

Could you elaborate on how you measure your GFLOPS?

Thanks in Advance.

A question about the halting score distribution code

In the paper, the halting score distribution is defined as below:
image

However, the corresponding code seems wrong.

if self.args.distr_prior_alpha>0.:
self.halting_score_layer.append(torch.mean(h_lst[1][1:]))

The shape of h_lst[1] is [B, N], so the code seems to average on the whole batch and ignores the first sample of each batch.
I think the right code is:
self.halting_score_layer.append(torch.mean(h_lst[1][:, 1:], dim=-1))

Can you tell me which one is correct? Thanks!

Cannot Reproduce Reported Accuracy

Hello,

I am unable to reproduce the validation results for A-ViT-Tiny. When training on four RTX A6000 GPUs with the exact same training script, it produces 68.17% Top-1 and 88.816% Top-5. It seems like other people in the thread have the same issue with similar results obtained. Could the author please address this issue?

Thank you!

About inference stage

Dear author, I'm a student from China. I'm wondering if i can have the file about run-time token zipping for inference, please. I would appreciate if you could help me.

A question about inference

Hi, Thanks for your excellent work.
I have a question about the inference stage code. According to the paper, you simply remove the halted tokens from computation at the inference time. However, I can't find the corresponding code. It seems that the network operates in the same way during training and testing (I paste the related code below).

def forward_features_act_token(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# now start the act part
bs = x.size()[0] # The batch size
# this part needs to be modified for higher GPU utilization
if self.c_token is None or bs != self.c_token.size()[0]:
self.c_token = Variable(torch.zeros(bs, self.total_token_cnt).cuda())
self.R_token = Variable(torch.ones(bs, self.total_token_cnt).cuda())
self.mask_token = Variable(torch.ones(bs, self.total_token_cnt).cuda())
self.rho_token = Variable(torch.zeros(bs, self.total_token_cnt).cuda())
self.counter_token = Variable(torch.ones(bs, self.total_token_cnt).cuda())
c_token = self.c_token.clone()
R_token = self.R_token.clone()
mask_token = self.mask_token.clone()
self.rho_token = self.rho_token.detach() * 0.
self.counter_token = self.counter_token.detach() * 0 + 1.
# Will contain the output of this residual layer (weighted sum of outputs of the residual blocks)
output = None
# Use out to backbone
out = x
if self.args.distr_prior_alpha>0.:
self.halting_score_layer = []
for i, l in enumerate(self.blocks):
# block out all the parts that are not used
out.data = out.data * mask_token.float().view(bs, self.total_token_cnt, 1)
# evaluate layer and get halting probability for each sample
# block_output, h_lst = l.forward_act(out) # h is a vector of length bs, block_output a 3D tensor
block_output, h_lst = l.forward_act(out, 1.-mask_token.float()) # h is a vector of length bs, block_output a 3D tensor
if self.args.distr_prior_alpha>0.:
self.halting_score_layer.append(torch.mean(h_lst[1][1:]))
out = block_output.clone() # Deep copy needed for the next layer
_, h_token = h_lst # h is layer_halting score, h_token is token halting score, first position discarded
# here, 1 is remaining, 0 is blocked
block_output = block_output * mask_token.float().view(bs, self.total_token_cnt, 1)
# Is this the last layer in the block?
if i==len(self.blocks)-1:
h_token = Variable(torch.ones(bs, self.total_token_cnt).cuda())
# for token part
c_token = c_token + h_token
self.rho_token = self.rho_token + mask_token.float()
# Case 1: threshold reached in this iteration
# token part
reached_token = c_token > 1 - self.eps
reached_token = reached_token.float() * mask_token.float()
delta1 = block_output * R_token.view(bs, self.total_token_cnt, 1) * reached_token.view(bs, self.total_token_cnt, 1)
self.rho_token = self.rho_token + R_token * reached_token
# Case 2: threshold not reached
# token part
not_reached_token = c_token < 1 - self.eps
not_reached_token = not_reached_token.float()
R_token = R_token - (not_reached_token.float() * h_token)
delta2 = block_output * h_token.view(bs, self.total_token_cnt, 1) * not_reached_token.view(bs, self.total_token_cnt, 1)
self.counter_token = self.counter_token + not_reached_token # These data points will need at least one more layer
# Update the mask
mask_token = c_token < 1 - self.eps
if output is None:
output = delta1 + delta2
else:
output = output + (delta1 + delta2)
x = self.norm(output)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]

Am I missing something?

BTW, I've tested the model on ImageNet1K and got the same inference time whether to load weights or not (The results are listed below). To my understanding, the model should predict faster due to adaptive token halt when loading the weights. Could you please tell me why would this happen?

  • the result (no weight)

42011db10dfcbbdf34bbe6d062d90fd

  • the result (load official weight)

bb3ccb39cb6861666a2249ce2f69091

Unable to Reproduce top-1 Accuracy

Hi

I trained avit_tiny with the provided hyperparameters, and the validation accuracy on Imagenet is only 68.2% instead of the reported 71.4%.

Could you please let me know what hyperparameters to use to reproduce the results in the paper.

Non-zero outputs from discarded tokens

Hi! Congratulations on the great paper.

These lines are concerning to me:

else:
    x = x + self.drop_path(self.attn(self.norm1(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1), mask=mask))
    x = x + self.drop_path(self.mlp(self.norm2(x*(1-mask).view(bs, token, 1))*(1-mask).view(bs, token, 1)))

I can see two issues here:

  1. Masking of layer normalization inputs seems redundant. To see what I mean:
In [54]: some_tokens = torch.randn(2, 5, 10)

In [55]: ln = torch.nn.LayerNorm(10)

In [56]: continue_mask = torch.zeros(2, 5, 1)

In [57]: continue_mask[0, :3] = 1.0

In [58]: continue_mask[1, 2:] = 1.0

In [59]: ln(continue_mask * some_tokens) * continue_mask == ln(some_tokens) * continue_mask
Out[59]:
tensor([[[True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True]],

        [[True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True]]])
  1. On the other hand the outputs of MLP or MHA are not masked. Since MLP consists of two linear layers with biases (and an activation function in between, of course) - see here and here - the outputs of that module are not zero. While this can be seen as a kind of bias that is added to the output when the token has already been dropped, it goes against the spirit of: 1. not adding any parameters; and 2. reducing compute/FLOPs - as described in your paper. This seems to be a bug to me.

Please correct me if I am wrong.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.