GithubHelp home page GithubHelp logo

pkurainbow / dino Goto Github PK

View Code? Open in Web Editor NEW

This project forked from facebookresearch/dino

0.0 1.0 0.0 24.37 MB

PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

License: Other

Python 100.00%

dino's Introduction

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the training and evaluation logs.

arch params k-nn linear download
DeiT-S/16 21M 74.5% 77.0% backbone only full checkpoint args logs eval logs
DeiT-S/8 21M 78.3% 79.7% backbone only full checkpoint args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full checkpoint args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full checkpoint args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full checkpoint args logs eval logs

The pretrained models are available on PyTorch Hub.

import torch
deits16 = torch.hub.load('facebookresearch/dino', 'dino_deits16')
deits8 = torch.hub.load('facebookresearch/dino', 'dino_deits8')
vitb16 = torch.hub.load('facebookresearch/dino', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino', 'dino_vitb8')
resnet50 = torch.hub.load('facebookresearch/dino', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training ๐Ÿฆ•

Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach ~69.3% on k-NN eval and ~73.8% on linear eval. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance ๐Ÿฆ–

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch deit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach ~73.4% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We will shortly provide training and linear evaluation logs for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example here is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs:

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

License

See the LICENSE file for more details.

Citation

If you find this repository useful, please consider giving a star โญ and citation ๐Ÿฆ–:

@article{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  journal={arXiv preprint arXiv:2104.14294},
  year={2021}
}

dino's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.