GithubHelp home page GithubHelp logo

mstc-xqp / cossl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from yue-fan/cossl

0.0 0.0 0.0 6.29 MB

Official PyTorch Implementation of "CoSSL: Co-Learning of Representation and Classifier for Imbalanced Semi-Supervised Learning" (CVPR 2022)

License: MIT License

Shell 2.92% Python 97.08%

cossl's Introduction

CoSSL: Co-Learning of Representation and Classifier for Imbalanced Semi-Supervised Learning

LICENSE Python PyTorch

This repository contains the PyTorch implementation for the CVPR 2022 Paper "CoSSL: Co-Learning of Representation and Classifier for Imbalanced Semi-Supervised Learning" by Yue Fan, Dengxin Dai, Anna Kukleva, and Bernt Schiele.

If you have any questions on this repository or the related paper, feel free to create an issue or send me an email.

Summary

Introduction

Standard semi-supervised learning (SSL) using class-balanced datasets has shown great progress to leverage unlabeled data effectively. However, the more realistic setting of class-imbalanced data - called imbalanced SSL - is largely underexplored and standard SSL tends to underperform. In this paper, we propose a novel co-learning framework (CoSSL), which decouples representation and classifier learning while coupling them closely. To handle the data imbalance, we devise Tail-class Feature Enhancement (TFE) for classifier learning. Furthermore, the current evaluation protocol for imbalanced SSL focuses only on balanced test sets, which has limited practicality in real-world scenarios. Therefore, we further conduct a comprehensive evaluation under various shifted test distributions. In experiments, we show that our approach outperforms other methods over a large range of shifted distributions, achieving state-of-the-art performance on benchmark datasets ranging from CIFAR-10, CIFAR-100, ImageNet, to Food-101.

Figure: Our co-learning framework CoSSL decouples the training of representation and classifier while coupling them in a non-gradient manner. CoSSL consists of three modules: a semi-supervised representation learning module, a balanced classifier learning module, and a carefully designed pseudo-label generation module. The representation module provides a momentum encoder for feature extraction in the other two modules, and the classifier module produces a balanced classifier using our novel Tail-class Feature Enhancement (TFE). Then, pseudo-label module generates pseudo-labels for the representation module using the momentum encoder and the balanced classifier. The interplay between these modules enhances each other, leading to both a more powerful representation and a more balanced classifier. Additionally, our framework is flexible as it can accommodate any standard SSL methods and classifier learning methods.

Installation

  • python 3.7
  • pytorch == 1.0.0
  • torchvision == 0.2.2.post3
  • randAugment (Pytorch re-implementation: https://github.com/ildoonet/pytorch-randaugment)
  • torchvision
  • progressbar

Running Experiments

We provide run*.sh for reproducing results in our paper.

Example

Here is an example of running CoSSL on CIFAR-10 at imbalance ratio 150 with FixMatch backbone:

Run pretrain phase:

python train_cifar_fix.py --ratio 2 --num_max 1500 --imb_ratio_l 150 --imb_ratio_u 150 --epoch 500 --val-iteration 500 --out ./results/cifar10/fixmatch/baseline/wrn28_N1500_r150_seed1 --manualSeed 1 --gpu 2

Applying CoSSL:

python train_cifar_fix_cossl.py --ratio 2 --num_max 1500 --imb_ratio_l 150 --imb_ratio_u 150 --epoch 100 --val-iteration 500 --resume ./results/cifar10/fixmatch/baseline/wrn28_N1500_r150_seed1/checkpoint_401.pth.tar --out ./results/cifar10/fixmatch/cossl/wrn28_N1500_r150_lam06_seed1 --max_lam 0.6 --manualSeed 1 --gpu 0

Performance

Test Acc. (%) CIFAR-10-LT γ=50 CIFAR-10-LT γ=100 CIFAR-10-LT γ=150
MixMatch+CoSSL 80.3 ± 0.31 76.4 ± 1.14 73.5 ± 1.25
ReMixMatch+CoSSL 87.7 ± 0.21 84.1 ± 0.56 81.3 ± 0.83
FixMatch+CoSSL 86.8 ± 0.30 83.2 ± 0.49 80.3 ± 0.55
Test Acc. (%) CIFAR-100-LT γ=20 CIFAR-100-LT γ=50 CIFAR-100-LT γ=100
ReMixMatch+CoSSL 55.8 ± 0.62 48.9 ± 0.61 44.1 ± 0.59
FixMatch+CoSSL 53.9 ± 0.78 47.6 ± 0.57 43.0 ± 0.61
Test Acc. (%) Food-101-LT γ=50 Food-101-LT γ=100
FixMatch+CoSSL 49.0 40.4
Test Acc. (%) Small-ImageNet-127 32x32 Small-ImageNet-127 64x64
FixMatch+CoSSL 43.7 54.4
  • The performance for the PyTorch version is under checking.

Citation

Please cite our paper if it is helpful to your work:

@inproceedings{fan2021cossl,
title={CoSSL: Co-Learning of Representation and Classifier for Imbalanced Semi-Supervised Learning},
author={Fan, Yue and Dai, Dengxin and Kukleva, Anna and Schiele, Bernt},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022}
}

Acknowledgements

Our implementations use the source code from the following repositories and users:

cossl's People

Contributors

yue-fan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.