GithubHelp home page GithubHelp logo

stablenet's Introduction

StableNet

StableNet is a deep stable learning method for out-of-distribution generalization.

This is the official repo for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization" and the arXiv version can be found at https://arxiv.org/abs/2104.07876.

Please note that some hyper-parameters(such as lrbl, epochb, lambdap) may affect the performance , which can vary among different tasks/environments/software/hardware/random seeds, and thus careful tunning is required. Similar to other DG repositories, direct migration may lead to different results as ours. We are sorry for this and trying to address this problem in the following work.

Introduction

Approaches based on deep neural networks have achieved striking performance when testing data and training data share similar distribution, but can significantly fail otherwise. Therefore, eliminating the impact of distribution shifts between training and testing data is crucial for building performance-promising deep models. Conventional methods assume either the known heterogeneity of training data (e.g. domain labels) or the approximately equal capacities of different domains. In this paper, we consider a more challenging case where neither of the above assumptions holds. We propose to address this problem by removing the dependencies between features via learning weights for training samples, which helps deep models get rid of spurious correlations and, in turn, concentrate more on the true connection between discriminative features and labels. Extensive experiments clearly demonstrate the effectiveness of our method on multiple distribution generalization benchmarks compared with state-of-the-art counterparts. Through extensive experiments on distribution generalization benchmarks including PACS, VLCS, MNIST-M, and NICO, we show the effectiveness of our method compared with state-of-the-art counterparts.

Installation

Requirements

  • Linux with Python >= 3.6
  • PyTorch >= 1.1.0
  • torchvision >= 0.3.0
  • tensorboard >= 1.14.0

Quick Start

Train StableNet

python main_stablenet.py --gpu 0

You can see more options from

python main_stablenet.py -h

Result files will be saved in results/.

Performance and trained models

setting dataset source domain target domain network dataset split accuracy trained model
unbalanced(5:1:1) PACS A,C,S photo ResNet18 split file 94.864 model file
unbalanced(5:1:1) PACS C,S,P art_painting ResNet18 split file 80.344 model file
unbalanced(5:1:1) PACS A,S,P cartoon ResNet18 split file 74.249 model file
unbalanced(5:1:1) PACS A,C,P sketch ResNet18 split file 71.046 model file
unbalanced(5:1:1) VLCS L,P,S caltech ResNet18 split file 88.776 model file
unbalanced(5:1:1) VLCS C,P,S labelme ResNet18 split file 63.243 model file
unbalanced(5:1:1) VLCS C,L,S pascal ResNet18 split file 66.383 model file
unbalanced(5:1:1) VLCS C,L,P sun ResNet18 split file 55.459 model file
flexible(5:1:1) PACS - - ResNet18 split file 45.964 model file
flexible(5:1:1) VLCS - - ResNet18 split file 81.157 model file

Citing StableNet

If you find this repo useful for your research, please consider citing the paper.

@inproceedings{zhang2021deep,
  title={Deep Stable Learning for Out-Of-Distribution Generalization},
  author={Zhang, Xingxuan and Cui, Peng and Xu, Renzhe and Zhou, Linjun and He, Yue and Shen, Zheyan},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={5372--5382},
  year={2021}
}

stablenet's People

Contributors

windxrz avatar xxgege avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

stablenet's Issues

covariance matrix with w or without w

Hello, author, in the cov method of loss_reweighting.py file in your code, you define the operation of the covariance matrix with and without w, but it seems that the calculation of w does not match the definition in your paper, I do not know whether it is a code error or I understand the error.
QQ截图20211124204138

Seems like a mismatch between implementation code & mathematical formulation in the paper

My understanding:

  • RFF: random_fourier_features_gpu() adds a new dimension of size num_f rather than performs dimension transformation.
  • lossb_expect(cfeaturec, num_f) computes the square of the Frobenius norm of the original feature vector (of size batch_size * feature_dimemsion ) at each num_f dimension, then accumulating them minus the trace of covariance matrix gives the final loss.

Question:

  • Does cfeaturec refer to A_i and B_i, and n the batch size in Eq. (3)?
  • num_f refer to n_A and n_B in Eq. (4)?
    image

Reproduction about NICO dataset

Thanks for your great work!

In your paper, you introduced results for NICO dataset. But in this repo, there is no dataset split file for it. I try to split the dataset according to the description in the paper. Then I try to reproduce experiments of baseline ResNet-18 and StableNet.

The results show that best accuracy of baseline ResNet-18 is 47.71 while in paper it is 51.71. The gap seems small. However, the best accuracy of StableNet is 48.20 while in paper it is 59.76, which is confusing.

I know there are some variance about randomness of data split and difference of hyperparameter tuning. Could you please provide the dataset split file of NICO and recommended hyperparameter setting for it? Thank you!

can't reproduce the result

Hi ~
Thank you for your excellent work. I've been working on this recently, but I can't exactly reproduce the results in the paper through your open-source code. The dataset partitioning in my experiment refers to split files, and then all parameters are set by default. I would like to know if the results in your paper are averaged over the last few epochs or if the best results are obtained. And whether the experimental parameter settings are exactly the same for PACS and VLCS datasets?
Looking forward to your reply!

can not reproduce the results

Hi, Thanks for your great work. I enjoyed reading your paper! And really thanks for sharing your code and model checkpoint.
When running main_stablenet.py, I experience no reproducibility of the results。
Example:

Dataset:PACS(unbalanced(5:1:1) | PACS | C,S,P | art_painting | ResNet18) ,and we also use the dataset split files you provided。
The best results which we reproduce:
image

Do you have a clue why that is the case?
Thanks in advance!

Shape of weight in RFF

Hi, according to the code, specifically Line 40 of loss_reweighting.py, the shape of weight w is (1,1). Does it mean each sample shares the same weight in RFF?
image
image

Some Questions about loss_reweighting.py

Hello, author, in the cov function of loss_reweighting.py. I think there is a little different with the definition in the paper:
Your code: cov = torch.matmul((w * x).t(), x)
Your paper:
T}USG}2ICE}E02SL 84(1O1
Why the code is not cov = torch.matmul((w * x).t(), w * x)
Thank you very much.

confused with code

cfeaturecs = random_fourier_features_gpu(cfeaturec, num_f=num_f, sum=sum).cuda()

I think this cfeaturecs has shape of [B, D, H_rff], where B is batch size, D is number of features, and H_rff is size of RFF space.
However, the later for loop

for i in range(cfeaturecs.size()[-1]):
        cfeaturec = cfeaturecs[:, :, i]
        cov1 = cov(cfeaturec, weight)

The cfeaturec now is [B, D]. The last dimension is omitted. This does not correspond to the original equation, where the partial corss-covariance is calculated in the RFF space.

args.n_feature

Hi,
May I ask should the args.n_feature be set equal to the batch size?

About the decay factor $\alpha_i$

Hi,
In the paper, it is said that the $\alpha_i$ is different when fusing global information $Z_{G_i}$ (in Eq.10).
However, I found that $\alpha_i$'s are actually the same (in the reweighting.py):

        pre_features = pre_features * args.presave_ratio + cfeatures * (1 - args.presave_ratio)
        pre_weight1 = pre_weight1 * args.presave_ratio + weight * (1 - args.presave_ratio)

Have I missed anything?
Tks.

set config.py

Hello, thank you for your great work. May I ask whether the parameter Settings given in config.py are the parameters used in the final training of the paper?

two

1.the A is current features, the B is pre_features?
2.loss = criterion(output, target).view(1, -1).mm(weight1).view(1),when loss.backward(), the weight1 will be back propagation?

Confusing about the calculation of accuracy

image
The author's approach to calculating accuracy seems to differ from what I believe. I think the classification should be based on the highest probability prediction, but the author seems to consider a prediction as correct if any of the top k predictions correspond to the target label. This approach may seem inappropriate.

Problem with reproduction

Hi~ Sorry to disturb the place, please forgive me
I would like to communicate with you about some problems encountered in the reproduction process.
The problem is:
FileNotFoundError: [WinError 3] 系统找不到指定的路径。: '/DATA/DATANAS1/windxrz/dataset/PACS/split_compositional_with_val_sketch\train'
image

I downloaded README's split file but couldn't find it.
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.