GithubHelp home page GithubHelp logo

reds-lab / lava Goto Github PK

View Code? Open in Web Editor NEW
38.0 0.0 7.0 274.42 MB

This is an official repository for "LAVA: Data Valuation without Pre-Specified Learning Algorithms" (ICLR2023).

Home Page: https://openreview.net/forum?id=JJuP86nBl4q

License: MIT License

Jupyter Notebook 31.47% Python 68.53%
data-valuation optimal-transport efficient model-agnostic ot

lava's Introduction

LAVA: Data Valuation without Pre-Specified Learning Algorithms

Python 3.8.10

This repository is the official implementation of the "LAVA: Data Valuation without Pre-Specified Learning Algorithms" (ICLR 2023).

We propose LAVA: a novel model-agnostic framework to data valuation using a non-conventional, class-wise Wasserstein discrepancy. We further introduce an efficient way to measure datapoint contribution at no cost from the optimization solution.

Limitations of traditional data valuation methods

Traditional data valuation methods assume knowledge of the underlying learning algorithm.

  • Learning algorithm is unknown prior to valuation
  • Stochastic training process => Unstable values
  • Model training => Computational burden

Data Valuation via Optimal Transport

We propose data valuation via optimal transport to replace the current data valuation frameworks which rely on the underlying learning algorithm.

LAVA_OT_Valuation

Strong analytical properties of OT:

  • well-defined distance metric
  • computationally tractable
  • computable from finite samples

LAVA: Individual Datapoint Valuation

To compute individual datapoint valuation, we propose the notion calibrated gradient, which measures sensitivity of the data point to the dataset distance by shifting the probability mass of the datapoint in the dual OT formulation.

$$Value(z_i) = \frac{\partial\text{OT}(\mu_t,\mu_v)}{\partial\mu_t(z_i)} = f_{i}^* -\sum_{j\in{1, ... N} \setminus i} \frac{f^*_j}{N-1}$$

  • Exactly the gradient of the dual formulation
  • Obtained for free when solving original OT problem

Applications

LAVA can be applied to numerous data quality applications:

  • Mislabeled Data
  • Noisy Features
  • Dataset Redundancy
  • Dataset Bias
  • Irrelevant Data
  • and more.

Requirements

Install a virtual environment (conda).

conda env create -f environment.yaml python=3.8

Getting Started

Load data package.

import lava

Create a corrupted dataset and the index list of corrupted data points or create your own.

loaders, shuffle_ind = lava.load_data_corrupted(corrupt_type='shuffle', dataname='CIFAR10', resize=resize, training_size=training_size, test_size=valid_size, currupt_por=portion)

Load a feature embedder.

feature_extractor = lava.load_pretrained_feature_extractor('cifar10_embedder_preact_resnet18.pth', device)

Compute the Dual Solution of the Optimal Transport problem.

dual_sol, trained_with_flag = lava.compute_dual(feature_extractor, loaders['train'], loaders['test'], training_size, shuffle_ind, resize=resize)

Compute the Data Values with LAVA + visualization.

calibrated_gradient = lava.compute_values_and_visualize(dual_sol, trained_with_flag, training_size, portion)

Examples

For better understanding of applying LAVA to data valuation, we have provided examples on CIFAR-10 and STL-10.

Checkpoints

The pretrained embedders are included in the folder 'checkpoint'.

Optimal Transport Solver

This repo relies on the OTDD implementation to compute the class-wise Wasserstein distance.
We are immensely grateful to the authors of that project.

lava's People

Contributors

redsgnaoh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lava's Issues

Logical bug in the code

Hi folks, thanks for your excellent work!
I have used the repo and I found a very key bug in the following position.
Actually, in your model architecture, there is no layer called ".fc", therefore, directly replace to identity layer will not transfer the network architecture to embedder (actually not change the network architecture at all, just add another identity layer in the network class not to use). It is better to fix the bug in the repo.
https://github.com/ruoxi-jia-group/LAVA/blob/84e8bd55e62239625d8728240ec041d2fca5b2f4/lava.py#L88

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.