GithubHelp home page GithubHelp logo

retnet_vit-rmt-'s Introduction

Retentive Networks Meet Vision Transformers

This is a warehouse for "Retentive Networks Meet Vision Transformers" based on pytorch framework, can be used to train your image datasets. The code mainly comes from official code

Preparation

Download the dataset: flower_dataset.

Project Structure(VisRetNet)

├── datasets: Load datasets
    ├── my_dataset.py: Customize reading data sets and define transforms data enhancement methods
    ├── split_data.py: Define the function to read the image dataset and divide the training-set and test-set
    ├── threeaugment.py: Additional data augmentation methods
├── models: VisRetNet Model
    ├── VisRetNet.py: Construct "VisRetNet" models
    ├── VisRetNet_token_label_release.py: Construct "VisRetNet_token_label_release" models
├── util:
    ├── engine.py: Function code for a training/validation process
    ├── losses.py: Knowledge distillation loss, combined with teacher model (if any)
    ├── optimizer.py: Define Sophia optimizer
    ├── samplers.py: Define the parameter of "sampler" in DataLoader
    ├── utils.py: Record various indicator information and output and distributed environment
├── estimate_model.py: Visualized evaluation indicators ROC curve, confusion matrix, classification report, etc.
└── train_gpu.py: Training model startup file

Precautions

Before you use the code to train your own data set, please first enter the train_gpu.py file and modify the data_root, batch_size, num_workers and nb_classes parameters. If you want to draw the confusion matrix and ROC curve, you only need to set the predict parameter to True. Finally, if you want to use VisRetNet without token_label, please remember set the token_label parameter to False.

Use Sophia Optimizer (in util/optimizer.py)

You can use anther optimizer sophia, just need to change the optimizer in train_gpu.py, for this training sample, can achieve better results

# optimizer = create_optimizer(args, model_without_ddp)
optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=args.weight_decay)

Train this model

Parameters Meaning:

1. nproc_per_node: <The number of GPUs you want to use on each node (machine/server)>
2. CUDA_VISIBLE_DEVICES: <Specify the index of the GPU corresponding to a single node (machine/server) (starting from 0)>
3. nnodes: <number of nodes (machine/server)>
4. node_rank: <node (machine/server) serial number>
5. master_addr: <master node (machine/server) IP address>
6. master_port: <master node (machine/server) port number>

Note:

If you want to use multiple GPU for training, whether it is a single machine with multiple GPUs or multiple machines with multiple GPUs, each GPU will divide the batch_size equally. For example, batch_size=4 in my train_gpu.py. If I want to use 2 GPUs for training, it means that the batch_size on each GPU is 4. Do not let batch_size=1 on each GPU, otherwise BN layer maybe report an error. If you recive an error like "unrecognized arguments: --local-rank=1" when you use distributed multi-GPUs training, just replace the command "torch.distributed.launch" to "torch.distributed.run".

train model with single-machine single-GPU:

python train_gpu.py

train model with single-machine multi-GPU:

python -m torch.distributed.launch --nproc_per_node=8 train_gpu.py

train model with single-machine multi-GPU:

(using a specified part of the GPUs: for example, I want to use the second and fourth GPUs)

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train_gpu.py

train model with multi-machine multi-GPU:

(For the specific number of GPUs on each machine, modify the value of --nproc_per_node. If you want to specify a certain GPU, just add CUDA_VISIBLE_DEVICES= to specify the index number of the GPU before each command. The principle is the same as single-machine multi-GPU training)

On the first machine: python -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py

On the second machine: python -m torch.distributed.launch --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py

Citation

@article{fan2023rmt,
  title={Rmt: Retentive networks meet vision transformers},
  author={Fan, Qihang and Huang, Huaibo and Chen, Mingrui and Liu, Hongmin and He, Ran},
  journal={arXiv preprint arXiv:2309.11523},
  year={2023}
}
@article{besta2023graph,
  title={Graph of thoughts: Solving elaborate problems with large language models},
  author={Besta, Maciej and Blach, Nils and Kubicek, Ales and Gerstenberger, Robert and Gianinazzi, Lukas and Gajda, Joanna and Lehmann, Tomasz and Podstawski, Michal and Niewiadomski, Hubert and Nyczyk, Piotr and others},
  journal={arXiv preprint arXiv:2308.09687},
  year={2023}
}

retnet_vit-rmt-'s People

Contributors

jiaowoguanren0615 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

retnet_vit-rmt-'s Issues

Dataset

Can you please point me to the dataset used to test this model ?
Thank you in advance

Reproduction Implementation

It's an excellent work! Your implementation of RMT is truly impressive!
Nevertheless, I have a couple of questions regarding the code implementation. Was the "1D to 2D" and the "Decomposed ReSA in Early Stage" described in the paper also implemented in the code? If they were, could you please indicate their specific location in the code? I would greatly appreciate your assistance in clarifying this matter.

Softmax in Retention?

Hey,
I am working on a similar project and wanted to understand RMT paper more clearly. And your work helped a lot! Thanks! It is very well written! I had a few doubts.

  1. You have used swish gate and Group Norm at places, (which is correct according to original retention paper). But because you said that this is a implementation of RMT, I am confused as to why did you implement GroupNorm and not softmax, as in RMT the authors have claimed to use softmax.
  2. The D mask/ gamma implementation of yours seems to make the retention matrix to 2D just as RMT, but I could not figure out where exactly are you making the matrix bidirectional? RMT claims to make the Retention matrix both, bi-directional and 2D. So I am kind of confused about that.
  3. You have used XPos encoding, yes it is exactly like pure retenton, but RMT dooes not mention using XPos AFAIK.
    Thanks for publishing this code. This has been very useful.
    Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.