GithubHelp home page GithubHelp logo

agent-attention's Introduction

Agent Attention

This repo contains the official PyTorch code and pre-trained models for Agent Attention.

Introduction

The attention module is the key component in Transformers. While the global attention mechanism offers robust expressiveness, its excessive computational cost constrains its applicability in various scenarios. In this paper, we propose a novel attention paradigm, Agent Attention, to strike a favorable balance between computational efficiency and representation power. Specifically, the Agent Attention, denoted as a quadruple $(Q, A, K, V)$, introduces an additional set of agent tokens $A$ into the conventional attention module. The agent tokens first act as the agent for the query tokens $Q$ to aggregate information from $K$ and $V$, and then broadcast the information back to $Q$. Given the number of agent tokens can be designed to be much smaller than the number of query tokens, the agent attention is significantly more efficient than the widely adopted Softmax attention, while preserving global context modelling capability. Interestingly, we show that the proposed agent attention is equivalent to a generalized form of linear attention. Therefore, agent attention seamlessly integrates the powerful Softmax attention and the highly efficient linear attention.

Method

An illustration of our agent attention and agent attention module. (a) Agent attention uses agent tokens to aggregate global information and distribute it to individual image tokens, resulting in a practical integration of Softmax and linear attention. $\rm{\sigma}(\cdot)$ represents Softmax function. In (b), we depict the information flow of agent attention module. As a showcase, we acquire agent tokens through pooling. Subsequently, agent tokens are utilized to aggregate information from $V$, and $Q$ queries features from the agent features. In addition, agent bias and DWC are adopted to add positional information and maintain feature diversity.

Results

Classification

Please go to the folder agent_transformer for specific document.

  • Comparison of different models on ImageNet-1K.

  • Accuracy-Runtime curve on ImageNet.

  • Increasing resolution to ${256^2, 288^2, 320^2, 352^2, 384^2}$.

Downstream tasks

Please go to the folder detection, segmentation for specific documents.

AgentSD

When applied to Stable Diffusion, our agent attention accelerates generation and substantially enhances image generation quality without any additional training. Please go to the folder agentsd for specific document.

  • Quantitative Results of Stable Diffusion, ToMeSD and our AgentSD.

  • Samples generated by Stable Diffusion, ToMeSD ($r=0.4$) and AgentSD ($r=0.4$).

TODO

  • Classification
  • Segmentation
  • Detection
  • Agent Attention for Stable Diffusion

Acknowledgements

Our code is developed on the top of PVT, Swin Transformer, CSwin Transformer and ToMeSD.

Citation

If you find this repo helpful, please consider citing us.

@article{han2023agent,
  title={Agent Attention: On the Integration of Softmax and Linear Attention},
  author={Han, Dongchen and Ye, Tianzhu and Han, Yizeng and Xia, Zhuofan and Song, Shiji and Huang, Gao},
  journal={arXiv preprint arXiv:2312.08874},
  year={2023}
}

Contact

If you have any questions, please feel free to contact the authors.

Dongchen Han: [email protected]

Tianzhu Ye: [email protected]

agent-attention's People

Contributors

tian-qing001 avatar ailmlhl avatar eltociear avatar ytianzhu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.