Implementing
$SE(3)$ -equivariant Tensor Field Networks1 (TFN) from scratch withFlux.jl
.
This repository is not affiliated with the original paper, and is just an independent implementation.
Rather, the official fully-featured library implementations of this work are the written in PyTorch and JAX: e3nn and e3nn-jax.
Julia implementations of TFN and its derivatives are (as far as I know) non-existent.
This is a first step before I rewrite these equivariant layers in the style of GeometricFlux.jl
, where point clouds will be represented by bona fide graphs rather than arrays as they are here.
For those unfamiliar with TFN, I motivate this architecture at the end of this README.
A disadvantage of TFN is that one must keep track of which rotation representation any given feature vector belongs to. As such, the format for storing feature vectors is inevitably messier than a multidimensional array.
The first choice I have made is to separate positions and feature vectors into a tuple (rrs, Vss)
since they are treated differently in the pipeline.
Raw cartesian positions rs
of points in the point cloud must first be transformed into pairwise separation vectors, then converted to spherical coordinates.
This only needs to be done once, rather than with every forward-pass of the network, so it's more efficient to calculate this outside the network.
Because positions and features are stored as arrays, all point clouds in a batch must have the same number of points.
For example, a valid array of positions could be generated by the following snippet:
rs = rand(Float32, (num_points, 3, batch_size))
rrs_cart = rs |> pairwise_rs # size(rrs_cart) = (num_points, num_points, 3, batch_size)
rrs_sph = rrs_cart |> cart_to_sph # convert [x, y, z] -> [r, θ, ϕ]
Features have a more complicated structure: they are stored in tuple of vectors.
There is one vector for each V
.
When an V
has size (num_points, batch_size, 2ℓ+1)
.
Features compatible with the above positions could be generated by
V11 = ones(Float32, (num_points, batch_size, 3)) # Some ℓ = 1 features
V21 = ones(Float32, (num_points, batch_size, 5)) # Some ℓ = 2 features
Vss = (Vector{typeof(V11)}(undef, 0), [V11], [V21]) # Feature tuple, being careful to use type-stable empty-vector
Some layers make no use of position information, so I have defined a separate gluing structure similar to Flux.jl
's Parallel
structure that holds multiple parallel layers but acts trivially on the first element of the (rrs, Vss)
tuple.
Let's consider the architecture used for the shape classification example in the TFN paper, implemented in /Shape_Classification.ipynb
.
Here, the aim is to classify a bunch of Tetris-like blocks, of which there are 8 distinct types.
The intrinsic rotational invariance of the network means that after being shown just one example of each block, the classifier can be equally confident in recognising the blocks even when they are arbitrarily orientated.
Below is a cartoon of the invariance of the output of the entire pipeline with respect to rotation of the input, a special case of equivariance with a trivial identity representation.
A diagram of the network architecture is shown below.
We keep track of the rotation representations of the feature vectors with
In this repository, this network is implemented with the code block below.
It should hopefully be clear which component corresponds to which.
The most important (and complicated) component is the E3ConvLayer
.
The remaining components are non-linear and self-interaction layers (interfaced through NLWrapper
and SIWrapper
), used to scale feature vectors pointwise and mix channels in an equivariant way.
The number of output and input channels for every layer must be correctly specified at the time of construction.
(Above, the number of channels with a particular representation is given by
# Define centers of radial basis functions
centers = range(0f0, 3.5f0; length=4) |> collect
# `Chain` is the `Flux.jl` constructor for sequential layers
classifier = Chain(
SIWrapper([1 => 4]),
E3ConvLayer([4], [[(0, 0) => [0], (0, 1) => [1]]], centers),
SIWrapper([4 => 4, 4 => 4]),
NLWrapper([4, 4]),
E3ConvLayer([4, 4], [[(0, 0) => [0], (0, 1) => [1]],
[(1, 0) => [1], (1, 1) => [0, 1]]], centers),
SIWrapper([8 => 4, 12 => 4]),
NLWrapper([4, 4]),
E3ConvLayer([4, 4], [[(0, 0) => [0]], [(1, 1) => [0]]], centers),
SIWrapper([8 => 4]),
NLWrapper([4]),
PLayer(),
Dense(4 => 8)
);
Rotational and translational symmetry are common in nature. It is therefore useful to have neural networks that can exploit this simplified structure of many physical problems, without needing to carry around redundant information. That is, it would nice not to have to relearn the same thing over and over again in different coordinate systems. (For example, a naïve approach to approximating symmetric functions is training networks on augmented data, namely data that has been bulked up with transformed copies. However, in that case no promises can be made about equivariance outside the training dataset.) This is solved by making layers equivariant with respect to a symmetry group by construction (reviewed in detail in the Geometric Deep Learning textbook), meaning that the output transforms appropriately when the input is transformed. There are many ways to design such neural network architectures, often with a trade-off between expressivity and computational cost.
The Tensor Field Network (TFN) is built from matrix representations of rotations and acts on point clouds of features. Translation symmetry is trivially upheld by only ever considering the relative displacement between points. The advantage of TFN is that features can be complex physical quantities (and not just scalars), but this expressivity comes with additional cost compared to some alternatives, especially because it essentially acts on an "all-to-all" graph. Some follow-up works such as SE(3)-Transformers are more performant.
A neural network acts on feature vectors, which are sometimes physical quantities.
These quantities can transform differently under rotation depending on their rotation representation, indexed by the non-negative integer
The README in the test
folder outlines the conventions used in this repository for rotation representations of
Footnotes
-
Despite its name, this is completely unrelated "Tensor Networks" used in condensed matter, for which one would use
ITensors.jl
or a related package. ↩