GithubHelp home page GithubHelp logo

Comments (2)

ak422 avatar ak422 commented on July 19, 2024

And for backward, it shoud be the node j <= i, so the code should be:
order_mask_backward = torch.einsum('ji, bqi, bpj->bqp',
(1 - torch.triu(torch.ones(mask_size, mask_size, device=device), diagonal=1)),
permutation_matrix_reverse, permutation_matrix_reverse)

from proteinmpnn.

MaoSihong avatar MaoSihong commented on July 19, 2024

您好,我正在学习您的 ProteinMPNN 框架。order_mask_backward在查看您的脚本 protein_mpnn_utils.py 时,我对分别在第 1085 行和第 1086 行中定义和使用的变量感到困惑。

order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1) 

据我理解(根据第1086行,即上面的第二行), order_mask_backward _应该是_记录的张量:对于每个残基,哪些残基在它之前以相反的顺序解码(值为1,意味着这些残基看得到)。在这种情况下,index alongdim=-1dim=-2both 都代表残基的位置,因此order_mask_backward可以通过E_idx(E_idx是一个张量,记录每个残基,哪些残基被识别为邻居,形状为[num_batch, num_residues, num_neighbors])。

但是,据我理解,order_mask_backward上面第一行的定义记录的是,对于解码对(q,p),是否存在对应的残差对(i,j),以i > j为准。如果存在则值为1,否则为0。这里,q和分别p是张量沿dim=-2dim=-1的索引order_mask_backwardij是残差在序列中的位置。

为了澄清,举一个简单的例子如下。

import torch
import torch.nn.functional as F

num = 4 # num of residues
a = torch.Tensor([2,3,0,1]).long() # random decoding order, i.e., a[position_of_residue] = value of decoding order

one_hot_a = F.one_hot(a, num_classes=num).float()
one_hot_a = one_hot_a.unsqueeze(0)
result = torch.einsum('ij, biq, bjp->bqp', (1-torch.triu(torch.ones(num, num))), one_hot_a, one_hot_a) #  given by line 1085
result
tensor([[[0., 0., 1., 1.],
         [1., 0., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 1., 0.]]])

例如,result[0][0][2] = 1,表示存在解码顺序为 (0, 2) 的残差对 (i, j, st i > j)。事实上,残基 2 和残基 0 构成了满足上述条件的一对。这个例子证明我order_mask_backward对第1085行给出的变量的理解可能是正确的。然而,在那种情况下,order_mask_backward并不会在第 1086 行中出现E_idx,因为索引 alongdim=-2并不dim=-1 代表残基的位置。

如下修改 torch.einsum() 中的等式可以解决该问题。

torch.einsum('ji, bqi, bpj->bqp', (1-torch.triu(torch.ones(num, num)), one_hot_a, one_hot_a))
tensor([[[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 0., 1.],
         [1., 1., 0., 0.]]])

在上面的结果中, result[0][0][1] = 1,意味着残差1的解码顺序在残差0之前(顺序相反,即向后解码)(分别为3和2),因此在解码残差0时可以看到残差1。

不知道我的理解对不对,对模型训练结果有没有影响。

in my view, the result[0][0], which equals to tensor([0,0,1,1]), indicates that residue type at idx 2,3 are known while decoding at position 0

from proteinmpnn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.