GithubHelp home page GithubHelp logo

ozzie00 / deterministic-uncertainty-quantification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from y0ast/deterministic-uncertainty-quantification

0.0 2.0 0.0 177 KB

Code for "Uncertainty Estimation Using a Single Deep Deterministic Neural Network"

Home Page: https://arxiv.org/abs/2003.02037

License: MIT License

Python 71.70% Jupyter Notebook 28.30%

deterministic-uncertainty-quantification's Introduction

Deterministic Uncertainty Quantification (DUQ)

This repo contains the code for Uncertainty Estimation Using a Single Deep Deterministic Neural Network, which is accepted for publication at ICML 2020.

If the code or the paper has been useful in your research, please add a citation to our work:

@article{van2020uncertainty,
  title={Uncertainty Estimation Using a Single Deep Deterministic Neural Network},
  author={van Amersfoort, Joost and Smith, Lewis and Teh, Yee Whye and Gal, Yarin},
  booktitle={International Conference on Machine Learning},
  year={2020}
}

Dependencies

The code is based on PyTorch and requires a few further dependencies, listed in environment.yml. The code was tested with the versions specified in the environment file, but should work with newer versions as well (except for ignite=0.4.3). If you find an incompatibility, please let me know and I'll gladly update the code for the newest version of each library.

Datasets

Most datasets will be downloaded on the fly by Torchvision. Only NotMNIST needs to be downloaded in advance in a subfolder called data/:

mkdir -p data && cd data && curl -O "http://yaroslavvb.com/upload/notMNIST/notMNIST_small.mat"

FastFashionMNIST is based on this script. The default Torchvision implementation first creates a PIL image (see here) which creates a CPU bottleneck (while training on GPU). The FastFashionMNIST class provides a significant speed up.

Running

The Two Moons experiments can be replicated using the Two Moons notebook. The FashionMNIST experiment is implemented in train_duq_fm.py. For both experiments, the paper's default are hardcoded and can be changed in place.

The ResNet18 based CIFAR experiments are implemented in train_duq_cifar.py. All command line parameter defaults are as listed in the experimental details in Appendix A of the paper. We additionally include a Wide ResNet based architecture.

For example: CIFAR-10 with gradient penalty with weight 0.5 and full training set:

python train_duq_cifar.py --final_model --l_gradient_penalty 0.5

Note that ommitting --final_model will lead to 20% of the training data to be used for validation, such that hyper parameter selection can be done in a responsible manner. The code also supports the Wide ResNet with --architecture WRN.

I also include code for my implementation of Deep Ensembles. It's a very simple implementation that achieves good results (95% accuracy in 75 epochs using 5 models).

python train_deep_ensemble.py --dataset CIFAR10

This command will train a Deep Ensemble consisting of 5 models (the default) on CIFAR10.

Questions

For questions about the code or the paper, feel free to open an issue or email me directly. My email can be found on my GitHub profile, my website and the paper above.

Deep Ensembles vs DUQ

deterministic-uncertainty-quantification's People

Contributors

y0ast avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.