GithubHelp home page GithubHelp logo

chq2018515 / ova-l2d Goto Github PK

View Code? Open in Web Editor NEW

This project forked from rajevv/ova-l2d

0.0 0.0 0.0 57.34 MB

Code for Calibrated Learning to Defer with One vs All Classifiers [ICML'22]

License: MIT License

Shell 0.08% Python 10.10% Jupyter Notebook 89.82%

ova-l2d's Introduction

Rajeev Verma, Eric Nalisnick. "Calibrated Learning to Defer with One-vs-All Classifiers." https://arxiv.org/abs/2202.03673

In this paper, we propose an alternate One-vs-All loss function parameterization for the Learning to Defer (L2D) problem. The proposed loss function is a consistent surrogate loss function for 0-1 misclassification L2D loss with improved calibration of confidence estimates with respect to the expert correctness.

Setup

First set up a conda environment. We provide the environment yml file: defer.yml. Next, create a directory ./Data and keep datasets in it.

Starter Guide

The implementation of the loss function(s) is available in ./losses. The main script to execute training is ./main.py. We provide the usage guide below. Note specifically the flag --loss_type which one can set to softmax or ova. Trained models will be saved in a (sub)directory with the name loss_type in the ckp_dir directory.

usage: main.py [-h] [--batch_size BATCH_SIZE] [--alpha ALPHA] [--epochs EPOCHS] [--patience PATIENCE] [--expert_type EXPERT_TYPE]
               [--n_classes N_CLASSES] [--k K] [--lr LR] [--weight_decay WEIGHT_DECAY] [--warmup_epochs WARMUP_EPOCHS]
               [--loss_type LOSS_TYPE] [--ckp_dir CKP_DIR] [--experiment_name EXPERIMENT_NAME]

optional arguments:
  -h, --help            show this help message and exit
  --batch_size BATCH_SIZE
  --alpha ALPHA         scaling parameter for the loss function, default=1.0.
  --epochs EPOCHS
  --patience PATIENCE   number of patience steps for early stopping the training.
  --expert_type EXPERT_TYPE
                        specify the expert type. For the type of experts available, see-> models -> experts. defualt=predict.
  --n_classes N_CLASSES
                        K for K class classification.
  --k K
  --lr LR               learning rate.
  --weight_decay WEIGHT_DECAY
  --warmup_epochs WARMUP_EPOCHS
  --loss_type LOSS_TYPE
                        surrogate loss type for learning to defer.
  --ckp_dir CKP_DIR     directory name to save the checkpoints.
  --experiment_name EXPERIMENT_NAME
                        specify the experiment name. Checkpoints will be saved with this name.

Citation

@inproceedings{Verma2022Calibrated,
  title = {Calibrated Learning to Defer with One-vs-All Classifiers},
  author = {Verma, Rajeev and Nalisnick, Eric},
  booktitle = {Proceedings of the 39th International Conference on Machine Learning (ICML)},
  year = {2022}
}

Acknowledgements

As with everything else, the code in this repo is built upon the excellent works of other researchers. We greatly acknowledge Hussein Mozannar and David Sontag. Their code for the paper Consistent Estimators for Learning to Defer (https://github.com/clinicalml/learn-to-defer) formed the basis of this repository. Additionally, we use code from Nastaran Okati et al. (https://github.com/Networks-Learning/differentiable-learning-under-triage) and Matthijs Hollemans (https://github.com/hollance/reliability-diagrams).

ova-l2d's People

Contributors

rajevv avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.