Comments (3)
Hi @fawazsammani,
thanks for your interest in this work! The lines of code you are citing are basically implementing Eq. 6 and 8 in the paper (i.e. the cross-attention in which each decoder layer takes all encoder layer). Being this a cross-attention, queries always come from the decoder, while keys/values are from the encoder (see Eq. 7).
Regarding the 1-to-1 experiment: that's not included in the code, but could be implemented by setting enc_att = enc_attx
(where x
is the index of the current decoding layer) here and eventually removing all unnecessary lines.
Finally, yes, we tested with a plain sum but found that it leads to worse results than using the weighted sum.
If you have other doubts do no hesitate to contact!
Best,
Lorenzo.
from meshed-memory-transformer.
Hi @baraldilorenzo. Thanks for your reply. I just want to ask regarding the meshed transformer that significantly improves the result. I think it should be a default step to do in every transformer. For your meshed decoder, are there other important things to make it work? I am using harvardnlp code for transformer, i added the meshed decoder part but the performance is worse compared to the base pure transformer. Are there specific hyperparameters you set to make it perform as expected (other than 3 layers for transformer and warmup with 10000? You can have a look at the snippet below:
class DecoderLayer(nn.Module):
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.sublayer = clones(SublayerConnection(size, dropout), 3)
self.fc_alpha1 = nn.Linear(d_model * 2, d_model)
self.fc_alpha2 = nn.Linear(d_model * 2, d_model)
self.fc_alpha3 = nn.Linear(d_model * 2, d_model)
def forward(self, x, memory, src_mask, tgt_mask):
# memory of shape: (batch_size, num_layers, num_boxes, d_model)
query = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
enc_att1 = self.src_attn(query, memory[:, 0], memory[:, 0], src_mask)
enc_att2 = self.src_attn(query, memory[:, 1], memory[:, 1], src_mask)
enc_att3 = self.src_attn(query, memory[:, 2], memory[:, 2], src_mask)
alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([query, enc_att1], -1)))
alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([query, enc_att2], -1)))
alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([query, enc_att3], -1)))
sum_src_attn = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
out_sum = self.sublayer[1](sum_src_attn, None)
return self.sublayer[2](out_sum, self.feed_forward)
where the sublayer performs: self.norm(x + self.dropout(sublayer(x)))
Thanks a lot for your help and your brilliant work
from meshed-memory-transformer.
Sorry to bother you, but when i want to use meshed decoder in harvardnlp code for transformer, I found that the memoty'shape is [batch_size, num_boxes, d_model] not contain the num_layers. I want to know how do you get the num_layers in memory.
Much thanks for your reply!!!
from meshed-memory-transformer.
Related Issues (20)
- Problems during operation
- memory
- About Rl training
- RuntimeError: gather(): Expected dtype int64 for index, in beam_search/beam_search.py, line 26, in fn HOT 1
- FileNotFoundError: [Errno 2] No such file or directory: 'java' HOT 2
- About Online Evaluation
- no file found
- FileNotFoundError: [WinError 2] 系统找不到指定的文件。 HOT 6
- Ensemble problem
- Vocabulary of the test split
- Test on Custom Dataset HOT 1
- Beam search error HOT 3
- coco_detections.hdf5
- OSError: [Errno 24] Too many open files
- Hello, I would like to ask myself where is the description of the output generated by the model? HOT 1
- incomplete captions generated! HOT 1
- Features extracted from the flickr30k dataset
- "Error when running model: TypeError: 'generator' object is not callable" HOT 1
- OSError: [Errno 22] Invalid argument
- 关于cider得分的问题
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from meshed-memory-transformer.