[NeurIPS 2020] Official Implementation: "SMYRF: Efficient Attention using Asymmetric Clustering".

Related tags

Deep Learningsmyrf
Overview

SMYRF: Efficient attention using asymmetric clustering

Get started:

Colab

Abstract

We propose a novel type of balanced clustering algorithm to approximate attention. Attention complexity is reduced from O(N^2) to O(NlogN), where N is the sequence length. Our algorithm, SMYRF, uses Locality Sensitive Hashing (LSH) in a novel way by defining new Asymmetric transformations and an adaptive scheme that produces balanced clusters. The biggest advantage of SMYRF is that it can be used as a drop-in replacement for dense attention layers without any retraining. On the contrary, prior fast attention methods impose constraints (e.g. tight queries and keys) and require re-training from scratch. We apply our method to pre-trained state-of-the-art Natural Language Processing and Computer Vision models and we report significant memory and speed benefits. Notably, SMYRF-BERT outperforms (slightly) BERT on GLUE, while using $50%$ less memory. We also show that SMYRF can be used interchangeably with dense attention before and after training. Finally, we use SMYRF to train GANs with attention in high resolutions. Using a single TPU, we train BigGAN on Celeba-HQ, with attention at resolution 128x128 and 256x256, capable of generating realistic human faces.

Authors: Giannis Daras, Nikita Kitaev, Augustus Odena, Alexandros G. Dimakis

Results

Memory-quality trade-off

GLUE benchmark

Avg. # C CoLA MNLI-m/mm MRPC QNLI QQP RTE SST-2 STS-B
BERT128 82.69 1 1 57.83 84.43/84.68 88.41 91.31 89.70 65.70 93.46 88.73
SMYRF-BERT2x32 82.98 2 32 58.79 83.76/84.27 87.69 91.14 89.72 68.59 93.23 89.65
SMYRF-BERT2x16 81.74 2 16 58.90 82.86/83.49 85.72 89.53 89.33 64.98 93.12 87.75
BERT64 81.57 1 64 58.80 82.34/82.47 87.02 90.48 89.69 61.73 93.00 88.64
BERT32 73.56 1 32 56.40 64.51/63.41 77.89 79.81 88.59 55.23 92.66 83.53

Interchangeability of SMYRF and dense attention

Results on IMDB dataset. Using dense attention on inference consistently improves results, nearly matching dense attention perf.

Memory SMYRF Inference Accuracy
RoBERTa 100% 94.96%
SMYRF-RoBERTa 50% 93.72%
SMYRF-RoBERTa 50% 94.62%
BERT 100% 94.12%
SMYRF-BERT 50% 92.64%
SMYRF-BERT 50% 93.54%

Smyrf-BigGAN training on Celeba-HQ-128

Generated faces by a Smyrf-BigGAN trained on 128x128 resolution with attention at 128x128, using 50% of dense memory.

Results after 120k iterations:

Resolution Attention # C FID
BigGAN 128x128 64x64 1 4096 26.06
Smyrf-BigGAN 128x128 128x128 4 2048 25.03

where # denotes number of hashes and C number of queries per cluster.

What's here

The code hosted in this repository is the one we used to run all the experiments in the paper. Get started:

Colab

For a deeper dive, look at the examples/ folder where we have code for pre-training SMYRF-BigGAN, sampling from a pre-trained BigGAN with SMYRF, finetuning state-of-the-art NLP models with SMYRF and a lot more.

Acknowledgments

We would like to wholeheartedly thank the TensorFlow Research Cloud (TFRC) program that gave us access to Cloud TPUs and GCP credits to train our models.

The code for the NLP experiments is exclusively based on the HuggingFace transformers library. We are very grateful to the authors of the library for their work.

The code for the CV experiments is based on the PyTorch implementation of BigGAN available in this url. The code has been expanded to support training on TPUs. Again, we want to thank the author for open-sourcing this implementation.

You might also like...
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Code for Discriminative Sounding Objects Localization (NeurIPS 2020)
Code for Discriminative Sounding Objects Localization (NeurIPS 2020)

Discriminative Sounding Objects Localization Code for our NeurIPS 2020 paper Discriminative Sounding Objects Localization via Self-supervised Audiovis

Advances in Neural Information Processing Systems (NeurIPS), 2020.

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Defending graph neural networks against adversarial attacks (NeurIPS 2020)
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik ([email protected].

Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Code release for NeurIPS 2020 paper "Co-Tuning for Transfer Learning"

CoTuning Official implementation for NeurIPS 2020 paper Co-Tuning for Transfer Learning. [News] 2021/01/13 The COCO 70 dataset used in the paper is av

Discovering Interpretable GAN Controls [NeurIPS 2020]
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Comments
  • Auto-regressive

    Auto-regressive

    Hi Giannis!

    Thanks for the great paper! I am interested in your asymmetric LSH, as I think having separate query / key space (as opposed to shared QK as in Reformer) will bring performance improvements in LSH-based attention.

    I saw that you recommended to a previous user to use this form of clustering for the auto-regressive case, and just wanted to probe if you had considered the scenario where a bucket of queries do not get matched with any keys from the past at all. This was an issue I had with trying to make separate QK space work with routing transformer, but just wondering if you had identified and found a solution to this problem.

    Phil

    opened by lucidrains 2
  • Logging and scoring

    Logging and scoring

    Currently logging and scoring is disabled for TPU BigGAN for maximum efficiency. We can probably re-write the logger and scorer to lower their performance bottleneck by converting most cpu materializations to XLA ops.

    bug example 
    opened by giannisdaras 0
  • Ema not working on TPU

    Ema not working on TPU

    Exponential moving average on weights of G is not working on TPUs. The problem is related to the loading of the state dict: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L614

    For now, we disable ema.

    bug example 
    opened by giannisdaras 0
Releases(1.0)
Owner
Giannis Daras
Machine Learning Researcher. Ph.D. student, UT Austin.
Giannis Daras
NLU Dataset Diagnostics

NLU Dataset Diagnostics This repository contains data and scripts to reproduce the results from our paper: Aarne Talman, Marianna Apidianaki, Stergios

Language Technology at the University of Helsinki 1 Jul 20, 2022
ACL'2021: LM-BFF: Better Few-shot Fine-tuning of Language Models

LM-BFF (Better Few-shot Fine-tuning of Language Models) This is the implementation of the paper Making Pre-trained Language Models Better Few-shot Lea

Princeton Natural Language Processing 607 Jan 07, 2023
Single object tracking and segmentation.

Single/Multiple Object Tracking and Segmentation Codes and comparison of recent single/multiple object tracking and segmentation. News 💥 AutoMatch is

ZP ZHANG 385 Jan 02, 2023
PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML)

pytorch-maml This is a PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML): https://arxiv

Kate Rakelly 516 Jan 05, 2023
This is the official source code of "BiCAT: Bi-Chronological Augmentation of Transformer for Sequential Recommendation".

BiCAT This is our TensorFlow implementation for the paper: "BiCAT: Sequential Recommendation with Bidirectional Chronological Augmentation of Transfor

John 15 Dec 06, 2022
Решения, подсказки, тесты и утилиты для тренировки по алгоритмам от Яндекса.

Решения и подсказки к тренировке по алгоритмам от Яндекса Что есть внутри Решения с подсказками и комментариями; рекомендую сначала смотреть md файл п

Yankovsky Andrey 50 Dec 26, 2022
SOTA easy to use PyTorch-based DL training library

Easily train or fine-tune SOTA computer vision models from one training repository. SuperGradients Introduction Welcome to SuperGradients, a free open

619 Jan 03, 2023
Demo project for real time anomaly detection using kafka and python

kafkaml-anomaly-detection Project for real time anomaly detection using kafka and python It's assumed that zookeeper and kafka are running in the loca

Rodrigo Arenas 36 Dec 12, 2022
a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LSTM layers

RNN-Playwrite a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LS

Arno Barton 1 Oct 29, 2021
Build fully-functioning computer vision models with PyTorch

Detecto is a Python package that allows you to build fully-functioning computer vision and object detection models with just 5 lines of code. Inferenc

Alan Bi 576 Dec 29, 2022
paper: Hyperspectral Remote Sensing Image Classification Using Deep Convolutional Capsule Network

DC-CapsNet This is a tensorflow and keras based implementation of DC-CapsNet for HSI in the Remote Sensing Letters R. Lei et al., "Hyperspectral Remot

LEI 7 Nov 29, 2022
This program generates a random 12 digit/character password (upper and lowercase) and stores it in a file along with your username and app/website.

PasswordGeneratorAndVault This program generates a random 12 digit/character password (upper and lowercase) and stores it in a file along with your us

Chris 1 Feb 26, 2022
Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.

HAWQ: Hessian AWare Quantization HAWQ is an advanced quantization library written for PyTorch. HAWQ enables low-precision and mixed-precision uniform

Zhen Dong 293 Dec 30, 2022
FwordCTF 2021 Infrastructure and Source code of Web/Bash challenges

FwordCTF 2021 You can find here the source code of the challenges I wrote (Web and Bash) in FwordCTF 2021 and the source code of the platform with our

Kahla 5 Nov 25, 2022
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval

BiDR Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval. Requirements torch==

Microsoft 11 Oct 20, 2022
PatchMatch-RL: Deep MVS with Pixelwise Depth, Normal, and Visibility

PatchMatch-RL: Deep MVS with Pixelwise Depth, Normal, and Visibility Jae Yong Lee, Joseph DeGol, Chuhang Zou, Derek Hoiem Installation To install nece

31 Apr 19, 2022
TensorFlow implementation of Barlow Twins (Barlow Twins: Self-Supervised Learning via Redundancy Reduction)

Barlow-Twins-TF This repository implements Barlow Twins (Barlow Twins: Self-Supervised Learning via Redundancy Reduction) in TensorFlow and demonstrat

Sayak Paul 36 Sep 14, 2022
Learning 3D Part Assembly from a Single Image

Learning 3D Part Assembly from a Single Image This repository contains a PyTorch implementation of the paper: Learning 3D Part Assembly from A Single

18 Dec 21, 2022