GithubHelp home page GithubHelp logo

jacobgil / vit-explain Goto Github PK

View Code? Open in Web Editor NEW
801.0 6.0 93.0 4.21 MB

Explainability for Vision Transformers

License: MIT License

Python 100.00%
vision-transformer pytorch explainable-ai deep-learning transformer

vit-explain's Introduction

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path <image path> --head_fusion <mean, min or max> --discard_ratio <number between 0 and 1> --category_index <category_index>

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio <value between 0 and 1>

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion <mean, min or max>

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

vit-explain's People

Contributors

jacobgil avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

vit-explain's Issues

How do we know the index category?

I trained VIT from scratch on face datasets and doing face verification. I want to know that how do I know the correct index number? I have 200+ classes.

TypeError, IndexError, ValueError occured

Greeting, Thank you for your nice transformer visualization code

I am trying to visualize my model but stucked in some error

model_name = 'vit_base_patch16_224'
model = timm.create_model(model_name, pretrained=True, num_classes=3)

I trained my model and use your libaries
I have put vit_explain.py, vit_grad_rollout.py, vit_rollout.py files into the same directory of my main python file

and typed as below

from vit_grad_rollout import VITAttentionGradRollout
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor = inputs, category_index=243)

and I got the error below


TypeError Traceback (most recent call last)
in <cell line: 4>()
2 grad_rollout = VITAttentionRollout(model, discard_ratio = 0.8, head_fusion='max')
3 new_image = class_images[0].unsqueeze(0)
----> 4 mask = grad_rollout(input_tensor = new_image, category_index=3)

TypeError: VITAttentionRollout.call() got an unexpected keyword argument 'category_index'

I have checked VITAttentionRollout.call() but this is well defined and I didn`t changed anything...

def call(self, input_tensor, category_index):

I don`t know why this error happens

to solve the error, I have overided your VITAttentionRollout.call() like below
def call(self, input_tensor, category_index=3):

and changed mask code

mask = grad_rollout(input_tensor = inputs)

then, typeerror is solved but valueerror came out


ValueError Traceback (most recent call last)
in <cell line: 4>()
2 grad_rollout = VITAttentionRollout(model, discard_ratio = 0.8, head_fusion='max')
3 new_image = class_images[0]
----> 4 mask = grad_rollout(input_tensor = new_image)

5 frames
in call(self, input_tensor)
17 self.attentions = []
18 with torch.no_grad():
---> 19 output = self.model(input_tensor)
20
21 return rollout(self.attentions, self.discard_ratio, self.head_fusion)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/timm/models/vision_transformer.py in forward(self, x)
650
651 def forward(self, x):
--> 652 x = self.forward_features(x)
653 x = self.forward_head(x)
654 return x

/usr/local/lib/python3.10/dist-packages/timm/models/vision_transformer.py in forward_features(self, x)
631
632 def forward_features(self, x):
--> 633 x = self.patch_embed(x)
634 x = self._pos_embed(x)
635 x = self.patch_drop(x)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.10/dist-packages/timm/layers/patch_embed.py in forward(self, x)
67
68 def forward(self, x):
---> 69 B, C, H, W = x.shape
70 if self.img_size is not None:
71 if self.strict_img_size:

ValueError: not enough values to unpack (expected 4, got 3)

my inputs size is [3, 224, 224]

It seems some problem with tensor size, to solve, I have used unsqueeze(0) to inputs

inputs = inputs.unsqueeze(0)

and I got the error below


IndexError Traceback (most recent call last)
in <cell line: 4>()
2 grad_rollout = VITAttentionRollout(model, discard_ratio = 0.8, head_fusion='max')
3 new_image = class_images[0].unsqueeze(0)
----> 4 mask = grad_rollout(input_tensor = new_image)

1 frames
in call(self, input_tensor)
19 output = self.model(input_tensor)
20
---> 21 return rollout(self.attentions, self.discard_ratio, self.head_fusion)

in rollout(attentions, discard_ratio, head_fusion)
32 # Look at the total attention between the class token,
33 # and the image patches
---> 34 mask = result[0, 0 , 1 :] # [196,196]
35 # In case of 224x224 image, this brings us from 196 to 14
36 width = int(mask.size(-1)**0.5)

IndexError: too many indices for tensor of dimension 2

What whas the your inputs size?

Visualization Results on DeiT-Base

Hi, thanks for your work. When I use rollout on Deit-Base, it gets worse results and it did not foucus on the object. Have u ever try your method on Deit-Base?

Rollout for different ViT models

Hello,
I would like to use for different models and currently I used with

' elif args.model_name == 'dino_vit':
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').eval().to(device)'_

However, it gives error such as below. So the question is it is applicable to

  1. model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50').eval().to(device)

IndexError: list index out of range

  1. model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16').eval().to(device) '

(myenv) python vit_explain_foolbox.py --model_name dino_xcit --attack_name LinfPGD --use_cuda --head_fusion "min" --discard_ratio 0.9
Using cache found in C:\Users.cache\torch\hub\facebookresearch_dino_main
Using cache found in C:\Users.cache\torch\hub\facebookresearch_xcit_main
epsilons
[0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Doing Attention Rollout
Traceback (most recent call last):
File "vit_explain_foolbox.py", line 202, in
mask = attention_rollout(perturbed_data) ###############
File "\VisionTransformer\VisionTransformer\VisionTransformer\vit_rollout.py", line 68, in call
return rollout(self.attentions, self.discard_ratio, self.head_fusion)
File "\VisionTransformer\VisionTransformer\VisionTransformer\vit_rollout.py", line 33, in rollout
result = torch.matmul(a, result)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x64 and 197x197)

normalize (sum to 1) attention score seems not right

Hi Thanks for sharing nice work.

I noticed that you've done normalizing attention score (row sum to 1) as mentioned in the original attention rollout paper.

I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0*I)/2
a = a / a.sum(dim=-1)

But it seems when dividing the summation of row attention score, keepdim=True should be apply to ensure that sum of row attention score after normalization should be 1.

a = a / a.sum(dim=-1,keepdim=True)

Maybe I'm wrong, please double check this issue.
Thanks

No separate `head_fusion` strategies for Gradient Attention Rollout?

Should there be a condition to allow users to pass the head_fusion method in grad_rollout()?

Something like -

...
    for attention, grad in zip(attentions, gradients):                
            weights = grad
            if head_fusion == "mean":
                attention_heads_fused = (attention*weights).mean(axis=1)
            elif head_fusion == "max":
                attention_heads_fused = (attention*weights).max(axis=1)[0]
            elif head_fusion == "min":
                 attention_heads_fused = (attention*weights).min(axis=1)[0]
            else:
                raise "Attention head fusion type Not supported"
            
            attention_heads_fused[attention_heads_fused < 0] = 0
...

"unexpected keyword argument 'head_fusion'" and "required positional argument: 'category_index'"

Hi,
I tried to run the example from the README.md and I run into the following issues:

  • grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max') causes TypeError: __init__() got an unexpected keyword argument 'head_fusion'.
  • mask = grad_rollout(input_tensor) causes TypeError: __call__() missing 1 required positional argument: 'category_index' although the README suggests it is optional (If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.). Is this a mistake or should I pass None for example?

Thanks a lot and best regards
Verena

Error in tensor size mismatch

Hi @jacobgil,
I am using this project for my swin transformers but it is giving error showing

     35             # print("a : ",a)
     36             # print(a.size())
---> 37             a = a / a.sum(dim=-1)
     38 
     39             result = torch.matmul(a, result)

RuntimeError: The size of tensor a (49) must match the size of tensor b (64) at non-singleton dimension 1 ```

Please a give a solution for the same

Shape mismatch for timm vit model

while applying the attention rollout on a finetuned timm vit model (base_patch32_224) I'm am getting the following error with input tensor of shape: torch.Size([1, 3, 224, 224]):

RuntimeError Traceback (most recent call last)
in ()
----> 1 mask_1 = attention_rollout(test_image_1_tensor)

8 frames
in reshape_transform(tensor, height, width)
1 def reshape_transform(tensor, height=7, width=7):
2 result = tensor[:, 1 : , :].reshape(tensor.size(0),
----> 3 height, width, tensor.size(2))
4
5 # Bring the channels to the first dimension,

RuntimeError: shape '[1, 7, 7, 7]' is invalid for input of size 37583

Kindly advice on how to properly apply on the model as I'm facing the same issue for FullGrad in [pytorch-grad-cam] on the same model.

cv2 is bgr whereas matplotlib is rgb

There's a tiny mistake in the function show_mask_on_image. The heatmap you get from cv2 is in bgr format, so you need to convert it to rgb before adding to the img:

def show_mask_on_image(img, mask):
    ...
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    cam = heatmap + np.float32(img)
    ...

What if the visual transformer does not have a class token?

I see the code in the VITAttentionGradRollout code requires a class token. What if the model architecture does not have a class token?

If for example my attention layer is 196x196 (corresponding to 14x14 spatial resolution), can one take the mean of all other patches w.r.t. to each patch, as follows: mask = result[0].mean(0)? I've tried this and I didn't get very meaningful results. Is there another way to deal with transformers without class tokens?

visualize timesformer

hello, can you show one example of using your code for TimeSformer visualization?

Code for Google's ViT and complete example

Hi @jacobgil!

Thank you for this amazing piece of work. I was wondering if you plan to open-source the code to try out your experiments on Google's ViT (An Image is Worth ...) as well. If it's already there inside the repo, could you point me to it?

Update: I was able to use timm and make use of the ViT model it comes with:

timm_vit_model = timm.create_model('vit_large_patch16_384', pretrained=True)
timm_vit_model.eval()
roller = VITAttentionGradRollout(timm_vit_model, discard_ratio=0.9)
mask = roller(x.unsqueeze(0), label_idx)

However, I am still a bit unsure as to how to actually visualize the mask. Could you help?

Cannot use block.attn.fused_attn = False in another ViT model

I am trying to run the code for another ViT model, and more specifically:

    #  Get pretrained weights for ViT-Base
    retrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT # requires torchvision >= 0.13, "DEFAULT" means best available
    pretrained_vit = torchvision.models.vit_b_16(weights=retrained_vit_weights).to(device)
    
    #pretrained_vit_1 = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True).to(device)

In this model I have noticed that I cannot use the following code:

for block in model.blocks:
            block.attn.fused_attn = False
            

Since the model does not have the same structure as deit_tiny_patch16_224 one. I am also sure how to do the same fused_attn in this mode. Can you explain a bit what this code is about and why it procures different results when I comment it out?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.