GithubHelp home page GithubHelp logo

m-e-r-c-u-r-y / pytorch-transformers Goto Github PK

View Code? Open in Web Editor NEW
14.0 1.0 0.0 880 KB

Collection of different types of transformers for learning purposes

License: MIT License

Jupyter Notebook 100.00%
pytorch transformers multi-head-attention multi-query-attention einsum-notation

pytorch-transformers's People

Contributors

m-e-r-c-u-r-y avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

pytorch-transformers's Issues

question about encoder mask

I have a little question. Why is the attention of the output encoder only mask on the columns?
Here is my simple code:

import torch

class Encoder(torch.nn.Module):
    
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, dropout, device, max_length):
        
        super(Encoder, self).__init__()
        
        self.device = device
        self.tok_embedding = torch.nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = torch.nn.Embedding(max_length, hid_dim)
        self.layers = torch.nn.ModuleList([EncoderLayer(hid_dim, n_heads, dropout, device) for _ in range(n_layers)])
        self.dropout = torch.nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim]))
        
    def forward(self, src, src_mask):
        
        # src = [batch size, src len]
        # src_mask = [batch size, src len]
        
        batch_size = src.shape[0]
        src_len = src.shape[1]
        pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        
        # pos = [batch size, src len]
        
        src = self.dropout((self.tok_embedding(src) * self.scale.to(self.device)) + self.pos_embedding(pos))
        
        # src = [batch size, src len, hid dim]
        
        for layer in self.layers:
            src, attention = layer(src, src_mask)
        
        # src = [batch size, src len, hid dim]
        
        return src, attention

class EncoderLayer(torch.nn.Module):
    
    def __init__(self, hid_dim, n_heads, dropout, device):
        
        super(EncoderLayer, self).__init__()
        
        self.self_attn_layer_norm = torch.nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, src, src_mask):
        
        # src = [batch size, src len, hid dim]
        # src_mask = [batch size, src len]
                
        # self attention
        _src, attention = self.self_attention(src, src, src, src_mask)
        
        # dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))
        
        return src, attention

class MultiHeadAttentionLayer(torch.nn.Module):
    
    def __init__(self, hid_dim, n_heads, dropout, device):
        
        super(MultiHeadAttentionLayer, self).__init__()
        
        assert hid_dim % n_heads == 0
        
        self.device = device
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.fc_q = torch.nn.Linear(hid_dim, hid_dim)
        self.fc_k = torch.nn.Linear(hid_dim, hid_dim)
        self.fc_v = torch.nn.Linear(hid_dim, hid_dim)
        self.fc_o = torch.nn.Linear(hid_dim, hid_dim)
        self.dropout = torch.nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
        
    def forward(self, query, key, value, mask = None):
        
        batch_size = query.shape[0]
        
        # query = [batch size, query len, hid dim]
        # key = [batch size, key len, hid dim]
        # value = [batch size, value len, hid dim]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        # Q = [batch size, query len, hid dim]
        # K = [batch size, key len, hid dim]
        # V = [batch size, value len, hid dim]
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Q = [batch size, n heads, query len, head dim]
        # K = [batch size, n heads, key len, head dim]
        # V = [batch size, n heads, value len, head dim]
        
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale.to(self.device)
        
        # energy = [batch size, n heads, query len, key len]
        
        if mask is not None:
            energy.masked_fill_(mask == 0, -1e10)
        
        attention = torch.softmax(energy, dim = -1)
        
        # attention = [batch size, n heads, query len, key len]
        
        x = torch.matmul(self.dropout(attention), V)
        
        # x = [batch size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()
        
        # x = [batch size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        
        # x = [batch size, query len, hid dim]
        
        x = self.fc_o(x)
        
        # x = [batch size, query len, hid dim]
        
        return x, attention

class TransformerModel(torch.nn.Module):
    
    def __init__(self, encoder, src_pad_idx, device):
        
        super(TransformerModel, self).__init__()
        
        self.encoder = encoder
        self.src_pad_idx = src_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        
        # src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        # src_mask = [batch size, 1, 1, src len]
        
        return src_mask
    def forward(self, src):
        
        # src = [batch size, src len]
        # trg = [batch size, trg len]
        
        src_mask = self.make_src_mask(src.to(self.device))
        # src_mask = [batch size, 1, 1, src len]
        # trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src, attention = self.encoder(src, src_mask)
        
        return attention



model = TransformerModel(encoder = Encoder(input_dim = 7,
                                           hid_dim = 64,
                                           n_layers = 1,
                                           n_heads = 2,
                                           dropout = 0.1,
                                           device = 'cpu',
                                           max_length = 5),
                         src_pad_idx = 0,
                         device = 'cpu').to('cpu')
a = torch.tensor([[0,2,3,5,6],[1,5,4,0,0]])
model(a)

The output is:

tensor([[[[0.0000e+00, 3.7932e-07, 3.8124e-06, 1.0000e+00, 1.3747e-13],
          [0.0000e+00, 5.6091e-01, 7.3318e-11, 2.5095e-07, 4.3909e-01],
          [0.0000e+00, 6.7948e-17, 7.8028e-05, 6.2536e-01, 3.7456e-01],
          [0.0000e+00, 4.3038e-12, 2.1656e-12, 1.0000e+00, 7.0719e-12],
          [0.0000e+00, 2.6063e-07, 1.5446e-15, 9.9998e-01, 2.1456e-05]],

         [[0.0000e+00, 4.9789e-08, 7.9109e-01, 2.0891e-01, 1.2142e-07],
          [0.0000e+00, 1.6303e-07, 4.0151e-17, 1.0000e+00, 2.5143e-24],
          [0.0000e+00, 3.7450e-13, 1.9566e-04, 1.3678e-06, 9.9980e-01],
          [0.0000e+00, 2.0732e-01, 7.9268e-01, 3.0749e-12, 4.6080e-20],
          [0.0000e+00, 2.9789e-15, 7.2438e-16, 1.0000e+00, 4.2854e-06]]],


        [[[1.0000e+00, 2.2995e-07, 1.5590e-12, 0.0000e+00, 0.0000e+00],
          [7.3918e-05, 9.9993e-01, 4.2441e-11, 0.0000e+00, 0.0000e+00],
          [9.9459e-01, 5.2655e-03, 1.4139e-04, 0.0000e+00, 0.0000e+00],
          [5.3461e-21, 3.1996e-03, 9.9680e-01, 0.0000e+00, 0.0000e+00],
          [5.7396e-22, 3.0495e-07, 1.0000e+00, 0.0000e+00, 0.0000e+00]],

         [[1.0733e-06, 1.0000e+00, 1.8439e-27, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 1.2462e-19, 3.5610e-20, 0.0000e+00, 0.0000e+00],
          [5.8192e-05, 9.9994e-01, 4.8438e-15, 0.0000e+00, 0.0000e+00],
          [9.1263e-04, 1.6909e-01, 8.2999e-01, 0.0000e+00, 0.0000e+00],
          [9.8007e-01, 2.1476e-06, 1.9926e-02, 0.0000e+00, 0.0000e+00]]]],
       grad_fn=<SoftmaxBackward>)

I assume the padding_ind = 0. I think wo should mask pad in rows and columns.
I think the correct output like this:

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 5.6091e-01, 7.3318e-11, 2.5095e-07, 4.3909e-01],
          [0.0000e+00, 6.7948e-17, 7.8028e-05, 6.2536e-01, 3.7456e-01],
          [0.0000e+00, 4.3038e-12, 2.1656e-12, 1.0000e+00, 7.0719e-12],
          [0.0000e+00, 2.6063e-07, 1.5446e-15, 9.9998e-01, 2.1456e-05]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.6303e-07, 4.0151e-17, 1.0000e+00, 2.5143e-24],
          [0.0000e+00, 3.7450e-13, 1.9566e-04, 1.3678e-06, 9.9980e-01],
          [0.0000e+00, 2.0732e-01, 7.9268e-01, 3.0749e-12, 4.6080e-20],
          [0.0000e+00, 2.9789e-15, 7.2438e-16, 1.0000e+00, 4.2854e-06]]],


        [[[1.0000e+00, 2.2995e-07, 1.5590e-12, 0.0000e+00, 0.0000e+00],
          [7.3918e-05, 9.9993e-01, 4.2441e-11, 0.0000e+00, 0.0000e+00],
          [9.9459e-01, 5.2655e-03, 1.4139e-04, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

         [[1.0733e-06, 1.0000e+00, 1.8439e-27, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 1.2462e-19, 3.5610e-20, 0.0000e+00, 0.0000e+00],
          [5.8192e-05, 9.9994e-01, 4.8438e-15, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]]],
       grad_fn=<SoftmaxBackward>)

Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.