GithubHelp home page GithubHelp logo

mitkotak / e3nn.c Goto Github PK

View Code? Open in Web Editor NEW

This project forked from teddykoker/e3nn.c

1.0 0.0 0.0 899 KB

Pure C implementation of e3nn

License: MIT License

Python 0.61% C 99.32% Makefile 0.06%

e3nn.c's Introduction

e3nn.c

Pure C implementation of e3nn. Mostly done for pedagogical reasons, but similar code could be used for C/C++ implementations of e3nn-based models for inference or CUDA kernels for faster operations within Python libraries.

Currently the only operations implemented are the tensor product, and spherical harmonics.

Single-thread CPU performance of the tensor product on an Intel i5 Desktop Processor.

Example

// example.c
#include <stdio.h>

#include "e3nn.h"

int main(void){

    // tensor product
    float input1[] = { 0, 1, 2, 3, 4 };
    float input2[] = { 0, 1, 2, 3, 4, 5 };
    float product[30] = { 0 };
    tensor_product("2x0e + 1x1o", input1, 
                   "1x0o + 1x2o", input2, 
                   "2x0o + 2x1e + 1x2e + 2x2o + 1x3e", product);

    printf("product ["); for (int i = 0; i < 30; i++){ printf("%.2f, ", product[i]); } printf("]\n");

    // spherical harmonics
    float sph[9] = { 0 };
    spherical_harmonics("1x0e + 1x1o + 1x2e", 1.0, 2.0, 3.0, sph);

    printf("sph ["); for (int i = 0; i < 9; i++) { printf("%.2f, ", sph[i]); } printf("]\n");

    // linear/self-interaction
    float input3[] = { 0, 1, 2, 3, 4, 5, 6, 7 };
    //                 [  2 x 3 weight  ][  2 x 3 weight  ]
    float weight[] = { 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5 };
    float output[12] = { 0 };
    linear("2x0e + 2x1o", input3, weight,
           "3x0e + 3x1o", output);

    printf("output ["); for (int i = 0; i < 12; i++) { printf("%.2f, ", output[i]); } printf("]\n");
    
    return 0;
}
$ make example && ./example
product [0.00, 0.00, 0.00, 0.00, 0.00, -1.90, 16.65, 14.83, 7.35, -12.57, 0.00, -0.66, 4.08, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 2.00, 3.00, 4.00, 5.00, 9.90, 10.97, 9.27, -1.97, 12.34, 15.59, 12.73, ]
sph [1.00, 0.46, 0.93, 1.39, 0.83, 0.55, -0.16, 1.66, 1.11, ]
output [2.12, 2.83, 3.54, 10.61, 12.73, 14.85, 15.56, 19.09, 22.63, 20.51, 25.46, 30.41, ]

Writes the same values to buffer output as the following Python code:

import jax.numpy as jnp
import e3nn_jax as e3nn

# tensor product
input1 = e3nn.IrrepsArray("2x0e + 1x1o", jnp.arange(5))
input2 = e3nn.IrrepsArray("1x0o + 1x2o", jnp.arange(6))
product = e3nn.tensor_product(input1, input2)
print("product", product.array)

# spherical harmonics
sph = e3nn.spherical_harmonics("1x0e + 1x1o + 1x2e", jnp.array([1, 2, 3]), normalize=True, normalization="component")
print("sph", sph.array)

# linear/self-interaction
input3 = e3nn.IrrepsArray("2x0e + 2x1o", jnp.arange(8))
linear = e3nn.flax.Linear(
    irreps_in="2x0e + 2x1o",
    irreps_out="3x0e + 3x1o",
)
w = {"params": {
    "w[0,0] 2x0e,3x0e": jnp.arange(6, dtype=jnp.float32).reshape(2, 3),
    "w[1,1] 2x1o,3x1o": jnp.arange(6, dtype=jnp.float32).reshape(2, 3),
}}
output = linear.apply(w, input3)
print("output", output)

Usage

See example above and in example.c. Run with

make example
./example

Currently the output irrep must be defined manually. This could be computed on the fly with minimal computational cost, however I am not sure what makes for the best API here. Additionally, only component normalization is currently implemented, and it will not function properly if the output irreps do not match the full simplified output irreps (i.e. no filtering); see Todo.

Benchmarking

python -m ./venv
source venv/bin/activate
pip install -r extra/requirements.txt

make benchmark

e3nn.c contains several tensor product implementations, each with improvements over the previous for faster runtime.

v1

tensor_product_v1 Is a naive implementation that performs the entire tensor product for all Clebsch-Gordan coefficients:

$$(u \otimes v)^{(l)}_m = \sum_{m_1 = -l_1}^{l_1}\sum_{m_2 = -l_2}^{l_2} C^{(l, m)}_{(l_1, m_1)(l_2, m_2)} u^{(l_1)}_{m_1}v^{(l_2)}_{m_2}$$

To minize overhead in the computation of the Clebsch-Gordan coefficients, they are pre-computed up to L_MAX and cached the first time the tensor product is called, creating a one-time startup cost.

v2

The tensor_product_v2 implementation leverages the fact that, even after conversion to the real basis, the Clebsch-Gordan coeffecients are generally sparse, with many entries equal to 0. To take advantage of this, we precompute a data structure that stores only the non-zero entries of $C$ at each $l_1$, $l_2$, $l$ and their corresponding index at $m_1$, $m_2$, $m$. This significantly improves performance by elminating needless operations of iterating through 0 valued coefficients. Just-in-time (JIT) compilers built into JAX and PyTorch are likely able to perform this optimization as well.

v3

tensor_product_v3 forgoes the computation of Clebsch-Gordan coefficients all together, and instead generates C code to compute the partial tensor product at every $l_1$, $l_2$, $l$ combination up to L_MAX. This elimates the need to iterate over any coefficients, allowing each value in the output to be written in a single step. As it as generated at compile time, the C compliler can also make optimizations to ensure the operations are fast. See tp_codegen.py, which generates tp.c, containing all of the tensor product paths.

Todo:

  • Benchmark against e3nn and e3nn-jax
  • Sparse Clebsch-Gordan implementation
  • Implement Spherical Harmonics
  • Implement Linear/Self-interaction operation
  • Implement filter_ir_out and irrep_normalization="norm" for tensor product
  • Full Nequip, Allegro, or ChargE3Net implementation
  • Implement integral, norm, and no normalization for spherical harmonics
  • ...

See also

e3nn.c's People

Contributors

teddykoker avatar mitkotak avatar

Stargazers

Killian Sheriff avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.