It is a fork of NVIDIA's SE(3)-Transformer implementation. I made some minor modifications, including
- removal of torch.cuda.nvtx.nvtx_range
- addition of the
nonlinearity
argument toNormSE3
,SE3Transformer
, and so on. - addition of some basic network implementations using SE(3)-Transformer.
pip install git+http://github.com/huhlim/SE3Transformer
- Install DGL library with CUDA support
# This is an example with cudatoolkit=11.3.
# Set a proper cudatoolkit version that is compatible with your CUDA drivier and DGL library.
conda install dgl -c dglteam/label/cu113
# or
pip install dgl -f https://data.dgl.ai/wheels/cu113/repo.html
- Install this package
pip install git+http://github.com/huhlim/SE3Transformer
se3_transformer.LinearModule
:LinearSE3
andNormSE3
SE3Transformer/se3_transformer/snippets.py
Lines 14 to 64 in b74f707
se3_transformer.InteractionModule
: A wrapper of SE3TransformerSE3Transformer/se3_transformer/snippets.py
Lines 67 to 118 in b74f707
- LinearModule + InteractionModule
SE3Transformer/example/example.py
Lines 1 to 84 in b74f707
- A fully connected graph is created with random coordinates
- Input features: 8 scalars and 4 vectors
- Output features: 2 scalars and 1 vector
- LinearModule: two
LinearSE3
withNormSE3
, returns 16 scalars and 8 vectors. - InteractionModule: two layers of attention blocks with two heads, takes the output of the LinearModule as
node_feats
and noedge_feats
.