GithubHelp home page GithubHelp logo

lucidrains / en-transformer Goto Github PK

View Code? Open in Web Editor NEW
203.0 6.0 28.0 84 KB

Implementation of E(n)-Transformer, which incorporates attention mechanisms into Welling's E(n)-Equivariant Graph Neural Network

License: MIT License

Python 100.00%
artificial-intelligence deep-learning equivariance transformer attention-mechanism

en-transformer's Introduction

E(n)-Equivariant Transformer

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention mechanisms and ideas from transformer architecture.

Update: Used for designing of CDR loops in antibodies!

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,                       # depth
    dim_head = 64,                   # dimension per head
    heads = 8,                       # number of heads
    edge_dim = 4,                    # dimension of edge feature
    neighbors = 64,                  # only do attention between coordinates N nearest neighbors - set to 0 to turn off
    talking_heads = True,            # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
    checkpoint = True,               # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
    use_cross_product = True,        # use cross product vectors (idea by @MattMcPartlon)
    num_global_linear_attn_heads = 4 # if your number of neighbors above is low, you can assign a certain number of attention heads to weakly attend globally to all other nodes through linear attention (https://arxiv.org/abs/1812.01243)
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

mask = torch.ones(1, 1024).bool()

feats, coors = model(feats, coors, edges, mask = mask)  # (1, 1024, 512), (1, 1024, 3)

Letting the network take care of both atomic and bond type embeddings

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,       # number of unique nodes, say atoms
    rel_pos_emb = True,    # set this to true if your sequence is not an unordered set. it will accelerate convergence
    num_edge_tokens = 5,   # number of unique edges, say bond types
    dim = 128,
    edge_dim = 16,
    depth = 3,
    heads = 4,
    dim_head = 32,
    neighbors = 8
)

atoms = torch.randint(0, 10, (1, 16))    # 10 different types of atoms
bonds = torch.randint(0, 5, (1, 16, 16)) # 5 different types of bonds (n x n)
coors = torch.randn(1, 16, 3)            # atomic spatial coordinates

feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)

If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N adjacency matrix.

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    num_tokens = 10,
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    neighbors = 0,
    only_sparse_neighbors = True,    # must be set to true
    num_adj_degrees = 2,             # the number of degrees to derive from 1st degree neighbors passed in
    adj_dim = 8                      # whether to pass the adjacency degree information as an edge embedding
)

atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)

# naively assume a single chain of atoms
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

# adjacency matrix must be passed in
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)

Edges

If you need to pass in continuous edges

import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot

model = EnTransformer(
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    edge_dim = 4,
    num_nearest_neighbors = 0,
    only_sparse_neighbors = True
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)

Example

To run a protein backbone coordinate denoising toy task, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise.py

Todo

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
    title   = {Talking-Heads Attention}, 
    author  = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
    year    = {2020},
    eprint  = {2003.02436},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
    title   = {The Lipschitz Constant of Self-Attention},
    author  = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
    booktitle = {International Conference on Machine Learning},
    year    = {2020},
    url     = {https://api.semanticscholar.org/CorpusID:219530837}
}
@article {Mahajan2023.07.15.549154,
    author  = {Sai Pooja Mahajan and Jeffrey A. Ruffolo and Jeffrey J. Gray},
    title   = {Contextual protein and antibody encodings from equivariant graph transformers},
    elocation-id = {2023.07.15.549154},
    year    = {2023},
    doi     = {10.1101/2023.07.15.549154},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154},
    eprint  = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154.full.pdf},
    journal = {bioRxiv}
}
@article{Bondarenko2023QuantizableTR,
    title   = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
    author  = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.12929},
    url     = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
    title   = {Zoology: Measuring and Improving Recall in Efficient Language Models},
    author  = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:266149332}
}

en-transformer's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

en-transformer's Issues

varying number of nodes

@lucidrains Thank you for your efficient implementation. I was wondering how to use this implementation for the dataset when the number of nodes in each graph is not the same? For example, the datasets of small molecules.

question about the data format

Hi @lucidrains,

I found that the input of the En-Transformer is different from EGNN in your released repository that EGNN using torch_geometric lib to deal with the node and edge. For instance, EGNN-pytorch encode node: [bs_node_number_sum, dim], and edge: [2, bs_edge_number_sum]. So I want to know how to put this type data into En-Transformer?

thanks!

E(n)-transformer with dgl

Thank you for your awesome implementation of graph transformer + EGNN!
Recently I've studied EGNN implemented by DGL. I was wondering whether it is possible to implement E(n)-Equivariant Transformer with DGL.
Thank you!

Incorrect output shape

# taken from https://github.com/lucidrains/En-transformer/blob/main/README.md
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

feats, coors = model(feats, coors, edges, mask = mask) # (1, 16, 512), (1, 16, 3)

As we see the input shape of node features (feats) and coordinates (coors) is [batch_size, num_nodes, num_features]. Num nodes is equal to 1024. however the output num_nodes is equal to 16.

My assumption is that the output shape was by mistake taken from here https://github.com/lucidrains/egnn-pytorch/blob/main/README.md

The varying num_nodes is confusing, can you explain, whether this is just a mistake of num_nodes really changes and why?

Modularity features: (inter)dependence of feats, edges, coors

I am curious about whether it is possible to make the model more modular. I haven't gone through the code and don't know what is possible to do without major changes in the architecture, but I wanted to start the discussion.

Current status:

  • Training on atoms and coors, edges are voluntary
  • Outputting atoms and coors (no edges)

Proposed modularity features:

  • Possibility to output edges
  • Possibility to train on any pair of inputs: a) atoms and coors (already implemented), b) atoms and edges c) coors and edges
  • Possibility to train on any single input

Reasons why & Use cases:

  • Outputting edges might be super-useful for predicting bonds, protein-ligand binding sites etc.
  • Any pair: atoms-edges makes sense when we don't have 3D structures and we are training on "2D" molecular or protein graphs with only atoms and connectivity info. Might be used for de novo prediction of coordinates too.
  • Single input: a) On atoms that would collapse into simple transformer b) on coordinates it could be used for training on 3D point-clouds (and benchmarked with PointNet etc.) c) on edges it could be used for training on graphs.

Edge model/rep

Hi,

Thank you for providing this version of the EnGNN model. This is not really an issue just a query. The original model as implemented here (https://github.com/vgsatorras/egnn) has 3 main steps per layer:
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
I am interested in the edge_feat and was wondering what would be an equivalent edge representation in your implementation. Line 335 in EnTransformer.py: qk = self.edge_mlp(qk) seems like the best candidate.
Thanks,
Pooja

efficient implementation

Hi,
I wonder if relative distances and coordinates can be handled more efficiently using memory efficient attention as in " Self-attention Does Not Need O(n^2) Memory". It is straightforward for the scalar part.

On rotary embeddings

Hi @lucidrains, thank you for your amazing work; big fan! I had a quick question on the usage of this repository.

Based on my understanding, rotary embeddings are a drop-in replacement for the original sinusoidal or learnt PEs in Transformers for sequential data, as in NLP or other temporal applications. If my application is not on sequential data, is there a reason why I should still use rotary embeddings?

E.g. for molecular datasets such as QM9 (from the En-GNNs paper), would it make sense to have rotary embeddings?

Performance drop with checkpointing update

I see a drop in performance (higher loss) when I update checkpointing from checkpoint_sequential(self.layers, 1, inp) to checkpoint_sequential(self.layers, len(self.layers), inp). Is this expected?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.