Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

Overview

Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

Overview | Abstract | Installation | Examples | Citation

arXiv Python 3.8 Pytorch License Maintenance

Overview

Hi, good to see you here! 👋

Thanks for checking out the code for Non-Parametric Transformers (NPTs).

This codebase will allow you to reproduce experiments from the paper as well as use NPTs for your own research.

Abstract

We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.

Installation

Set up and activate the Python environment by executing

conda env create -f environment.yml
conda activate npt

For now, we recommend installing CUDA <= 10.2:

See issue with CUDA >= 11.0 here.

If you are running this on a system without a GPU, use the above with environment_no_gpu.yml instead.

Examples

We now give some basic examples of running NPT.

NPT downloads all supported datasets automatically, so you don't need to worry about that.

We use wandb to log experimental results. Wandb allows us to conveniently track run progress online. If you do not want wandb enabled, you can run wandb off in the shell where you execute NPT.

For example, run this to explore NPT with default configuration on Breast Cancer

python run.py --data_set breast-cancer

Another example: A run on the poker-hand dataset may look like this

python run.py --data_set poker-hand \
--exp_batch_size 4096 \
--exp_print_every_nth_forward 100

You can find all possible config arguments and descriptions in NPT/configs.py or using python run.py --help.

In scripts/ we provide a list with the runs and correct hyperparameter configurations presented in the paper.

We hope you enjoy using the code and please feel free to reach out with any questions 😊

Citation

If you find this code helpful for your work, please cite our paper Paper as

@article{kossen2021self,
  title={Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning},
  author={Kossen, Jannik and Band, Neil and Gomez, Aidan N. and Lyle, Clare and Rainforth, Tom and Gal, Yarin},
  journal={arXiv:2106.02584},
  year={2021}
}
Owner
OATML
Oxford Applied and Theoretical Machine Learning Group
OATML
Adversarial Graph Representation Adaptation for Cross-Domain Facial Expression Recognition (AGRA, ACM 2020, Oral)

Cross Domain Facial Expression Recognition Benchmark Implementation of papers: Cross-Domain Facial Expression Recognition: A Unified Evaluation Benchm

89 Dec 09, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
Backdoor Attack through Frequency Domain

Backdoor Attack through Frequency Domain DEPENDENCIES python==3.8.3 numpy==1.19.4 tensorflow==2.4.0 opencv==4.5.1 idx2numpy==1.2.3 pytorch==1.7.0 Data

5 Jun 18, 2022
Machine Translation Implement By Bi-GRU And Transformer

Seq2Seq Translation Implement By Bidirectional GRU And Transformer In Pytorch Before You Run The Code You should download the data through the link be

He Wang 2 Oct 27, 2021
Creative Applications of Deep Learning w/ Tensorflow

Creative Applications of Deep Learning w/ Tensorflow This repository contains lecture transcripts and homework assignments as Jupyter Notebooks for th

Parag K Mital 1.5k Dec 30, 2022
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
OpenL3: Open-source deep audio and image embeddings

OpenL3 OpenL3 is an open-source Python library for computing deep audio and image embeddings. Please refer to the documentation for detailed instructi

Music and Audio Research Laboratory - NYU 326 Jan 02, 2023
gym-anm is a framework for designing reinforcement learning (RL) environments that model Active Network Management (ANM) tasks in electricity distribution networks.

gym-anm is a framework for designing reinforcement learning (RL) environments that model Active Network Management (ANM) tasks in electricity distribution networks. It is built on top of the OpenAI G

Robin Henry 99 Dec 12, 2022
Adaptive Pyramid Context Network for Semantic Segmentation (APCNet CVPR'2019)

Adaptive Pyramid Context Network for Semantic Segmentation (APCNet CVPR'2019) Introduction Official implementation of Adaptive Pyramid Context Network

21 Nov 09, 2022
This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine

LSHTM_RCS This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine (LSHTM) in collabo

Lukas Kopecky 3 Jan 30, 2022
A sequence of Jupyter notebooks featuring the 12 Steps to Navier-Stokes

CFD Python Please cite as: Barba, Lorena A., and Forsyth, Gilbert F. (2018). CFD Python: the 12 steps to Navier-Stokes equations. Journal of Open Sour

Barba group 2.6k Dec 30, 2022
Some bravo or inspiring research works on the topic of curriculum learning.

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

131 Jan 07, 2023
Trash Sorter Extraordinaire is a software which efficiently detects the different types of waste in a pile of random trash through feeding it pictures or videos.

Trash-Sorter-Extraordinaire Trash Sorter Extraordinaire is a software which efficiently detects the different types of waste in a pile of random trash

Rameen Mahmood 1 Nov 07, 2021
PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

PyTorch implementation of the paper Ultra Fast Structure-aware Deep Lane Detection

1.4k Jan 06, 2023
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022
A Pytorch Implementation of ClariNet

ClariNet A Pytorch Implementation of ClariNet (Mel Spectrogram -- Waveform) Requirements PyTorch 0.4.1 & python 3.6 & Librosa Examples Step 1. Downlo

Sungwon Kim 286 Sep 15, 2022
BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond

BasicVSR BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond Ported from https://github.com/xinntao/BasicSR Dependencie

Holy Wu 8 Jun 07, 2022
Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Junxian He 57 Jan 01, 2023
Code for Transformer Hawkes Process, ICML 2020.

Transformer Hawkes Process Source code for Transformer Hawkes Process (ICML 2020). Run the code Dependencies Python 3.7. Anaconda contains all the req

Simiao Zuo 111 Dec 26, 2022