GithubHelp home page GithubHelp logo

Comments (3)

baraldilorenzo avatar baraldilorenzo commented on July 30, 2024 1

Hi @fawazsammani,

thanks for your interest in this work! The lines of code you are citing are basically implementing Eq. 6 and 8 in the paper (i.e. the cross-attention in which each decoder layer takes all encoder layer). Being this a cross-attention, queries always come from the decoder, while keys/values are from the encoder (see Eq. 7).
Regarding the 1-to-1 experiment: that's not included in the code, but could be implemented by setting enc_att = enc_attx (where x is the index of the current decoding layer) here and eventually removing all unnecessary lines.
Finally, yes, we tested with a plain sum but found that it leads to worse results than using the weighted sum.

If you have other doubts do no hesitate to contact!

Best,
Lorenzo.

from meshed-memory-transformer.

fawazsammani avatar fawazsammani commented on July 30, 2024

Hi @baraldilorenzo. Thanks for your reply. I just want to ask regarding the meshed transformer that significantly improves the result. I think it should be a default step to do in every transformer. For your meshed decoder, are there other important things to make it work? I am using harvardnlp code for transformer, i added the meshed decoder part but the performance is worse compared to the base pure transformer. Are there specific hyperparameters you set to make it perform as expected (other than 3 layers for transformer and warmup with 10000? You can have a look at the snippet below:

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        self.fc_alpha1 = nn.Linear(d_model * 2, d_model)
        self.fc_alpha2 = nn.Linear(d_model * 2, d_model)
        self.fc_alpha3 = nn.Linear(d_model * 2, d_model)
 
    def forward(self, x, memory, src_mask, tgt_mask):
        # memory of shape:  (batch_size, num_layers, num_boxes, d_model)
        query = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        enc_att1 = self.src_attn(query, memory[:, 0], memory[:, 0], src_mask)
        enc_att2 = self.src_attn(query, memory[:, 1], memory[:, 1], src_mask)
        enc_att3 = self.src_attn(query, memory[:, 2], memory[:, 2], src_mask)
        alpha1 = torch.sigmoid(self.fc_alpha1(torch.cat([query, enc_att1], -1)))
        alpha2 = torch.sigmoid(self.fc_alpha2(torch.cat([query, enc_att2], -1)))
        alpha3 = torch.sigmoid(self.fc_alpha3(torch.cat([query, enc_att3], -1)))
        sum_src_attn = (enc_att1 * alpha1 + enc_att2 * alpha2 + enc_att3 * alpha3) / np.sqrt(3)
        out_sum = self.sublayer[1](sum_src_attn, None)
        return self.sublayer[2](out_sum, self.feed_forward)

where the sublayer performs: self.norm(x + self.dropout(sublayer(x)))

Thanks a lot for your help and your brilliant work

from meshed-memory-transformer.

HN123-123 avatar HN123-123 commented on July 30, 2024

Sorry to bother you, but when i want to use meshed decoder in harvardnlp code for transformer, I found that the memoty'shape is [batch_size, num_boxes, d_model] not contain the num_layers. I want to know how do you get the num_layers in memory.
Much thanks for your reply!!!

from meshed-memory-transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.