GithubHelp home page GithubHelp logo

janus-shiau / ood_confidence_tensorflow Goto Github PK

View Code? Open in Web Editor NEW
8.0 4.0 0.0 227 KB

Here is an unofficial implementation of "Learning Confidence for Out-of-Distribution Detection in Neural Networks"

Python 100.00%

ood_confidence_tensorflow's Introduction

Confidence of Out-of-Distribution for Tensorflow

Here is an unofficial implementation and experiement of estimating the confidence for prediction form neural networks from "Learning Confidence for Out-of-Distribution Detection in Neural Networks" for Tensorflow.

Overview of the work

Results of the work

How to measure the confidence of the network is an interesting topics, and raising more and more attention recently.
For example, the workshop, Uncertainty and Robustness in Deep Learning - ICML 2019.

In this paper, the author propose confidence branch with estimated confidence to perform out-of-distribution detection.
On the other hand, this confidence can also be use to estimate the uncertainty of the model.

To be brief, the idea of this paper is very intuitive and easy to apply.

  • Simple and intuitive.
  • Quick proof of concept!
  • Very few addictional computations required.

However, in my opinions, there are still some issues should be discussed and explored.
The author have proposed some tricks to tackle with these issues. I still work on reproduce them.

  • Hyper-parameter is somehow too sensitive, even using the budget proposed in the paper.
  • Incompatible with high accuracy model, due to insufficient negative samples.

All in all, if you are finding some works to estimate the uncertainty of the model.
This work still worth a try because it won't take too much time to test it.

Environment

This code is implemmented and tested with Tensorflow 1.13.0.
I didn't use any spetial operator, so it should also work for other version of tensorflow.

In a Nutshell

Just run a simple example of classification problem on MNIST dataset.

python run_example.py

The code run a very simple fully-connected model with a relatively small brach network to estimate confidence.
You can also turn off the confidence branch by setting WITH_CONF to False in run_example.py.

To check out how this work in Tensorboard:

tensorboard --logdir=logs

Usage

  1. Set up the parameters of confidence branch.
import ood._conf.conf_net import base_conf

conf_admin = base_conf.BaseConfNet(lambda1=1.0)
  1. During building your own network, branch out in the penultimate layer.
confidence = conf_admin.forward(hidden)
  1. During training, hint you output by the value of confidence.
conf_admin.hinting(outputs, labels)
  1. Add the confidence penalty to your final losses.
losses += conf_admin.calculate_loss_conf()

Implementation Details

The outline of the approach is illustrated in the example. Outline of the approach

Small and simple branch network

The sub-net is brached out from the penultimate layer as shown in the figure.
The author use very light-weighted sub-net with small fully-connected layer, leading an ignorable addictional computation.
However, other structures of networks are also work.

The interpolation for hinting

The paper using a simple linear interpolation function to give hints to model.
In my opinions, different problem should using different interpolation function to generate a smooth interpolated reuslts.

The weight of confidence penalty

The weight of the confidence loss (penalty) is very critical and sensitive. By setting large penalty, the model will avoid to output low confidence predictions.

Half batch hinting

As porposed in the paper, only apply hinting on half of the batch.
Active this feature by setting half_batch to True, and use specific batch size (not None).

Budget for auto-tunning

The author propose budget between [0.1, 1] for auto tunning the weight of confidence penalty.
If confidence loss is greater than budger, decrease the weight; otherwise, increase the weight.

The code of using budget will be updated soon.

Experimental results

To be brief, the results of MNIST look similar to the paper.
Results of moderate acuurate model

However, building a high accuracy model disintegrates the performance of OOD a little bit. This is caused by the insufficient negative samples. I'm still working on this issue.
Results of highly acuurate model

More experiemntal results on MNIST dataset will be updated soon.

Although the confidence classification problem can also be abtained from the value of softmax, I found this approach do work in a more elegent manner witch is quite simple .

BTW, I also conducted experiments on other task, such as 3D hand tracking.

TODO

  1. Update the budget for auto-tinning the weight of confidence penalty.
  2. Detailed analysis on MNIST.
  3. The issue of insufficient begative samples.

please let me know if you have any suggestion, I'll be very grateful.

Contact & Copy Right

Code work by Jia-Yau Shiau [email protected].

ood_confidence_tensorflow's People

Contributors

janus-shiau avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ood_confidence_tensorflow's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.