GithubHelp home page GithubHelp logo

vict0rsch / phast Goto Github PK

View Code? Open in Web Editor NEW
1.0 3.0 1.0 2.75 MB

PyTorch implementation for PhAST: Physics-Aware, Scalable and Task-specific GNNs for Accelerated Catalyst Design

Home Page: https://phast.readthedocs.io/en/latest/

License: MIT License

Python 100.00%

phast's Introduction

💻  Code   •   Docs  📑

Python Documentation Status


PhAST: Physics-Aware, Scalable, and Task-specific GNNs for Accelerated Catalyst Design

This repository contains implementations for 2 of the PhAST components presented in the paper:

  • PhysEmbedding that allows one to create an embedding vector from atomic numbers that is the concatenation of:
    • A learned embedding for the atom's group
    • A learned embedding for the atom's period
    • A fixed or learned embedding from a set of known physical properties, as reported by mendeleev
    • In the case of the OC20 dataset, a learned embedding for the atom's tag (adsorbate, catalyst surface or catalyst sub-surface)
  • Tag-based graph rewiring strategies for the OC20 dataset:
    • remove_tag0_nodes deletes all nodes in the graph associated with a tag 0 and recomputes edges

    • one_supernode_per_graph replaces all tag 0 atoms with a single new atom

    • one_supernode_per_atom_type replaces all tag 0 atoms of a given element with its own super node

Also: https://github.com/vict0rsch/faenet

Installation

pip install phast

⚠️ The above installation does not include torch_geometric which is a complex and very variable dependency you have to install yourself if you want to use the graph re-wiring functions of phast.

☮️ Ignore torch_geometric if you only care about the PhysEmbeddings.

Getting started

Physical embeddings

Embedding illustration

import torch
from phast.embedding import PhysEmbedding

z = torch.randint(1, 85, (3, 12)) # batch of 3 graphs with 12 atoms each
phys_embedding = PhysEmbedding(
    z_emb_size=32, # default
    period_emb_size=32, # default
    group_emb_size=32, # default
    properties_proj_size=32, # default is 0 -> no learned projection
    n_elements=85, # default
)
h = phys_embedding(z) # h.shape = (3, 12, 128)

tags = torch.randint(0, 3, (3, 12))
phys_embedding = PhysEmbedding(
    tag_emb_size=32, # default is 0, this is OC20-specific
    final_proj_size=64, # default is 0, no projection, just the concat. of embeds.
)

h = phys_embedding(z, tags) # h.shape = (3, 12, 64)

# Assuming torch_geometric is installed:
data = torch.load("examples/data/is2re_bs3.pt")
h = phys_embedding(data.atomic_numbers.long(), data.tags) # h.shape = (261, 64)

Graph rewiring

Rewiring illustration

from copy import deepcopy
import torch
from phast.graph_rewiring import (
    remove_tag0_nodes,
    one_supernode_per_graph,
    one_supernode_per_atom_type,
)

data = torch.load("./examples/data/is2re_bs3.pt")  # 3 batched OC20 IS2RE data samples
print(
    "Data initially contains {} graphs, a total of {} atoms and {} edges".format(
        len(data.natoms), data.ptr[-1], len(data.cell_offsets)
    )
)
rewired_data = remove_tag0_nodes(deepcopy(data))
print(
    "Data without tag-0 nodes contains {} graphs, a total of {} atoms and {} edges".format(
        len(rewired_data.natoms), rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
rewired_data = one_supernode_per_graph(deepcopy(data))
print(
    "Data with one super node per graph contains a total of {} atoms and {} edges".format(
        rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
rewired_data = one_supernode_per_atom_type(deepcopy(data))
print(
    "Data with one super node per atom type contains a total of {} atoms and {} edges".format(
        rewired_data.ptr[-1], len(rewired_data.cell_offsets)
    )
)
Data initially contains 3 graphs, a total of 261 atoms and 11596 edges
Data without tag-0 nodes contains 3 graphs, a total of 64 atoms and 1236 edges
Data with one super node per graph contains a total of 67 atoms and 1311 edges
Data with one super node per atom type contains a total of 71 atoms and 1421 edges

Tests

This requires poetry. Make sure to have torch and torch_geometric installed in your environment before you can run the tests. Unfortunately because of CUDA/torch compatibilities, neither torch nor torch_geometric are part of the explicit dependencies and must be installed independently.

git clone [email protected]:vict0rsch/phast.git
poetry install --with dev
pytest --cov=phast --cov-report term-missing

Testing on Macs you may encounter a Library Not Loaded Error

Requires Python <3.12 because

mendeleev (0.14.0) requires Python >=3.8.1,<3.12

phast's People

Contributors

vict0rsch avatar

Stargazers

Xiao Jin avatar

Watchers

 avatar Kostas Georgiou avatar  avatar

Forkers

alexduvalinho

phast's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.