GithubHelp home page GithubHelp logo

lucidrains / egnn-pytorch Goto Github PK

View Code? Open in Web Editor NEW
383.0 7.0 62.0 825 KB

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence deep-learning equivariance graph-neural-network

egnn-pytorch's Introduction

** A bug has been discovered with the neighbor selection in the presence of masking. If you ran any experiments prior to 0.1.12 that had masking, please rerun them. ๐Ÿ™ **

EGNN - Pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

A full EGNN network

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    num_positions = 1024,           # unless what you are passing in is an unordered set, set this to the maximum sequence length
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 8,
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

Only attend to sparse neighbors, given to the network as an adjacency matrix.

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

You can also have the network automatically determine the Nth-order neighbors, and pass in an adjacency embedding (depending on the order) to be used as an edge, with two extra keyword arguments

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_adj_degrees = 3,           # fetch up to 3rd degree neighbors
    adj_dim = 8,                   # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
    only_sparse_neighbors = True
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

Edges

If you need to pass in continuous edges

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    edge_dim = 4,
    num_nearest_neighbors = 3
)

feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()

continuous_edges = torch.randn(1, 1024, 1024, 4)

# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)

Stability

The initial architecture for EGNN suffered from instability when there was high number of neighbors. Thankfully, there seems to be two solutions that largely mitigate this.

import torch
from egnn_pytorch import EGNN_Network

net = EGNN_Network(
    num_tokens = 21,
    dim = 32,
    depth = 3,
    num_nearest_neighbors = 32,
    norm_coors = True,              # normalize the relative coordinates
    coor_weights_clamp_value = 2.   # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)

feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3)         # (1, 1024, 3)
mask = torch.ones_like(feats).bool()    # (1, 1024)

feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)

All parameters

import torch
from egnn_pytorch import EGNN

model = EGNN(
    dim = dim,                         # input dimension
    edge_dim = 0,                      # dimension of the edges, if exists, should be > 0
    m_dim = 16,                        # hidden model dimension
    fourier_features = 0,              # number of fourier features for encoding of relative distance - defaults to none as in paper
    num_nearest_neighbors = 0,         # cap the number of neighbors doing message passing by relative distance
    dropout = 0.0,                     # dropout
    norm_feats = False,                # whether to layernorm the features
    norm_coors = False,                # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper    
    update_feats = True,               # whether to update features - you can build a layer that only updates one or the other
    update_coors = True,               # whether ot update coordinates
    only_sparse_neighbors = False,     # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in 
    valid_radius = float('inf'),       # the valid radius each node considers for message passing
    m_pool_method = 'sum',             # whether to mean or sum pool for output node representation
    soft_edges = False,                # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
    coor_weights_clamp_value = None    # clamping of the coordinate updates, again, for stabilization purposes
)

Examples

To run the protein backbone denoising example, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise_sparse.py

Tests

Make sure you have pytorch geometric installed locally

$ python setup.py test

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

egnn-pytorch's People

Contributors

amorehead avatar brennanaba avatar hypnopump avatar jscant avatar lucidrains avatar souramoo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

egnn-pytorch's Issues

Training model with pyg graphs

Hello, I am using this model to train on protein data where each graph has a different number of atoms. So far I see good performance by padding shorter sequences, however I'd like to avoid this as I scale the model to much larger sequences. Is there currently any support for passing in batches of differently sized graphs (as in PyG)?

Few queries on the implementation

Hi - fast work coding these things up, as usual! Looking at the paper and your code, you're not using squared distance for the edge weighting. Is that intentional? Also, it looks like you are adding the old feature vectors to the new ones rather than taking the new vectors directly from the fully connected net - is that also an intentional change from the paper?

Exploding Gradients With 4 Layers

I'm using EGNN with 4 layers (where I also do global attention after each layer), and I'm seeing exploding gradients after 90 epochs or so. I'm using techniques discussed earlier (sparse attention matrix, coor_weights_clamp_value, norm_coors), but I'm not sure if there's anything else I should be doing. I'm also not updating the coordinates, so the fix in the pull request doesn't apply.

About aggregations in EGNN_sparse

Hi, thanks for your great work!

I have a question on how aggregations are computed for node embedding and coordinate embedding. In the paper, the aggregation for node embedding is computed over its neighbors, while the aggregation for coordinate embedding is computed over is computed over all others. However, in EGNN_sparse, I didn't notice such difference in aggregations.

I guess it is because computing all-pair messages for coordinate embedding makes 'sparse' meaningless, but I would like to double-check to see if I get this correctly. So anyway, did you do this intentionally? Or did I miss something?

My appreciation.

Pytorch-Geometric Version Attention code is buggy

First of all - this is a great repo and thank you for this. The pyg version however has some bugs with the attention.

Just a few that I have encountered:

  1. In forward method attention layer is at index -1 not 0 and EGNN layer is index 0 not -1 (which is the opposite in the other implementation).
  2. self.global_tokens init has undefined var dim
  3. Uses GlobalLinearAttention from other implementation although GlobalLinearAttention_Sparse is defined in the file (not sure if this is a bug or on purpose?

I have refactored a lot of the code, but can try and do a PR in a few days

Edge features thrown out

Hi, thanks for this implementation!

I was wondering if the pytorch-geometric implementation of this architecture is throwing the edge features out by mistake, as seen here

if edge_attr is None:
edge_attr = torch.cat([edge_attr, rel_dist], dim=-1)
else:
edge_attr = rel_dist

Or maybe my understanding is wrong?
Cheers,

weird problem with pytorch lightning

I tried a simple code of using EGNN for feature extraction, it works fine. But when I put it together in network training using pytorch lightning, I got the problem as following,

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.

Can someone tell me where the problem is?

Wrong edge_index size hint in class EGNN_Sparse of pyg version

Hi, I found there may be a little mistake. In the input hint of class EGNN_Sparse of pyg version, the size of edge_index is (n_edges, 2). However, it should be (2, n_edges). Otherwise, the distance calculation will be not correct.
""" Inputs: * x: (n_points, d) where d is pos_dims + feat_dims * edge_index: (n_edges, 2) * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. * batch: (n_points,) long tensor. specifies xloud belonging for each point * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. * size: None """

How to use EGNN to update edge properties based on node features and coordinates?

Hello, I am using your EGNN implementation for a project involving protein structure analysis, where I am representing residues as nodes. Currently, I have each residue type as a feature, and each residue also has corresponding coordinates.

What I am trying to achieve is to update all the edge properties based on the residue type and the coordinates. My goal is to adjust the interactions between the residues based on these properties, but I am not sure about the best way to go about this.

I've looked through the current documentation and examples but couldn't find an example or guideline about how to update edge properties in this way. I am not sure how to proceed and would appreciate any guidance you can provide.

Nan Values after stacking multiple layers

Hi Lucid!!

I find that when stacking multiple layers the output from the model rapidly goes to Nan. I suspect it may be related to the weights used for initialization.

Here is a minimal working example:

Make some data:

    import numpy as np
    import torch
    from egnn_pytorch import EGNN
    
    torch.set_default_dtype(torch.double)

    zline = np.arange(0, 2, 0.05)
    xline = np.sin(zline * 2 * np.pi) 
    yline = np.cos(zline * 2 * np.pi)
    points = np.array([xline, yline, zline])
    geom = torch.tensor(points.transpose())[None,:]
    feat = torch.randint(0, 20, (1, geom.shape[1],1))

Make a model:

    class ResEGNN(torch.nn.Module):
        def __init__(self, depth = 2, dims_in = 1):
            super().__init__()
            self.layers = torch.nn.ModuleList([EGNN(dim = dims_in) for i in range(depth)])
        
        def forward(self, geom, feat):
            for layer in self.layers:
                feat, geom = layer(feat, geom)
            return geom

Run model for varying depths:

    for i in range(10):
        model = ResEGNN(depth = i)
        pred = model(geom, feat)
        mean_absolute_value  = torch.abs(pred).mean()
        print("Order of predictions {:.2f}".format(np.log(mean_absolute_value.detach().numpy())))

Output :
Order of predictions -0.29
Order of predictions 0.05
Order of predictions 6.65
Order of predictions 21.38
Order of predictions 78.25
Order of predictions 302.71
Order of predictions 277.38
Order of predictions nan
Order of predictions nan
Order of predictions nan

EGNN_sparse incorrect positional encoding output

Hi, many thanks for the implementation!

I was quickly checking the code for the pytorch geometric implementation of the EGNN_sparse layer, and I noticed that it expects the first 3 columns in the features to be the coordinates. However, in the update method, features and coordinates are passed in the wrong order.

return self.update((hidden_out, coors_out), **update_kwargs)

This may cause problems during learning (think of concatenating several of these layers), as they expect coordinate and feature order to be consistent.

One can reproduce this behaviour in the following snippet:

layer = EGNN_sparse(feats_dim=1, pos_dim=3, m_dim=16, fourier_features=0)

R = rot(*torch.rand(3))
T = torch.randn(1, 1, 3)

feats = torch.randn(16, 1)
coors = torch.randn(16, 3)
x1 = torch.cat([coors, feats], dim=-1)
x2 = torch.cat([(coors @ R + T).squeeze() , feats], dim=-1)
edge_idxs = (torch.rand(2, 20) * 16).long()

out1 = layer(x=x1, edge_index=edge_idxs)
out2 = layer(x=x2, edge_index=edge_idxs)

After fixing the order of these arguments in the update method then the layer behaves as expected (output features are equivariant, and coordinate features are equivariant upon se(3) transformation)

Questions about the EGNN code

Recently, I've tried to read EGNN paper and study your EGNN code.
Actually, I had hard time to understand both paper and code because my major is not computer science.
When studying your code, I realize that the shape of hidden_out and the shape of kwargs["x"] must be same to perform add operation (becaus of residual connection) in the class EGNN_sparse forward method.
How can I increase or decrease the hidden dimension size of x?

I would like to get some advice.

Thanks for your consideration in this regard.

training batch size

Dear authors,

thanks for your great work! I saw your example, which is easy to understand. But I notice that during training, in each iteration, it seems it supports the case where batch-size > 1, but all the graphs have the same adj_mat. do you have better solution for that? thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.