GithubHelp home page GithubHelp logo

lucidrains / axial-attention Goto Github PK

View Code? Open in Web Editor NEW
333.0 11.0 29.0 39 KB

Implementation of Axial attention - attending to multi-dimensional data efficiently

License: MIT License

Python 100.00%
artificial-intelligence deep-learning attention-mechanism pytorch

axial-attention's Introduction

Axial Attention

PyPI version

Implementation of Axial attention in Pytorch. A simple but powerful technique to attend to multi-dimensional data efficiently. It has worked wonders for me and many other researchers.

Simply add some positional encoding to your data and pass it into this handy class, specifying which dimension is considered the embedding, and how many axial dimensions to rotate through. All the permutating, reshaping, will be taken care of for you.

This paper was actually rejected on the basis of being too simple. And yet, it has since been used successfully in a number of applications, among those weather prediction, all-attention image segmentation. Just goes to show.

Install

$ pip install axial_attention

Usage

Image

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
    dim = 3,               # embedding dimension
    dim_index = 1,         # where is the embedding dimension
    dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
    heads = 1,             # number of heads for multi-head attention
    num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
    sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

Channel-last image latents

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
    dim = 512,           # embedding dimension
    dim_index = -1,      # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(img) # (1, 20, 20 ,512)

Video

import torch
from axial_attention import AxialAttention

video = torch.randn(1, 5, 128, 256, 256)

attn = AxialAttention(
    dim = 128,           # embedding dimension
    dim_index = 2,       # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 3,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(video) # (1, 5, 128, 256, 256)

Image Transformer, with reversible network

import torch
from torch import nn
from axial_attention import AxialImageTransformer

conv1x1 = nn.Conv2d(3, 128, 1)

transformer = AxialImageTransformer(
    dim = 128,
    depth = 12,
    reversible = True
)

img = torch.randn(1, 3, 512, 512)

transformer(conv1x1(img)) # (1, 3, 512, 512)

With axial positional embedding

import torch
from axial_attention import AxialAttention, AxialPositionalEmbedding

img = torch.randn(1, 512, 20, 20)

attn = AxialAttention(
    dim = 512,
    heads = 8,
    dim_index = 1
)

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    shape = (20, 20)
)

img = pos_emb(img)  # (1, 512, 20, 20)  - now positionally embedded
img = attn(img)     # (1, 512, 20, 20)

Citation

@misc{ho2019axial,
    title  = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year   = {2019},
    archivePrefix = {arXiv}
}
@misc{wang2020axialdeeplab,
    title   = {Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation},
    author  = {Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen},
    year    = {2020},
    eprint  = {2003.07853},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{huang2019ccnet,
    title   = {Ccnet: Criss-cross attention for semantic segmentation},
    author  = {Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages   = {603--612},
    year    = {2019}
}

axial-attention's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

axial-attention's Issues

Positional embeddings for different image sizes

Hi,
once again thanks for your great work! Since I want to use the axial attention with positional embedding for unknown image sizes (But I know the max size), I was wondering if you think that changing https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L104 to

for cnt, param in enumerate(self.params):
    x = x + param[([slice(None)] * (cnt + 2) + [slice(x.shape[cnt + 2])])]

does the right thing. I can now do this

v = AxialImageTransformer(64, depth = 1, axial_pos_emb_shape = (64,64), dim_index = 1)       
t1 = torch.randn(2, 64, 17, 16)
t2 = torch.randn(2, 64, 13, 18)
t3 = torch.randn(2, 64, 64, 64)
print(v(t1).shape)
print(v(t2).shape)
print(v(t3).shape)
Output:
torch.Size([2, 64, 17, 16])
torch.Size([2, 64, 13, 18])
torch.Size([2, 64, 64, 64])

I think that makes it easier to integrate it in fully convolutional nets for multi scale training.

Examples for image sequence/video

Hello,
Do you have examples of integrating this on image sequences?
I am trying to get rid of ConvLSTM's for encoding sequence of images and AxialAttention may be a good starting point.
Do you have an exmaple/notebook that I could look to integrate this on my type of data?
Thank you for this amazing work.
Thomas

Hi, I have a problem

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?

AxialPositionalEmbedding

Would you be able to provide an example of how to add the positional encoding with the AxialPositionalEmbedding class or explain what the emb_dim, emb_dim_index, and dimensions arguments are specifically? Thanks for the repo!

User Warning: Mixed memory format inputs detected

At site-packages/axial_attention/axial_attention.py:176:
UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (
Triggered internally at /opt/conda/conda-bld/pytorch_1595629427286/work/aten/src/ATen/native/TensorIterator.cpp:918.)
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))

I am using latest axial_attention (v0.4) and Pytorch 1.6.0

Code:

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 24, 64, 64)

attn = AxialAttention(
    dim = 24,               # embedding dimension
    dim_index = 1,         # where is the embedding dimension
    dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
    heads = 8,             # number of heads for multi-head attention
    num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
    sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

out= attn(img) 

Will it affect trainings and inference?

Extracting attention maps

Hi there,

Excellent project!

I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True, and I wish to visualise the attention maps.

Essentially, given my video, and two frame indices frame_a_idx and frame_b_idx, I need to extract the attention map over frame_b to a chosen pixel (x, y) in frame_a (after the axial sum).

My understanding is that I should be able to reshape the dots (after softmax) according to the permutations in calculate_permutations, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.

I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:

In SelfAttention.forward():

dots_reshaped = dots.reshape(b, h, t, t)
return out, dots_reshaped

In PermuteToFrom.forward():

 # attention
axial, dots = self.fn(axial, **kwargs)

# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
dots = dots.reshape(*shape[:3], *dots.shape[1:])

However, I am unsure of how to un-permute the dots appropriately such that all resulting β€œaxes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!

Axial-deeplab

Hello, Is it possible to build axial-deeplab with help of the provided axial-attention blocks? Any suggestion on how to do so?

Ask a question

I'm interested to your excellent work,but I'm new to pytorch,can I ask a question where is the start position in the code that i will understand whole project from it ?Thx for your reply

Problem of ParameterList with nn.DataParallel

self.params = nn.ParameterList(parameters)

This line would lead to the following issue:
"UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."

It is a known issue here

The simple solution should be to store the Parameters directly on the Module.

class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, shape, emb_dim_index = 1):
        super().__init__()
        parameters = []
        total_dimensions = len(shape) + 2
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
        
        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
            shape = [1] * total_dimensions
            shape[emb_dim_index] = dim
            shape[axial_dim_index] = axial_dim
            parameter = nn.Parameter(torch.randn(*shape))
            setattr(self, f'param_{i}', parameter)
            setattr(self, f'param_num', i+1)

    def forward(self, x):
        for i in range(self.param_num):
            x = x + getattr(self, f'param_{i}')
        return x

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.