lucidrains / axial-attention Goto Github PK
View Code? Open in Web Editor NEWImplementation of Axial attention - attending to multi-dimensional data efficiently
License: MIT License
Implementation of Axial attention - attending to multi-dimensional data efficiently
License: MIT License
Hi there,
Excellent project!
I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True
, and I wish to visualise the attention maps.
Essentially, given my video, and two frame indices frame_a_idx
and frame_b_idx
, I need to extract the attention map over frame_b to a chosen pixel (x
, y
) in frame_a (after the axial sum).
My understanding is that I should be able to reshape the dots
(after softmax) according to the permutations in calculate_permutations
, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.
I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:
In SelfAttention.forward()
:
dots_reshaped = dots.reshape(b, h, t, t)
return out, dots_reshaped
In PermuteToFrom.forward()
:
# attention
axial, dots = self.fn(axial, **kwargs)
# restore to original shape and permutation
axial = axial.reshape(*shape)
axial = axial.permute(*self.inv_permutation).contiguous()
dots = dots.reshape(*shape[:3], *dots.shape[1:])
However, I am unsure of how to un-permute the dots appropriately such that all resulting βaxesβ (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!
Hi,
once again thanks for your great work! Since I want to use the axial attention with positional embedding for unknown image sizes (But I know the max size), I was wondering if you think that changing https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L104 to
for cnt, param in enumerate(self.params):
x = x + param[([slice(None)] * (cnt + 2) + [slice(x.shape[cnt + 2])])]
does the right thing. I can now do this
v = AxialImageTransformer(64, depth = 1, axial_pos_emb_shape = (64,64), dim_index = 1)
t1 = torch.randn(2, 64, 17, 16)
t2 = torch.randn(2, 64, 13, 18)
t3 = torch.randn(2, 64, 64, 64)
print(v(t1).shape)
print(v(t2).shape)
print(v(t3).shape)
Output:
torch.Size([2, 64, 17, 16])
torch.Size([2, 64, 13, 18])
torch.Size([2, 64, 64, 64])
I think that makes it easier to integrate it in fully convolutional nets for multi scale training.
This line would lead to the following issue:
"UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."
It is a known issue here
The simple solution should be to store the Parameters directly on the Module.
class AxialPositionalEmbedding(nn.Module):
def __init__(self, dim, shape, emb_dim_index = 1):
super().__init__()
parameters = []
total_dimensions = len(shape) + 2
ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
shape = [1] * total_dimensions
shape[emb_dim_index] = dim
shape[axial_dim_index] = axial_dim
parameter = nn.Parameter(torch.randn(*shape))
setattr(self, f'param_{i}', parameter)
setattr(self, f'param_num', i+1)
def forward(self, x):
for i in range(self.param_num):
x = x + getattr(self, f'param_{i}')
return x
I'm interested to your excellent work,but I'm new to pytorch,can I ask a question where is the start position in the code that i will understand whole project from it ?Thx for your reply
At site-packages/axial_attention/axial_attention.py:176:
UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. (
Triggered internally at /opt/conda/conda-bld/pytorch_1595629427286/work/aten/src/ATen/native/TensorIterator.cpp:918.)
return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))
I am using latest axial_attention (v0.4) and Pytorch 1.6.0
Code:
import torch
from axial_attention import AxialAttention
img = torch.randn(1, 24, 64, 64)
attn = AxialAttention(
dim = 24, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 8, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
out= attn(img)
Will it affect trainings and inference?
Would you be able to provide an example of how to add the positional encoding with the AxialPositionalEmbedding class or explain what the emb_dim, emb_dim_index, and dimensions arguments are specifically? Thanks for the repo!
Thanks for the sharing.
This is great job!
As tensorflow is another major framework widely used in production environments, is there a tf version for the work?
Hello,
Do you have examples of integrating this on image sequences?
I am trying to get rid of ConvLSTM's for encoding sequence of images and AxialAttention may be a good starting point.
Do you have an exmaple/notebook that I could look to integrate this on my type of data?
Thank you for this amazing work.
Thomas
Hello, Is it possible to build axial-deeplab with help of the provided axial-attention blocks? Any suggestion on how to do so?
import torch
from axial_attention import AxialAttention
img = torch.randn(1, 3, 256, 256)
attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn(img) # (1, 3, 256, 256)
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?
Hi,
this is a nice paper.
How can I use your shared code to reimplement the image modeling task on ImageNet 32x32?
Thanks.
Looking forward to your reply.
Any examples of sampling / training?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. πππ
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google β€οΈ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.