GithubHelp home page GithubHelp logo

fish's Introduction

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

fish's People

Contributors

yugeten avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

fish's Issues

Question about model selection for the evaluation in FMoW-wilds

Hi Fish authors,

Thanks for the nice work! When I was trying to reproduce the results of Fish in FMoW-wilds, I found the function save_best_model (code) always saves the last epoch model. It seems Fish would use the last epoch model selection instead of selecting the models according to the validation performance, which appears to be different from the wilds evaluation protocol (I also failed to find a corresponding description in the paper).

Maybe I missed something, could you help look into this problem? Thank you very much ๐Ÿ˜„.

Best, Andrew

Question About proof for theorem 3.1

Hi, I am a reader of your paper~ And I'm so intrigued by your paper, it inspired me a lot. But I have an question about how to comprehend theorem 3.1.

image

I think this is an excellent proof, but I have an question about how to comprehend this:
if alpha(inner loop learning rate) is 0 instead of approaching 0, this limit will not be 1, it will be 0.
So when alpha approaches 0, this limit is 1. But when alpha is exactly 0, this limit is 0. How to comprehend this kind of "discontinuity"? Which step in this proof cause this "discontinuity?"

The results on the OfficeHome dataset.

Hi, I tried running your code on domainbed and got only 55% average accuracy on the officehome dataset. You reported 58% in your paper, I was wondering if you tested the performance of the open source code on this dataset. Is there something wrong with this code?

Error in training on cdsprites with ERM

There is a filter of the indices that selects only latents[:, 0] != 3.

fish/src/models/datasets.py

Lines 266 to 271 in 333efa2

if split=='val':
self.latents = self.latents[domain_indices[-2]]
self.images = self.images[domain_indices[-2]]
elif split=='test':
self.latents = self.latents[domain_indices[-1]]
self.images = self.images[domain_indices[-1]]

In the code above, this condition is applied on val and test, but not on train. This results in an error in training with ERM. Could you provide the correct way to perform ERM on cdsprites?

Thanks.

DomainBed results with "oracle" model selection

Congratulations for this really interesting work.
I was wondering whether you could provide the DomainBed results, but with the best hyper-parameter chosen on a validation dataset from the test domain (i.e the oracle model selection).
That would be of great help to include your paper as a new comparable approach in upcoming papers.
Best regards
Alexandre

Domain Sampling Issue during Reproducing Results in Civil

Hi fish authors,

Thanks for the nice work!

However, when I was trying to reproduce results in civilcomments using the current code, it seems there was a bug regarding sample_domains in train_fish. Specifically, a RuntimeError was encountered in get_batch and it's said stack expects a non-empty TensorList. After looking into the sampled domains, it seems some domains with 1 batch left were sampled while there was no batch_index actually.

Could you help look into this issue? Thank you very much.

Best, Andrew

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.