GithubHelp home page GithubHelp logo

whuhxb / semmae Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ucasligang/semmae

0.0 0.0 0.0 463 KB

[NeurIPS 2022] code for the paper, SemMAE: Semantic-guided masking for learning masked autoencoders

Python 100.00%

semmae's Introduction

Introduction

Paper accepted at NeurIPS 2022.

This is a official repository of SemMAE. Our code references the MAE, thanks a lot for their outstanding work! For details of our work see Semantic-Guided Masking for Learning Masked Autoencoders.

Citation

@article{li2022semmae,
  title={SemMAE: Semantic-Guided Masking for Learning Masked Autoencoders},
  author={Li, Gang and Zheng, Heliang and Liu, Daqing and Wang, Chaoyue and Su, Bing and Zheng, Changwen},
  journal={arXiv preprint arXiv:2206.10207},
  year={2022}
}

This implementation is in PyTorch+GPU.

  • This repo is based on timm==0.3.2, for which a fix is needed to work with PyTorch 1.8.1+.
  • It maybe needed for the repository: tensorboard. It can be installed by 'pip install '.

Process ImageNet dataset(including part mask and pixel values).

size 16x16 patch 8x8 patch
link download waiting
md5 losed waiting

Pretrained models

800-epochs ViT-Base 16x16 patch ViT-Base 8x8 patch
pretrained checkpoint download download
md5 1482ae 322b6a

Evaluation

As a sanity check, run evaluation using our ImageNet fine-tuned models:

800-epochs ViT-Base 16x16 patch ViT-Base 8x8 patch
fine-tuned checkpoint download download
md5 bbc5ef 6abd9e
reference ImageNet accuracy 83.352 84.444

Evaluate ViT-Base_16 in a single GPU (${IMAGENET_DIR} is a directory containing {train, val} sets of ImageNet):

python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint-99.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}

This should give:

* Acc@1 83.352 Acc@5 96.494 loss 0.745
Accuracy of the network on the 50000 test images: 83.4%

Evaluate ViT-Base_8 in a single GPU (${IMAGENET_DIR} is a directory containing {train, val} sets of ImageNet):

python main_finetune.py --eval --resume SemMAE_epoch799_vit_base_checkpoint_patch8-78.pth --model vit_base_patch8 --batch_size 8 --data_path ${IMAGENET_DIR}

This should give:

* Acc@1 84.444 Acc@5 97.032 loss 0.683
Accuracy of the network on the 50000 test images: 84.44%. 

Note that all of our results are obtained on the pretraining 800-epoches setting, the best checkpoint is lost for vit_base_patch8(The paper reported a performance of 84.5% top-1 acc vs. 84.44% in 78-th epoch).

Pre-training

To pre-train ViT-Large (recommended default) with multi-node distributed training, run the following on 8 nodes with 8 GPUs each:

python -m torch.distributed.launch --nproc_per_node=8 --master_port=${MASTER_PORT} \
        --nnodes=${NNODES} --node_rank=\${SLURM_NODEID} --master_addr=${MASTER_ADDR} \
        --use_env main_pretrain_setting3.py \
        --output_dir ${OUTPUT_DIR} --log_dir=${OUTPUT_DIR} \
        --batch_size 128 \
        --model mae_vit_base_patch16 \
        --norm_pix_loss \
        --mask_ratio 0.75 \
        --epochs 800 \
        --warmup_epochs 40 \
        --blr 1.5e-4 --weight_decay 0.05 \
        --setting 3 \
        --data_path ${DATA_DIR}

Note that the input path ${DATA_DIR} is our processed dataset path.

Contact

This repo is currently maintained by Gang Li(@ucasligang).

semmae's People

Contributors

ucasligang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.