A Domain-Agnostic Benchmark for Self-Supervised Learning

Related tags

Deep Learningdabs
Overview

DABS: A Domain Agnostic Benchmark for Self-Supervised Learning

This repository contains the code for DABS, a benchmark for domain-agnostic self-supervised learning algorithms. The basic components of the benchmark can be found in datasets, encoders, and algorithms. Training is implemented with the PyTorch Lightning framework, logging with Weights and Biases, and configuration management with Hydra.

Usage

We provide support for Python >= 3.7. Install requirements with

python -m pip install -r requirements.txt

For instructions on how to install PyTorch versions compatible with your CUDA versions, see pytorch.org.

Datasets

We provide a set of dataset implementations (in src/datasets) from image, text, speech, sensor, medical imaging, and image-text domains. Preprocessing operations on these datasets are minimal and hard-coded as simple resizing (i.e. of images) and truncations (i.e. of text, audio). These should not be changed so as to maintain fair comparisons across other users of the benchmark.

See conf/datasets/*.yaml for all dataset configs, including the loss, metrics, and batch size used for each dataset.

Almost all datasets will download automatically when the dataset class is instantiated. The exceptions are the CheXpert, ImageNet, and CU Birds datasets, where manual registration or download is required. See the respective dataset files for specific instructions.

Pretraining Dataset (unlabeled) Transfer Dataset (labeled)
CIFAR10 Aircraft, CIFAR10, CU Birds, DTD, Traffic Sign, VGG Flower
PAMAP2 PAMAP2
MSCOCO MSCOCO (mismatched detection), VQA (Binary classification)
Wikitext-103 GLUE (10 Tasks)
mC4 PAWS-X (7 Tasks)
CheXpert CheXpert (atelectasis, cardiomegaly, consolidation, edema, and pleural effusion), ChestX-ray8 (atelectasis, cardiomegaly, effusion, infiltration, mass, nodule, pneumonia, pneumothorax)
LibriSpeech Audio MNIST, Fluent Speech (Action, Object, Location), Google Speech Commands, LibriSpeech, VoxCeleb1

Pretraining

During the pretraining phase, self-supervised encoders are trained to learn good representations from unlabeled data. We currently support seven datasets for pretraining, one for each domain: MS COCO, ImageNet, CheXpert, PAMAP2, mC4, WikiText-103, and LibriSpeech. If the pretraining dataset has associated labels, an online linear evaluator is jointly trained with the encoder to provide a heuristic of transfer performance.

Run pretraining with commands like

python pretrain.py exp.name=<experiment-name> dataset=<dataset> algorithm=<algorithm>

Each dataset and encoder has its own config file, so to train a Transformer on the CheXpert dataset with the e-Mix algorithm, run

python pretrain.py exp.name=emix-chexpert encoder=transformer dataset=chexpert algorithm=emix

See conf/pretrain.yaml for all pretraining configuration fields.

For more information on the datasets, encoders, and algorithms, see the following section.

Pretraining Dataset Modality Label type (unused) Input Type
CIFAR10 Natural images Single label 2d
PAMAP2 Sensor Single label 2d
MSCOCO Captioned images Single label 2d +
tokens
WikiText-103 English Text No label tokens
mC4 Multilingual Text No label tokens
CheXpert Medical images Multi label 2d
LibriSpeech Speech No label 2d

Transfer Learning

After pretraining, a small linear classifier is trained on top of the frozen encoder. Run transfer learning from a randomly initialized encoder with

python transfer.py exp.name=<experiment-name> dataset=<dataset> ckpt=null 

See conf/transfer.yaml for all transfer learning configuration fields and optionally replace null with the path to your pretrained encoder checkpoint.

Dataset Modality Label type Evaluation metric Input Type
Aircraft Natural images Single label Accuracy 2d
CU Birds Natural images Single label Accuracy 2d
DTD Natural images Single label Accuracy 2d
Traffic Sign Natural images Single label Accuracy 2d
VGG Flower Natural images Single label Accuracy 2d
Pamap2 Sensor Single label Accuracy 2d
MS COCO Captioned images Binary label Accuracy 2d +
tokens
VQA Captioned images Binary label Accuracy 2d +
tokens
CheXpert Medical images Multi label AUROC 2d
ChestX-ray8 Medical images Multi label AUROC 2d
PAWS-X Multilingual Text Binary label Accuracy tokens
COLA English Text Binary label Pearson correlation tokens
MNLI Matched English Text Single label Accuracy tokens
MNLI Mismatched English Text Single label Accuracy tokens
MRPC English Text Binary label Accuracy tokens
QNLI English Text Binary label Accuracy tokens
QQP English Text Binary label Accuracy tokens
RTE English Text Binary label Accuracy tokens
SST2 English Text Binary label Accuracy tokens
STSB English Text Regression Spearman correlation tokens
WNLI English Text Binary label Accuracy tokens
Audio MNIST Speech Single label Accuracy 2d
Fluent Speech Speech Single label Accuracy 2d
Google Speech Commands Speech Single label Accuracy 2d
LibriSpeech Speech Single label Accuracy 2d
VoxCeleb1 Speech Single label Accuracy 2d

Encoders

A domain-agnostic SSL method should have an encoder which remains as constant as possible across domains. We provide a general transformer encoder baseline (in src/encoders). The transformer operates on a sequence of vectors that are produced by a small set of embedding modules (e.g. patch or token embeddings).

Pretraining algorithms

The pretraining algorithm is the framework and objective that the encoder is trained with. Examples of domain-specific algorithms include SimCLR, BYOL, and MoCo, but these are not domain-agnostic methods as they depend on vision-specific augmentations. We provide our own domain-agnostic implementations of recent algorithms, including e-mix (a generalization of i-mix) and Shuffled Embedding Detection (ShED; a generalization of ELECTRA), which randomly permutes a subset of the input embeddings and trains the model to identify the permuted embeddings.

Results

Below are results for algorithms trained on each dataset in DABS. The baseline performance is obtained via a randomly initialized encoder.

Pretrain Dataset Transfer Dataset Encoder Baseline Performance e-mix Performance ShED Performance
ImageNet CIFAR10 Transformer 24.20% 39.43% 39.63%
ImageNet CU Birds Transformer 1.62% 3.86% 2.95%
ImageNet VGG Flowers Transformer 9.03% 25.96% 13.03%
ImageNet DTD Transformer 7.39% 8.83% 18.35%
ImageNet Traffic Sign Transformer 14.33% 65.07% 27.51%
ImageNet Aircraft Transformer 2.70% 10.15% 5.60%
PAMAP2 PAMAP2 Transformer 69.81% 79.48% 88.69%
MSCOCO VQA Transformer 57.50% 48.90% 54.30%
CheXpert CheXpert Transformer 68.14% 72.40% 72.40%
CheXpert ChestX-ray8 Transformer 57.00% 63.00% 63.70%
Wikitext-103 GLUE (average) Transformer 42.29% 44.08% 48.37%
mC4 PAWS-X (average) Transformer 58.11% 56.16% 59.91%
LibriSpeech Audio MNIST Transformer 33.13% 80.35% 67.33%
LibriSpeech Fluent Locations Transformer 62.09% 60.93% 60.24%
LibriSpeech Fluent Actions Transformer 26.15% 29.87% 30.53%
LibriSpeech Fluent Objects Transformer 30.13% 39.89% 39.36%
LibriSpeech Google Speech Commands Transformer 4.87% 19.22% 20.73%
LibriSpeech LibriSpeech Transformer 17.12% 60.18% 34.77%
LibriSpeech VoxCeleb1 Transformer 0.59% 2.43% 2.81%
Owner
Alex Tamkin
PhD at @stanfordnlp
Alex Tamkin
How to Predict Stock Prices Easily Demo

How-to-Predict-Stock-Prices-Easily-Demo How to Predict Stock Prices Easily - Intro to Deep Learning #7 by Siraj Raval on Youtube ##Overview This is th

Siraj Raval 752 Nov 16, 2022
Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis

Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis. You write a high level configuration file specifying your in

Blue Collar Bioinformatics 917 Jan 03, 2023
G-NIA model from "Single Node Injection Attack against Graph Neural Networks" (CIKM 2021)

Single Node Injection Attack against Graph Neural Networks This repository is our Pytorch implementation of our paper: Single Node Injection Attack ag

Shuchang Tao 18 Nov 21, 2022
Novel and high-performance medical image classification pipelines are heavily utilizing ensemble learning strategies

An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks Novel and high-performance medical ima

14 Dec 18, 2022
Memory-Augmented Model Predictive Control

Memory-Augmented Model Predictive Control This repository hosts the source code for the journal article "Composing MPC with LQR and Neural Networks fo

Fangyu Wu 1 Jun 19, 2022
DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

Jason Antic 15.8k Jan 04, 2023
Official Pytorch implementation of RePOSE (ICCV2021)

RePOSE: Iterative Rendering and Refinement for 6D Object Detection (ICCV2021) [Link] Abstract We present RePOSE, a fast iterative refinement method fo

Shun Iwase 68 Nov 15, 2022
B2EA: An Evolutionary Algorithm Assisted by Two Bayesian Optimization Modules for Neural Architecture Search

B2EA: An Evolutionary Algorithm Assisted by Two Bayesian Optimization Modules for Neural Architecture Search This is the offical implementation of the

SNU ADSL 0 Feb 07, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
This is a Keras implementation of a CNN for estimating age, gender and mask from a camera.

face-detector-age-gender This is a Keras implementation of a CNN for estimating age, gender and mask from a camera. Before run face detector app, expr

Devdreamsolution 2 Dec 04, 2021
Official repo for SemanticGAN https://nv-tlabs.github.io/semanticGAN/

SemanticGAN This is the official code for: Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalizat

151 Dec 28, 2022
Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

148 Dec 29, 2022
This's an implementation of deepmind Visual Interaction Networks paper using pytorch

Visual-Interaction-Networks An implementation of Deepmind visual interaction networks in Pytorch. Introduction For the purpose of understanding the ch

Mahmoud Gamal Salem 166 Dec 06, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
NLMpy - A Python package to create neutral landscape models

NLMpy is a Python package for the creation of neutral landscape models that are widely used by landscape ecologists to model ecological patterns

Manaaki Whenua – Landcare Research 1 Oct 08, 2022
Code for Max-Margin Contrastive Learning - AAAI 2022

Max-Margin Contrastive Learning This is a pytorch implementation for the paper Max-Margin Contrastive Learning accepted to AAAI 2022. This repository

Anshul Shah 12 Oct 22, 2022
Self-Supervised Learning with Data Augmentations Provably Isolates Content from Style

Self-Supervised Learning with Data Augmentations Provably Isolates Content from Style [NeurIPS 2021] Official code to reproduce the results and data p

Yash Sharma 27 Sep 19, 2022
Square Root Bundle Adjustment for Large-Scale Reconstruction

RootBA: Square Root Bundle Adjustment Project Page | Paper | Poster | Video | Code Table of Contents Citation Dependencies Installing dependencies on

Nikolaus Demmel 205 Dec 20, 2022
Understanding Convolutional Neural Networks from Theoretical Perspective via Volterra Convolution

nnvolterra Run Code Compile first: make compile Run all codes: make all Test xconv: make npxconv_test MNIST dataset needs to be downloaded, converted

1 May 24, 2022
Implementation of algorithms for continuous control (DDPG and NAF).

DEPRECATION This repository is deprecated and is no longer maintaned. Please see a more recent implementation of RL for continuous control at jax-sac.

Ilya Kostrikov 288 Dec 31, 2022