Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Overview

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(torch.arange(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = torch.randn(1, 1024, 64) # queries - (batch, seq len, dimension of head)
k = torch.randn(1, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # unsqueeze for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = torch.cat((rots1, rots2), dim = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Learning embeddings for classification, retrieval and ranking.
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

Improving XGBoost survival analysis with embeddings and debiased estimators
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Reliable probability face embeddings
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

 UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

🤖 A Python library for learning and evaluating knowledge graph embeddings
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Comments
  • Custom position offset when rotating queries or keys

    Custom position offset when rotating queries or keys

    This library seems to assume that queries and keys are left-aligned position-wise e.g.

    q = [p_0, p_1, p_2]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

    q =           [p_2, p_3, p_4]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

    import torch
    from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding
    
    rot = RotaryEmbedding(dim=32)
    
    q = torch.ones(1, 8, 4, 32)
    k = torch.ones(1, 8, 6, 32)
    
    q = q / torch.norm(q, dim=-1, keepdim=True)
    k = k / torch.norm(k, dim=-1, keepdim=True)
    
    q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    prints the following relative position embedding

    tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
            [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
            [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])
    

    (diagonal of 1s right-aligned) whereas the default behavior

    ...
    
    q_rot = rot.rotate_queries_or_keys(q)
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    would print

    tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
            [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
            [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])
    

    (diagonal of 1s left-aligned).

    opened by krasserm 1
  • about axial rotary embeddings

    about axial rotary embeddings

    Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi " Where does this formula come from?What parameter is max_freqs?Why the freqs is not " 1/(10000^(2i/d))"?

    Thank you again.

    opened by raindrop313 0
Owner
Phil Wang
Working with Attention
Phil Wang
[ICCV'21] Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment

CKDN The official implementation of the ICCV2021 paper "Learning Conditional Knowledge Distillation for Degraded-Reference Image Quality Assessment" O

Multimedia Research 50 Dec 13, 2022
Spectral Temporal Graph Neural Network (StemGNN in short) for Multivariate Time-series Forecasting

Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting This repository is the official implementation of Spectral Temporal Gr

Microsoft 306 Dec 29, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022
Real Time Object Detection and Classification using Yolo Algorithm.

Real time Object detection & Classification using YOLO algorithm. Real Time Object Detection and Classification using Yolo Algorithm. What is Object D

Ketan Chawla 1 Apr 17, 2022
VLGrammar: Grounded Grammar Induction of Vision and Language

VLGrammar: Grounded Grammar Induction of Vision and Language

Yining Hong 27 Dec 23, 2022
RAMA: Rapid algorithm for multicut problem

RAMA: Rapid algorithm for multicut problem Solves multicut (correlation clustering) problems orders of magnitude faster than CPU based solvers without

Paul Swoboda 60 Dec 13, 2022
A collection of Jupyter notebooks to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

StyleGAN3 CLIP-based guidance StyleGAN3 + CLIP StyleGAN3 + inversion + CLIP This repo is a collection of Jupyter notebooks made to easily play with St

Eugenio Herrera 176 Dec 30, 2022
MAg: a simple learning-based patient-level aggregation method for detecting microsatellite instability from whole-slide images

MAg Paper Abstract File structure Dataset prepare Data description How to use MAg? Why not try the MAg_lib! Trained models Experiment and results Some

Calvin Pang 3 Apr 08, 2022
A Simulation Environment to train Robots in Large Realistic Interactive Scenes

iGibson: A Simulation Environment to train Robots in Large Realistic Interactive Scenes iGibson is a simulation environment providing fast visual rend

Stanford Vision and Learning Lab 493 Jan 04, 2023
Eth brownie struct encoding example

eth-brownie struct encoding example Overview This repository contains an example of encoding a struct, so that it can be used in a function call, usin

Ittai Svidler 2 Mar 04, 2022
Model Zoo of BDD100K Dataset

Model Zoo of BDD100K Dataset

ETH VIS Group 200 Dec 27, 2022
This is a vision-based 3d model manipulation and control UI

Manipulation of 3D Models Using Hand Gesture This program allows user to manipulation 3D models (.obj format) with their hands. The project support bo

Cortic Technology Corp. 43 Oct 23, 2022
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
Codes for the compilation and visualization examples to the HIF vegetation dataset

High-impedance vegetation fault dataset This repository contains the codes that compile the "Vegetation Conduction Ignition Test Report" data, which a

1 Dec 12, 2021
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022
QICK: Quantum Instrumentation Control Kit

QICK: Quantum Instrumentation Control Kit The QICK is a kit of firmware and software to use the Xilinx RFSoC to control quantum systems. It consists o

81 Dec 15, 2022
[SDM 2022] Towards Similarity-Aware Time-Series Classification

SimTSC This is the PyTorch implementation of SDM2022 paper Towards Similarity-Aware Time-Series Classification. We propose Similarity-Aware Time-Serie

Daochen Zha 49 Dec 27, 2022
An Implementation of SiameseRPN with Feature Pyramid Networks

SiameseRPN with FPN This project is mainly based on HelloRicky123/Siamese-RPN. What I've done is just add a Feature Pyramid Network method to the orig

3 Apr 16, 2022
L-Verse: Bidirectional Generation Between Image and Text

Far beyond learning long-range interactions of natural language, transformers are becoming the de-facto standard for many vision tasks with their power and scalabilty

Kim, Taehoon 102 Dec 21, 2022
scalingscattering

Scaling The Scattering Transform : Deep Hybrid Networks This repository contains the experiments found in the paper: https://arxiv.org/abs/1703.08961

Edouard Oyallon 78 Dec 21, 2022