GithubHelp home page GithubHelp logo

gengsheng-li / cmae Goto Github PK

View Code? Open in Web Editor NEW

This project forked from zhichenghuang/cmae

0.0 0.0 0.0 952 KB

The official implementation of CMAE https://arxiv.org/abs/2207.13532 and https://ieeexplore.ieee.org/document/10330745

Shell 0.24% Python 99.76%

cmae's Introduction

Contrastive Masked Autoencoders are Stronger Vision Learners

The official implementation of the paper Contrastive Masked Autoencoders are Stronger Vision Learners

@article{huang2023contrastive,
  title={Contrastive masked autoencoders are stronger vision learners},
  author={Huang, Zhicheng and Jin, Xiaojie and Lu, Chengze and Hou, Qibin and Cheng, Ming-Ming and Fu, Dongmei and Shen, Xiaohui and Feng, Jiashi},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year={2023},
  publisher={IEEE}
}

Installation

Step 0. Download and install Miniconda from official website

Step 1. Create a conda environment and activate it

conda create --name cmae python=3.9 -y
conda activate cmae

Step 2. Install PyTorch following official instructions, e.g.

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia

Step 3. Install MMCV

pip3 install boto3 addict tqdm regex pyyaml opencv-python opencv-contrib-python future tensorboard filelock yapf attrs einops opencv-python-headless numpy scipy  mmcv==2.1.0 psutil mmengine timm

Step 4. Install CMAE

git clone https://github.com/ZhichengHuang/CMAE
cd CMAE
pip3 install -v -e .

Prepare ImageNet

  1. Register an account and login to the download page

  2. Find download links for ILSVRC2012 and download the following two files:

    • ILSVRC2012_img_train.tar
    • ILSVRC2012_img_val.tar
    • ILSVRC2012_devkit_t12.tar.gz
  3. Untar the download files

    mkdir train && tar -xvf ILSVRC2012_img_train.tar -C train && for x in `ls train/*tar`; do fn=train/`basename $x .tar`; mkdir $fn; tar -xvf $x -C $fn; rm -f $fn.tar; done
    # if success, then retrun 11,281,167
    ls -lR|grep "^-"|wc -l
    
    # unzip validation
    mkdir val
    tar xvf ILSVRC2012_img_val.tar -C ./val
    
    tar -xzf ILSVRC2012_devkit_t12.tar.gz
    
  4. The project has the following structures.

    CMAE
    ├── cmae
    ├── tools
    ├── configs
    ├── docs
    ├── work_dirs
    │   ├──init
    ├── data
    │   ├── ImageNet
    │   │   ├── ImageNet_train.json
    │   │   ├── ImageNet_val.json
    │   │   ├── train
    │   │   ├── val
    
    

Fine-tuning with pre-trained checkpoints

The following table provides the pre-trained checkpoints

Size Model Epoch PT FT ImageNet Acc.
ViT-Base MAE 300 - - 82.9
ViT-Base MAE 1600 - - 83.6
ViT-Base CMAE 300 aws|ali aws|ali 83.8
ViT-Base CMAE 1600 aws|ali aws|ali 84.7
# download the pre-training weights to the "work_dirs/init/"
cd work_dirs/init/
wget https://cmae.s3.us-west-1.amazonaws.com/weight/cmae_vit-base-p16_32xb128-coslr-300e_in1k.pth
#run the fine-tuning script with 8 GPUs
bash tools/dist_train.sh configs/fine-tuning/cmae/cmae_base_pre300_8x128_100e_in1k_amp.py 8
# direct test without fine-tuning model
mkdir -p work_dirs/fine-tuning/cmae_base_pre300_8x128_100e_in1k_amp
cd work_dirs/fine-tuning/cmae_base_pre300_8x128_100e_in1k_amp
#download the fine-tuning weight
wget https://cmae.s3.us-west-1.amazonaws.com/weight/cmae_base_pre300_8x128_100e_in1k_amp_fine.pth
#test the model
bash tools/dist_test.sh configs/fine-tuning/cmae/cmae_base_pre300_8x128_100e_in1k_amp.py work_dirs/fine-tuning/cmae_base_pre300_8x128_100e_in1k_amp/cmae_base_pre300_8x128_100e_in1k_amp_fine.pth 8

Pre-training

# pre-training the model with 4*8 GPUs,  
# node 0
bash tools/dist_train.sh configs/pre-training/cmae/cmae_vit-base-p16_32xb128-coslr-1600e_in1k.py 8 4 0 
# node 1
bash tools/dist_train.sh configs/pre-training/cmae/cmae_vit-base-p16_32xb128-coslr-1600e_in1k.py 8 4 1
# node 2
bash tools/dist_train.sh configs/pre-training/cmae/cmae_vit-base-p16_32xb128-coslr-1600e_in1k.py 8 4 2
# node 3
bash tools/dist_train.sh configs/pre-training/cmae/cmae_vit-base-p16_32xb128-coslr-1600e_in1k.py 8 4 3

# when you finish the pre-training you can use the script to extract the backbone weight for downstream tasks fine-tuning
python3 tools/extract_backbone_weights.py pre-training.pth init_weight.pth

more details for training model with multiple machines, you can refer to mmengine.

Our previous code was based on mmcv. Due to the upgrade of some external libraries, some areas were incompatible. Therefore, we redesigned all the code based on mmengine, and the pytorch version was also upgraded to torch2.1.

Thanks to mmengine and mmselfsup, the code organization structure and part of the code in this project refer to these two open source projects.

cmae's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.