GithubHelp home page GithubHelp logo

vita-group / vit-anti-oversmoothing Goto Github PK

View Code? Open in Web Editor NEW
74.0 9.0 7.0 288 KB

[ICLR 2022] "Anti-Oversmoothing in Deep Vision Transformers via the Fourier Domain Analysis: From Theory to Practice" by Peihao Wang, Wenqing Zheng, Tianlong Chen, Zhangyang Wang

License: MIT License

Python 100.00%
vision-transformer oversmoothing fourier-analysis signal-processing

vit-anti-oversmoothing's Introduction

Anti-Oversmoothing in Deep Vision Transformers via the Fourier Domain Analysis: From Theory to Practice

License: MIT

The official implementation of ICLR 2022 paper Anti-Oversmoothing in Deep Vision Transformers via the Fourier Domain Analysis: From Theory to Practice.

Peihao Wang, Wenqing Zheng, Tianlong Chen, Zhangyang (Atlas) Wang

This repository is built based on DeiT and CaiT official repository.

Introduction

Vision Transformer (ViT) has recently demonstrated promise in computer vision problems. However, unlike Convolutional Neural Networks (CNN), it is known that the performance of ViT saturates quickly with depth increasing, due to the observed attention collapse or patch uniformity. Despite a couple of empirical solutions, a rigorous framework studying on this scalability issue remains elusive. In this paper, we first establish a rigorous theory framework to analyze ViT features from the Fourier spectrum domain. We show that the self-attention mechanism inherently amounts to a low-pass filter, which indicates when ViT scales up its depth, excessive low-pass filtering will cause feature maps to only preserve their Direct-Current (DC) component. We then propose two straightforward yet effective techniques to mitigate the undesirable low-pass limitation. The first technique, termed AttnScale, decomposes a self-attention block into low-pass and high-pass components, then rescales and combines these two filters to produce an all-pass self-attention matrix. The second technique, termed FeatScale, re-weights feature maps on separate frequency bands to amplify the high-frequency signals. Both techniques are efficient and hyperparameter-free, while effectively overcoming relevant ViT training artifacts such as attention collapse and patch uniformity. By seamlessly plugging in our techniques to multiple ViT variants, we demonstrate that they consistently help ViTs benefit from deeper architectures, bringing up to 1.1% performance gains "for free" (e.g., with little parameter overhead).

Getting Started

Dependency

First of all, clone our repository locally:

git clone https://github.com/VITA-Group/ViT-Anti-Oversmoothing.git

Then, install the following Python libraries which are required to run our code:

pytorch 1.7.0
cudatoolkit 11.0
torchvision 0.8.0
timm 0.4.12

Data Preparation

Download and extract ImageNet train and val images from the official website. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

To automatically collate the dataset directory, you may find these shell scripts useful.

Usage

Training

Training AttnScale and FeatScale from scratch usually requires multiple GPUs. Please use the following command to train our model with distributed data parallel:

python -m torch.distributed.launch --nproc_per_node=<num_nodes> --master_port <port> --use_env \
main.py --auto_reload --model <model_name> --batch-size <batch_size> \
--data-path <data_path> --data-set IMNET --input-size 224 \
--output_dir <log_dir>

where <model_name> specifies the name of model to build. To specify our techniques, use names such as attnscale_<size>_<depth> or featscale_<size>_<depth> where <size> only supports base and small, and <depth> takes from 12 or 24.

To enable multinode training, please refer to this instruction.

To reproduce our results, please follow the command lines below:

12-layer DeiT-S + AttnScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29700 --use_env \
main.py --auto_reload --model attnscale_small_12 --batch-size 512 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_attnscale_small_12
24-layer DeiT-S + AttnScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29701 --use_env \
main.py --auto_reload --model attnscale_small_24 --batch-size 256 --drop 0.2 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_attnscale_small_24
12-layer DeiT-S + FeatScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29702 --use_env \
main.py --auto_reload --model featscale_small_12 --batch-size 512 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_featscale_small_12
24-layer DeiT-S + FeatScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29703 --use_env \
main.py --auto_reload --model featscale_small_24 --batch-size 256 --drop 0.2 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_featscale_small_24

Fine-tune

One can also fine-tune their pre-trained model with our add-ons. To train AttnScale or FeatScale from a checkpoint, specify --resume argument.

To reproduce our results, please follow the command lines below:

24-layer CaiT-S + AttnScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29704 --use_env \
main.py --auto_reload --model attnscale_cait_S24_224 --batch-size 128 \
--epochs 60 --lr 5e-5 --weight-decay 5e-4  --min-lr 1e-6 --warmup-epochs 1 --decay-epochs 5 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_attnscale_cait_s24_224 \
--resume </ckpt_path>
24-layer CaiT-S + FeatScale
python -m torch.distributed.launch --nproc_per_node=2 --master_port 29705 --use_env \
main.py --auto_reload --model featscale_cait_S24_224 --batch-size 128 \
--epochs 60 --lr 5e-5 --weight-decay 5e-4  --min-lr 1e-6 --warmup-epochs 1 --decay-epochs 5 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_featscale_cait_s24_224 \
--resume </ckpt_path>

Pre-trained Models

Our pre-trained model parameters can be downloaded from HuggingFace Hub. To evaluate our pre-trained models, please specify flags --eval and --resume to the path to the checkpoints. For example, to reproduce our results of DeiT-S + AttnScale, one can run the following command:

python -m torch.distributed.launch --nproc_per_node=2 --master_port 29701 --use_env \
main.py --model attnscale_small_12 --batch-size 256 --drop 0.2 \
--data-path </data_path> --data-set IMNET --input-size 224 \
--output_dir ./logs/imnet1k_attnscale_small_12 \
--eval --resume </ckpt_dir>/attnscale_small_12.pth

Citation

If you find this work or our code implementation helpful for your own resarch or work, please cite our paper.

@inproceedings{wang2022antioversmooth,
title={Anti-Oversmoothing in Deep Vision Transformers via the Fourier Domain Analysis: From Theory to Practice},
author={Wang, Peihao and Zheng, Wenqing and Chen, Tianlong and Wang, Zhangyang},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=O476oWmiNNp},
}

vit-anti-oversmoothing's People

Contributors

peihaowang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vit-anti-oversmoothing's Issues

About freq_decompose

Hi,the paper provides a novel analysis of the low-pass filtering of Transformers self-attention matrix from the frequency domain and includes extensive visualization.I have some questions about freq_decompose.
featscale.py 42-45
def freq_decompose(self, x):
x_d = torch.mean(x, -2, keepdim=True) # [bs, 1, dim]
x_h = x - x_d # high freq [bs, len, dim]
return x_d, x_h
when decompose the self-attention matrix,I have not found any code about frequency domain transforming,such as torch.FFT or DFT matrix.The process of freq_decompose is still on image space but can realize the decompose of DC and high component,or I miss the code of the frequency domain transforming?
Thank you!

FLOPs about the attention scale

Hi, Peihao. Thanls for the nice work.

The flops calculation code can be found at attn_scale.py:

def flops(self, N):
    # calculate flops for 1 window with token length of N
    flops = 0
    # qkv = self.qkv(x)
    flops += N * self.dim * 3 * self.dim
    # attn = (q @ k.transpose(-2, -1))
    flops += self.num_heads * N * (self.dim // self.num_heads) * N
    # attnscale
    flops += self.num_heads * N * N
    #  x = (attn @ v)
    flops += self.num_heads * N * N * (self.dim // self.num_heads)
    # x = self.proj(x)
    flops += N * self.dim * self.dim
    return flops
  1. is the computational effort of scale and softmax being ignored?

  2. is the computation of the averaging operation for attn_d ignored?

Thanks for your reply. :-)

How to plot Figure 4 in the paper?

Hi, the paper is quite interesting. And I wonder whether you can provide codes about plotting the cosine similarity in the Fig 4. of the paper?

Thanks!

How can i visualise the spectrum of the attention maps?

I really enjoyed your paper, and i would like to know how to get the figure 5 in your paper.I have tried the scipy.signal.freqz function but the results are very different from yours. I would be grateful if you could help me with the code.

About FeatScale's code

Hello, I have a question about the implementation.
In the featscale.py line 53-55

        x_attn = x_attn + x_d + x_h

        x = x + self.drop_path(x_attn + x_d + x_h)

You add x_d & x_h twice.
Why does the implementation need to be done this way? Is there any special purpose?

Cosine Similarity Graph

Hi ,thank you very much for the wonderful work.
I am wondering how the implementation of the cosine similarity graphs in the paper was done.I would appreciate if i can get a code implementation for that.

Question about Figure 5

Hi, Peihao!

This is an amazing paper with fantastic visualization results.

I have 2 questions about Visualize the spectrum of attention maps (Figure 5), that is:

  1. Is the model on which the plot in Figure 5 is based the pre-trained DeiT-S? I plotted some plots based on the visualization code you provided and the pre-trained weights open to DeiT, but they don't seem to match Figure 5.
  2. Is the matrix on which Figure 5 is plotted an attention map after softmax? Or is it from another location?

Thanks for replying!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.