GithubHelp home page GithubHelp logo

apollohuang1 / transformer-explainability Goto Github PK

View Code? Open in Web Editor NEW

This project forked from hila-chefer/transformer-explainability

0.0 0.0 0.0 3.85 MB

[CVPR 2021] Official PyTorch implementation for Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks.

License: MIT License

Python 20.53% Jupyter Notebook 79.47%

transformer-explainability's Introduction

Faster, more general, and can be applied to any type of attention! Among the features:

  • We remove LRP for a simple and quick solution, and prove that the great results from our first paper still hold!
  • We expand our work to any type of Transformer- not just self-attention based encoders, but also co-attention encoders and encoder-decoders!
  • We show that VQA models can actually understand both image and text and make connections!
  • We use a DETR object detector and create segmentation masks from our explanations!
  • We provide a colab notebook with all the examples. You can very easily add images and questions of your own!


ViT explainability notebook:

Open In Colab

BERT explainability notebook:

Open In Colab

Updates

April 5 2021: Check out this new post about our paper! A great resource for understanding the main concepts behind our work.

March 15 2021: A Colab notebook for BERT for sentiment analysis added!

Feb 28 2021: Our paper was accepted to CVPR 2021!

Feb 17 2021: A Colab notebook with all examples added!

Jan 5 2021: A Jupyter notebook for DeiT added!

Introduction

Official implementation of Transformer Interpretability Beyond Attention Visualization.

We introduce a novel method which allows to visualize classifications made by a Transformer based model for both vision and NLP tasks. Our method also allows to visualize explanations per class.

Method consists of 3 phases:
  1. Calculating relevance for each attention matrix using our novel formulation of LRP.

  2. Backpropagation of gradients for each attention matrix w.r.t. the visualized class. Gradients are used to average attention heads.

  3. Layer aggregation with rollout.

Please notice our Jupyter notebook where you can run the two class specific examples from the paper.

alt text

To add another input image, simply add the image to the samples folder, and use the generate_visualization function for your selected class of interest (using the class_index={class_idx}), not specifying the index will visualize the top class.

Credits

ViT implementation is based on:

BERT implementation is taken from the huggingface Transformers library: https://huggingface.co/transformers/

ERASER benchmark code adapted from the ERASER GitHub implementation: https://github.com/jayded/eraserbenchmark

Text visualizations in supplementary were created using TAHV heatmap generator for text: https://github.com/jiesutd/Text-Attention-Heatmap-Visualization

Reproducing results on ViT

Section A. Segmentation Results

Example:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/imagenet_seg_eval.py --method transformer_attribution --imagenet-seg-path /path/to/gtsegs_ijcv.mat

Link to download dataset.

In the exmaple above we run a segmentation test with our method. Notice you can choose which method you wish to run using the --method argument. You must provide a path to imagenet segmentation data in --imagenet-seg-path.

Section B. Perturbation Results

Example:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/generate_visualizations.py --method transformer_attribution --imagenet-validation-path /path/to/imagenet_validation_directory

Notice that you can choose to visualize by target or top class by using the --vis-cls argument.

Now to run the perturbation test run the following command:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 baselines/ViT/pertubation_eval_from_hdf5.py --method transformer_attribution

Notice that you can use the --neg argument to run either positive or negative perturbation.

Reproducing results on BERT

  1. Download the pretrained weights:
  1. Download the dataset pkl file:
  1. Download the dataset:
  1. Now you can run the model.

Example:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies/ --model_params BERT_params/movies_bert.json

To control which algorithm to use for explanations change the method variable in BERT_rationale_benchmark/models/pipeline/bert_pipeline.py (Defaults to 'transformer_attribution' which is our method). Running this command will create a directory for the method in bert_models/movies/<method_name>.

In order to run f1 test with k, run the following command:

PYTHONPATH=./:$PYTHONPATH python3 BERT_rationale_benchmark/metrics.py --data_dir data/movies/ --split test --results bert_models/movies/<method_name>/identifier_results_k.json

Also, in the method directory there will be created .tex files containing the explanations extracted for each example. This corresponds to our visualizations in the supplementary.

Citing our paper

If you make use of our work, please cite our paper:

@InProceedings{Chefer_2021_CVPR,
    author    = {Chefer, Hila and Gur, Shir and Wolf, Lior},
    title     = {Transformer Interpretability Beyond Attention Visualization},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2021},
    pages     = {782-791}
}

transformer-explainability's People

Contributors

hila-chefer avatar shirgur avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.