GithubHelp home page GithubHelp logo

adrita78 / gearnet Goto Github PK

View Code? Open in Web Editor NEW

This project forked from deepgraphlearning/gearnet

0.0 0.0 0.0 524 KB

GearNet and Geometric Pretraining Methods for Protein Structure Representation Learning, ICLR'2023 (https://arxiv.org/abs/2203.06125)

License: MIT License

Python 98.25% Dockerfile 1.75%

gearnet's Introduction

GearNet: Geometry-Aware Relational Graph Neural Network

This is the official codebase of the paper

Protein Representation Learning by Geometric Structure Pretraining, ICLR'2023

[ArXiv] [OpenReview]

Zuobai Zhang, Minghao Xu, Arian Jamasb, Vijil Chenthamarakshan, Aurelie Lozano, Payel Das, Jian Tang

and the paper

Enhancing Protein Language Models with Structure-based Encoder and Pre-training, ICLR'2023 MLDD Workshop

[ArXiv] [OpenReview]

Zuobai Zhang, Minghao Xu, Vijil Chenthamarakshan, Aurelie Lozano, Payel Das, Jian Tang

News

  • [2023/10/17] Please check the latest version of the ESM-GearNet paper and code implementation!!

  • [2023/03/14] The code for ESM_GearNet has been released with our latest paper.

  • [2023/02/25] The code for GearNet_Edge_IEConv & Fold3D dataset has been released.

  • [2023/02/01] Our paper has been accepted by ICLR'2023! We have released the pretrained model weights here.

  • [2022/11/20] We add the scheduler in the downstream.py and provide the config file for training GearNet-Edge with single GPU on EC. Now you can reproduce the results in the paper.

Overview

GeomEtry-Aware Relational Graph Neural Network (GearNet) is a simple yet effective structure-based protein encoder. It encodes spatial information by adding different types of sequential or structural edges and then performs relational message passing on protein residue graphs, which can be further enhanced by an edge message passing mechanism. Though conceptually simple, GearNet augmented with edge message passing can achieve very strong performance on several benchmarks in a supervised setting.

GearNet

Five different geometric self-supervised learning methods based on protein structures are further proposed to pretrain the encoder, including Multivew Contrast, Residue Type Prediction, Distance Prediction, Angle Prediction, Dihedral Prediction. Through extensively benchmarking these pretraining techniques on diverse downstream tasks, we set up a solid starting point for pretraining protein structure representations.

SSL

This codebase is based on PyTorch and TorchDrug (TorchProtein). It supports training and inference with multiple GPUs. The documentation and implementation of our methods can be found in the docs of TorchDrug. To adapt our model in your setting, you can follow the step-by-step tutorials in TorchProtein.

Installation

You may install the dependencies via either conda or pip. Generally, GearNet works with Python 3.7/3.8 and PyTorch version >= 1.8.0.

From Conda

conda install torchdrug pytorch=1.8.0 cudatoolkit=11.1 -c milagraph -c pytorch-lts -c pyg -c conda-forge
conda install easydict pyyaml -c conda-forge

From Pip

pip install torch==1.8.0+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torchdrug
pip install easydict pyyaml

Using Docker

First, make sure to setup docker with GPU support (guide).

Next, build docker image

docker build . -t GearNet

Then, after image is built, you can run training commands from within docker with following command

docker run -it -v /path/to/dataset/directory/on/disk:/root/scratch/ --gpus all GearNet bash

Reproduction

Training From Scratch

To reproduce the results of GearNet, use the following command. Alternatively, you may use --gpus null to run GearNet on a CPU. All the datasets will be automatically downloaded in the code. It takes longer time to run the code for the first time due to the preprocessing time of the dataset.

# Run GearNet on the Enzyme Comission dataset with 1 gpu
python script/downstream.py -c config/downstream/EC/gearnet.yaml --gpus [0]

We provide the hyperparameters for each experiment in configuration files. All the configuration files can be found in config/*.yaml.

To run GearNet with multiple GPUs, use the following commands.

# Run GearNet on the Enzyme Comission dataset with 4 gpus
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/downstream/EC/gearnet.yaml --gpus [0,1,2,3]

# Run ESM_GearNet on the Enzyme Comission dataset with 4 gpus
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/downstream/EC/ESM_gearnet.yaml --gpus [0,1,2,3]

# Run GearNet_Edge_IEConv on the Fold3D dataset with 4 gpus
# You need to first install the latest version of torchdrug from source. See https://github.com/DeepGraphLearning/torchdrug.
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/downstream/Fold3D/gearnet_edge_ieconv.yaml --gpus [0,1,2,3]

Pretraining and Finetuning

By default, we will use the AlphaFold Datase for pretraining. To pretrain GearNet-Edge with Multiview Contrast, use the following command. Similar, all the datasets will be automatically downloaded in the code and preprocessed for the first time you run the code.

# Pretrain GearNet-Edge with Multiview Contrast
python script/pretrain.py -c config/pretrain/mc_gearnet_edge.yaml --gpus [0]

# Pretrain ESM_GearNet with Multiview Contrast
python script/pretrain.py -c config/pretrain/mc_esm_gearnet.yaml --gpus [0]

After pretraining, you can load the model weight from the saved checkpoint via the --ckpt argument and then finetune the model on downstream tasks.

# Finetune GearNet-Edge on the Enzyme Commission dataset
python script/downstream.py -c config/downstream/EC/gearnet_edge.yaml --gpus [0] --ckpt <path_to_your_model>

You can find the pretrained model weights here, including those pretrained with Multiview Contrast, Residue Type Prediction, Distance Prediction, Angle Prediction and Dihedral Prediction.

Results

Here are the results of GearNet w/ and w/o pretraining on standard benchmark datasets. All the results are obtained with 4 A100 GPUs (40GB). Note results may be slightly different if the model is trained with 1 GPU and/or a smaller batch size. For EC and GO, the provided config files are for 4 GPUs with batch size 2 on each one. If you run the model on 1 GPU, you should set the batch size as 8. More detailed results are listed in the paper.

Method EC GO-BP GO-MF GO-CC
GearNet 0.730 0.356 0.503 0.414
GearNet-Edge 0.810 0.403 0.580 0.450
Multiview Contrast 0.874 0.490 0.654 0.488
Residue Type Prediction 0.843 0.430 0.604 0.465
Distance Prediction 0.839 0.448 0.616 0.464
Angle Prediction 0.853 0.458 0.625 0.473
Dihedral Prediction 0.859 0.458 0.626 0.465
ESM_GearNet 0.883 0.491 0.677 0.501
ESM_GearNet (Multiview Contrast) 0.894 0.516 0.684 0.5016

Citation

If you find this codebase useful in your research, please cite the following papers.

@inproceedings{zhang2022protein,
  title={Protein representation learning by geometric structure pretraining},
  author={Zhang, Zuobai and Xu, Minghao and Jamasb, Arian and Chenthamarakshan, Vijil and Lozano, Aurelie and Das, Payel and Tang, Jian},
  booktitle={International Conference on Learning Representations},
  year={2023}
}
@article{zhang2023enhancing,
  title={A Systematic Study of Joint Representation Learning on Protein Sequences and Structures},
  author={Zhang, Zuobai and Wang, Chuanrui and Xu, Minghao and Chenthamarakshan, Vijil and Lozano, Aurelie and Das, Payel and Tang, Jian},
  journal={arXiv preprint arXiv:2303.06275},
  year={2023}
}

gearnet's People

Contributors

inc0 avatar oxer11 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.