LieTransformer: Equivariant Self-Attention for Lie Groups

Overview

LieTransformer

This repository contains the implementation of the LieTransformer used for experiments in the paper

LieTransformer: Equivariant Self-Attention for Lie Groups

by Michael Hutchinson*, Charline Le Lan*, Sheheryar Zaidi*, Emilien Dupont, Yee Whye Teh and Hyunjik Kim

* Equal contribution.

Pattern recognition Molecular property prediction Particle Dynamics
Constellations Rotating molecule Particle trajectories

Introduction

LieTransformer is a equivariant Transformer-like model, built out of equivariant self attention layers (LieSelfAttention). The model can be made equivariant to any Lie group, simply by providing and implementation of the group of interest. A number of commonly used groups are already implemented, building off the work of LieConv. Switching group equivariance requires no change to model architecture, only passsing a different group to the model.

Architecture

The overall architecture of the LieTransformer is similar to the architecture of the original Transformer, interleaving series of attention layers and pointwise MLPs in residual blocks. The architecture of the LieSelfAttention blocks differs however, and can be seen below. For more details, please see the paper.

model diagram

Installation

To repoduce the experiments in this library, first clone the repo via git clone [email protected]:oxcsml/eqv_transformer.git. To install the dependencies and create a virtual environment, execute setup_virtualenv.sh. Alternatively you can install the library and its dependencies without creating a virtual environment via pip install -e ..

To install the library as a dependency for another project use pip install git+https://github.com/oxcsml/eqv_transformer.

Training a model

Example command to train a model (in this case the Set Transformer on the constellation dataset):

python3 scripts/train.py --data_config configs/constellation.py --model_config configs/set_transformer.py --run_name my_experiment --learning_rate=1e-4 --batch_size 128

The model and the dataset can be chosen by specifying different config files. Flags for configuring the model and the dataset are available in the respective config files. The project is using forge for configs and experiment management. Please refer to this forge description and examples for details.

Counting patterns in the constellation dataset

The first task implemented is counting patterns in the constellation dataset. We generate a fixed dataset of constellations, where each constellation consists of 0-8 patterns; each pattern consists of corners of a shape. Currently available shapes are triangle, square, pentagon and an L. The task is to count the number of occurences of each pattern. To save to file the constellation datasets, run before training:

python3 scripts/data_to_file.py

Else, the constellation datasets are regenerated at the beginning of the training.

Dataset and model consistency

When changing the dataset parameters (e.g. number of patterns, types of patterns etc) make sure that the model parameters are adjusted accordingly. For example patterns=square,square,triangle,triangle,pentagon,pentagon,L,L means that there can be four different patterns, each repeated two times. That means that counting will involve four three-way classification tasks, and so that n_outputs and output_dim in classifier.py needs to be set to 4 and 3, respectively. All this can be set through command-line arguments.

Results

Constellations results

QM9

This dataset consists of 133,885 small inorganic molecules described by the location and charge of each atom in the molecule, along with the bonding structure of the molecule. The dataset includes 19 properties of each molecule, such as various rotational constants, energies and enthalpies. We aim to predict 12 of these properties.

python scripts/train_molecule.py \
    --run_name "molecule_homo" \
    --model_config "configs/molecule/eqv_transformer_model.py" \
    --model_seed 0
    --data_seed 0 \
    --task homo

Results

QM9 results

Hamiltonian dynamics

In this experiment, we aim to predict the trajectory of a number of particles connected together by a series of springs. This is done by learning the Hamiltonian of the system from observed trajectories.

The following command generates a dataset of trajectories and trains LieTransformer on it. Data generation occurs in the first run and can take some time.

T(2) default: python scripts/train_dynamics.py
SE(2) default: python scripts/train_dynamics.py --group 'SE(2)_canonical' --lift_samples 2 --num_layers 3 --dim_hidden 80

Results

Rollout MSE Example Trajectories
dynamics data efficiency trajectories

Contributing

Contributions are best developed in separate branches. Once a change is ready, please submit a pull request with a description of the change. New model and data configs should go into the config folder, and the rest of the code should go into the eqv_transformer folder.

Owner
OxCSML (Oxford Computational Statistics and Machine Learning)
OxCSML (Oxford Computational Statistics and Machine Learning)
Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis

HAABSAStar Code for "Adversarial Training for a Hybrid Approach to Aspect-Based Sentiment Analysis". This project builds on the code from https://gith

1 Sep 14, 2020
Deep Learning Head Pose Estimation using PyTorch.

Hopenet is an accurate and easy to use head pose estimation network. Models have been trained on the 300W-LP dataset and have been tested on real data with good qualitative performance.

Nataniel Ruiz 1.3k Dec 26, 2022
A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization

MADGRAD Optimization Method A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization pip install madgrad Try it out! A best

Meta Research 774 Dec 31, 2022
Multimodal Temporal Context Network (MTCN)

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Deploy optimized transformer based models on Nvidia Triton server

🤗 Hugging Face Transformer submillisecond inference 🤯 and deployment on Nvidia Triton server Yes, you can perfom inference with transformer based mo

Lefebvre Sarrut Services 1.2k Jan 05, 2023
This project aims at providing a concise, easy-to-use, modifiable reference implementation for semantic segmentation models using PyTorch.

Semantic Segmentation on PyTorch (include FCN, PSPNet, Deeplabv3, Deeplabv3+, DANet, DenseASPP, BiSeNet, EncNet, DUNet, ICNet, ENet, OCNet, CCNet, PSANet, CGNet, ESPNet, LEDNet, DFANet)

2.4k Jan 08, 2023
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022
GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification

GalaXC GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification @InProceedings{Saini21, author = {Saini, D. and Jain,

Extreme Classification 28 Dec 05, 2022
Simple streamlit app to demonstrate HERE Tour Planning

Table of Contents About the Project Built With Getting Started Prerequisites Installation Usage Roadmap Contributing License Acknowledgements About Th

Amol 8 Sep 05, 2022
Learning to trade under the reinforcement learning framework

Trading Using Q-Learning In this project, I will present an adaptive learning model to trade a single stock under the reinforcement learning framework

Uirá Caiado 470 Nov 28, 2022
Tooling for converting STAC metadata to ODC data model

手语识别 0、使用到的模型 (1). openpose,作者:CMU-Perceptual-Computing-Lab https://github.com/CMU-Perceptual-Computing-Lab/openpose (2). 图像分类classification,作者:Bubbl

Open Data Cube 65 Dec 20, 2022
This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language Models"

GreaseLM: Graph REASoning Enhanced Language Models This repo provides the source code & data of our paper "GreaseLM: Graph REASoning Enhanced Language

137 Jan 02, 2023
Build Low Code Automated Tensorflow, What-IF explainable models in just 3 lines of code.

Build Low Code Automated Tensorflow explainable models in just 3 lines of code.

Hasan Rafiq 170 Dec 26, 2022
Unofficial Implement PU-Transformer

PU-Transformer-pytorch Pytorch unofficial implementation of PU-Transformer (PU-Transformer: Point Cloud Upsampling Transformer) https://arxiv.org/abs/

Lee Hyung Jun 7 Sep 21, 2022
A project that uses optical flow and machine learning to detect aimhacking in video clips.

waldo-anticheat A project that aims to use optical flow and machine learning to visually detect cheating or hacking in video clips from fps games. Che

waldo.vision 542 Dec 03, 2022
A small library for doing fluid simulation with neural networks.

Neural Fluid Fields This is a small library for doing fluid simulation with neural fields. Check out our review paper, Neural Fields in Visual Computi

Towaki 23 Jun 23, 2022
Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Deepak Nandwani 1 Dec 31, 2021
YOLTv5 rapidly detects objects in arbitrarily large aerial or satellite images that far exceed the ~600×600 pixel size typically ingested by deep learning object detection frameworks

YOLTv5 rapidly detects objects in arbitrarily large aerial or satellite images that far exceed the ~600×600 pixel size typically ingested by deep learning object detection frameworks.

Adam Van Etten 145 Jan 01, 2023
Official Code Release for Container : Context Aggregation Network

Container: Context Aggregation Network Official Code Release for Container : Context Aggregation Network Comparion between CNN, MLP-Mixer and Transfor

peng gao 42 Nov 17, 2021