Implementation of a Transformer, but completely in Triton

Overview

Transformer in Triton (wip)

Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repository will mostly be a learning experience, with the end-goal being a vanilla transformer that is faster and more efficient to train.

Install

$ pip install triton-transformer

Usage

import torch
from triton_transformer import Transformer

model = Transformer(
    num_tokens = 256,
    max_seq_len = 1024,
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 256)

Citations

@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need}, 
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation "

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation ". Please

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

Comments
  • Question concerning PyTorch build

    Question concerning PyTorch build

    Hello. I find your project very interesting and I have seen your comparison between PyTorch and Triton implementations.

    However, I am curious whether your PyTorch environment is a source build optimized for your machine or a pip/conda install.

    Source building has faster runtimes and if a conda install is being used for comparison, the difference in speed may simply be due to Triton optimizing CUDA for the run environment.

    Thank you again for your interesting project.

    opened by veritas9872 13
  • _layernorm implementation forward result not equal F.layer_norm

    _layernorm implementation forward result not equal F.layer_norm

    I have a try on your triton-transformer and test the layernorm module alone. It's very weird that the forward result is different while the backward result is equal.

    code: from triton_transformer.layernorm import layernorm import torch import torch.nn as nn

    torch.manual_seed(0) x = torch.randn(2,5).cuda() x.requires_grad_(True) dy = .1*torch.randn_like(x).cuda() dim = 5 norm = nn.LayerNorm(dim).cuda()

    y1 = layernorm(x, norm.weight, norm.bias, use_triton = True) y2 = layernorm(x, norm.weight, norm.bias, use_triton = False) print(y1, y2) print(torch.allclose(y1, y2))

    y1.backward(dy, retain_graph=True) dx_y1 = x.grad.clone()

    x.grad = None

    y2.backward(dy, retain_graph=True) dx_y2 = x.grad.clone() print(dx_y1, dx_y2) print(torch.allclose(dx_y1, dx_y2))

    result: `tensor([[ 0.9492, -0.0021, -0.9797, 0.4449, -0.4123], [-0.7624, 0.4399, 0.7299, -0.3091, -0.0983]], device='cuda:0', grad_fn=<_layernormBackward>) tensor([[ 1.4217, -0.0031, -1.4674, 0.6663, -0.6175], [-1.4342, 0.8276, 1.3732, -0.5815, -0.1850]], device='cuda:0', grad_fn=) False

    tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') True`

    opened by Tengxu-Sun 1
  • Current state of benchmarking & contributing?

    Current state of benchmarking & contributing?

    Hey @lucidrains - hope you're doing well! I have some time to hack the next couple weeks, just wanted to get a sense of:

    • Current state of benchmarking (what Triton kernels provide how much lift, aggregate lift over a "vanilla Transformer implementation"
    • If there's anything I could help with, especially as I learn Triton!
    opened by siddk 0
  • Official layer norm added

    Official layer norm added

    Hi @lucidrains , in Triton layer norm was just added in examples, https://github.com/openai/triton/commit/d4baad426db72b83c5222e1c83c929c1860cae54 I tested it, it's twice as fast as Torch, often faster then Apex.

    I'm looking forward for your implementation of attention, so far the Torch implementation is the fastest with 12.3 / 14.5 (forw / back) vs the other Triton implementation in DeepSpeed which is 17.3/ 23.0 on my data.

    opened by olegklimov 2
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Codes for AAAI22 paper "Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum"

Paper For more details, please see our paper Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum which has been accepted a

14 Sep 30, 2022
Architecture Patterns with Python (TDD, DDD, EDM)

architecture-traning Architecture Patterns with Python (TDD, DDD, EDM) Chapter 5. 높은 기어비와 낮은 기어비의 TDD 5.2 도메인 계층 테스트를 서비스 계층으로 옮겨야 하는가? 도메인 계층 테스트 def

minsung sim 2 Mar 04, 2022
A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks

SVHNClassifier-PyTorch A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks If

Potter Hsu 182 Jan 03, 2023
Two-stage CenterNet

Probabilistic two-stage detection Two-stage object detectors that use class-agnostic one-stage detectors as the proposal network. Probabilistic two-st

Xingyi Zhou 1.1k Jan 03, 2023
Official repository for "Intriguing Properties of Vision Transformers" (2021)

Intriguing Properties of Vision Transformers Muzammal Naseer, Kanchana Ranasinghe, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, & Ming-Hsuan Yang P

Muzammal Naseer 155 Dec 27, 2022
Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022)

Stratified Transformer for 3D Point Cloud Segmentation Xin Lai*, Jianhui Liu*, Li Jiang, Liwei Wang, Hengshuang Zhao, Shu Liu, Xiaojuan Qi, Jiaya Jia

DV Lab 195 Jan 01, 2023
PyTorch implementation of the Value Iteration Networks (VIN) (NIPS '16 best paper)

Value Iteration Networks in PyTorch Tamar, A., Wu, Y., Thomas, G., Levine, S., and Abbeel, P. Value Iteration Networks. Neural Information Processing

LEI TAI 75 Nov 24, 2022
IOT: Instance-wise Layer Reordering for Transformer Structures

Introduction This repository contains the code for Instance-wise Ordered Transformer (IOT), which is introduced in the ICLR2021 paper IOT: Instance-wi

IOT 19 Nov 15, 2022
GyroSPD: Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices

GyroSPD Code for the paper "Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices" accepted at NeurIPS 2021. Re

Federico Lopez 12 Dec 12, 2022
SPRING is a seq2seq model for Text-to-AMR and AMR-to-Text (AAAI2021).

SPRING This is the repo for SPRING (Symmetric ParsIng aNd Generation), a novel approach to semantic parsing and generation, presented at AAAI 2021. Wi

Sapienza NLP group 98 Dec 21, 2022
The source code of the paper "SHGNN: Structure-Aware Heterogeneous Graph Neural Network"

SHGNN: Structure-Aware Heterogeneous Graph Neural Network The source code and dataset of the paper: SHGNN: Structure-Aware Heterogeneous Graph Neural

Wentao Xu 7 Nov 13, 2022
DilatedNet in Keras for image segmentation

Keras implementation of DilatedNet for semantic segmentation A native Keras implementation of semantic segmentation according to Multi-Scale Context A

303 Mar 15, 2022
[ICLR 2022 Oral] F8Net: Fixed-Point 8-bit Only Multiplication for Network Quantization

F8Net Fixed-Point 8-bit Only Multiplication for Network Quantization (ICLR 2022 Oral) OpenReview | arXiv | PDF | Model Zoo | BibTex PyTorch implementa

Snap Research 76 Dec 13, 2022
Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21

MonoFlex Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21. Work in progress. Installation This repo is tested w

Yunpeng 169 Dec 06, 2022
PyTorch Implementation of "Light Field Image Super-Resolution with Transformers"

LFT PyTorch implementation of "Light Field Image Super-Resolution with Transformers", arXiv 2021. [pdf]. Contributions: We make the first attempt to a

Squidward 62 Nov 28, 2022
A colab notebook for training Stylegan2-ada on colab, transfer learning onto your own dataset.

Stylegan2-Ada-Google-Colab-Starter-Notebook A no thrills colab notebook for training Stylegan2-ada on colab. transfer learning onto your own dataset h

Harnick Khera 66 Dec 16, 2022
6D Grasping Policy for Point Clouds

GA-DDPG [website, paper] Installation git clone https://github.com/liruiw/GA-DDPG.git --recursive Setup: Ubuntu 16.04 or above, CUDA 10.0 or above, py

Lirui Wang 48 Dec 21, 2022
Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness

Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness This repository contains the code used for the exper

H.R. Oosterhuis 28 Nov 29, 2022
Robust Lane Detection via Expanded Self Attention (WACV 2022)

Robust Lane Detection via Expanded Self Attention (WACV 2022) Minhyeok Lee, Junhyeop Lee, Dogyoon Lee, Woojin Kim, Sangwon Hwang, Sangyoun Lee Overvie

Min Hyeok Lee 18 Nov 12, 2022
Implementation of the ALPHAMEPOL algorithm, presented in Unsupervised Reinforcement Learning in Multiple Environments.

ALPHAMEPOL This repository contains the implementation of the ALPHAMEPOL algorithm, presented in Unsupervised Reinforcement Learning in Multiple Envir

3 Dec 23, 2021