GithubHelp home page GithubHelp logo

diegovalenzuelaiturra / temperature_scaling Goto Github PK

View Code? Open in Web Editor NEW

This project forked from gpleiss/temperature_scaling

1.0 1.0 0.0 19 KB

A simple way to calibrate your neural network.

License: MIT License

Python 100.00%

temperature_scaling's Introduction

Temperature Scaling

A simple way to calibrate your neural network. The temperature_scaling.py module can be easily used to calibrated any trained model.

Based on results from On Calibration of Modern Neural Networks.

Motivation

TLDR: Neural networks tend to output overconfident probabilities. Temperature scaling is a post-processing method that fixes it.

Long:

Neural networks output "confidence" scores along with predictions in classification. Ideally, these confidence scores should match the true correctness likelihood. For example, if we assign 80% confidence to 100 predictions, then we'd expect that 80% of the predictions are actually correct. If this is the case, we say the network is calibrated.

A simple way to visualize calibration is plotting accuracy as a function of confidence. Since confidence should reflect accuracy, we'd like for the plot to be an identity function. If accuracy falls below the main diagonal, then our network is overconfident. This happens to be the case for most neural networks, such as this ResNet trained on CIFAR100.

Uncalibrated ResNet

Temperature scaling is a post-processing technique to make neural networks calibrated. After temperature scaling, you can trust the probabilities output by a neural network:

Calibrated ResNet

Temperature scaling divides the logits (inputs to the softmax function) by a learned scalar parameter. I.e.

softmax = e^(z/T) / sum_i e^(z_i/T)

where z is the logit, and T is the learned parameter. We learn this parameter on a validation set, where T is chosen to minimize NLL.

Demo

First train a DenseNet on CIFAR100, and save the validation indices:

python train.py --data <path_to_data> --save <save_folder_dest>

Then temperature scale it

python demo.py --data <path_to_data> --save <save_folder_dest>

To use in a project

Copy the file temperature_scaling.py to your repo. Train a model, and save the validation set. (You must use the same validation set for training as for temperature scaling). You can do something like this:

from temperature_scaling import ModelWithTemperature

orig_model = ... # create an uncalibrated model somehow
valid_loader = ... # Create a DataLoader from the SAME VALIDATION SET used to train orig_model

scaled_model = ModelWithTemperature(orig_model)
scaled_model.set_temperature(valid_loader)

temperature_scaling's People

Contributors

gpleiss avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.