GithubHelp home page GithubHelp logo

bootmae's Introduction

BootMAE, ECCV2022

This repo is the official implementation of "Bootstrapped Masked Autoencoders for Vision BERT Pretraining".

Introduction

We propose bootstrapped masked autoencoders (BootMAE), a new approach for vision BERT pretraining. BootMAE improves the original masked autoencoders (MAE) with two core designs:

  1. momentum encoder that provides online feature as extra BERT prediction targets;
  2. target-aware decoder that tries to reduce the pressure on the encoder to memorize target-specific information in BERT pretraining.

pipeline

Requirements

timm==0.3.4, pytorch>=1.7, opencv, ... , run:

bash setup.sh

Results

model Pretrain Epoch Pretrain Model Linear acc@1 Finetune Model Finetune acc@1
ViT-B 800 model 66.1 model 84.2
ViT-L 800 model 77.1 model 85.9

See Segmentation for segmetation results and config.

Pretrain

The BootMAE-base model can be pretrained on ImageNet-1k using 16 V100-32GB:

OUTPUT_DIR=/path/to/save/your_model
DATA_PATH=/path/to/imagenet

run_pretraining.py \
    --data_path ${DATA_PATH} \
    --output_dir ${OUTPUT_DIR} \
    --model ${MODEL} \
    --model_ema --model_ema_decay 0.999 --model_ema_dynamic \
    --batch_size 256 --lr 1.5e-4 --min_lr 1e-4 \
    --epochs 801 --warmup_epochs 40 --update_freq 1 \
    --mask_num 147 --feature_weight 1 --weight_mask 
  • --mask_num: number of the input patches need be masked.
  • --batch_size: batch size per GPU.
  • Effective batch size = number of GPUs * --batch_size. So in the above example, the effective batch size is 128*16 = 2048.
  • --lr: learning rate.
  • --warmup_epochs: learning rate warmup steps.
  • --epochs: total pre-training epochs.
  • --model_ema_decay: the start model ema decay, we increase it to 0.9999 at the first 100 epoch
  • --model_ema_dynamic: if True, further increase the ema from 0.9999 to 0.99999 at the first 400 epoch.
  • --feature_weight: weight of the feature prediction branch
  • --weight_mask: if True, assign larger loss weight to the center of the block region.

see scripts/pretrain for more config

Finetuning

For finetuning BootMAE-base on ImageNet-1K

MODEL=bootmae_base
OUTPUT_DIR=/path/to/save/your_model
DATA_PATH=/path/to/imagenet
FINE=/path/to/your_pretrain_model

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \
    --model ${MODEL} --data_path $DATA_PATH \
    --input_size 224 \
    --finetune ${FINE} \
    --num_workers 8 \
    --output_dir ${OUTPUT_DIR} \
    --batch_size 256 --lr 5e-3 --update_freq 1 \
    --warmup_epochs 20 --epochs 100 \
    --layer_decay 0.6 --backbone_decay 1 \
    --drop_path 0.1 \
    --abs_pos_emb --disable_rel_pos_bias \
    --weight_decay 0.05 --mixup 0.8 --cutmix 1.0 \
    --nb_classes 1000 --model_key model \
    --enable_deepspeed \
    --model_ema --model_ema_decay 0.9998 \
  • --batch_size: batch size per GPU.
  • Effective batch size = number of GPUs * --batch_size * --update_freq. So in the above example, the effective batch size is 16*64*2 = 2048.
  • --lr: learning rate.
  • --warmup_epochs: learning rate warmup epochs.
  • --epochs: total pre-training epochs.
  • --clip_grad: clip gradient norm.
  • --drop_path: stochastic depth rate.

see scripts/finetune for more config

Linear Probing

For evaluate linear probing accuracy of BootMAE-base on ImageNet-1K with 8 GPU

OUTPUT_DIR=/path/to/save/your_model
DATA_PATH=/path/to/imagenet
FINETUNE=/path/to/your_pretrain_model

LAYER=9

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
        main_linprobe.py \
        --batch_size 1024 --accum_iter 2 \
        --data_path ${DATA_PATH} --output_dir ${OUTPUT_DIR} \
        --model base_patch16_224 --depth ${LAYER} \
        --finetune ${FINETUNE} \
        --global_pool \
        --epochs 90 \
        --blr 0.1 \
        --weight_decay 0.0 \
        --dist_eval 
  • --batch_size: batch size per GPU.
  • Effective batch size = number of GPUs * --batch_size * --accum_iter. So in the above example, the effective batch size is 8*1024*2 = 16384.
  • --blr: base learning rate. the learning rate is --blr * effective batch size / 256
  • --epochs: total pre-training epochs.
  • --depth: index of the layer to evaluate

see scripts/linear for more config

Acknowledgments

This repository is modified from BEiT, built using the timm library, the DeiT repository and the Dino repository. The linear probing part is modified from MAE.

Citation

If you use this code for your research, please cite our paper.

@article{dong2022bootstrapped,
  title={Bootstrapped Masked Autoencoders for Vision BERT Pretraining},
  author={Dong, Xiaoyi and Bao, Jianmin and Zhang, Ting and Chen, Dongdong and Zhang, Weiming and Yuan, Lu and Chen, Dong and Wen, Fang and Yu, Nenghai},
  journal={arXiv preprint arXiv:2207.07116},
  year={2022}
}

bootmae's People

Contributors

lightdxy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

bootmae's Issues

论文细节请教

您好,有两个问题想请教一下:
1、动量编码器部分的权重在预训练过程中会进行更新吗?具体的代码在那一部分呀
2、请问怎么理解ema的权重使0.999,为什么这么大的参数效果会比较好
谢谢!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.