GithubHelp home page GithubHelp logo

icadada / mae-pytorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from pengzhiliang/mae-pytorch

0.0 0.0 0.0 214 KB

Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

Python 100.00%

mae-pytorch's Introduction

Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

This repository is built upon BEiT, thanks very much!

Now, we implement the pretrain and finetune process according to the paper, but still can't guarantee the performance reported in the paper can be reproduced!

Difference

shuffle and unshuffle

shuffle and unshuffle operations don't seem to be directly accessible in pytorch, so we use another method to realize this process:

  • For shuffle, we use the method of randomly generating mask-map (14x14) in BEiT, where mask=0 illustrates keeping the token, mask=1 denotes dropping the token (not participating caculation in encoder). Then all visible tokens (mask=0) are fed into encoder network.
  • For unshuffle, we get the postion embeddings (with adding the shared mask token) of all masked tokens according to the mask-map and then concate them with the visible tokens (from encoder), and feed them into the decoder network to recontrust.

sine-cosine positional embeddings

The positional embeddings mentioned in the paper are sine-cosine version. And we adopt the implemention of here, but it seems like a 1-D embeddings not 2-D's. So we don't know what effect it will bring. And I find the 2D's sine-cosine positional embeddings in MoCoV3. If someone is interested, you can try it.

TODO

  • implement the finetune process
  • reuse the model in modeling_pretrain.py
  • caculate the normalized pixels target
  • add the cls token in the encoder
  • visualization of reconstruction image
  • knn and linear prob
  • ...

Setup

pip install -r requirements.txt

Run

  1. Pretrain
# Set the path to save checkpoints
OUTPUT_DIR='output/pretrain_mae_base_patch16_224'
# path to imagenet-1k train set
DATA_PATH='/path/to/ImageNet_ILSVRC2012/train'


# batch_size can be adjusted according to the graphics card
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_mae_pretraining.py \
        --data_path ${DATA_PATH} \
        --mask_ratio 0.75 \
        --model pretrain_mae_base_patch16_224 \
        --batch_size 128 \
        --opt adamw \
        --opt_betas 0.9 0.95 \
        --warmup_epochs 40 \
        --epochs 1600 \
        --output_dir ${OUTPUT_DIR}
  1. Finetune
# Set the path to save checkpoints
OUTPUT_DIR='output/'
# path to imagenet-1k set
DATA_PATH='/path/to/ImageNet_ILSVRC2012'
# path to pretrain model
MODEL_PATH='/path/to/pretrain/checkpoint.pth'

# batch_size can be adjusted according to the graphics card
OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 run_class_finetuning.py \
    --model vit_base_patch16_224 \
    --data_path ${DATA_PATH} \
    --finetune ${MODEL_PATH} \
    --output_dir ${OUTPUT_DIR} \
    --batch_size 128 \
    --opt adamw \
    --opt_betas 0.9 0.999 \
    --weight_decay 0.05 \
    --epochs 100 \
    --dist_eval
  1. Visualization of reconstruction
# Set the path to save images
OUTPUT_DIR='output/'
# path to image for visualization
IMAGE_PATH='files/ILSVRC2012_val_00031649.JPEG'
# path to pretrain model
MODEL_PATH='/path/to/pretrain/checkpoint.pth'

# Now, it only supports pretrained models with normalized pixel targets
python run_mae_vis.py ${IMAGE_PATH} ${OUTPUT_DIR} ${MODEL_PATH}

Result

model pretrain finetune accuracy log weight
vit-base 400e 100e 83.1% pretrain finetune Google drive BaiduYun(code: mae6)
vit-large 400e 50e 84.5% pretrain finetune unavailable

Due to the limited gpus, it's really a chanllenge for us to pretrain with larger model or longer schedule mentioned in the paper. (the pretraining and end-to-end fine-tuning process of vit-large model are fininshed by this enthusiastic handsome guy with many v100s, but the weights are unavailable)

So if one can fininsh it, please feel free to report it in the issue or push a PR, thank you!

And your star is my motivation, thank u~

mae-pytorch's People

Contributors

pengzhiliang avatar flishwang avatar tikboahit avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.