GithubHelp home page GithubHelp logo

mfinzi / equivariant-mlp Goto Github PK

View Code? Open in Web Editor NEW
250.0 9.0 21.0 20.29 MB

A library for programmatically generating equivariant layers through constraint solving

License: MIT License

Python 39.99% Jupyter Notebook 60.01%
equivariance deep-learning

equivariant-mlp's Introduction

logo

A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups

Documentation | Paper | Open In Colab | codecov.io | PyPI version

EMLP is a jax library for the automated construction of equivariant layers in deep learning based on the ICML2021 paper A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups. You can read the documentation here.

What EMLP is great at doing

  • Computing equivariant linear layers between finite dimensional representations. You specify the symmetry group (discrete, continuous, non compact, complex) and the representations (tensors, irreducibles, induced representations, etc), and we will compute the basis of equivariant maps mapping from one to the other.

  • Automatic construction of full equivariant models for small data. E.g. if your inputs and outputs (and intended features) are a small collection of elements like scalars, vectors, tensors, irreps with a total dimension less than 1000, then you will likely be able to use EMLP as a turnkey solution for making the model or atleast function as a strong baseline.

  • As a tool for building larger models, but where EMLP is just one component in a larger system. For example, using EMLP as the convolution kernel in an equivariant PointConv network.

What EMLP is not great at doing

  • An efficient implementation of CNNs, Deep Sets, typical translation + rotation equivariant GCNNs, graph neural networks.

  • Handling large data like images, voxel grids, medium-large graphs, point clouds.

Given the current approach, EMLP can only ever be as fast as an MLP. So if flattening the inputs into a single vector would be too large to train with an MLP, then it will also be too large to train with EMLP.


Showcasing some examples of computing equivariant bases

We provide a type system for representations. With the operators ρᵤ⊗ρᵥ, ρᵤ⊕ρᵥ, ρ* implemented as *,+ and .T build up different representations. The basic building blocks for representations are the base vector representation V and tensor representations T(p,q) = V**p*V.T**q.

For any given matrix group and representation formed in our type system, you can get the equivariant basis with rep.equivariant_basis() or a matrix which projects to that subspace with rep.equivariant_projector().

For example to find all O(1,3) (Lorentz) equivariant linear maps from from a 4-Vector Xᶜ to a rank (2,1) tensor Mᵇᵈₐ, you can run

from emlp.reps import V,T
from emlp.groups import *

G = O13()
Q = (T(1,0)>>T(2,1))(G).equivariant_basis()

or how about equivariant maps from one Rubik's cube to another?

G = RubiksCube()

Q = (V(G)>>V(G)).equivariant_basis()

Using + and * you can put together composite representations (where multiple representations are concatenated together). For example lets find all equivariant linear maps from 5 node features and 2 edge features to 3 global invariants and 1 edge feature of a graph of size n=5:

G=S(5)

repin = 10*T(1)+5*T(2)
repout = 3*T(0)+T(2)
Q = (repin(G)>>repout(G)).equivariant_basis()

From the examples above, there are many different ways of writing a representation like 10*T(1)+5*T(2) which are all equivalent. 10*T(1)+5*T(2) = 10*V+5*V**2 = 5*V*(2+V)

You can even mix and match representations from different groups. For example with the cyclic group ℤ₃, the permutation group 𝕊₄, and the orthogonal group O(3)

rep = 2*V(Z(3))*V(S(4))+V(O(3))**2
Q = (rep>>rep).equivariant_basis()

Outside of these tensor representations, our type system works with any finite dimensional linear representation and you can even build your own bespoke representations following the instructions here.

You can visualize these equivariant bases with vis(repin,repout), such as with the three examples above

Checkout our documentation to see how to use our system and some worked examples.

Simple example of using EMLP as a full equivariant model

Suppose we want to construct a Lorentz equivariant model for particle physics data that takes in the input and output 4-momentum of two particles in a collision, as well as a some metadata about these particles like their charge, and we want to classify the output as belonging to 3 distinct classes of collisions. Since the outputs are simple logits, they should be unchanged by Lorentz transformation, and similarly with the charges.

import emlp
from emlp.reps import T
from emlp.groups import Lorentz
import numpy as np

repin = 4*T(1)+2*T(0) # 4 four vectors and 2 scalars for the charges
repout = 3*T(0) # 3 output logits for the 3 classes of collisions
group = Lorentz()
model = emlp.nn.EMLP(repin,repout,group=group,num_layers=3,ch=384)

x = np.random.randn(32,repin(group).size()) # Create a minibatch of data
y = model(x) # Outputs the 3 class logits

Here we have used the default Objax EMLP, but you can also use our PyTorch, Haiku, or Flax versions of the models. To see more examples, or how to use your own representations or symmetry groups, check out the documentation.

Installation instructions

To install as a package, run

pip install emlp

To run the scripts you will instead need to clone the repo and install it locally which you can do with

git clone https://github.com/mfinzi/equivariant-MLP.git
cd equivariant-MLP
pip install -e .[EXPTS]

Experimental Results from Paper

Assuming you have installed the repo locally, you can run the experiments we described in the paper.

To train the regression models on one of the Inertia, O5Synthetic, or ParticleInteraction datasets found in emlp.datasets.py you can run the script experiments/train_regression.py with command line arguments specifying the dataset, network, and symmetry group. For example to train EMLP with SO(3) equivariance on the Inertia dataset, you can run

python experiments/train_regression.py --dataset Inertia --network EMLP --group "SO(3)"

or to train the MLP baseline you can run

python experiments/train_regression.py --dataset Inertia --network MLP

Other command line arguments such as --aug=True for data augmentation or --ch=512 for number of hidden units and others are available, and you can browse the options and their defaults with python experiments/train_regression.py -h. If no group is specified, EMLP will automatically choose the one matched to the dataset, but you can also go crazy with any of the other groups implemented in groups.py provided the dimensions match the data (e.g. for the 3D inertia dataset you could do --group= "Z(3)" or "DkeR3(3)" but not "Sp(2)" or "SU(5)").

For the dynamical systems modeling experiments you can use the scripts experiments/neuralode.py to train (equivariant) Neural ODEs and experiments/hnn.py to train (equivariant) Hamiltonian Neural Networks.

For the dynamical system task, the Neural ODE and HNN models have special names. EMLPode and MLPode for the Neural ODEs in neuralode.py and EMLPH and MLPH for the HNNs in hnn.py. For example,

python experiments/neuralode.py --network EMLPode --group="O2eR3()"

or

python experiments/hnn.py --network EMLPH --group="DkeR3(6)"

These models are trained to fit a double spring dynamical system. 30s rollouts of the dataset, along with rollout error on these trajectories, and conservation of angular momentum are shown below.

If you find our work helpful, please cite it with

@article{finzi2021emlp,
  title={A Practical Method for Constructing Equivariant Multilayer Perceptrons for Arbitrary Matrix Groups},
  author={Finzi, Marc and Welling, Max and Wilson, Andrew Gordon},
  journal={Arxiv},
  year={2021}
}

equivariant-mlp's People

Contributors

cjrd avatar deepsourcebot avatar gelevlove avatar mfinzi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

equivariant-mlp's Issues

2V>>Scalar representation is all zero?

import jax
import jax.numpy as jnp
from emlp.reps import V,Scalar
from emlp.groups import SO
import numpy as np

W =V(SO(3))
rep = 2*W
P = (rep>>Scalar).equivariant_projector()
applyP = lambda v: P@v

P.to_dense()

DeviceArray([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], dtype=float32)

I am running a toy example, two vector as input, a scalar as output
if I rotate the vector, the output scalar should not change.

But I get P matrix is all zero!

Conceptual questions regarding the library

Hello! This is a wonderful library and I am very excited to start using it. I had a conceptual question though, and I want to make sure I’m thinking about the framework correctly before I start using it for experiments.

In the paper you talk about how this generalizes various equi models used in previous work, but you don’t go into detail about the relationship between the models (or if you did, I didn’t understand it). Are there conditions under which you can guarantee that the network is identical to the architecture proposed in, e.g., Cohen and Welling 2016? It seems intuitive to hope that for regular representations that would be the case, but I don’t know if it is.

The main reason I am asking about this is that I am interested in the training dynamics of equivariant models. Can results obtained on your architecture be assumed to hold for GCNNs? What about other types of equivariant models?

Point Clouds Example

I'm trying to work with molecules and would like a permutation group with an SO(3) group. This is discussed a little bit in the docs, but I'm having trouble getting it to implement correctly. I would like to have 5 features per input point and 5 coordinates and I would to output a single coordinate. An example would be computing dipole moment for a molecule with 5 atoms. I tried writing like this:

# make product group
from emlp.reps import V
from emlp.groups import SO, S
G = SO(3) * S(5)
#     element + coorindates
Vin = V(S(5)) + V(G)
Vout = V(SO(3))
print(Vin.size(), Vout.size())
# make model
model = emlp.nn.EMLP(Vin, Vout, group=G)
input_point = np.random.randn(Vin.size())
model(input_point).shape

but the output from the model is (15,) instead of (3,) like I would expect. Thank you!

Parity for O(3) scalars

Hi! First I'd like to say it's a nice library!

Coming from e3nn, I noticed that the scalar rep of O(3) in EMLP always has even parity, i.e. the sign won't flip under reflection, whereas the e3nn library distinguishes 0-order representation with even and odd parity.

How can I have a scalar representation that reflects the two connected components in O(3), or is it a valid question?

Saving and Loading Objax EMLPs yields slightly different predictions

It appears loading up an EMLP models saved with objax.io.save_var_collection yields slightly different predictions than the original model.

import emlp
from emlp.groups import SO
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
print(net(x).T)

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

Output: [0.4905533]

import emlp
from emlp.groups import S
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())
    
print(net(x).T)

Output [0.4904544]

Subdirect (pairing) product - Finding the equivariant basis for symmetry pairs

Hi,

I am building an equivariant architecture where the input and output spaces of one of the layers have symmetry groups G1 and G2 where G1, G2 ⊆ O(n) and |G1| == |G2|. The key property of my application is that equivariance is only desired between pairs of input/output group actions, rather than between a direct product of all input/output actions.

That is, for the linear mapping U1 -> U2 the symmetry constraints should be built only from unique pairs of G1 and G2, where each group action is associated only with one pair. By avoiding the Group Direct Product, we achieve a lower number of constraints and a larger number of resultant equivariant bases (trainable parameters).

I have played quite a bit with the internals of EMLP, but I haven't found yet a nice way of introducing this subproduct. From Eq. 8 on the paper, I realize that this should not be difficult, however, I have been unable to familiarize myself with your code structure.

I will try to solve this in the coming days, and if desired I can submit a pull request. For the time being, if you have any advice or insight it would be more than helpful.

Linear projected weight changes device - EMLP in Pytorch

I have the issue that after projecting the weight matrix is changes device from cuda to cpu. Below is the main components of the model if you can spot anything I am doing wrong that is causing this.

Imports and EMLP block in torch

from emlp.groups import S
from emlp.reps import V
import emlp
from torch.nn import Module
import emlp.nn.pytorch as emlp_torch

class EMLPBlock(Module):
    """ Basic building block of EMLP consisting of G-Linear, biLinear,
        and gated nonlinearity. """
    def __init__(self,rep_in,rep_out):
        super().__init__()
        rep_out_wgates = emlp_torch.gated(rep_out)
        self.linear = emlp_torch.Linear(rep_in,rep_out_wgates)
        self.bilinear = emlp_torch.BiLinear(rep_out_wgates,rep_out_wgates)
        self.nonlinearity = emlp_torch.GatedNonlinearity(rep_out)
    def __call__(self,x):
        print(f'linear weight device : {self.linear.weight.device}')
        print(f'linear weight proj device : {self.linear.proj_w(self.linear.weight).device}')
        lin = self.linear(x)
        preact =self.bilinear(lin)+lin
        return self.nonlinearity(preact)

In the model init

rin_2 = 10*V(S(2))**2
rout_2 = 20*V(S(2))**2
rin_3 = 10*V(S(3))**2
rout_3 = 20*V(S(3))**2
rin_4 = 10*V(S(4))**2
rout_4 = 20*V(S(4))**2
rin_5 = 10*V(S(5))**2
rout_5 = 20*V(S(5))**2

print(f'rep in layer 1 S(2) : {rin_2}')
print(f'rep in layer 1 S(2) : {rout_2}')
print(f'rep in layer 1 S(3) : {rin_3}')
print(f'rep in layer 1 S(3) : {rout_3}')
print(f'rep in layer 1 S(4) : {rin_4}')
print(f'rep in layer 1 S(4) : {rout_4}')
print(f'rep in layer 1 S(5) : {rin_5}')
print(f'rep in layer 1 S(5) : {rout_5}')

self.eqblock1_2 = EMLPBlock(rin_2,rout_2)
self.eqblock1_3 = EMLPBlock(rin_3,rout_3)
self.eqblock1_4 = EMLPBlock(rin_4,rout_4)
self.eqblock1_5 = EMLPBlock(rin_5,rout_5)

In the forward of the model

adj_2 = self.eqblock1_2(adj_2)

The input adj2 is on device cuda and I have model.to('cuda') before calling the forward part of the model.

Also when I check the print for the linear weight in the EMLP block is on device cuda before being projected, but after the projection is on device CPU.

Do you know why the projection moves the wights from cuda to cpu?

Many thanks,
Josh

rotation equivariant MLP for 2d images

Hello Marc, thanks for your work!

Using your library, I want to implement a rotation equivariant MLP and apply it for roto-MNIST classification, where I rotate 2D digits and ravel them into a vector before feeding them into the network.

Can you give an advice of how to implement EMLP model for such a problem?

Would it be:
model = nn.EMLP(repin,repout,G)
with repin = 2828Scalar ; repout = 10*Scalar ; G=SO(2)?

Thanks.

Regression in Example

I have an example in my book using emlp with the following syntax:

from emlp.reps import V
from emlp.groups import SO, S
# make product group
G = SO(3) * S(5)
# direct sum of  element + coorindates
Vin = V(S(5)) + V(G)
Vout = V(G)
print(Vin.size(), Vout.size())
# make model
model = emlp.nn.EMLP(Vin, Vout, group=G)
input_point = np.random.randn(Vin.size())
model(input_point)

This previously worked and we had discussed it a bit for modifying output in #10. Now in version 1.0.3, this code no longer executes - it gives the following error:

TypeError: Sequential layer[0] <emlp.nn.objax.EMLPBlock object at 0x7ff44436e5e0> dot_general requires contracting dimensions to have the same shape, got [20] and [15].

I was wondering if I need to update the syntax or if there is a bug. Thanks!

question about figure 2 in the paper

Hi, I try to draw figure 2 in the paper( for example, Convolutions in section 4.3). My code is written in Mathematica:

C = {{0, 0, 0, 1}, {1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}              # rho(h), permuation matrix 
{u, sigma, q} = SingularValueDecomposition[(C - IdentityMatrix[4])]       #  SVD(rho(h) - I)
MatrixPlot[q]

The result is very different from your paper. Any mistakes I make here? Thank you very much!

Breaking the Equivariant when using the haiku.module class

Hello, it's a great project!
I tried to use the EMLP with dm-haiku, and I write two version of codes in different ways. The first is directly using the emlp.nn.haiku, and the second is using the haiku.module class. But I found that the second version will break the equivariance of neural network. In my view, the two versions have no difference in the architecture of neural network. Could you tell me if something wrong with my way of using EMLP?

Attached is the codes. The first is:

import emlp.nn.haiku as ehk
import haiku as hk
from emlp.reps import V
from emlp.groups import SO
from jax import random
import jax.numpy as jnp

n = 10
dim = 3

G = SO(dim)

rep_in = n*V(G)
rep_out = n*V(G)

model = ehk.EMLP(rep_in, rep_out, group=G, num_layers=2, ch=256)
net = hk.without_apply_rng(hk.transform(model))

key = random.PRNGKey(0)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)

v = net.apply(params, x)

g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)

v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)

and the second is:

import emlp.nn.haiku as ehk
from emlp.reps import V
from emlp.groups import SO
import haiku as hk
from jax import random
import jax.numpy as jnp


class test_EMLP(hk.Module):
    def __init__(self, n, dim, group, num_layers, ch, name=None):
      super().__init__(name=name)
      self.n = n
      self.dim = dim
      self.group = group(dim)
      self.rep_in = self.n*V(self.group)
      self.rep_out = self.n*V(self.group)
      self.num_layers = num_layers
      self.ch = ch

      self.e_mlp =self.e_mlp()
    
    def e_mlp(self):
        return ehk.EMLP(self.rep_in,
                        self.rep_out, 
                        group=self.group, 
                        num_layers=self.num_layers, 
                        ch=self.ch)

    def __call__(self, x):
       return self.e_mlp(x)


def forward_fn(x):
    model = test_EMLP(n=10, dim=3, group=SO, num_layers=2, ch=256)
    return model(x)

net = hk.without_apply_rng(hk.transform(forward_fn))

n = 10
dim = 3
G = SO(dim)
rep_in = n*V(G)
rep_out = n*V(G)


key = random.PRNGKey(1)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)

v = net.apply(params, x)

g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)

v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)

olive oil dependencies

olive-oil-ml depends on pytorch, torchvision, and sklearn which makes the emlp install download about 1GB of packages even though it's a small package. Is there anyway to reduce the dependencies in olive oil, like by adding optional dependencies? Thanks!

Bilinear layer randomness

First, thanks for the great work! This is really helpful in several ways.

While playing with your code, I encountered random behaviors of emlp, and figured that it is caused by the bilinear layer. I wish I have checked the issue #8 down below, Saving and Loading Objax EMLPs yields slightly different predictions, before trying to identify it myself. Two things:

  1. It was suggested to use the same numpy random seed as a workaround. But, I'm checking if there is another way to resolve this, such as saving and loading additional parameters from the bilinear layer.
  2. In fact, the only part in your paper and code that is unclear to me is the bilinear layer. I do not understand why there is randomness in the bilinear layer, if it is presumably calculating something like x^T A x + b x + c with projections. It would be really helpful to understand what it is, if the mathematical expressions for your bilinear layer is provided. Thanks.

Paper results error

None of the experiments from the paper will run for me, and give

2021-05-16 03:52:49.141806: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:226] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2021-05-16 03:52:49.141858: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms) 

I can write my own code using the package without a problem, but your experiments won't work.

Dataset

Hello,

Thanks a lot for the nice code and documents. This is probably a naive question. I am trying to use your code to generate some data of Hamiltonian systems. I am a bit confused about the meaning of the arguments in

SHO(n_systems=300, chunk_len=10, dt=0.2, integration_time=300, regen=True).Zs

gives a np.array of size (300, 10, 2).
I was wondering is n_systems stands for the number of particles? What is chunk_len?
I expect the there is one dimension (for time) is of size 1500.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.