Authors implementation of LieTransformer: Equivariant Self-Attention for Lie Groups

Overview

LieTransformer

This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant self-attention for Lie Groups

Pattern recognition Molecular property prediction Particle Dynamics
Constellations Rotating molecule Particle trajectories

Introduction

LieTransformer is a equivariant Transformer-like model, built out of equivariant self attention layers (LieSelfAttention). The model can be made equivariant to any Lie group, simply by providing and implementation of the group of interest. A number of commonly used groups are already implemented, building off the work of LieConv. Switching group equivariance requires no change to model architecture, only passsing a different group to the model.

Architecture

The overall architecture of the LieTransformer is similar to the architecture of the original Transformer, interleaving series of attention layers and pointwise MLPs in residual blocks. The architecture of the LieSelfAttention blocks differs however, and can be seen below. For more details, please see the paper.

model diagram

Installation

To repoduce the experiments in this library, first clone the repo via https://github.com/anonymous-code-0/lie-transformer. To install the dependencies and create a virtual environment, execute setup_virtualenv.sh. Alternatively you can install the library and its dependencies without creating a virtual environment via pip install -e ..

To install the library as a dependency for another project use https://github.com/anonymous-code-0/lie-transformer.

Alternatively, you can install all the dependencies using pip install -r requirements.txt. If you do so, you will need to install the LieConv, Forge, and this repo itself (using the pip install -e command). Please note the version of LieConv used in this project is a slightly modified version of the original repo which fixes a bug for updated PyTorch versions.

Training a model

Example command to train a model (in this case the Set Transformer on the constellation dataset):

python3 scripts/train_constellation.py --data_config configs/constellation.py --model_config configs/set_transformer.py --run_name my_experiment --learning_rate=1e-4 --batch_size 128

The model and the dataset can be chosen by specifying different config files. Flags for configuring the model and the dataset are available in the respective config files. The project is using forge for configs and experiment management. Please refer to examples for details.

Counting patterns in the constellation dataset

The first task implemented is counting patterns in the constellation dataset. We generate a fixed dataset of constellations, where each constellation consists of 0-8 patterns; each pattern consists of corners of a shape. Currently available shapes are triangle, square, pentagon and an L. The task is to count the number of occurences of each pattern. To save to file the constellation datasets, run before training:

python3 scripts/data_to_file.py

Else, the constellation datasets are regenerated at the beginning of the training.

Dataset and model consistency

When changing the dataset parameters (e.g. number of patterns, types of patterns etc) make sure that the model parameters are adjusted accordingly. For example patterns=square,square,triangle,triangle,pentagon,pentagon,L,L means that there can be four different patterns, each repeated two times. That means that counting will involve four three-way classification tasks, and so that n_outputs and output_dim in classifier.py needs to be set to 4 and 3, respectively. All this can be set through command-line arguments.

Results

Constellations results

QM9

This dataset consists of 133,885 small inorganic molecules described by the location and charge of each atom in the molecule, along with the bonding structure of the molecule. The dataset includes 19 properties of each molecule, such as various rotational constants, energies and enthalpies. We aim to predict 12 of these properties.

python scripts/train_molecule.py \
    --run_name "molecule_homo" \
    --model_config "configs/molecule/eqv_transformer_model.py" \
    --model_seed 0
    --data_seed 0 \
    --task homo

Configurable scripts for running the experiments in the paper exist in the scripts folder, scripts/train_molecule_SE3transformer.sh, scripts/train_molecule_SE3lieconv.sh.

Results

QM9 results

Hamiltonian dynamics

In this experiment we aim to predict the trajectory of a number of particles connected together by a series of springs. This is done by learning the Hamiltonian of the system from observed trajectories.

The following command generates a dataset of trajectories and trains LieTransformer on it

T(2) default: python scripts/train_dynamics.py
SE(2) default: python scripts/train_dynamics.py --group 'SE(2)_canonical' --lift_samples 2 --num_layers 3 --dim_hidden 80

Results

Rollout MSE Example Trajectories
dynamics rollout trajectories

Contributing

Contributions are best developed in separate branches. Once a change is ready, please submit a pull request with a description of the change. New model and data configs should go into the config folder, and the rest of the code should go into the eqv_transformer folder.

Steerable discovery of neural audio effects

Steerable discovery of neural audio effects Christian J. Steinmetz and Joshua D. Reiss Abstract Applications of deep learning for audio effects often

Christian J. Steinmetz 182 Dec 29, 2022
PyTorch implementation of Munchausen Reinforcement Learning based on DQN and SAC. Handles discrete and continuous action spaces

Exploring Munchausen Reinforcement Learning This is the project repository of my team in the "Advanced Deep Learning for Robotics" course at TUM. Our

Mohamed Amine Ketata 10 Mar 10, 2022
An official TensorFlow implementation of “CLCC: Contrastive Learning for Color Constancy” accepted at CVPR 2021.

CLCC: Contrastive Learning for Color Constancy (CVPR 2021) Yi-Chen Lo*, Chia-Che Chang*, Hsuan-Chao Chiu, Yu-Hao Huang, Chia-Ping Chen, Yu-Lin Chang,

Yi-Chen (Howard) Lo 58 Dec 17, 2022
HGCN: Harmonic Gated Compensation Network For Speech Enhancement

HGCN The official repo of "HGCN: Harmonic Gated Compensation Network For Speech Enhancement", which was accepted at ICASSP2022. How to use step1: Calc

ScorpioMiku 33 Nov 14, 2022
Python periodic table module

elemenpy Hello! elements.py is a small Python periodic table module that is used for calling certain information about an element. Installation Instal

Eric Cheng 2 Dec 27, 2021
PG2Net: Personalized and Group PreferenceGuided Network for Next Place Prediction

PG2Net PG2Net:Personalized and Group Preference Guided Network for Next Place Prediction Datasets Experiment results on two Foursquare check-in datase

Urban Mobility 5 Dec 20, 2022
Cupytorch - A small framework mimics PyTorch using CuPy or NumPy

CuPyTorch CuPyTorch是一个小型PyTorch,名字来源于: 不同于已有的几个使用NumPy实现PyTorch的开源项目,本项目通过CuPy支持

Xingkai Yu 23 Aug 17, 2022
A TensorFlow implementation of the Mnemonic Descent Method.

MDM A Tensorflow implementation of the Mnemonic Descent Method. Mnemonic Descent Method: A recurrent process applied for end-to-end face alignment G.

123 Oct 07, 2022
Fast, differentiable sorting and ranking in PyTorch

Torchsort Fast, differentiable sorting and ranking in PyTorch. Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.)

Teddy Koker 655 Jan 04, 2023
BigbrotherBENL - Face recognition on the Big Brother episodes in Belgium and the Netherlands.

BigbrotherBENL - Face recognition on the Big Brother episodes in Belgium and the Netherlands. Keeping statistics of whom are most visible and recognisable in the series and wether or not it has an im

Frederik 2 Jan 04, 2022
A face dataset generator with out-of-focus blur detection and dynamic interval adjustment.

A face dataset generator with out-of-focus blur detection and dynamic interval adjustment.

Yutian Liu 2 Jan 29, 2022
Dense Prediction Transformers

Vision Transformers for Dense Prediction This repository contains code and models for our paper: Vision Transformers for Dense Prediction René Ranftl,

Intelligent Systems Lab Org 1.3k Jan 02, 2023
No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consistency

This repository contains the implementation for the paper: No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consiste

Alireza Golestaneh 75 Dec 30, 2022
Commonsense Ability Tests

CATS Commonsense Ability Tests Dataset and script for paper Evaluating Commonsense in Pre-trained Language Models Use making_sense.py to run the exper

XUHUI ZHOU 28 Oct 19, 2022
ReAct: Out-of-distribution Detection With Rectified Activations

ReAct: Out-of-distribution Detection With Rectified Activations This is the source code for paper ReAct: Out-of-distribution Detection With Rectified

38 Dec 05, 2022
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection

DDMP-3D Pytorch implementation of Depth-conditioned Dynamic Message Propagation forMonocular 3D Object Detection, a paper on CVPR2021. Instroduction T

Li Wang 32 Nov 09, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
CARLA: A Python Library to Benchmark Algorithmic Recourse and Counterfactual Explanation Algorithms

CARLA - Counterfactual And Recourse Library CARLA is a python library to benchmark counterfactual explanation and recourse models. It comes out-of-the

Carla Recourse 200 Dec 28, 2022
This was initially the repo for the project of [email protected] of Asaf Mazar, Millad Kassaie and Georgios Chochlakis named "Powered by the Will? Exploring Lay Theories of Behavior Change through Social Media"

Subreddit Analysis This repo includes tools for Subreddit analysis, originally developed for our class project of PSYC 626 in USC, titled "Powered by

Georgios Chochlakis 1 Dec 17, 2021