GithubHelp home page GithubHelp logo

yongzx / concept-erasure Goto Github PK

View Code? Open in Web Editor NEW

This project forked from eleutherai/concept-erasure

0.0 0.0 0.0 110 KB

Erasing concepts from neural representations with provable guarantees

License: MIT License

Python 100.00%

concept-erasure's Introduction

Least-Squares Concept Erasure (LEACE)

Concept erasure aims to remove specified features from a representation. It can be used to improve fairness (e.g. preventing a classifier from using gender or race) and interpretability (e.g. removing a concept to observe changes in model behavior). This is the repo for LEAst-squares Concept Erasure (LEACE), a closed-form method which provably prevents all linear classifiers from detecting a concept while inflicting the least possible damage to the representation. You can check out the paper here.

Installation

We require Python 3.10 or later. You can install the package from PyPI:

pip install concept-erasure

Usage

ConceptEraser is the central class in this repo. It keeps track of the covariance and cross-covariance statistics needed to erase a concept, and lazily computes the LEACE parameters when needed.

Batch usage

In most cases, you probably have a batch of feature vectors X and concept labels Z and want to erase the concept from X. The easiest way to do this is using ConceptEraser.fit() followed by ConceptEraser.forward():

import torch
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression

from concept_erasure import ConceptEraser

n, d, k = 2048, 128, 2

X, Y = make_classification(
    n_samples=n,
    n_features=d,
    n_classes=k,
    random_state=42,
)
X_t = torch.from_numpy(X)
Y_t = torch.from_numpy(Y)

# Logistic regression does learn something before concept erasure
real_lr = LogisticRegression(max_iter=1000).fit(X, Y)
beta = torch.from_numpy(real_lr.coef_)
assert beta.norm(p=torch.inf) > 0.1

eraser = ConceptEraser.fit(X_t, Y_t)
X_ = eraser(X_t)

# But learns nothing after
null_lr = LogisticRegression(max_iter=1000, tol=0.0).fit(X_.numpy(), Y)
beta = torch.from_numpy(null_lr.coef_)
assert beta.norm(p=torch.inf) < 1e-4

Streaming usage

If you have a stream of data, you can use ConceptEraser.update() to update the statistics and ConceptEraser.forward() to erase the concept. This is useful if you have a large dataset and want to avoid storing it all in memory.

from concept_erasure import ConceptEraser
from sklearn.datasets import make_classification
import torch

n, d, k = 2048, 128, 2

X, Y = make_classification(
    n_samples=n,
    n_features=d,
    n_classes=k,
    random_state=42,
)
X_t = torch.from_numpy(X)
Y_t = torch.from_numpy(Y)

eraser = ConceptEraser(d, 1, dtype=X_t.dtype)

# Compute cross-covariance matrix using batched updates
for x, y in zip(X_t.chunk(2), Y_t.chunk(2)):
    eraser.update(x, y)

# Erase the concept from the data
x_ = eraser(X_t[0])

Paper replication

Scripts used to generate the part-of-speech tags for the concept scrubbing experiments can be found in this repo. We plan to upload the tagged datasets to the HuggingFace Hub shortly.

Concept scrubbing

The concept scrubbing code is a bit messy right now, and will probably be refactored soon. We found it necessary to write bespoke implementations for different HuggingFace model families. So far we've implemented LLaMA and GPT-NeoX. These can be found in the concept_erasure.scrubbing submodule.

concept-erasure's People

Contributors

norabelrose avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.