GithubHelp home page GithubHelp logo

icadada / flatten-transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from leaplabthu/flatten-transformer

0.0 0.0 0.0 1.4 MB

Official repository of FLatten Transformer (ICCV2023)

Python 100.00%

flatten-transformer's Introduction

FLatten Transformer

This repo contains the official PyTorch code and pre-trained models for FLatten Transformer (ICCV 2023).

Introduction

Motivation

The quadratic computation complexity of self-attention $\mathcal{O}(N^2)$ has been a long-standing problem when applying Transformer models to vision tasks. Apart from reducing attention regions, linear attention is also considered as an effective solution to avoid excessive computation costs. By approximating Softmax with carefully designed mapping functions, linear attention can switch the computation order in the self-attention operation and achieve linear complexity $\mathcal{O}(N)$. Nevertheless, current linear attention approaches either suffer from severe performance drop or involve additional computation overhead from the mapping function. In this paper, we propose a novel Focused Linear Attention module to achieve both high efficiency and expressiveness.

Method

In this paper, we first perform a detailed analysis of the inferior performances of linear attention from two perspectives: focus ability and feature diversity. Then, we introduce a simple yet effective mapping function and an efficient rank restoration module and propose our Focused Linear Attention (FLatten) which adequately addresses these concerns and achieves high efficiency and expressive capability.

Results

  • Comparison of different models on ImageNet-1K.

  • Accuracy-Runtime curve on ImageNet.

Dependencies

  • Python 3.9
  • PyTorch == 1.11.0
  • torchvision == 0.12.0
  • numpy
  • timm == 0.4.12
  • einops
  • yacs

Data preparation

The ImageNet dataset should be prepared as follows:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Pretrained Models

Based on different model architectures, we provide several pretrained models, as listed below.

model Reso acc@1 config pretrained weights
FLatten-PVT-T $224^2$ 77.8 (+2.7) config TsinghuaCloud
FLatten-PVTv2-B0 $224^2$ 71.1 (+0.6) config TsinghuaCloud
FLatten-Swin-T $224^2$ 82.1 (+0.8) config TsinghuaCloud
FLatten-Swin-S $224^2$ 83.5 (+0.5) config TsinghuaCloud
FLatten-Swin-B $224^2$ 83.8 (+0.3) config TsinghuaCloud
FLatten-Swin-B $384^2$ 85.0 (+0.5) config TsinghuaCloud
FLatten-CSwin-T $224^2$ 83.1 (+0.4) config TsinghuaCloud

Evaluate one model on ImageNet:

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg <path-to-config-file> --data-path <imagenet-path> --output <output-path> --eval --resume <path-to-pretrained-weights>

Outputs of the four T/B0 pretrained models are:

[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 294): INFO  * Acc@1 77.758 Acc@5 93.910
[2023-07-21 07:50:09 flatten_pvt_tiny] (main.py 149): INFO Accuracy of the network on the 50000 test images: 77.8%

[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 294): INFO  * Acc@1 71.098 Acc@5 90.596
[2023-07-21 07:51:36 flatten_pvt_v2_b0] (main.py 149): INFO Accuracy of the network on the 50000 test images: 71.1%

[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 294): INFO  * Acc@1 82.106 Acc@5 95.900
[2023-07-21 07:46:13 flatten_swin_tiny_patch4_224] (main.py 149): INFO Accuracy of the network on the 50000 test images: 82.1%

[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 294): INFO  * Acc@1 83.130 Acc@5 96.376
[2023-07-21 07:52:46 FLatten_CSWin_tiny](main.py 149): INFO Accuracy of the network on the 50000 test images: 83.1%

Train Models from Scratch

  • To train FLatten-PVT-T/S/M/B on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_t.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_s.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_m.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_b.yaml --data-path <imagenet-path> --output <output-path> --find-unused-params
  • To train FLatten-PVT-v2-b0/1/2/3/4 on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b0.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b1.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b2.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b3.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_pvt_v2_b4.yaml --data-path <imagenet-path> --output <output-path>
  • To train FLatten-Swin-T/S/B on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_t.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_s.yaml --data-path <imagenet-path> --output <output-path>
python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b.yaml --data-path <imagenet-path> --output <output-path>
  • To train FLatten-CSwin-T/S/B on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_t.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_s.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99984
python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b.yaml --data-path <imagenet-path> --output <output-path> --model-ema --model-ema-decay 0.99982

Fine-tuning on higher resolution

Fine-tune a FLatten-Swin-B model pre-trained on 224x224 resolution to 384x384 resolution:

python -m torch.distributed.launch --nproc_per_node=8 main.py --cfg ./cfgs/flatten_swin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights>

Fine-tune a FLatten-CSwin-B model pre-trained on 224x224 resolution to 384x384 resolution:

python -m torch.distributed.launch --nproc_per_node=8 main_ema.py --cfg ./cfgs/flatten_cswin_b_384.yaml --data-path <imagenet-path> --output <output-path> --pretrained <path-to-224x224-pretrained-weights> --model-ema --model-ema-decay 0.99982

Acknowledgements

This code is developed on the top of Swin Transformer. The computational resources supporting this work are provided by Hangzhou High-Flyer AI Fundamental Research Co.,Ltd

Citation

If you find this repo helpful, please consider citing us.

@InProceedings{han2023flatten,
  title={FLatten Transformer: Vision Transformer using Focused Linear Attention},
  author={Han, Dongchen and Pan, Xuran and Han, Yizeng and Song, Shiji and Huang, Gao},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year={2023}
}

Contact

If you have any questions, please feel free to contact the authors.

Dongchen Han: [email protected]

Xuran Pan: [email protected]

flatten-transformer's People

Contributors

tian-qing001 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.