GithubHelp home page GithubHelp logo

wgcban / mix-bt Goto Github PK

View Code? Open in Web Editor NEW
12.0 2.0 1.0 8.88 MB

Official PyTorch Implementation of Guarding Barlow Twins Against Overfitting with Mixed Samples

Home Page: https://arxiv.org/abs/2312.02151

License: MIT License

Python 80.67% Shell 19.20% Dockerfile 0.13%
barlow-twins byol contrastive-learning mixup representation-learning self-supervised-learning simclr cifar-10 cifar-100 stl-10

mix-bt's Introduction

Mixed Barlow Twins for Self-Supervised Representation Learning

Guarding Barlow Twins Against Overfitting with Mixed Samples

arXiv Hugging Face Model Card

Wele Gedara Chaminda Bandara (Johns Hopkins University), Celso M. De Melo (U.S. Army Research Laboratory), and Vishal M. Patel (Johns Hopkins University)

1 Overview of Mixed Barlow Twins

TL;DR

  • Mixed Barlow Twins aims to improve sample interaction during Barlow Twins training via linearly interpolated samples.
  • We introduce an additional regularization term to the original Barlow Twins objective, assuming linear interpolation in the input space translates to linearly interpolated features in the feature space.
  • Pre-training with this regularization effectively mitigates feature overfitting and further enhances the downstream performance on CIFAR-10, CIFAR-100, TinyImageNet, STL-10, and ImageNet datasets.

$C^{MA} = (Z^M)^TZ^A$

$C^{MB} = (Z^M)^TZ^B$

$C^{MA}_{gt} = \lambda (Z^A)^TZ^A + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^A$

$C^{MB}_{gt} = \lambda (Z^A)^TZ^B + (1-\lambda)\mathtt{Shuffle}^*(Z^B)^TZ^B$

2 Usage

2.1 Requirements

Before using this repository, make sure you have the following prerequisites installed:

You can install PyTorch with the following command (in Linux OS):

conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia

2.2 Installation

To get started, clone this repository:

git clone https://github.com/wgcban/mix-bt.git

Next, create the conda environment named ssl-aug by executing the following command:

conda env create -f environment.yml

All the train-val-test statistics will be automatically upload to wandb, and please refer wandb-quick-start documentation if you are not familiar with using wandb.

2.3 Supported Pre-training Datasets

This repository supports the following pre-training datasets:

CIFAR-10, CIFAR-100, and STL-10 datasets are directly available in PyTorch.

To use TinyImageNet, please follow the preprocessing instructions provided in the TinyImageNet-Script. Download these datasets and place them in the data directory.

2.4 Supported Transfer Learning Datasets

You can download and place transfer learning datasets under their respective paths, such as 'data/DTD'. The supported transfer learning datasets include:

2.5 Supported SSL Methods

This repository supports the following Self-Supervised Learning (SSL) methods:

  • SimCLR: contrastive learning for SSL
  • BYOL: distilation for SSL
  • Witening MSE: infomax for SSL
  • Barlow Twins: infomax for SSL
  • Mixed Barlow Twins (ours): infomax + mixed samples for SSL

2.6 Pre-Training with Mixed Barlow Twins

To start pre-training and obtain k-NN evaluation results for Mixed Barlow Twins on CIFAR-10, CIFAR-100, TinyImageNet, and STL-10 with ResNet-18/50 backbones, please run:

sh scripts-pretrain-resnet18/[dataset].sh
sh scripts-pretrain-resnet50/[dataset].sh

To start the pre-training on ImageNet with ResNet-50 backbone, please run:

sh scripts-pretrain-resnet18/imagenet.sh

2.7 Linear Evaluation of Pre-trained Models

Before running linear evaluation, ensure that you specify the model_path argument correctly in the corresponding .sh file.

To obtain linear evaluation results on CIFAR-10, CIFAR-100, TinyImageNet, STL-10 with ResNet-18/50 backbones, please run:

sh scripts-linear-resnet18/[dataset].sh
sh scripts-linear-resnet50/[dataset].sh

To obtain linear evaluation results on ImageNet with ResNet-50 backbone, please run:

sh scripts-linear-resnet50/imagenet_sup.sh

2.8 Transfer Learning of Pre-trained Models

To perform transfer learning from pre-trained models on CIFAR-10, CIFAR-100, and STL-10 to fine-grained classification datasets, execute the following command, making sure to specify the model_path argument correctly:

sh scripts-transfer-resnet18/[dataset]-to-x.sh

3 Pre-Trained Checkpoints

Download the pre-trained models from GitHub (Releases v1.0.0) and store them in checkpoints/. This repository provides pre-trained checkpoints for both ResNet-18 and ResNet-50 architectures.

3.1 ResNet-18 [CIFAR-10, CIFAR-100, TinyImageNet, and STL-10]

Dataset $d$ $\lambda_{BT}$ $\lambda_{reg}$ Download Link to Pretrained Model KNN Acc. Linear Acc.
CIFAR-10 1024 0.0078125 4.0 4wdhbpcf_cifar10.pth 90.52 92.58
CIFAR-100 1024 0.0078125 4.0 76kk7scz_cifar100.pth 61.25 69.31
TinyImageNet 1024 0.0009765 4.0 02azq6fs_tiny_imagenet.pth 38.11 51.67
STL-10 1024 0.0078125 2.0 i7det4xq_stl10.pth 88.94 91.02

3.2 ResNet-50 [CIFAR-10, CIFAR-100, TinyImageNet, and STL-10]

Dataset $d$ $\lambda_{BT}$ $\lambda_{reg}$ Download Link to Pretrained Model KNN Acc. Linear Acc.
CIFAR-10 1024 0.0078125 4.0 v3gwgusq_cifar10.pth 91.39 93.89
CIFAR-100 1024 0.0078125 4.0 z6ngefw7_cifar100.pth 64.32 72.51
TinyImageNet 1024 0.0009765 4.0 kxlkigsv_tiny_imagenet.pth 42.21 51.84
STL-10 1024 0.0078125 2.0 pbknx38b_stl10.pth 87.79 91.70

3.3. ResNet-50 on ImageNet (300 epochs)

Setting: epochs = 300, $d$ = 8192, $\lambda_{BT}$ = 0.0051

$\lambda_{reg}$ Linear Acc. Download Link to Pretrained Model Train Log Download Link to Linear-Probed Model Val. Log
0.0 (BT) 71.3 3on0l4wl_resnet50.pth train_log checkpoint_3tb4tcvp.pth val_log
0.0025 70.9 l418b9zw_resnet50.pth train_log checkpoint_09g7ytcz.pth val_log
0.1 71.6 13awtq23_resnet50.pth train_log checkpoint_pgawzr4e.pth val_log
1.0 72.2 (best) 3fb1op86_resnet50.pth train_log checkpoint_wvi0hle8.pth val_log
2.0 72.1 5n9yqio0_resnet50.pth train_log checkpoint_p9aeo8ga.pth val_log
3.0 72.0 q03u2xjz_resnet50.pth train_log checkpoint_00atvp6x.pth val_log

3.4. ResNet-50 on ImageNet (1000 epochs)

Setting: epochs = 1000, $d$ = 8192, $\lambda_{BT}$ = 0.0051, $\lambda_{reg}$=2.0

Linear Eval. Top1 Linear Eval. Top5 Download Link to Pretrained Model Train Log Download Link to Linear-Probed Model Val. Log
74.06 (best) 91.47 4wpu8wmd_resnet50.pth train_log vfd2nu64_checkpoint.pth val_log

4 Training/Val Logs

3.1 Pre-trianing for 300 epochs

Logs are available on wandb and can access via following links:

Here we provide some training and validation (linear probing) statistics for Barlow Twins vs. Mixed Barlow Twins with ResNet-50 backbone on ImageNet:

3.1 Pre-trianing for 1000 epochs

We also provide trianing-val statistics for our pre-trained model for 1000 epochs.

๐Ÿ”ฅ Access pre-training statistcis on wandb: wandb-imagenet-pretrain

5 Disclaimer

A large portion of the code is from Barlow Twins HSIC (for experiments on small datasets: CIFAR-10, CIFAR-100, TinyImageNet, and STL-10) and official implementation of Barlow Twins here (for experiments on ImageNet), which is a great resource for academic development.

Also, note that the implementation of SOTA methods (SimCLR, BYOL, and Witening-MSE) in ssl-sota are copied from Witening-MSE.

We would like to thank all of them for making their repositories publicly available for the research community. ๐Ÿ™

6 Reference

If you feel our work is useful, please consider citing our work. Thanks!

@misc{bandara2023guarding,
      title={Guarding Barlow Twins Against Overfitting with Mixed Samples}, 
      author={Wele Gedara Chaminda Bandara and Celso M. De Melo and Vishal M. Patel},
      year={2023},
      eprint={2312.02151},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

7 License

This code is under MIT licence, you can find the complete file here.

mix-bt's People

Contributors

wgcban avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

sabadijou

mix-bt's Issues

i got bad clustering score for cifar-10 in resnet 50

@wgcban
Cluster Centers:
[[0.00748559 0.00252423 0.02586649 ... 0.00637615 0.0058655 0.00041923]
[0.00362701 0.00321243 0.01586772 ... 0.0027485 0.00394693 0.01120272]
[0.00392837 0.00411785 0.01762331 ... 0.00431684 0.00594873 0.0005309 ]
...
[0.00623401 0.00322232 0.03249142 ... 0.00796974 0.00280931 0.000397 ]
[0.00561143 0.00398928 0.02272538 ... 0.00499028 0.00987207 0.00204776]
[0.00240571 0.00451606 0.01574576 ... 0.00336538 0.00481124 0.0005551 ]]
Adjusted Rand Index (Train): -2.352540539571238e-05
Normalized Mutual Information (Train): 0.0024080991363850815
Silhouette Score (Train): 0.016357887536287308
Homogeneity (Train): 0.001759865775481539
Completeness (Train): 0.0038123494767823784
V-measure (Train): 0.0024080991363850815

Hi, thanks for the great method. i tried to use your mehtod to find clustering scores. Hoowever si do not get good scores despite high accracy. can you tell me what is the issue?

i tried both feature and out in line 118 from main.py to get my clustering score with the help fo the given checkpoints for resnet-50

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.