GithubHelp home page GithubHelp logo

Comments (11)

peihaowang avatar peihaowang commented on June 30, 2024

Dear Jiahao,

Thanks for your interest.

  1. Yes, I guess you can get the similar visualization of the spectrum using pre-trained DeiT-S if it is computed correctly.
  2. The Figure 5 plots the post-softmax attention matrix.

Best,
Peihao

from vit-anti-oversmoothing.

techmonsterwang avatar techmonsterwang commented on June 30, 2024

Hi, Peihao!

Thanks for the nice reply.

I have plot the post-softmax attention matrix using the open-source DeiT-S pretrained model and get the results. But I found that the results in layer 5 has been shown like this:

image

I also found that attention matrix after layer 2 shows a similar pattern, which are not similar with Figure.5.

Is it possible that I am wrong somewhere? Can you provide the full code for this part please? Thanks a lot.

from vit-anti-oversmoothing.

techmonsterwang avatar techmonsterwang commented on June 30, 2024

Besides, I plot the attention matrix through code as:

class Attention(nn.Module):
def init(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().init()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5

    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x, attn

Am I doing the same as you? Thanks for the reply~

from vit-anti-oversmoothing.

peihaowang avatar peihaowang commented on June 30, 2024

Hi Jiahao,

Thanks for the follow-up. I guess we are using the same code for computing attention. I do think the figure you showed above has the right shape as we expect - a peak around the DC component and a long tail across the high-frequency bands. However, I'm not very convinced that the frequencies other than DC have zero responses. This is not possible, otherwise the attention map will turn out to be a (normalized) all-one matrix. I would suggest you first visualize the attention map directly, and try with different batches of images.

Best,
Peihao

from vit-anti-oversmoothing.

techmonsterwang avatar techmonsterwang commented on June 30, 2024

Hi Peihao,

Thanks for the quick and nice reply!
I saw the paper say that "Below we provide a complete spectral visualization of attention maps computed from a random sample in ImageNet validation set", which mean that Figure.5 was drawn from only 1 sample.

Do you mean that figure 5 is obtained from the data of a batch validation data? If so, how big is the data in this batch?

from vit-anti-oversmoothing.

techmonsterwang avatar techmonsterwang commented on June 30, 2024

Hi Peihao,

Yes, as you say, the attention matrix after layer 2 indeed turns out to be a (normalized) all-one matrix like:

tensor([[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051],
[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051],
[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051],
...,
[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051],
[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051],
[0.0051, 0.0051, 0.0051, ..., 0.0051, 0.0051, 0.0051]],
grad_fn=)

I wonder why it be like this? since I directly use the pretrained open-sourced DeiT-S checkpoint.

from vit-anti-oversmoothing.

techmonsterwang avatar techmonsterwang commented on June 30, 2024

Hi Peihao,

Thanks for the quick and nice reply!
I saw the paper say that "Below we provide a complete spectral visualization of attention maps computed from a random sample in ImageNet validation set", which mean that Figure.5 was drawn from only 1 sample.

Do you mean that figure 5 is obtained from the data of a batch validation data? If so, how big is the data in this batch?

from vit-anti-oversmoothing.

peihaowang avatar peihaowang commented on June 30, 2024

Hi Jiahao,

Thanks for your continued interest and sorry for the late reply. Getting uniformly distributed attention maps does not make sense to me. Have you ever tested the test accuracy?

Fig. 5 is visualized for one sample (i.e., one image) not a batch. As far as I remember, this was not cherry picky. An arbitrary sample can produce the similar results.

Peihao

from vit-anti-oversmoothing.

gcxamy avatar gcxamy commented on June 30, 2024

Hello, I really love your paper. Both the results and visualization are incredible. I am very curious about how you visualize Figure 5, could I have your code about it?

from vit-anti-oversmoothing.

peihaowang avatar peihaowang commented on June 30, 2024

Hi, thanks for your interest. The code to visualize the spectrum of attention can be found here: #1 (comment). Hope you find this helpful!

from vit-anti-oversmoothing.

peihaowang avatar peihaowang commented on June 30, 2024

Feel free to reopen this thread if any further issues.

from vit-anti-oversmoothing.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.