GithubHelp home page GithubHelp logo

billpsomas / simpool Goto Github PK

View Code? Open in Web Editor NEW
90.0 2.0 2.0 4.73 MB

This repo contains the official implementation of ICCV 2023 paper "Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?"

Home Page: https://arxiv.org/abs/2309.06891

License: Apache License 2.0

Python 82.48% Shell 0.01% JavaScript 0.01% MDX 17.51%
attention-mechanism computer-vision convolutional-neural-networks deep-learning neural-networks pooling vision-transformer self-supervised-learning supervised-learning

simpool's Introduction

Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?

Official PyTorch implementation and pretrained models for ICCV 2023 SimPool. [arXiv], [paper], [poster], [demo]

SimPool illustration

Overview

Motivation

  • Convolutional networks and vision transformers have different forms of pairwise interactions, pooling across layers and pooling at the end of the network. Does the latter really need to be different❓ What would happen if we completely discarded the [CLS]❓
SimPool illustration

  • As a by-product of pooling, vision transformers provide spatial attention for free, but this is most often of low quality unless self-supervised, which is not well studied. Is supervision really the problem❓ Can we obtain high-quality attention maps in a supervised setting❓

Approach

We introduce SimPool, a simple attention-based pooling method at the end of network, obtaining clean attention maps under supervision or self-supervision, for both convolutional and transformer encoders.

  • Attention maps of ViT-S on ImageNet-1k:
SimPool attention maps
Note that when using SimPool with Vision Transformers, the [CLS] token is completely discarded.
  • Attention maps of ResNet-50 and ConvNeXt-S on ImageNet-1k:
SimPool attention maps

📢 NOTE: Considering integrating SimPool into your workflow?
Use SimPool when you need high quality attention maps, delineating object boundaries. Use SimPool as an alternative pooling mechanism. It's super easy to try!

SimPool Attention Map Visualizer 🌌

Check out the SimPool interactive [demo] for attention map visualization:

Demo of SimPool Attention Map Visualizer

Integration

SimPool is by definition plug and play.

To integrate SimPool into any architecture (convolutional network or transformer) or any setting (supervised, self-supervised, etc.), follow the steps below:

1. Initialization (__init__ method):

from sp import SimPool

# this part goes into your model's __init___()
self.simpool = SimPool(dim, gamma=None) # dim is depth (channels)

NOTE: Remember to adapt the value of gamma according to the architecture, e.g. gamma=2.0 for convolutional networks. Here we consider the naive case not using gamma.

2. Model Forward Pass (forward method):

Assuming input tensor X has dimensions:

  • (B, d, H, W) for convolutional networks
  • (B, N, d) for transformers, where:

B = batch size, d = depth (channels), H = height of the feature map, W = width of the feature map, N = patch tokens

# this part goes into your model's forward()
cls = self.simpool(x) # (B, d)

NOTE: Remember to integrate the above code snippets into the appropriate locations in your model definition.

Experiments

We provide experiments on ImageNet in both supervised and self-supervised learning. Have a look on the respective folders for pre-trained models, reproduction recipes, etc.

Preliminaries

We use two different Anaconda environments, both utilizing PyTorch. For both, you will first need to download ImageNet.

Self-supervised learning environment

Create this environment for self_supervised learning experiments.

conda create -n simpoolself python=3.8 -y
conda activate simpoolself
pip3 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip3 install timm==0.3.2 tensorboardX six

Supervised learning environment

Create this environment for supervised learning experiments.

conda create -n simpoolsuper python=3.9 -y
conda activate simpoolsuper
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
pip3 install pyyaml

Acknowledgement

This repository is built using Attmask, DINO, ConvNeXt, DETR, timm and Metrix repositories.

NTUA thanks NVIDIA for the support with the donation of GPU hardware. Bill thanks IARAI for the hardware support.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star 🌟 and citation:

@InProceedings{psomas2023simpool,
    author    = {Psomas, Bill and Kakogeorgiou, Ioannis and Karantzalos, Konstantinos and Avrithis, Yannis},
    title     = {Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2023},
    pages     = {5350-5360}
}

simpool's People

Contributors

billpsomas avatar gkakogeorgiou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

dl-vit nickb-

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.