GithubHelp home page GithubHelp logo

s-chh / vision-transformer-vit-positional-encoding Goto Github PK

View Code? Open in Web Editor NEW
2.0 2.0 0.0 530 KB

Different 2-D Positional Encodings for Vision Transformers (ViT). Available Positional Embeddings: Sinusoidal (Absolute), Learnable, Relative and Rotation (Rope).

Python 98.77% Shell 1.23%
vision-transformer-positional-encoding vit-learn-position vit-no-position vit-positional-encoding vit-relative vit-rope vit-sinusoidal

vision-transformer-vit-positional-encoding's Introduction

2D Positional Encodings for Vision Transformers (ViT)

Implemented 2-D Positional Embeddings: No Position, Learnable, Sinusoidal (Absolute), Relative, and Rotary (RoPe).

  • Works by splitting dimensions into two parts and implements 1D positional encoding on each part.
  • One part uses the x-positions sequence, and the other uses y-positions. Check below for more details.
  • Classification token is handled differently in all methods. Check below for more details.
  • Network used here in a scaled-down version of the original ViT with only 800k parameters.
  • Works with small datasets by using a smaller patch size of 4.
  • Datasets tested: CIFAR10 and CIFAR100

Appreciate any feedback I can get on this.
If you want me to include any new positional encoding, feel free to raise it as an issue.

Run commands (also available in scripts.sh):

Different positional encoding can be chosen using the pos_embed argument. Example:

Positional Encoding Type Run command
No Position python main.py --dataset cifar10 --pos_embed none
Learnable python main.py --dataset cifar10 --pos_embed learn
Sinusoidal (Absolute) python main.py --dataset cifar10 --pos_embed sinusoidal
Relative python main.py --dataset cifar10 --pos_embed relative --max_relative_distance 2
Rotary (Rope) python main.py --dataset cifar10 --pos_embed rope
Relative Positional Encoding uses a "max_relative_distance" hyper-parameter to clamp distances between -max_relative_distance and max_relative_distance (referred to as k in paper).
The dataset can be changed using the dataset argument.

Results

Test set accuracy when ViT is trained using different positional Encoding.

Positional Encoding Type CIFAR10 CIFAR100
No Position 79.63 53.25
Learnable 86.52 60.87
Sinusoidal (Absolute) 86.09 59.73
Relative 90.57 65.11
Rotary (Rope) 88.49 62.88

Splitting X and Y-axis to Multiple 1D Positional Encoding:

A naive way to apply 1-D positional Encoding is to apply it directly to the sequence generated by flattening the patches, as shown below. However, this does not relate to the 2-D spatial positioning of images.

To handle this, the encoding dimensions are split into two parts. One part uses the x-axis position sequence, and the other part uses the y-axis position sequence. 2-D positioning is split into two 1-D positions:

The x and y-axis sequences are replicated using get_x_positions and get_y_positions in the utils.py file. This provides combined 2-D spatial positioning of patches to the Vision Transformer. Example below:


Handling Classification Token:

Many of the Positional Encoding were designed to work without classification tokens. When a classification token is present, some techniques (Sinusoidal, Relative, and Rotary) must be adapted.

Positional Encoding Type Classification Token's Positional Encoding
No Position No positional encoding added.
Learnable Classification token learns its positional encoding
Sinusoidal (Absolute) Sinusoidal positional encoding is provided to the patch tokens only, and the classification token learns its positional encoding
Relative Relative positional encoding uses relative distance between tokens to provide positional information. However, the classification token is always the first and should not be considered when calculating relative distances. One solution is not to consider the distances to the classification token. Instead, I used a fixed separate index (0 here) in the encoding lookup tables to represent the distance from all the tokens to the classification token.
Rotary (Rope) X and Y positions start at 1 instead of 0. The 0th index indicates the position of the classification token and results in no change/rotation to the classification token. The rest of the tokens are handled normally.

Parameters Comparison:

Comparison of additional parameters added by different positional encoding.

Positional Encoding Type Additional Parameters Explaination Parameters Count
No Position N/A 0
Learnable Number of Patches x Embed dim 64 x 128 = 8192
Sinusoidal (Absolute) No learned parameters 0
Relative (2 x max_relative_distance + 1 + 1) x Embed_dim/(2 x Number_of_attention_heads) x 2 x 2 x Number_of_encoder_blocks (2 x 2 + 1 + 1) x 128/(2 x 4) x 2 x 2 x 6 = 2304
Rotary (Rope) No learned parameters 0

Base Transformer Config:

Below are the base training and network details used in the experiments.

Input Size 3 X 32 X 32 Epochs 200
Patch Size 4 Batch Size 128
Sequence Length 8*8 = 64 Optimizer AdamW
Embedding Dim 128 Learning Rate 5e-4
Num of Layers 6 Weight Decay 1e-3
Num of Heads 4 Warmup epochs 10
Forward Multiplier 2 Warmup schedule Linear
Dropout 0.1 Learning Rate Decay Schedule Cosine
Parameters 820k Minimum Learning Rate 1e-5

Note: This repo is built upon the following GitHub repo: Vision Transformers from Scratch in PyTorch

Citations

@article{vaswani2017attention,
  title={Attention is all you need},
  author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia},
  journal={Advances in neural information processing systems},
  volume={30},
  year={2017}
}
@inproceedings{dosovitskiy2020image,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  booktitle={International Conference on Learning Representations},
  year={2020}
}
@article{shaw2018self,
  title={Self-attention with relative position representations},
  author={Shaw, Peter and Uszkoreit, Jakob and Vaswani, Ashish},
  journal={arXiv preprint arXiv:1803.02155},
  year={2018}
}
@article{su2024roformer,
  title={Roformer: Enhanced transformer with rotary position embedding},
  author={Su, Jianlin and Ahmed, Murtadha and Lu, Yu and Pan, Shengfeng and Bo, Wen and Liu, Yunfeng},
  journal={Neurocomputing},
  volume={568},
  pages={127063},
  year={2024},
  publisher={Elsevier}
}

vision-transformer-vit-positional-encoding's People

Contributors

s-chh avatar

Stargazers

 avatar Cheng Guo avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.