GithubHelp home page GithubHelp logo

aum's Introduction

AUM

Pytorch Library for Area Under the Margin (AUM) Ranking, as proposed in the paper: Identifying Mislabeled Data using the Area Under the Margin Ranking

Install

pip install -U aum

Usage

Instantiate an AUMCalculator object:

from aum import AUMCalculator

save_dir = '~/Desktop'
aum_calculator = AUMCalculator(save_dir, compressed=True)

Note: you can set compressed to False if you want to store the AUM metrics at every call to the update method. This will require considerably more space, however.

You can then update aum rankings on batches of data during training with:

model.train()
for batch in loader:
    inputs, targets, sample_ids = batch

    logits = model(inputs)

    records = aum_calculator.update(logits, targets, sample_ids)

    ...

records is a dictionary mapping a sample_id to an AUMRecord containing the information below, including the AUM for the sample at this point in time.

@dataclass
class AUMRecord:
    """
    Class for holding info around an aum update for a single sample
    """
    sample_id: Optional[int, str]
    num_measurements: int
    target_logit: int
    target_val: float
    other_logit: int
    other_val: float
    margin: float
    aum: float

And once you are done training, you can generate a csv of ranked samples with their aum scores with:

aum_calculator.finalize()

If you have a dataset that does not return sample_ids, you can wrap it in DatasetWithIndex. The last element of the tuple returned for a given sample will be its sample_id.

from aum import DatasetWithIndex
from torch.utils.data import Dataset

my_dataset = Dataset(...)
my_dataset_with_index = DatasetWithIndex(my_dataset)

Example Outputs

Calling finalize() on an AUMCalculator will result in the creation of 1 or 2 csv files, depending if compressed was set to True or False.

If AUMCalculator was instantiated with compressed = True, you will find a csv file titled aum_values.csv in the following format:

sample_id aum
sample_1 1.205
sample_3 1.145
sample_2 -3.785

If AUMCalculator was instantiated with compressed = False, you will find a csv file titled full_aum_records.csv in addition to the aum_values.csv. full_aum_records.csv is in the following format:

sample_id num_measurements target_logit target_val other_logit other_val margin aum
sample_1 1 0 3.74 10 2.48 1.26 1.26
sample_1 2 0 4.59 10 3.44 1.15 1.205
sample_2 1 1 -0.09 0 3.11 -3.20 -3.02
sample_2 2 1 -1.12 0 3.25 -4.37 -3.785
sample_3 1 6 3.39 10 1.62 1.77 1.77
sample_3 2 6 2.63 2 2.11 0.52 1.145

Replicate results from the paper

To replicate results, please refer to the examples/paper_replication section.

Example usage

For a more basic example of using the AUMCalculator and DatasetWithIndex in a training script, please refer to the examples/cifar100 section.

Cite

@article{pleiss2020identifying,
  title={Identifying Mislabeled Data using the Area Under the Margin Ranking},
  author={Geoff Pleiss and Tianyi Zhang and Ethan R. Elenberg and Kilian Q. Weinberger},
  journal={arXiv preprint arXiv:2001.10528},
  year={2020}
}

aum's People

Contributors

gpleiss avatar jshapiro-asapp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

aum's Issues

ERROR: No matching distribution found for aum

Hi. I am trying to install aum on ubuntu, using pip install -U aum. However, I am getting this error:
ERROR: Could not find a version that satisfies the requirement aum (from versions: none) ERROR: No matching distribution found for aum

I have torch: 1.9.1, torchvision: 0.10.1, pandas: 1.1.5.
Do you have any idea why this is happening?

Typo: `threshold_samples_set_idx` is not a bool

In examples/paper_replication/runner.py:
:param bool threshold_samples_set_idx: (default 1) Which set of threshold samples to use (based on index)
However in examples/paper_replication/small_dataset_aum.sh, the value of threshold_samples_set_idx was provided as 2.

Milestones

Can the milestones generated for producing the results be made available?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.