Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Overview

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

  • write up a test for invariance under rotation
  • enforce float32 for certain operations

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
Comments
  • Computing point dist - use cartesian dimension instead of hidden dimension

    Computing point dist - use cartesian dimension instead of hidden dimension

    https://github.com/lucidrains/invariant-point-attention/blob/2f1fb7ca003d9c94d4144d1f281f8cbc914c01c2/invariant_point_attention/invariant_point_attention.py#L130

    I think it should be dim=-1, thus using the cartesian (xyz) axis, rather than dim=-2, which uses the hidden dimension.

    opened by aced125 3
  • In-place rotation detach not allowed

    In-place rotation detach not allowed

    Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

    Traceback (most recent call last):
      File "denoise.py", line 56, in <module>
        denoised_coords = net(
      File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
        rotations.detach_()
    RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.
    

    Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

    opened by sidnarayanan 1
  • Report a bug that causes instability in training

    Report a bug that causes instability in training

    Hi, I would like to report a bug in the rotation, that causes instability in training. https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L322

    The IPA Transformer is similar to the structure module in AF2, where the recycling is used. Note that we usually detach the gradient of rotation, which may causes instability during training. The reason is that the gradient of rotation would update the rotation during back propagation, which results in the instability based on experiments. Therefore we usually detach the rotation to dispel the updating effect of gradient descent. I have seen you do this in your alphafold2 repo (https://github.com/lucidrains/alphafold2).

    If you think this is a problem, please let me know. I am happy to submit a pr to fix that.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • Subtle mistake in the implementation

    Subtle mistake in the implementation

    Hi. Thanks for your implementation. It is very helpful. However, I find that you miss the dropout in the IPAModule.

    https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L239

    In the alphafold2 supplementary, the dropout is nested in the layer norm, which also holds true in the layer norm at transition layer (line 9 in the figure below). image

    If you think this is a problem, please let me know. I will submit a pr to fix it. Thanks again for sharing such an amazing repo.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • change quaternions update as original alphafold2

    change quaternions update as original alphafold2

    In the original alphafold2 IPA module, pure-quaternion (without real part) description is used for quaternion update. This can be broken down to the residual-update-like formulation. But in this code you use (1, a, b, c) style quaternion so I believe the quaternion update should be done as a simple multiply update. As far as I have tested, the loss seems to go down more efficiently with the modification.

    opened by ShintaroMinami 1
  • #126 maybe omit the 'self.point_attn_logits_scale'?

    #126 maybe omit the 'self.point_attn_logits_scale'?

    Hi luci:

    I read the original paper and compare it to your implement, found one place might be some mistake:

    #126. attn_logits_points = -0.5 * (point_dist * point_weights).sum(dim = -1),

    I thought it should be attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale).sum(dim = -1)

    Thanks for your sharing!

    opened by CiaoHe 1
  • Application of Invariant point attention : preserver part of structure.

    Application of Invariant point attention : preserver part of structure.

    Hi, lucidrian. First of all really thanks for your work!

    I have a question, how can I change(denoise) the structure only in the region I want, how do I do it? (denoise.py)

    opened by hw-protein 0
  • Equivariance test for IPA Transformer

    Equivariance test for IPA Transformer

    @lucidrains I would like to ask about the equivariance of the transformer (not IPA blocks). I wonder if you checked for the equivariance of the output when you allow the transformation of local points to global points using the updated quaternions and translations. I am not sure why this test fails in my case.

    opened by amrhamedp 1
Owner
Phil Wang
Working with Attention
Phil Wang
OpenMMLab Image and Video Editing Toolbox

Introduction MMEditing is an open source image and video editing toolbox based on PyTorch. It is a part of the OpenMMLab project. The master branch wo

OpenMMLab 3.9k Jan 04, 2023
A geometric deep learning pipeline for predicting protein interface contacts.

A geometric deep learning pipeline for predicting protein interface contacts.

44 Dec 30, 2022
General purpose Slater-Koster tight-binding code for electronic structure calculations

tight-binder Introduction General purpose tight-binding code for electronic structure calculations based on the Slater-Koster approximation. The code

9 Dec 15, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
This is a repo of basic Machine Learning!

Basic Machine Learning This repository contains a topic-wise curated list of Machine Learning and Deep Learning tutorials, articles and other resource

Ekram Asif 53 Dec 31, 2022
Receptive Field Block Net for Accurate and Fast Object Detection, ECCV 2018

Receptive Field Block Net for Accurate and Fast Object Detection By Songtao Liu, Di Huang, Yunhong Wang Updatas (2021/07/23): YOLOX is here!, stronger

Liu Songtao 1.4k Dec 21, 2022
The second project in Python course on FCC

Assignment Write a function named add_time that takes in two required parameters and one optional parameter: a start time in the 12-hour clock format

Denise T 1 Dec 13, 2021
Xi Dongbo 78 Nov 29, 2022
🐦 Opytimizer is a Python library consisting of meta-heuristic optimization techniques.

Opytimizer: A Nature-Inspired Python Optimizer Welcome to Opytimizer. Did you ever reach a bottleneck in your computational experiments? Are you tired

Gustavo Rosa 546 Dec 31, 2022
Object recognition using Azure Custom Vision AI and Azure Functions

Step by Step on how to create an object recognition model using Custom Vision, export the model and run the model in an Azure Function

El Bruno 11 Jul 08, 2022
The "breathing k-means" algorithm with datasets and example notebooks

The Breathing K-Means Algorithm (with examples) The Breathing K-Means is an approximation algorithm for the k-means problem that (on average) is bette

Bernd Fritzke 75 Nov 17, 2022
Official code of CVPR 2021's PLOP: Learning without Forgetting for Continual Semantic Segmentation

PLOP: Learning without Forgetting for Continual Semantic Segmentation This repository contains all of our code. It is a modified version of Cermelli e

Arthur Douillard 116 Dec 14, 2022
Annotated, understandable, and visually interpretable PyTorch implementations of: VAE, BIRVAE, NSGAN, MMGAN, WGAN, WGANGP, LSGAN, DRAGAN, BEGAN, RaGAN, InfoGAN, fGAN, FisherGAN

Overview PyTorch 0.4.1 | Python 3.6.5 Annotated implementations with comparative introductions for minimax, non-saturating, wasserstein, wasserstein g

Shayne O'Brien 471 Dec 16, 2022
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
This is an official implementation of "Polarized Self-Attention: Towards High-quality Pixel-wise Regression"

Polarized Self-Attention: Towards High-quality Pixel-wise Regression This is an official implementation of: Huajun Liu, Fuqiang Liu, Xinyi Fan and Don

DeLightCMU 212 Jan 08, 2023
WSDM‘2022: Knowledge Enhanced Sports Game Summarization

Knowledge Enhanced Sports Game Summarization Cooming Soon! :) Data will be released after approval process. Code will be published once the author of

Jiaan Wang 14 Jul 13, 2022
Neural style in TensorFlow! 🎨

neural-style An implementation of neural style in TensorFlow. This implementation is a lot simpler than a lot of the other ones out there, thanks to T

Anish Athalye 5.5k Dec 29, 2022
Tutorial: Introduction to Graph Machine Learning, with Jupyter notebooks

GraphMLTutorialNLDL22 Tutorial NLDL22: Introduction to Graph Machine Learning, with Jupyter notebooks This tutorial takes place during the conference

UiT Machine Learning Group 3 Jan 10, 2022
ilpyt: imitation learning library with modular, baseline implementations in Pytorch

ilpyt The imitation learning toolbox (ilpyt) contains modular implementations of common deep imitation learning algorithms in PyTorch, with unified in

The MITRE Corporation 11 Nov 17, 2022