GithubHelp home page GithubHelp logo

xxxnell / how-do-vits-work Goto Github PK

View Code? Open in Web Editor NEW
802.0 7.0 78.0 18.71 MB

(ICLR 2022 Spotlight) Official PyTorch implementation of "How Do Vision Transformers Work?"

Home Page: https://arxiv.org/abs/2202.06709

License: Apache License 2.0

Jupyter Notebook 19.06% Python 80.94%
vision-transformer transformer self-attention loss-landscape pytorch

how-do-vits-work's Introduction

How Do Vision Transformers Work?

[arxiv, poster, slides]

This repository provides a PyTorch implementation of "How Do Vision Transformers Work? (ICLR 2022 Spotlight)" In the paper, we show that the success of multi-head self-attentions (MSAs) for computer vision does NOT lie in their weak inductive bias and the capturing of long-range dependencies. MSAs are not merely generalized Convs, but rather generalized spatial smoothings that complement Convs. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

Q1. What properties of MSAs do we need to better optimize NNs?

A1. MSAs have their pros and cons. MSAs improve NNs by flattening the loss landscapes. A key feature is their data specificity (data dependency), not long-range dependency. On the other hand, ViTs suffers from non-convex losses.

Q2. Do MSAs act like Convs?

A2. MSAs and Convs exhibit opposite behaviorsโ€”e.g., MSAs are low-pass filters, but Convs are high-pass filters. It suggests that MSAs are shape-biased, whereas Convs are texture-biased. Therefore, MSAs and Convs are complementary.

Q3. How can we harmonize MSAs with Convs?

A3. MSAs at the end of a stage (not a model) significantly improve the accuracy. Based on this, we introduce AlterNet by replacing Convs at the end of a stage with MSAs. AlterNet outperforms CNNs not only in large data regimes but also in small data regimes.

๐Ÿ‘‡ Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes (reducing the magnitude of Hessian eigenvalues). Such improvement is primarily attributable to their data specificity, NOT long-range dependency ๐Ÿ˜ฑ On the other hand, ViTs suffers from non-convex losses (negative Hessian eigenvalues). Their weak inductive bias and long-range dependency produce negative Hessian eigenvalues in small data regimes, and these non-convex points disrupt NN training. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. Therefore, MSAs and Convs are complementary. For example, MSAs are low-pass filters, but Convs are high-pass filters. Likewise, Convs are vulnerable to high-frequency noise but that MSAs are vulnerable to low-frequency noise: it suggests that MSAs are shape-biased, whereas Convs are texture-biased. In addition, Convs transform feature maps and MSAs aggregate transformed feature map predictions. Thus, it is effective to place MSAs after Convs.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage (not the end of a model) play a key role in prediction. Considering these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of the canonical Transformer, which has one MLP block for one MSA block.

Based on these design rules, we introduce AlterNet (code) by replacing Conv blocks at the end of a stage with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes, e.g., CIFAR. This contrasts with canonical ViTs, models that perform poorly on small amounts of data. For more details, see below ("How to Apply MSA to Your Own Model" section).

But why do Vision Transformers work that way? Our recent paper, "Blurs Behaves Like Ensembles: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness (ICML 2022)" (code and summary :octocat:, poster), shows that even a simple (non-trainable) 2 โœ• 2 box blur filter has the same properties. Spatial smoothings improve accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs and flattening loss landscapes, and self-attentions can be deemed as trainable importance-weighted ensembles of feature maps. In conclusion, MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb (Colab notebook) for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, โ†‘) and Acc for 90% certain results (Acc-90, โ†‘), negative log-likelihood (NLL, โ†“), Expected Calibration Error (ECE, โ†“), Intersection-over-Union (IoU, โ†‘) and IoU for certain results (IoU-90, โ†‘), Unconfidence (Unc-90, โ†‘), and Frequency for certain results (Freq-90, โ†‘). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default. timm implementations also can be used.

Pretrained models for CIFAR-100 are also provided: ResNet-50, ViT-Ti, PiT-Ti, and Swin-Ti. We recommend using timm for ImageNet-1K for the sake of simplicity (e.g., please refer to fourier_analysis.ipynb).
The codes below are snippets for (a) loading pretrained models and (b) converting them into block sequences.
# ResNet-50
import models
  
# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
models.download(url=url, path=path)

name = "resnet_50"
model = models.get_model(name, num_classes=100,  # timm does not provide a ResNet for CIFAR
                         stem=model_args.get("stem", False))
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

# b. model โ†’ blocks. `blocks` is a sequence of blocks
blocks = [
    model.layer0,
    *model.layer1,
    *model.layer2,
    *model.layer3,
    *model.layer4,
    model.classifier,
]
# ViT-Ti
import copy
import timm
import torch
import torch.nn as nn
import models

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar"
path = "checkpoints/vit_ti_cifar100_9857b21357.pth.tar"
models.download(url=url, path=path)

model = timm.models.vision_transformer.VisionTransformer(
    num_classes=100, img_size=32, patch_size=2,  # for CIFAR
    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti 
)
model.name = "vit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])


# b. model โ†’ blocks. `blocks` is a sequence of blocks

class PatchEmbed(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)
        
    def forward(self, x, **kwargs):
        x = self.model.patch_embed(x)
        cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.model.pos_drop(x + self.model.pos_embed)
        return x


class Residual(nn.Module):
    def __init__(self, *fn):
        super().__init__()
        self.fn = nn.Sequential(*fn)
        
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x
    
    
class Lambda(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x):
        return self.fn(x)


def flatten(xs_list):
    return [x for xs in xs_list for x in xs]


# model โ†’ blocks. `blocks` is a sequence of blocks
blocks = [
    PatchEmbed(model),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.blocks]),
    nn.Sequential(model.norm, Lambda(lambda x: x[:, 0]), model.head),
]
# PiT-Ti
import copy
import math
import timm

import torch
import torch.nn as nn

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/pit_ti_cifar100_0645889efb.pth.tar"
path = "checkpoints/pit_ti_cifar100_0645889efb.pth.tar"
models.download(url=url, path=path)

model = timm.models.pit.PoolingVisionTransformer(
    num_classes=100, img_size=32, patch_size=2, stride=1,  # for CIFAR-100
    base_dims=[32, 32, 32], depth=[2, 6, 4], heads=[2, 4, 8], mlp_ratio=4,  # for PiT-Ti
)
model.name = "pit_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])


# b. model โ†’ blocks. `blocks` is a sequence of blocks

class PatchEmbed(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)
        
    def forward(self, x, **kwargs):
        x = self.model.patch_embed(x)
        x = self.model.pos_drop(x + self.model.pos_embed)
        cls_tokens = self.model.cls_token.expand(x.shape[0], -1, -1)

        return (x, cls_tokens)

    
class Concat(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = copy.deepcopy(model)
        
    def forward(self, x, **kwargs):
        x, cls_tokens = x
        B, C, H, W = x.shape
        token_length = cls_tokens.shape[1]

        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((cls_tokens, x), dim=1)

        return x
    
    
class Pool(nn.Module):
    def __init__(self, block, token_length):
        super().__init__()
        self.block = copy.deepcopy(block)
        self.token_length = token_length
        
    def forward(self, x, **kwargs):
        cls_tokens = x[:, :self.token_length]
        x = x[:, self.token_length:]
        B, N, C = x.shape
        H, W = int(math.sqrt(N)), int(math.sqrt(N))
        x = x.transpose(1, 2).reshape(B, C, H, W)

        x, cls_tokens = self.block(x, cls_tokens)
        
        return x, cls_tokens
    
    
class Classifier(nn.Module):
    def __init__(self, norm, head):
        super().__init__()
        self.head = copy.deepcopy(head)
        self.norm = copy.deepcopy(norm)
        
    def forward(self, x, **kwargs):
        x = x[:,0]
        x = self.norm(x)
        x = self.head(x)
        return x

    
class Residual(nn.Module):
    def __init__(self, *fn):
        super().__init__()
        self.fn = nn.Sequential(*fn)
        
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

    
def flatten(xs_list):
    return [x for xs in xs_list for x in xs]


blocks = [
    nn.Sequential(PatchEmbed(model), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[0].blocks]),
    nn.Sequential(Pool(model.transformers[0].pool, 1), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[1].blocks]),
    nn.Sequential(Pool(model.transformers[1].pool, 1), Concat(model),),
    *flatten([[Residual(b.norm1, b.attn), Residual(b.norm2, b.mlp)] 
              for b in model.transformers[2].blocks]),
    Classifier(model.norm, model.head),
]
# Swin-Ti
import copy
import timm
import models

import torch
import torch.nn as nn

# a. download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/swin_ti_cifar100_ec2894492b.pth.tar"
path = "checkpoints/swin_ti_cifar100_ec2894492b.pth.tar"
models.download(url=url, path=path)

model = timm.models.swin_transformer.SwinTransformer(
    num_classes=100, img_size=32, patch_size=1, window_size=4,  # for CIFAR-100
    embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), qkv_bias=False,  # for Swin-Ti
)
model.name = "swin_ti"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])


# b. model โ†’ blocks. `blocks` is a sequence of blocks

class Attn(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = copy.deepcopy(block)
        self.block.mlp = nn.Identity()
        self.block.norm2 = nn.Identity()
        
    def forward(self, x, **kwargs):
        x = self.block(x)
        x = x / 2
        
        return x

class MLP(nn.Module):
    def __init__(self, block):
        super().__init__()
        block = copy.deepcopy(block)
        self.mlp = block.mlp
        self.norm2 = block.norm2
        
    def forward(self, x, **kwargs):
        x = x + self.mlp(self.norm2(x))

        return x

    
class Classifier(nn.Module):
    def __init__(self, norm, head):
        super().__init__()
        self.norm = copy.deepcopy(norm)
        self.head = copy.deepcopy(head)
        
    def forward(self, x, **kwargs):
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)

        return x

    
def flatten(xs_list):
    return [x for xs in xs_list for x in xs]


blocks = [
    model.patch_embed,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[0].blocks]),
    model.layers[0].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[1].blocks]),
    model.layers[1].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[2].blocks]),
    model.layers[2].downsample,
    *flatten([[Attn(block), MLP(block)] for block in model.layers[3].blocks]),
    Classifier(model.norm, model.head)
]

Fourier Analysis of Representations

Refer to fourier_analysis.ipynb (Colab notebook) to analyze feature maps through the lens of Fourier transform. Run all cells to visualize Fourier transformed feature maps. Fourier analysis shows that MSAs reduce high-frequency signals, while Convs amplified high-frequency components.

Measuring Feature Map Variances

Refer to featuremap_variance.ipynb (Colab notebook) to measure feature map variance. Run all cells to visualize feature map variances. Feature map variance shows that MSAs aggregate feature maps, but Convs and MLPs diversify them.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb (Colab notebook) or the original repo for exploring the loss landscapes. Run all cells to get predictive performance of the model for weight space grid. Loss landscape visualization shows that ViT has a flatter loss than ResNet.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb (Colab notebook) for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of a stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA.
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regimes, e.g., CIFAR-100!

Investigate Loss Landscapes and Hessians With L2 Regularization on Augmented Datasets

Two common mistakes are investigating loss landscapes and Hessians (1) 'without considering L2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with L2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + L2' on 'augmented datasets'. Measuring criteria without L2 on clean datasets would give incorrect results.

Citation

If you find this useful, please consider citing ๐Ÿ“‘ the paper and starring ๐ŸŒŸ this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

@inproceedings{park2022how,
  title={How Do Vision Transformers Work?},
  author={Namuk Park and Songkuk Kim},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

how-do-vits-work's People

Contributors

xxxnell avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

how-do-vits-work's Issues

code about robustness for noise frequency exp (Fig. 2b)

์•ˆ๋…•ํ•˜์„ธ์š”. ์ €์ž๋‹˜.
์šฐ์„  ๋งŽ์€ ์ธ์‚ฌ์ดํŠธ๋ฅผ ์ฃผ๋Š” ์ข‹์€ ๋…ผ๋ฌธ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

์ €์ž๋‹˜์˜ ๋…ผ๋ฌธ์„ ์ฝ๊ณ  ์ฝ”๋“œ๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ ๋ถ„์„์„ ์ง„ํ–‰ํ•ด ๋ณด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
๊ทธ์ค‘์— ์ €์ž๋‹˜์˜ Fig. 2b์˜ robustness for noise frequency์— ๋Œ€ํ•œ ๋ถ„์„์„ ์ง„ํ–‰ํ•ด ๋ณด๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค.
๊ทธ๋Ÿฌ๋‚˜ ์ฝ”๋“œ์—๋Š” ์ด ๋ถ€๋ถ„์€ ์—†๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์—ฌ ์งˆ๋ฌธ๋“œ๋ฆฌ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

์•„๋งˆ๋„ FreqAttack ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ด๋Š”๋ฐ,
ํ˜น์‹œ ์ด ์‹คํ—˜์„ ์žฌํ˜„ํ•ด๋ณด๊ธฐ ์œ„ํ•œ ๊ฐ frequency๋ณ„ random noise๋ฅผ ์ ์šฉํ•˜๋Š” ์‹คํ—˜ ์ฝ”๋“œ ๊ณต์œ ๋ฅผ ํ•ด์ฃผ์‹ค ์ˆ˜ ์žˆ์„๊นŒ์š”?

๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

Conclusion about long-range dependency seems not true

Authors stated in the paper that: "Contrary to popular belief, the long-range dependency hinders NN optimization.". However, recent models that adopts long-range dependency achieves really great results like: VAN, ConvNeXt or RepLKNet

Therefore, the statement I mentioned above seems a little bit wrong? I know there's an issue that discuss about large kernel Conv, however, the issue did not mention the statement above.

Moreover, the Experiments in Fig 7, you use Convolutional SANs. This model has 2 variants: 1D-CSANs and 2D-CSANs. The one you are doing experiments on is 2D-CSANs right? It not only consider the interaction among tokens in a single, but also consider the interaction among different heads. The "long-range dependency" is still very beneficial in the 1D-CSANs (Fig below), which typically, is what I consider the true "long-range dependency" in Self-attention.

image

When using 2D-CSANs, it considers both aspects: interaction among heads, and tokens, which brings negative performance when scaling up window sizes. The results is align with Convolutional SANs paper.

image

However, I don't consider 2D-CSANs negative performance when scaling up window sizes is: "long-range dependency hinders NN optimization" since it consider 2 aspects in the model. Sorry for writing this long, if you don't understand any parts in my question, I can clarify it for you

relative log magnitude

hello! How is the relative log magnitude calculated? Is the first layer subtracted from the feature map of each layer?

Hessian Max eigenvalue spectra ์ฝ”๋“œ ๊ด€๋ จ ์งˆ๋ฌธ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

์•ˆ๋…•ํ•˜์„ธ์š”,
๋…ผ๋ฌธ์„ ํ†ตํ•ด์„œ Neural Network Visualization์— ๊ด€์‹ฌ์„ ๊ฐ–๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ข‹์€ ๋…ผ๋ฌธ ์จ ์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

Github์— ๋‹ค๋ฅธ Visualization ๊ด€๋ จ Tutorial์ด ์ž˜ ์ œ๊ณต๋˜์–ด ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ๋Š”๋ฐ Hessian Max eigenvalue ๊ด€๋ จ ์ฝ”๋“œ์˜ ๊ฒฝ์šฐ ๋”ฐ๋กœ ๋ฉ”์ผ์„ ๋‚จ๊ธฐ๋ฉด ๊ด€๋ จ ์ž๋ฃŒ๋ฅผ ๋ณด๋‚ด์ฃผ์‹œ๋Š” ๊ฒƒ ๊ฐ™์•„ ์ด์Šˆ๋ฅผ ๋‚จ๊ธฐ๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ˜น์‹œ ์•„๋ž˜ ๋ฉ”์ผ๋กœ ๊ด€๋ จ ๋‚ด์šฉ์„ ๋ณด๋‚ด์ฃผ์‹ค ์ˆ˜ ์žˆ์„๊นŒ์š”?

์ œ ๋ฉ”์ผ ์ฃผ์†Œ๋Š” [email protected] ์ž…๋‹ˆ๋‹ค. ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

โˆ† Log amplitude ๊ด€๋ จ ์งˆ๋ฌธ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

์•ˆ๋…•ํ•˜์„ธ์š”. ๋…ผ๋ฌธ ํฅ๋ฏธ๋กญ๊ฒŒ ์ž˜ ์ฝ์—ˆ์Šต๋‹ˆ๋‹ค. ์ข‹์€ ๋…ผ๋ฌธ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

โˆ† Log amplitude ๊ด€๋ จ ์งˆ๋ฌธ์„ ๋“œ๋ฆฌ๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

Figure 2 (a) ๊ทธ๋ž˜ํ”„์—์„œ, ํŠน์ • frequency์—์„œ์˜ โˆ† Log amplitude๊ฐ€ ์ •ํ™•ํžˆ ๋ฌด์—‡์„ ๋œปํ•˜๋Š”์ง€ ์กฐ๊ธˆ ํ—ท๊ฐˆ๋ฆฝ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, ResNet์—์„œ 0.5ฯ€ ๋ถ€๋ถ„์˜ โˆ† Log amplitude๋Š” ์•ฝ -6์ด๋ฉฐ, ์ด -6์ด๋ผ๋Š” ๊ฒƒ์€ 0.0ฯ€์˜ amplitude์™€ 0.5ฯ€์˜ amplitude์˜ ์ƒ๋Œ€์ ์ธ ํฌ๊ธฐ ์ฐจ์ด๋ผ๊ณ  ์ €๋Š” ์ดํ•ดํ–ˆ์Šต๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ Figure 2์˜ "โˆ† Log amplitude is the difference between the log amplitude at normalized frequency 0.0ฯ€ (center) and at 1.0ฯ€ (boundary)."๋ผ๋Š” ๋ฌธ์žฅ์„ ๋ดค์„ ๋•Œ๋Š”, ๋‹จ์ˆœํžˆ โˆ† Log amplitude๊ฐ€ ๊ทธ๋ž˜ํ”„์˜ ๋ชจ๋“  ๋ถ€๋ถ„์—์„œ 0.0ฯ€์˜ amplitude์™€ 1.0ฯ€์˜ amplitude์˜ ์ƒ๋Œ€์ ์ธ ํฌ๊ธฐ ์ฐจ์ด๋ฅผ ๋œปํ•˜๋Š” ๊ฒƒ์ด๋ผ๊ณ  ์ƒ๊ฐํ–ˆ์Šต๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ โˆ† Log amplitude์— ๋Œ€ํ•œ ์ œ ์ดํ•ด๊ฐ€ ๋…ผ๋ฌธ์— ์“ฐ์—ฌ์ ธ ์žˆ๋Š” ๊ฒƒ๊ณผ ๋‹ค๋ฅธ ๊ฒƒ ๊ฐ™์€๋ฐ, ํ˜น์‹œ ์ œ๊ฐ€ ์ž˜๋ชป ์ดํ•ดํ•˜๊ณ  ์žˆ๋Š” ๊ฒƒ์ธ์ง€ ์—ฌ์ญ™๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค.

๋ฏธ๋ฆฌ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

Lesion study

Hi @xxxnell ,

I find it hard to understand the conclusions about the lesion study. For example, VIT has not satisfied your conclusion (i.e., the latter MSA is more important)
image

Can I get a guideline for hessian eigenvalue visualization?

Hi, I am Ph.D/M.S. integrated student at Yonsei University. I am very interested in your research and I am looking into your code. However, I couldn't find the code about hessian eigenvalue and recognized that you don't share it right now.

It would be very pleasure if you give me the code or guideline to write about hessian eigenvalue visualization.

my email: [email protected]

Thank you.

How would the MSA build-up rules differ for upsampling stages?

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA.
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

I suppose the above rules apply to high-level computer vision tasks such as classifications that involve only downsampling. I wonder how these rules differ for tasks involving upsampling stages such as image generation from latent or segmentation with U-Net. In particular, I am interested in (1) the ordering of Conv and MSA blocks and (2) the number of heads and hidden dimensions in upsampling stages.

Based on your findings that Convs are high-pass and MSAs are low-pass filters, I suppose the ordering of Conv-MSA blocks should hold for both downsampling and upsampling stages instead of MSA-Conv blocks in upsampling.

Since downsampling stages usually reduce the spatial resolution and increase the channel dimension, the third rule makes sense.
However, upsampling stages usually increase the spatial resolution and reduce the channel dimension, does the third rule still hold for upsampling? Or should it be flipped to fewer heads and lower hidden dimensions for late stages?

I will appreciate your valuable insights on the application of these build-up rules for upsampling stages.

How do I implement Figure2(b) in detail?

I have obtained the performance evaluation results of resnet and vit under different types of noise and different intensification using the robustness code you have given. But I do not know how the results of the effects of different frequencies of noise on the model performance in Figure2(b) are obtained. Can you share the specific code?

AlterNet on CIFAR10

Hi, while trying to setup an alternet_18 to train on CIFAR10 I used the default config in models/alternet.py, which would be the following.

AlterNet(preresnet_dnn.BasicBlock, AttentionBasicBlockB, stem=partial(StemB, pool=stem),
                    num_blocks=(2, 2, 2, 2), num_blocks2=(0, 1, 1, 1), heads=(3, 6, 12, 24),
                    num_classes=num_classes, name=name, **block_kwargs)``

Upon doing so I get the following error
Input tensor shape: torch.Size([128, 128, 4, 4]). Additional info: {'p1': 7, 'p2': 7}.
Shape mismatch, can't divide axis of length 4 in chunks of 7
which is thrown by
x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)
in the Class LocalAttention.
This is happening because the default window size is 7, which doesn't work for 3 x 32 x 32 input images of CIFAR10. Could you point me to a setup used to train AlterNet for CIFAR10/100 images?
Thank you

how is robustness calculated?

Hi,

thank you for this wonderful work on vision transformers and how to understand them. I have some simple questions which I must apologize for.
I tried to reproduce figure 12 independently of your code base. I struggle a bit to understand the code. Is is correct that you define robustness as robustness = mean(accuracy(y_val_true, y_val_pred))?
Related to this, do I understand correctly that you compute this accuracy on batches of the validation dataset? These batches are of size 256, right?

Thanks.

question about detail--drop_pro parameter sd

thanks for your great work!
notice that you have set drop some value during the training with sd=0.1ใ€‚
did you do some exps to analysis the influence between the drop ratio ?

Potential mistake in loss landscape visualization.

Hi, thanks for your great work. I'd like to discuss the L2 Loss problem in loss landscape visualization. I found that your calculated L2 loss is significantly larger (10x) than the classification loss so the landscape visualization is basically a visualization of L2 Loss.
In fact, "weight decay" is slightly different from "L2 Loss" in Pytorch in implementation. Simply calculating the sum of norms as L2 loss is different from applying weight decays in Adam-like fancy optimizers in Pytorch. See blogs in https://bbabenko.github.io/weight-decay/.
Although one might find L2 Loss is significantly larger than the classification loss. In fact, in the practice of ViT, the weight decay loss does not dominate the classification loss, this is due to the implementation of weight decay in Pytorch.

Question about Figure 2(a)

image
Looking at this figure, I'm seeing that the early layers of ResNet has many low-freq components, and the deeper ResNet goes, it contains more high-freq components. Am I interpreting this figure right?

If I'm right, isn't this a little contradict to popular belief and visualization? That early layers in a ConvNet tend to learn high-freq components?

what is the attributes in the large-kernel CNN

Great analysis! I wonder the attributes of large-kernel CNN. In your paper, the basic 3x3 resnet is fully explored. 3x3 conv extracts detailed local patterns, thus may contribute to the high pass filtering. However, recent works investigate the effect of larger kernel. The attribute of 3x3 resnet might change, and similar to ViT?

Findings not compatible with other work?

In figure 1 of the paper, authors stated that MSA flattens the loss landscape, however, in When Vision Transformer outperform ResNets without pre-training or strong data augmentation, they stated that ViT converge at sharp local minima, which is contrast to your findings?

Furthermore, authors claim that "The magnitude of the Hessian eigenvalues of ViT is smaller than that of ResNet during training
phase" (Fig 1 still). However, in above paper, the "Hessian dominate eigenvalue" of ViT are "orders of magnitude larger
than that of ResNet" (Table 1).

Loss landscape and Hessian max eigenvalue of your work:
image

Loss landscape and Hessian max eigenvalue of other work:
image

In convit.py file, where does ConVit come from, really?

"""
This ConViT is ViT with two-dimensional convolutional MSA, NOT [1]!
[1] d'Ascoli, Stรฉphane, et al. "Convit: Improving vision transformers with soft convolutional inductive biases."
arXiv preprint arXiv:2103.10697 (2021).
"""

You said it's not the same with ConVit by d'Ascoli, Stรฉphane, et al. Then where does this ConVit come from? I ask because if I reuse this code, I want to know whom I should cite.

Trained models

Could you let accessible for the already trained models in this work ? thank you very much in advance

pretrained models

Can you provide the pretrained model of Alternet for ImageNet1k-C? Thanks !

Why is feed-forward not present in the paper and the code?

The original ViT and many ViT variants have feed-forward in their architectures. I noticed that feed-forward is neither mentioned in the paper nor implemented in the code of AlterNet. It would be interesting to learn about the intuitions behind such a design choice.

Understanding loss landscape

I understood that in the loss landscape visualization the z-axis is NLL. I'm curious what the x-axis and y-axis mean. Of course, we can see in loss_landscapes.py how the x and y values โ€‹โ€‹participate in the calculation, but I don't have an intuitive understanding of it.

    xs = np.linspace(x_min, x_max, n_x)
    ys = np.linspace(y_min, y_max, n_y)
    ratio_grid = np.stack(np.meshgrid(xs, ys), axis=0).transpose((1, 2, 0))
    print(ratio_grid)
    metrics_grid = {}
    for ratio in ratio_grid.reshape([-1, 2]):
        print(ratio)
        ws = copy.deepcopy(ws0)
        gs = [{k: r * bs[k] for k in bs} for r, bs in zip(ratio, bases)]
        gs = {k: torch.sum(torch.stack([g[k] for g in gs]), dim=0) + ws[k] for k in gs[0]}
        print(gs)
        model.load_state_dict(gs)

        print("Grid: ", ratio, end=", ")
        *metrics, cal_diag = tests.test(model, n_ff, dataset, transform=transform,
                                        cutoffs=cutoffs, bins=bins, verbose=verbose, period=period, gpu=gpu)
        l1, l2 = norm.l1(model, gpu).item(), norm.l2(model, gpu).item()
        metrics_grid[tuple(ratio)] = (l1, l2, *metrics)

    return metrics_grid

Thank you sincerely.

ํ—ค์‹œ์•ˆ ๊ด€๋ จํ•ด์„œ ์งˆ๋ฌธ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

์•ˆ๋…•ํ•˜์„ธ์š”, ํ•ด๋‹น ๋…ผ๋ฌธ ๋งค์šฐ ์žฌ๋ฐŒ๊ฒŒ ์ž˜ ์ฝ์—ˆ์Šต๋‹ˆ๋‹ค.

๊ด€๋ จํ•ด์„œ, Hassian max eegenvalues spectra๋ฅผ ๋…ผ๋ฌธ์ฒ˜๋Ÿผ ๊ตฌํ˜„ํ•ด๋ณด๊ณ  ์‹ถ์€๋ฐ ํ˜น์‹œ ์ด์™€ ๊ด€๋ จ๋œ ์ฝ”๋“œ๋Š” ์–ด๋””์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์„๊นŒ์š”?

๋ฏธ๋ฆฌ ๋‹ต๋ณ€ ๊ฐ์‚ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค!

model size

hello๏ผŒi have aquestion about why you use vit-s and vit-tiny,and counterpart is resnet-50,these size is not equal.i know you have explained on openview,i want to know whether vit-base's matrix eigenvalue spectrum is like vit-tiny in your paper,just stretch to the right.

Code for Alter-ResNet-50

Hi, awesome work and really good points about MSAs! I'm very much interested in the AlterNet mentioned in the paper(based on ResNet-50 and SwinTBlock), but I cant find the implementation of it in this repo. Did I miss? If not, can you release the code maybe?

Thanks a lot!

Running time for visualizing the loss landscape

Thank you for the amazing work! The summary organizes the main points of the paper really well and helps facilitate future research.

I would like to visualize the loss landscape for my model and I am trying the Python notebook in this repository on Colab. Apparently it is a long process and I would like to have a rough estimation on the time it will take for my model.

Could you kindly advise how long it took for you to run the notebook on Colab or on your machine?

ViT vs ResNet: Did you use SAM optimizer?

In section 1.1 Related works, I see that you have a comparision between loss landscape of ResNet and ViT. Did you use SAM optimizer while training both models at this stage of comparing or not ?
image

TensorFlow implementation

Does there exist a TF implementation of AlterNet?

Would be a great contribution to the field as there are so many that uses TF, me included.

What factors determine if a model or a layer behaves like a low- or high-pass filter?

Your paper reports that generally MSAs behave like low-pass filters (shape-biased) and Convs behave like high-pass filters (texture-biased). Recently I came across papers that report shape bias in their findings and I wonder about your thoughts on them.

Low-pass filters (shape-biased)

High-pass filters (texture-biased)

  • 3x3 Convs in ResNet

These findings suggest that factors affecting the behavior can be spatial aggregation, kernel size, training data, or training procedures. It seems that only 3x3 Convs behave like high-pass filters or I may be missing something. In your another thread you mentioned that group size also makes a difference. I wonder how ResNet and ResNeXt differ and I suppose ResNeXt is also texture-biased.

I will appreciate your insights on what factors determine if a model or a layer behaves like a low- or high-pass filter.

Question about harmonizing Convs with MSAs

In the paper authors stated that: "MSAs are low-pass filters, but Convs are high-pass filters". And authors proposed how to harmonize Convs with MSAs: by replacing Convs at (preferable) the end of a stage. And authors also have the idea that: "uses Convs in early stages and MSAs in late stages".

Sorry in advance if these following questions of mine are dumb.

In the late stage, adding Convs after MSAs should decreases the performance of a model right? Since the late stages produces low-frequency features, and adding Convs there suppress those features? I did an experiments: I trained a hierarchical ViT, Segformer, then replace the last stage 1x1 Conv in the decoder with a 3x3 Conv (pic below)
image

I trained the model on a Polyp Segmentation dataset, reported results below:

Model Dice Score
Segformer 84.95
Modified Segformer 84.61

I haven't test if replacing the 1x1 Conv in stage 1-2 with 3x3 Conv will increases the performance, but is the conclusion I made above correct?

Hi

When i run the forward function of LocalAttention class, some errors occurred.

x.shape = [1,128,84,64] and self.window_size=8.
The rearrange function can not run in the right way as n1=84//8 can not be divisible.

If i change the window_size=7/6/5, there may be other img's height or width can not be divisible.

I also try dynamic set window_size but it didn't succeed.

The image come from coco datasets.

Do you have any good suggestions ?

The code is

      b, c, h, w = x.shape

        p = self.window_size

        n1 = h // p

        n2 = w // p

        mask = torch.zeros(p ** 2, p ** 2, device=x.device) if mask is None else mask

        mask = mask + self.pos_embedding[self.rel_index[:, :, 0], self.rel_index[:, :, 1]]

        x = rearrange(x, "b c (n1 p1) (n2 p2) -> (b n1 n2) c p1 p2", p1=p, p2=p)

        x, attn = self.attn(x, mask)

        x = rearrange(x, "(b n1 n2) c p1 p2 -> b c (n1 p1) (n2 p2)", n1=n1, n2=n2, p1=p, p2=p)

Frequency Analysis for MoCo-v3

Hi, thank you for your great work. We take you code to analysis the feature of MoCo-v3 from the frequency perspective, but we obtain the following trend:
kl7gTzKilN
I am a little bit confused, because I think there should be a decreasing trend.

How to plot the Trajectories in polar coordinates?

Hello author, your work has brought me a lot, but when I read the paper, I am very interested in the drawing process of Figure 1(b)[Trajectories in polar coordinates]. Can you open source it?
Best regards to you

how to compute feature map variances?

hello,

Thank you for your great work!

I wonder how you get the feature map variances. According to my understanding, you first need to extract representations of all the samples, which should give us a vector with a length of D (let's just fatten the 2d tensor or concatenate all tokens). Then you calculate the variance of each element in this vector over all the samples, which should give us D variances. Finally, you take the mean value of all D variances and get the variance ready to report.

Did I get you correctly? Sorry if I didn't catch up with your existing documentation or description.

Thank you and I'm looking forward to your reply.

Best,

What exactly makes MSAs data specificity?

In the paper, authors state that "A key feature of MSAs is data specificity (not long-range dependency)".

Can you explain about the "data specificity" part? What is it, and how it behaves?

Further more, can you elaborate how MSAs (through visualization, formulas, etc) achieves data specificity

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.