GithubHelp home page GithubHelp logo

mae's Introduction

Masked Autoencoders: A PyTorch Implementation

Steps

  1. git clone https://github.com/Seeeeeyo/mae.git

0.1)

Use python 3.8 (should work). I tried only with python 3.10 but had to to the following changes:

In “/usr/local/lib/python3.10/dist-packages/timm/models/layers/helpers.py”, Add

import torch 
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
    from torch._six import container_abcs
else:
    import collections.abc as container_abcs

In “/content/mae/util/misc.py”, change from torch._six import inf By from torch import inf


1) 

*Download the data from the drive 'data_60k' (full medmnist dataset) and 'data_sampled.zip' (fractions of medmnist dataset).
*Download a medical classification dataset. I let you choose one, as you know these better than I do. Let's try to find one which 
 is kind of similar to MedMnist to hopefully reach some performances. The data structure should be as follow: 
	- eval_data
		- train 
			-class1
				-img1
				-img2
				-...
			-class2
			-...

  1. cd mae

  1. wget -nc https://dl.fbaipublicfiles.com/mae/finetune/mae_pretrain_vit_base.pth

!pip install submitit
!pip install timm==0.3.2

  1. Evaluate the mae_vit_base on eval_data python main_finetune.py --eval --resume mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 32 --data_path 'eval_data'

  1. FINETUNE
python main_finetune.py \
    --accum_iter 1 \
    --batch_size 32 \
    --model vit_base_patch16 \
    --finetune 'mae_pretrain_vit_base.pth' \
    --epochs 50 \
    --blr 5e-4 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
    --dist_eval --data_path 'AGI/data_60k'    

OR (couldn't try it without cluster and multiple GPU):

python submitit_finetune.py \
	--job_dir ${JOB_DIR} \
    --nodes 4 \ TO CHANGE 
    --batch_size 32 \ 
    --model vit_base_patch14 \
    --finetune 'mae_pretrain_vit_base.pth' \
    --epochs 50 \
    --blr 1e-3 --layer_decay 0.75 \
    --weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \
    --dist_eval --data_path 'AGI/data_60k'

OR (if running on 1 node with 8 GPUs. Couldn't try it without cluster and multiple GPU):

MP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
    --accum_iter 4 \
    --batch_size 32 \
    --model vit_base_patch16 \
    --finetune 'mae_pretrain_vit_base.pth' \
    --epochs 50 \
    --blr 5e-4 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
    --dist_eval --data_path 'AGI/data_60k'

model_path = 'mae/output_dir/checkpoint-49.pth'
  1. Evaluate the finetune mae_vit_base on eval_data
python main_finetune.py --eval --resume {model_path} --model vit_base_patch16 --batch_size TODO --data_path 'eval_data'

  1. Repeat 6) and 7) for the different dataset sizes ('data_sampled_6k', 'data_sampled_36k', 'data_sampled_600').

IF NEEDED

9)a) In case the results are shitty, we might need to pretrain the model on MedMnist and then finetune on the eval_data.

python submitit_pretrain.py \
    --job_dir ${JOB_DIR} \
    --nodes 8 \
    --use_volta32 \
    --batch_size 64 \
    --model mae_vit_large_patch16 \
    --norm_pix_loss \
    --mask_ratio 0.75 \
    --epochs 800 \
    --warmup_epochs 40 \
    --blr 1.5e-4 --weight_decay 0.05 \
    --data_path 'AGI/data_60k'

9)b)

python main_finetune.py \
    --accum_iter 1 \
    --batch_size 32 \
    --model vit_base_patch16 \
    --finetune 'mae_pretrain_vit_base.pth' \
    --epochs 50 \
    --blr 5e-4 --layer_decay 0.65 \
    --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
    --dist_eval --data_path 'eval_data'

This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners:

@Article{MaskedAutoencoders2021,
  author  = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll{\'a}r and Ross Girshick},
  journal = {arXiv:2111.06377},
  title   = {Masked Autoencoders Are Scalable Vision Learners},
  year    = {2021},
}
  • The original implementation was in TensorFlow+TPU. This re-implementation is in PyTorch+GPU.

  • This repo is a modification on the DeiT repo. Installation and preparation follow that repo.

  • This repo is based on timm==0.3.2, for which a fix is needed to work with PyTorch 1.8.1+.

Catalog

  • Visualization demo
  • Pre-trained checkpoints + fine-tuning code
  • Pre-training code

Visualization demo

Run our interactive visualization demo using Colab notebook (no GPU needed):

Fine-tuning with pre-trained checkpoints

The following table provides the pre-trained checkpoints used in the paper, converted from TF/TPU to PT/GPU:

ViT-Base ViT-Large ViT-Huge
pre-trained checkpoint download download download
md5 8cad7c b8b06e 9bdbb0

The fine-tuning instruction is in FINETUNE.md.

By fine-tuning these pre-trained models, we rank #1 in these classification tasks (detailed in the paper):

ViT-B ViT-L ViT-H ViT-H448 prev best
ImageNet-1K (no external data) 83.6 85.9 86.9 87.8 87.1
following are evaluation of the same model weights (fine-tuned in original ImageNet-1K):
ImageNet-Corruption (error rate) 51.7 41.8 33.8 36.8 42.5
ImageNet-Adversarial 35.9 57.1 68.2 76.7 35.8
ImageNet-Rendition 48.3 59.9 64.4 66.5 48.7
ImageNet-Sketch 34.5 45.3 49.6 50.9 36.0
following are transfer learning by fine-tuning the pre-trained MAE on the target dataset:
iNaturalists 2017 70.5 75.7 79.3 83.4 75.4
iNaturalists 2018 75.4 80.1 83.0 86.8 81.2
iNaturalists 2019 80.5 83.4 85.7 88.3 84.1
Places205 63.9 65.8 65.9 66.8 66.0
Places365 57.9 59.4 59.8 60.3 58.0

Pre-training

The pre-training instruction is in PRETRAIN.md.

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.