GithubHelp home page GithubHelp logo

reiinakano / invariant-risk-minimization Goto Github PK

View Code? Open in Web Editor NEW
82.0 4.0 10.0 383 KB

Implementation of Invariant Risk Minimization https://arxiv.org/abs/1907.02893

License: MIT License

Python 9.83% Jupyter Notebook 90.17%
machine-learning deep-learning neural-network invariant-risk-minimization causality

invariant-risk-minimization's Introduction

Implementation of Invariant Risk Minimization (https://arxiv.org/abs/1907.02893)

This is an attempt to reproduce the "Colored MNIST" experiments from the paper Invariant Risk Minimization by Arjovsky, et. al.

After trying lots of hyperparameters and various tricks, this implementation achieves close to the paper-reported values (train accuracy > 70%, test accuracy > 60%), though training can be quite unstable depending on the random seed.

The most common failure case is when the gradient norm penalty term is weighted too highly relative to the ERM term. In this case, Φ converges to a function that returns the same value for all inputs. The classifier cannot recover from this point and the accuracy is stuck at 50% for all environments. This makes sense mathematically. If the intermediate representation is the same regardless of input, then any classifier is the ideal classifier, resulting in the penalty gradient being 0.

Another failure case is when the gradient norm penalty is too low and the optimization essentially acts as in ERM (train accuracy > 80%, test accuracy ~10%).

The most important trick I used to get this to work is through scheduled increase of the gradient norm penalty weight. We start at 0 for the gradient norm penalty weight, essentially beginning as ERM, then slowly increase it per epoch.

I use early stopping to stop training once the accuracy on all environments, including the test set, reach an acceptable value. Yes, stopping training based on performance on the test set is not good practice, but I could not find a principled way of stopping training by only observing performance on the training environments. One thing that might be needed when applying IRM to real-world datasets is to leave out a separate environment as a validation set, which we can use for early stopping. The downside is we'll need a minimum of 4 environments to perform IRM (2 train, 1 validation, 1 test).

Feel free to leave an issue if you find a bug or a set of hyperparameters that makes this training stable. Otherwise, let's all just wait for the authors' code, which they say will be available soon. The authors' original code is here: https://github.com/facebookresearch/InvariantRiskMinimization, and apparently posted two months before I started this. For some reason, I wasn't able to find this when I searched the first time. Looks like instead of a gradual increase of the gradient norm penalty, what they do is start at 0 for a few iterations then jump straight up to the higher value for the rest of training. I think the important thing is to make sure the training effectively starts as ERM (0 penalty) before adding in the IRM penalty term.

How to run

You can run the provided notebook in Colaboratory.

Alternatively, you can run main.py locally. There is also an implementation of ERM in main.py if you want to run a baseline. Code depends on Pytorch.

invariant-risk-minimization's People

Contributors

reiinakano avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

invariant-risk-minimization's Issues

Question about the label of coloredMinist

Hi reiinakano, this is a fantastic work. I am trying to reproduce IRM with your codes. But there is a question about the labels of coloredMinist. Why do you assign a binary label y to the image based on the digit (y = 0 if digit < 5), and transform the problem to a binary classification problem? It seems the implemented codes cannot classify multiple classification problem, that is, the codes cannot recognize the specific digits of the images. Do I misunderstand the codes?
I am looking forward to your reply.

Why full-batch GD and not SGD?

Hi I was wondering why we have to do full-batch GD and not SGD?
I've seen the same implementation in the paper and original repo. Any ideas?
If we turn the implementation here into minibatch-SGD autograd complains throwing errors. Any insights?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.