[NeurIPS 2020] Official Implementation: "SMYRF: Efficient Attention using Asymmetric Clustering".

Related tags

Deep Learningsmyrf
Overview

SMYRF: Efficient attention using asymmetric clustering

Get started:

Colab

Abstract

We propose a novel type of balanced clustering algorithm to approximate attention. Attention complexity is reduced from O(N^2) to O(NlogN), where N is the sequence length. Our algorithm, SMYRF, uses Locality Sensitive Hashing (LSH) in a novel way by defining new Asymmetric transformations and an adaptive scheme that produces balanced clusters. The biggest advantage of SMYRF is that it can be used as a drop-in replacement for dense attention layers without any retraining. On the contrary, prior fast attention methods impose constraints (e.g. tight queries and keys) and require re-training from scratch. We apply our method to pre-trained state-of-the-art Natural Language Processing and Computer Vision models and we report significant memory and speed benefits. Notably, SMYRF-BERT outperforms (slightly) BERT on GLUE, while using $50%$ less memory. We also show that SMYRF can be used interchangeably with dense attention before and after training. Finally, we use SMYRF to train GANs with attention in high resolutions. Using a single TPU, we train BigGAN on Celeba-HQ, with attention at resolution 128x128 and 256x256, capable of generating realistic human faces.

Authors: Giannis Daras, Nikita Kitaev, Augustus Odena, Alexandros G. Dimakis

Results

Memory-quality trade-off

GLUE benchmark

Avg. # C CoLA MNLI-m/mm MRPC QNLI QQP RTE SST-2 STS-B
BERT128 82.69 1 1 57.83 84.43/84.68 88.41 91.31 89.70 65.70 93.46 88.73
SMYRF-BERT2x32 82.98 2 32 58.79 83.76/84.27 87.69 91.14 89.72 68.59 93.23 89.65
SMYRF-BERT2x16 81.74 2 16 58.90 82.86/83.49 85.72 89.53 89.33 64.98 93.12 87.75
BERT64 81.57 1 64 58.80 82.34/82.47 87.02 90.48 89.69 61.73 93.00 88.64
BERT32 73.56 1 32 56.40 64.51/63.41 77.89 79.81 88.59 55.23 92.66 83.53

Interchangeability of SMYRF and dense attention

Results on IMDB dataset. Using dense attention on inference consistently improves results, nearly matching dense attention perf.

Memory SMYRF Inference Accuracy
RoBERTa 100% 94.96%
SMYRF-RoBERTa 50% 93.72%
SMYRF-RoBERTa 50% 94.62%
BERT 100% 94.12%
SMYRF-BERT 50% 92.64%
SMYRF-BERT 50% 93.54%

Smyrf-BigGAN training on Celeba-HQ-128

Generated faces by a Smyrf-BigGAN trained on 128x128 resolution with attention at 128x128, using 50% of dense memory.

Results after 120k iterations:

Resolution Attention # C FID
BigGAN 128x128 64x64 1 4096 26.06
Smyrf-BigGAN 128x128 128x128 4 2048 25.03

where # denotes number of hashes and C number of queries per cluster.

What's here

The code hosted in this repository is the one we used to run all the experiments in the paper. Get started:

Colab

For a deeper dive, look at the examples/ folder where we have code for pre-training SMYRF-BigGAN, sampling from a pre-trained BigGAN with SMYRF, finetuning state-of-the-art NLP models with SMYRF and a lot more.

Acknowledgments

We would like to wholeheartedly thank the TensorFlow Research Cloud (TFRC) program that gave us access to Cloud TPUs and GCP credits to train our models.

The code for the NLP experiments is exclusively based on the HuggingFace transformers library. We are very grateful to the authors of the library for their work.

The code for the CV experiments is based on the PyTorch implementation of BigGAN available in this url. The code has been expanded to support training on TPUs. Again, we want to thank the author for open-sourcing this implementation.

You might also like...
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Code for Discriminative Sounding Objects Localization (NeurIPS 2020)
Code for Discriminative Sounding Objects Localization (NeurIPS 2020)

Discriminative Sounding Objects Localization Code for our NeurIPS 2020 paper Discriminative Sounding Objects Localization via Self-supervised Audiovis

Advances in Neural Information Processing Systems (NeurIPS), 2020.

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Defending graph neural networks against adversarial attacks (NeurIPS 2020)
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik ([email protected].

Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Code release for NeurIPS 2020 paper "Co-Tuning for Transfer Learning"

CoTuning Official implementation for NeurIPS 2020 paper Co-Tuning for Transfer Learning. [News] 2021/01/13 The COCO 70 dataset used in the paper is av

Discovering Interpretable GAN Controls [NeurIPS 2020]
Discovering Interpretable GAN Controls [NeurIPS 2020]

GANSpace: Discovering Interpretable GAN Controls Figure 1: Sequences of image edits performed using control discovered with our method, applied to thr

Comments
  • Auto-regressive

    Auto-regressive

    Hi Giannis!

    Thanks for the great paper! I am interested in your asymmetric LSH, as I think having separate query / key space (as opposed to shared QK as in Reformer) will bring performance improvements in LSH-based attention.

    I saw that you recommended to a previous user to use this form of clustering for the auto-regressive case, and just wanted to probe if you had considered the scenario where a bucket of queries do not get matched with any keys from the past at all. This was an issue I had with trying to make separate QK space work with routing transformer, but just wondering if you had identified and found a solution to this problem.

    Phil

    opened by lucidrains 2
  • Logging and scoring

    Logging and scoring

    Currently logging and scoring is disabled for TPU BigGAN for maximum efficiency. We can probably re-write the logger and scorer to lower their performance bottleneck by converting most cpu materializations to XLA ops.

    bug example 
    opened by giannisdaras 0
  • Ema not working on TPU

    Ema not working on TPU

    Exponential moving average on weights of G is not working on TPUs. The problem is related to the loading of the state dict: https://github.com/ajbrock/BigGAN-PyTorch/blob/master/utils.py#L614

    For now, we disable ema.

    bug example 
    opened by giannisdaras 0
Releases(1.0)
Owner
Giannis Daras
Machine Learning Researcher. Ph.D. student, UT Austin.
Giannis Daras
Post-training Quantization for Neural Networks with Provable Guarantees

Post-training Quantization for Neural Networks with Provable Guarantees Authors: Jinjie Zhang ( Yixuan Zhou 2 Nov 29, 2022

A C implementation for creating 2D voronoi diagrams

Branch OSX/Linux Windows master dev jc_voronoi A fast C/C++ header only implementation for creating 2D Voronoi diagrams from a point set Uses Fortune'

Mathias Westerdahl 481 Dec 29, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
Towards Long-Form Video Understanding

Towards Long-Form Video Understanding Chao-Yuan Wu, Philipp Krähenbühl, CVPR 2021 [Paper] [Project Page] [Dataset] Citation @inproceedings{lvu2021,

Chao-Yuan Wu 69 Dec 26, 2022
Development kit for MIT Scene Parsing Benchmark

Development Kit for MIT Scene Parsing Benchmark [NEW!] Our PyTorch implementation is released in the following repository: https://github.com/hangzhao

MIT CSAIL Computer Vision 424 Dec 01, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Alex Rogozhnikov 6.2k Jan 01, 2023
Tooling for GANs in TensorFlow

TensorFlow-GAN (TF-GAN) TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). Can be installed with pip

803 Dec 24, 2022
Learning from Synthetic Data with Fine-grained Attributes for Person Re-Identification

Less is More: Learning from Synthetic Data with Fine-grained Attributes for Person Re-Identification Suncheng Xiang Shanghai Jiao Tong University Over

SunchengXiang 68 Dec 13, 2022
[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore

[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6101 of Semester 1, AY2021-2022, starting from 08/2021. The instructors of

AccSrd 1 Sep 22, 2022
Image Captioning using CNN and Transformers

Image-Captioning Keras/Tensorflow Image Captioning application using CNN and Transformer as encoder/decoder. In particulary, the architecture consists

24 Dec 28, 2022
BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer

BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer Project Page | Paper | Video State-of-the-art image-to-image translatio

47 Dec 06, 2022
This is a repository for a No-Code object detection inference API using the OpenVINO. It's supported on both Windows and Linux Operating systems.

OpenVINO Inference API This is a repository for an object detection inference API using the OpenVINO. It's supported on both Windows and Linux Operati

BMW TechOffice MUNICH 68 Nov 24, 2022
The code of paper "Block Modeling-Guided Graph Convolutional Neural Networks".

Block Modeling-Guided Graph Convolutional Neural Networks This repository contains the demo code of the paper: Block Modeling-Guided Graph Convolution

22 Dec 08, 2022
Skyformer: Remodel Self-Attention with Gaussian Kernel and Nystr\"om Method (NeurIPS 2021)

Skyformer This repository is the official implementation of Skyformer: Remodel Self-Attention with Gaussian Kernel and Nystr"om Method (NeurIPS 2021).

Qi Zeng 46 Sep 20, 2022
Interpolation-based reduced-order models

Interpolation-reduced-order-models Interpolation-based reduced-order models High-fidelity computational fluid dynamics (CFD) solutions are time consum

Donovan Blais 1 Jan 10, 2022
A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM's

sign-language-detection A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM. The project is built for a vocabular

Hashim 4 Feb 06, 2022
Adjusting for Autocorrelated Errors in Neural Networks for Time Series

Adjusting for Autocorrelated Errors in Neural Networks for Time Series This repository is the official implementation of the paper "Adjusting for Auto

Fan-Keng Sun 51 Nov 05, 2022
Official implementation for the paper: Multi-label Classification with Partial Annotations using Class-aware Selective Loss

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

99 Dec 27, 2022
Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Obs-Causal-Q-Network AAAI 2022 - Training a Resilient Q-Network against Observational Interference Preprint | Slides | Colab Demo | Environment Setup

23 Nov 21, 2022