GithubHelp home page GithubHelp logo

shape mismatch about gearnet HOT 4 CLOSED

deepgraphlearning avatar deepgraphlearning commented on September 17, 2024
shape mismatch

from gearnet.

Comments (4)

Oxer11 avatar Oxer11 commented on September 17, 2024

Hi, could you provide more contexts about the error? Including what command you're running and what dataset and model you're using.

from gearnet.

pearl-rabbit avatar pearl-rabbit commented on September 17, 2024

error infomation:

Traceback (most recent call last):
  File "", line 67, in <module>
    output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/models/", line 99, in forward
    edge_hidden = self.edge_layers[i](line_graph, edge_input)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/layers/", line 91, in forward
    update = self.message_and_aggregate(graph, input)
  File "/home/admin/anaconda3/envs/test/lib/python3.7/site-packages/torchdrug-0.2.0-py3.7.egg/torchdrug/layers/", line 813, in message_and_aggregate
    return update.view(graph.num_node, self.num_relation * self.input_dim)
RuntimeError: shape '[19, 472]' is invalid for input of size 6080

I only loaded one protein, and i provide partial information in 1A0G.pdb.The number of loaded protein residues does not seem to affect the results.
PackedProtein(batch_size=1, num_atoms=[5], num_bonds=[19], num_residues=[5]) torch.Size([5, 3])

ATOM      1  N   GLY A   1      62.683  18.043  31.832  1.00 27.29           N  
ATOM      2  CA  GLY A   1      62.540  19.333  31.113  1.00 26.20           C  
ATOM      3  C   GLY A   1      61.709  20.294  31.930  1.00 26.29           C  
ATOM      4  O   GLY A   1      61.503  20.069  33.122  1.00 26.53           O  
ATOM      5  H1  GLY A   1      63.547  17.860  31.944  1.00 27.29           H  
ATOM      6  H2  GLY A   1      62.287  18.100  32.627  1.00 27.29           H  
ATOM      7  H3  GLY A   1      62.302  17.394  31.357  1.00 27.29           H  
ATOM      8  HA2 GLY A   1      63.415  19.715  30.943  1.00 26.20           H  
ATOM      9  HA3 GLY A   1      62.122  19.186  30.250  1.00 26.20           H  
ATOM     10  N   TYR A   2      61.235  21.352  31.285  1.00 25.57           N  
ATOM     11  CA  TYR A   2      60.403  22.375  31.908  1.00 25.80           C  
ATOM     12  C   TYR A   2      59.040  22.446  31.212  1.00 24.57           C  
ATOM     13  O   TYR A   2      58.920  22.239  29.996  1.00 23.52           O  
ATOM     14  CB  TYR A   2      61.066  23.748  31.808  1.00 27.46           C  
ATOM     15  CG  TYR A   2      62.320  23.894  32.630  1.00 30.74           C  
ATOM     16  CD1 TYR A   2      63.564  23.548  32.104  1.00 31.97           C  
ATOM     17  CD2 TYR A   2      62.265  24.368  33.941  1.00 32.03           C  
ATOM     18  CE1 TYR A   2      64.730  23.662  32.861  1.00 33.56           C  
ATOM     19  CE2 TYR A   2      63.429  24.490  34.713  1.00 34.25           C  
ATOM     20  CZ  TYR A   2      64.659  24.131  34.162  1.00 34.25           C  
ATOM     21  OH  TYR A   2      65.812  24.229  34.910  1.00 36.90           O  
ATOM     22  H   TYR A   2      61.392  21.500  30.452  1.00 25.57           H  
ATOM     23  HA  TYR A   2      60.290  22.135  32.841  1.00 25.80           H  
ATOM     24  HB2 TYR A   2      61.279  23.925  30.878  1.00 27.46           H  
ATOM     25  HB3 TYR A   2      60.429  24.424  32.087  1.00 27.46           H  
ATOM     26  HD1 TYR A   2      63.617  23.235  31.230  1.00 31.97           H  
ATOM     27  HD2 TYR A   2      61.444  24.606  34.308  1.00 32.03           H  
ATOM     28  HE1 TYR A   2      65.551  23.424  32.494  1.00 33.56           H  
ATOM     29  HE2 TYR A   2      63.381  24.808  35.586  1.00 34.25           H  
ATOM     30  HH  TYR A   2      65.626  24.526  35.674  1.00 36.90           H  
ATOM     31  N   THR A   3      58.029  22.784  31.994  1.00 22.76           N  
ATOM     32  CA  THR A   3      56.674  22.908  31.512  1.00 20.04           C  
ATOM     33  C   THR A   3      56.145  24.288  31.854  1.00 20.22           C  
ATOM     34  O   THR A   3      56.566  24.902  32.840  1.00 21.09           O  
ATOM     35  CB  THR A   3      55.813  21.835  32.187  1.00 19.90           C  
ATOM     36  OG1 THR A   3      56.348  20.551  31.868  1.00 18.58           O  
ATOM     37  CG2 THR A   3      54.358  21.891  31.725  1.00 19.23           C  
ATOM     38  H   THR A   3      58.116  22.950  32.833  1.00 22.76           H  
ATOM     39  HA  THR A   3      56.647  22.789  30.550  1.00 20.04           H  
ATOM     40  HB  THR A   3      55.829  21.996  33.143  1.00 19.90           H  
ATOM     41  HG1 THR A   3      57.091  20.455  32.247  1.00 18.58           H  
ATOM     42 HG21 THR A   3      53.849  21.198  32.174  1.00 19.23           H  
ATOM     43 HG22 THR A   3      53.983  22.759  31.941  1.00 19.23           H  
ATOM     44 HG23 THR A   3      54.317  21.751  30.766  1.00 19.23           H  
ATOM     45  N   LEU A   4      55.313  24.822  30.972  1.00 18.82           N  
ATOM     46  CA  LEU A   4      54.661  26.099  31.184  1.00 18.80           C  
ATOM     47  C   LEU A   4      53.361  25.744  31.916  1.00 18.45           C  
ATOM     48  O   LEU A   4      52.540  24.988  31.412  1.00 18.08           O  
ATOM     49  CB  LEU A   4      54.363  26.774  29.843  1.00 19.09           C  
ATOM     50  CG  LEU A   4      53.376  27.937  29.779  1.00 19.90           C  
ATOM     51  CD1 LEU A   4      53.899  29.172  30.510  1.00 19.94           C  
ATOM     52  CD2 LEU A   4      53.136  28.257  28.336  1.00 20.91           C  
ATOM     53  H   LEU A   4      55.110  24.447  30.225  1.00 18.82           H  
ATOM     54  HA  LEU A   4      55.209  26.720  31.688  1.00 18.80           H  
ATOM     55  HB2 LEU A   4      55.207  27.091  29.485  1.00 19.09           H  
ATOM     56  HB3 LEU A   4      54.040  26.087  29.239  1.00 19.09           H  
ATOM     57  HG  LEU A   4      52.552  27.678  30.220  1.00 19.90           H  
ATOM     58 HD11 LEU A   4      53.246  29.886  30.447  1.00 19.94           H  
ATOM     59 HD12 LEU A   4      54.052  28.956  31.443  1.00 19.94           H  
ATOM     60 HD13 LEU A   4      54.732  29.459  30.105  1.00 19.94           H  
ATOM     61 HD21 LEU A   4      52.510  28.995  28.268  1.00 20.91           H  
ATOM     62 HD22 LEU A   4      53.974  28.504  27.915  1.00 20.91           H  
ATOM     63 HD23 LEU A   4      52.768  27.479  27.889  1.00 20.91           H  
ATOM     64  N   TRP A   5      53.244  26.216  33.147  1.00 18.81           N  
ATOM     65  CA  TRP A   5      52.090  25.958  33.974  1.00 19.95           C  
ATOM     66  C   TRP A   5      51.552  27.327  34.334  1.00 20.86           C  
ATOM     67  O   TRP A   5      52.060  27.978  35.250  1.00 19.16           O  
ATOM     68  CB  TRP A   5      52.518  25.197  35.224  1.00 21.36           C  
ATOM     69  CG  TRP A   5      51.379  24.766  36.083  1.00 23.63           C  
ATOM     70  CD1 TRP A   5      50.043  24.813  35.774  1.00 24.43           C  
ATOM     71  CD2 TRP A   5      51.468  24.189  37.391  1.00 25.64           C  
ATOM     72  NE1 TRP A   5      49.305  24.293  36.805  1.00 25.53           N  
ATOM     73  CE2 TRP A   5      50.148  23.904  37.810  1.00 25.61           C  
ATOM     74  CE3 TRP A   5      52.535  23.882  38.250  1.00 25.32           C  
ATOM     75  CZ2 TRP A   5      49.866  23.330  39.050  1.00 27.27           C  
ATOM     76  CZ3 TRP A   5      52.254  23.312  39.483  1.00 27.22           C  
ATOM     77  CH2 TRP A   5      50.928  23.040  39.872  1.00 28.11           C  
ATOM     78  H   TRP A   5      53.844  26.702  33.527  1.00 18.81           H  
ATOM     79  HA  TRP A   5      51.420  25.418  33.526  1.00 19.95           H  
ATOM     80  HB2 TRP A   5      53.026  24.415  34.958  1.00 21.36           H  
ATOM     81  HB3 TRP A   5      53.113  25.758  35.747  1.00 21.36           H  
ATOM     82  HD1 TRP A   5      49.690  25.148  34.982  1.00 24.43           H  
ATOM     83  HE1 TRP A   5      48.448  24.222  36.818  1.00 25.53           H  
ATOM     84  HE3 TRP A   5      53.413  24.057  37.997  1.00 25.32           H  
ATOM     85  HZ2 TRP A   5      48.992  23.150  39.311  1.00 27.27           H  
ATOM     86  HZ3 TRP A   5      52.952  23.106  40.062  1.00 27.22           H  
ATOM     87  HH2 TRP A   5      50.768  22.656  40.704  1.00 28.11           H

model definition:

# protein
protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
_protein = data.Protein.pack([protein])
protein = graph_construction_model(_protein)

# model
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512, 512, 512, 512],
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")
pthfile = 'models/angle_gearnet_edge.pth'
net = torch.load(pthfile)

# written according to the document,
truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])
item = {"graph": protein}
 if transform:
     item = transform(item)
 protein = item['graph']

with torch.no_grad():
    output = gearnet_edge(protein, protein.node_feature.float(), all_loss=None, metric=None)

from gearnet.

Oxer11 avatar Oxer11 commented on September 17, 2024

It seems that the shape of edge feature in your protein is 19*40. Have you tried the following graph_construction_model?

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),

Note that you need to set the edge feature in graph_contruction_model as gearnet to get a feature of 19*59.

from gearnet.

pearl-rabbit avatar pearl-rabbit commented on September 17, 2024

Thank you for patiently answering my question. It has been resolved.

from gearnet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.