GithubHelp home page GithubHelp logo

chao1224 / moleculesde Goto Github PK

View Code? Open in Web Editor NEW
30.0 4.0 3.0 23.08 MB

A Group Symmetric Stochastic Differential Equation Model for Molecule Multi-modal Pretraining, ICML'23

Home Page: https://chao1224.github.io/MoleculeSDE

License: MIT License

Python 100.00%
diffusion generation geometry molecule pretraining representation stochastic-differential-equation conformation group-equivariant-neural-network sde

moleculesde's Introduction

A Group Symmetric Stochastic Differential Equation Model for Molecule Multi-modal Pretraining

ICML 2023

Shengchao Liu+, Weitao Du+, Zhiming Ma, Hongyu Guo, Jian Tang

+ Equal contribution

[Project Page] [Paper] [ArXiv] [Checkpoints on HuggingFace]

  • MoleculeSDE is GraphMVPv2, follow-up of GraphMVP
  • It includes two components:
    • Contrastive learning
    • Generative learning:
      • One 2D->3D diffusion model. Frame-based SE(3)-equivariant and reflection anti-symmetric model
      • One 3D->2D diffusion model. SE(3)-invariant.

All the pretrained checkpoints are available on this HuggingFace link. You can find detailed mapping between checkpoints and tables in file README_checkpoints.md.

Environments

conda create -n Geom3D python=3.7
conda activate Geom3D
conda install -y -c rdkit rdkit
conda install -y numpy networkx scikit-learn
conda install -y -c conda-forge -c pytorch pytorch=1.9.1
conda install -y -c pyg -c conda-forge pyg=2.0.2
pip install ogb==1.2.1

pip install sympy

pip install ase  # for SchNet

pip intall -e .

Datasets

  • For PCQM4Mv2 (pretraining) dataset
    • Download the dataset from PCQM4Mv2 website under folder data/PCQM4Mv2/raw:
        .
      ├── data
      │   └── PCQM4Mv2
      │       └── raw
      │           ├── data.csv
      │           ├── data.csv.gz
      │           ├── pcqm4m-v2-train.sdf
      │           └── pcqm4m-v2-train.sdf.tar.gz
      
    • Then run examples/generate_PCQM4Mv2.py.
  • For QM9, it is automatically downloaded in pyg class. The default path is data/molecule_datasets/QM9.
  • For MD17, it is automatically downloaded in pyg class. The default path is data/MD17.
  • For MoleculeNet, please follow GraphMVP instructions. The dataset structure is:
      .
    ├── data
    │   ├── molecule_datasets
    │   │   ├── bace
    │   │   │   ├── BACE_README
    │   │   │   └── raw
    │   │   │       └── bace.csv
    │   │   ├── bbbp
    ...............
    

Pretraining

A quick demo on pretraining is:

cd examples

python pretrain_MoleculeSDE.py \
--verbose --input_data_dir=../data --dataset=PCQM4Mv2 \
--model_3d=SchNet \
--lr=1e-4 --epochs=50 --num_workers=0 --batch_size=256 --SSL_masking_ratio=0 --gnn_3d_lr_scale=0.1 --dropout_ratio=0 --graph_pooling=mean --emb_dim=300 --epochs=1 \
--SDE_coeff_contrastive=1 --CL_similarity_metric=EBM_node_dot_prod --T=0.1 --normalize --SDE_coeff_contrastive_skip_epochs=0 \
--SDE_coeff_generative_2Dto3D=1 --SDE_2Dto3D_model=SDEModel2Dto3D_02 --SDE_type_2Dto3D=VE --use_extend_graph \
--SDE_coeff_generative_3Dto2D=1 --SDE_3Dto2D_model=SDEModel3Dto2D_node_adj_dense --SDE_type_3Dto2D=VE --noise_on_one_hot \
--output_model_dir=[MODEL_DIR]

Notice that the [MODEL_DIR] is where you are going to save your models/checkpoints.

Downstream

The downstream scripts can be found under the examples folder. Below we illustrate few simple examples.

  • finetune_MoleculeNet.py:
    python finetune_MoleculeNet.py \
    --dataset=tox21 \
    --input_model_file=[MODEL_DIR]/model_complete.pth
    
  • finetune_QM9.py:
    python finetune_QM9.py \
    --dataset=QM9 --task=gap \
    --model_3d=SchNet \
    --input_model_file=[MODEL_DIR]/model_complete.pth
    
  • finetune_MD17.py:
    python finetune_MD17.py \
    --dataset=MD17 --task=aspirin \
    --model_3d=SchNet \
    --input_model_file=[MODEL_DIR]/model_complete.pth
    

Cite Us

Feel free to cite this work if you find it useful to you!

@inproceedings{liu2023group,
  title={A group symmetric stochastic differential equation model for molecule multi-modal pretraining},
  author={Liu, Shengchao and Du, Weitao and Ma, Zhi-Ming and Guo, Hongyu and Tang, Jian},
  booktitle={International Conference on Machine Learning},
  pages={21497--21526},
  year={2023},
  organization={PMLR}
}

moleculesde's People

Contributors

chao1224 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

moleculesde's Issues

Question regarding molecule property scores

Hi, thanks for the nice paper+code. according to your paper, I think you missed a lot of SOTA papers regarding molecule properties when comparing. I can see that you're comparing pretrained methods with GIN. However, if you refer to the latest paper for instance https://github.com/HIM-AIM/BatmanNet, scores are way more higher than that in your papers.

for instance BBBP is 0.946 for batmannet and pretrained SMILES-BERT goes over 0.959 for AUC-ROC while the data in yours indicate near 0.8 maximum

Maybe i misunderstood how you compared the models, can you help me understand why there is such a huge gap between the scores?? It is clear that for geometric tasks such as QM9, 2D has low performance. But It's hard for me to understand how 2D + 3D considered representation has lower score compared to only 2D in predicting molecular property. maybe it is due to the dataset size(PCQM4Mv2)?

Looking forward for your help and Question

Hi Shengchao,

Thank you very much for your great work!

In your paper, you mentioned that "Yet, after confirming with the authors, certain mismatches exist between the 2D topologies and 3D conformations in Molecule3D."

Could you please explain this more? Do you mean that 2D topologies and 3D conformations can not match for every sample in Molecule3D?

Since I am using Molecule3D, I am really looking forward for your help!

Best,

About the loss of SDE_3Dto2D

I observed that in code SDEModel3Dto2D_node_adj_dense_02, loss is calculated by:
losses_x = torch.square(score_x + z_x) # [B, max_num_nodes, num_class_X] or [B, max_num_nodes, 1]
losses_adj = torch.square(score_adj + z_adj) # [B, max_num_nodes, max_num_nodes]

But in SDE_model_2d_to_3d, the code to calculate loss is
loss_pos = torch.sum((scores - pos_noise) ** 2, -1) # (num_node)

I'm confused why the 3D_to_2D code isn't score - x.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.