landskape-ai / triplet-attention Goto Github PK
View Code? Open in Web Editor NEWOfficial PyTorch Implementation for "Rotate to Attend: Convolutional Triplet Attention Module." [WACV 2021]
License: MIT License
Official PyTorch Implementation for "Rotate to Attend: Convolutional Triplet Attention Module." [WACV 2021]
License: MIT License
line 293, in parse_model
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
TypeError: init() takes from 1 to 2 positional arguments but 3 were given
您好,我在yolov5中想加入这个注意力,然后报这个错误,请问该如何修改呢?
triplet-attention.py
Hi,Run this file and find inconsistent dimensions
I wonder if there is a problem here
torch.cat( (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
it should be torch.cat( (torch.max(x, 1).unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 )
right?
我这边使用res2net的话,看到您这边没有对应的res2net+triplet_attentation预训练权重,会有影响吗?
Triplet Attention (k = 7), what is the parameter ‘k’ means ?
请问根目录下的triplet_attention.py和model文件下的triplet_attention.py哪个效果更好呢
Thanks for your code!
When I try to add this attention mechanism into the resnet backbone and train a detection network, the error appears as:
File "anaconda3/envs/maskrcnn/lib/python3.7/site-packages/torch/utils/data/_utils/signal_handling.py", line 63, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 13791) is killed by signal: Aborted.
The batch size and num_workers have been set as 1 and it stills goes wrong.
Could you give some suggestions?
Thank you for sharing your work, the interaction of information between different dimensions is very interesting and this is very helpful for me.
When I tried to reproduce Triplet attention on resnet20,32, I had a little problem, the accuracy of my baseline model as well as the model after applying Triplet attention were lower than the results mentioned in your article. So can you share your pre-training weights and experimental setup?
Thanks a lot!
Triplet Attention (k = 7), what is the parameter ‘k’ means ?
in triplet_attention.py
:
class TripletAttention(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(TripletAttention, self).__init__()
self.ChannelGateH = SpatialGate()
self.ChannelGateW = SpatialGate()
self.no_spatial = no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.ChannelGateH(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.ChannelGateW(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.SpatialGate(x)
x_out = (1 / 3) * (x_out + x_out11 + x_out21)
else:
x_out = (1 / 2) * (x_out11 + x_out21)
return x_out
seems like gate_channels
, reduction_ratio
, pool_types
are never used?
When the operation order was changed from cw > hc > hw to hw > cw > hc, performance improved in certain models. Calculations are performed independently of each other. Do you know why this is?
thank you.
Thank the author for sharing your work, which is very helpful to me.
I have a little question after studying your work
Is your attention level added after each block? Is there a theoretical difference between this and only adding at the last level?
Hi,
No matter which definitions that I used for triplet-attention definition, I can't load the pertaining weights.
And could you please let me know which the triplet-attention version that matches the Imagenet Pretraining weights?
model = ResidualNet( 'ImageNet', 18, 1000, 'TripletAttention')
state_dict = torch.load("model_best.pth.tar")['state_dict']. #ResNet-18+Triplet Attention(k =3)
model.load_state_dict(state_dict)
Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num
other double,
k=3 mean that SpatialGate module kernel size=3?
hello, i applied triple_attention on fasterrcnn using mmdetection, the problem is inference speed becomes very slow , about twice slower,why this happens? And where is the code you applied trilet_attention on fasterrcnn using mmdetection?i‘d like to learn some. Thanks~
Hi,I wanna ask that how to calculate the AP value of. I can't find it in the code. And I don't know how to calculate the Recall value of.
Thank you very much.
如题,有没有方法直接一键替换SE和CBAM
triplet_attention模块初始化并不是与SE一样是channels参数
# SE attention
class SE(nn.Module):
def __init__(self, in_channels, channels, se_ratio=12):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, channels // se_ratio, kernel_size=1, padding=0),
nn.BatchNorm2d(channels // se_ratio),
nn.ReLU(inplace=True),
nn.Conv2d(channels // se_ratio, channels, kernel_size=1, padding=0),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.fc(y)
return x * y
模型定义的名称和预训练模型的state_dict里的名称不一致导致无法加载。
具体的,原始的定义名称应为:
1、在Class SpatialGate()中,
self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)。
2、在Class TripletAttention()中,
self.cw = SpatialGate()
self.hc = SpatialGate()
self.hw = SpatialGate()
3、在resnet的block定义中,
self.triplet = TripletAttention(planes, 16)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.