GithubHelp home page GithubHelp logo

lucadinidue / alti Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mt-upc/transformer-contributions

0.0 0.0 0.0 43.38 MB

Measuring the Mixing of Contextual Information in the Transformer

License: Apache License 2.0

Shell 0.13% Python 4.49% TeX 0.04% Jupyter Notebook 95.34%

alti's Introduction

Measuring the Mixing of Contextual Information in the Transformer

Abstract

The Transformer architecture aggregates input information through the self-attention mechanism, but there is no clear understanding of how this information is mixed across the entire model. Additionally, recent works have demonstrated that attention weights alone are not enough to describe the flow of information. In this paper, we consider the whole attention block --multi-head attention, residual connection, and layer normalization-- and define a metric to measure token-to-token interactions within each layer. Then, we aggregate layer-wise interpretations to provide input attribution scores for model predictions. Experimentally, we show that our method, ALTI (Aggregation of Layer-wise Token-to-token Interactions), provides more faithful explanations and increased robustness than gradient-based methods.



Environment Setup

Clone this repostitory:

!git clone https://github.com/javiferran/transformer_contributions.git

Create a conda environment using the environment.yml file, and activate it:

conda env create -f ./environment.yml && \
conda activate alti

Text Classification

To analyze model predictions with the proposed (and others) intepretability methods in SST2 dataset:

Text_classification.ipynb

Usage with Transformers

In our paper we use BERT, DistilBERT and RoBERTa models from Huggingface's transformers library.

It can be easily extended to other encoder-based Transformers, just add the necessary layers' names in the configuration file: ./src/config.yaml.

We compare our method with:

We use Captum implementation of gradient-based methods.

The attributions obtained by each method for the specified model and dataset run:

python ./attributions.py \
  -model $model_name \         # model: bert/distilbert/roberta
  -dataset $dataset \       # dataset to use: sst2/IMDB/Yelp/sva
  -samples $num_samples \       # number of samples

The results in Table 2 can be obtained by running the following command:

python aupc.py  --model $model_name \ # model: bert/distilbert/roberta
                --dataset $dataset \ # dataset to use: sst2/IMDB/Yelp/sva
                --samples $num_samples \ # number of samples
                --fidelity-type $faith_metric \ # fidelity metric: comp/suff
                --bins \ # use bins (1%,5%,10%,20%,50%) or one-by-one token deletion

Please use $MODELS_DIR to store fine-tuned models.

Citation

If you use ALTI in your work, please consider citing:

@misc{alti,
  title = {Measuring the Mixing of Contextual Information in the Transformer},
  author = {Ferrando, Javier and Gállego, Gerard I. and Costa-jussà, Marta R.},
  doi = {10.48550/ARXIV.2203.04212},
  url = {https://arxiv.org/abs/2203.04212},
  keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
  publisher = {arXiv},
  year = {2022},
  copyright = {arXiv.org perpetual, non-exclusive license}
}

alti's People

Contributors

javiferran avatar gegallego avatar lucadinidue avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.