GithubHelp home page GithubHelp logo

landskape-ai / triplet-attention Goto Github PK

View Code? Open in Web Editor NEW
383.0 383.0 46.0 9.95 MB

Official PyTorch Implementation for "Rotate to Attend: Convolutional Triplet Attention Module." [WACV 2021]

Home Page: https://openaccess.thecvf.com/content/WACV2021/html/Misra_Rotate_to_Attend_Convolutional_Triplet_Attention_Module_WACV_2021_paper.html

License: MIT License

Python 5.18% Shell 0.24% Jupyter Notebook 94.55% TeX 0.03%
arxiv attention-mechanism attention-mechanisms computer-vision convolutional-neural-networks deep-learning detection gradcam imagenet paper triplet-attention

triplet-attention's People

Contributors

digantamisra98 avatar iyaja avatar trikaynalamada avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

triplet-attention's Issues

error

line 293, in parse_model
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
TypeError: init() takes from 1 to 2 positional arguments but 3 were given

您好,我在yolov5中想加入这个注意力,然后报这个错误,请问该如何修改呢?

预训练权重

我这边使用res2net的话,看到您这边没有对应的res2net+triplet_attentation预训练权重,会有影响吗?

triplet_attention.py

请问根目录下的triplet_attention.py和model文件下的triplet_attention.py哪个效果更好呢

data loader issue in the training stage

Thanks for your code!
When I try to add this attention mechanism into the resnet backbone and train a detection network, the error appears as:

File "anaconda3/envs/maskrcnn/lib/python3.7/site-packages/torch/utils/data/_utils/signal_handling.py", line 63, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 13791) is killed by signal: Aborted.

The batch size and num_workers have been set as 1 and it stills goes wrong.
Could you give some suggestions?

Pretrain weights and experiment set on resnet20,32

Thank you for sharing your work, the interaction of information between different dimensions is very interesting and this is very helpful for me.

When I tried to reproduce Triplet attention on resnet20,32, I had a little problem, the accuracy of my baseline model as well as the model after applying Triplet attention were lower than the results mentioned in your article. So can you share your pre-training weights and experimental setup?
Thanks a lot!

seems like the params are not used in `TripletAttention`'s definition?

in triplet_attention.py:

class TripletAttention(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    
    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        return x_out

seems like gate_channels, reduction_ratio, pool_types are never used?

No multiprocessing option

Hi,
For the usage, I didn't find the way to enable to multi-GPU training option. I attached one screenshot.

image

I Have a question

When the operation order was changed from cw > hc > hw to hw > cw > hc, performance improved in certain models. Calculations are performed independently of each other. Do you know why this is?

thank you.

A little question about adding attention to the network

Thank the author for sharing your work, which is very helpful to me.

I have a little question after studying your work

Is your attention level added after each block? Is there a theoretical difference between this and only adding at the last level?

Pre-training weights loading

Hi,

No matter which definitions that I used for triplet-attention definition, I can't load the pertaining weights.

And could you please let me know which the triplet-attention version that matches the Imagenet Pretraining weights?

load error

model = ResidualNet( 'ImageNet', 18, 1000, 'TripletAttention')
state_dict = torch.load("model_best.pth.tar")['state_dict']. #ResNet-18+Triplet Attention(k =3)
model.load_state_dict(state_dict)

Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num

other double,
k=3 mean that SpatialGate module kernel size=3?

wandb error

Hi,
I came across the following error. Does anyone has similar experience?
Thanks.
image

mmdetection_fasterrcnn

hello, i applied triple_attention on fasterrcnn using mmdetection, the problem is inference speed becomes very slow , about twice slower,why this happens? And where is the code you applied trilet_attention on fasterrcnn using mmdetection?i‘d like to learn some. Thanks~

怎么修改能一键替换SE 和CBAM

如题,有没有方法直接一键替换SE和CBAM

triplet_attention模块初始化并不是与SE一样是channels参数

# SE attention
class SE(nn.Module):
    def __init__(self, in_channels, channels, se_ratio=12):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0),
            nn.BatchNorm2d(channels // se_ratio),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y

Improvement of Triple Attention

  1. First of all, thank you very much for the triplet attention. After verification on the ReID task, the effect is really good. The disadvantage is that the running time is slow.
  2. Try to use this triplet approach to SE too, the effect is also okay
    3.For this feature map, find the mean along the direction of the h channel, then when you look at a certain channel alone, as shown in the figure, the outermost channel with people, when finding the mean, the person’s characteristics are calculated from head to toe. It feels very unreasonable.
    Can you talk about the inspiration for doing this, or explain it, I am confused about it
    H}9 70}RGE3`ZKMCDSO2XI5

预训练模型加载

模型定义的名称和预训练模型的state_dict里的名称不一致导致无法加载。
具体的,原始的定义名称应为:
1、在Class SpatialGate()中,
self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)。
2、在Class TripletAttention()中,
self.cw = SpatialGate()
self.hc = SpatialGate()
self.hw = SpatialGate()
3、在resnet的block定义中,
self.triplet = TripletAttention(planes, 16)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.