PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Related tags

Deep Learningflowgmm
Overview

Flow Gaussian Mixture Model (FlowGMM)

This repository contains a PyTorch implementation of the Flow Gaussian Mixture Model (FlowGMM) model from our paper

Semi-Supervised Learning with Normalizing Flows

by Pavel Izmailov, Polina Kirichenko, Marc Finzi and Andrew Gordon Wilson.

Introduction

Normalizing flows transform a latent distribution through an invertible neural network for a flexible and pleasingly simple approach to generative modelling, while preserving an exact likelihood. In this paper, we introduce FlowGMM (Flow Gaussian Mixture Model), an approach to semi-supervised learning with normalizing flows, by modelling the density in the latent space as a Gaussian mixture, with each mixture component corresponding to a class represented in the labelled data. FlowGMM is distinct in its simplicity, unified treatment of labelled and unlabelled data with an exact likelihood, interpretability, and broad applicability beyond image data.

We show promising results on a wide range of semi-supervised classification problems, including AG-News and Yahoo Answers text data, UCI tabular data, and image datasets (MNIST, CIFAR-10 and SVHN).

Screenshot from 2019-12-29 19-32-26

Please cite our work if you find it useful:

@article{izmailov2019semi,
  title={Semi-Supervised Learning with Normalizing Flows},
  author={Izmailov, Pavel and Kirichenko, Polina and Finzi, Marc and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:1912.13025},
  year={2019}
}

Installation

To run the scripts you will need to clone the repo and install it locally. You can use the commands below.

git clone https://github.com/izmailovpavel/flowgmm.git
cd flowgmm
pip install -e .

Dependencies

We have the following dependencies for FlowGMM that must be installed prior to install to FlowGMM

We provide the scripts and example commands to reproduce the experiments from the paper.

Synthetic Datasets

The experiments on synthetic data are implemented in this ipython notebook. We additionaly provide another ipython notebook applying FlowGMM to labeled data only.

Tabular Datasets

The tabular datasets will be download and preprocessed automatically the first time they are needed. Using the commands below you can reproduce the performance from the table.

AGNEWS YAHOO HEPMASS MINIBOONE
MLP 77.5 55.7 82.2 80.4
Pi Model 80.2 56.3 87.9 80.8
FlowGMM 82.1 57.9 88.5 81.9

Text Classification (Updated)

Train FlowGMM on AG-News (200 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':.6}" --net_config "{'k':1024,'coupling_layers':7,'nperlayer':1}" --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 100 --dataset AG_News --lr 3e-4 --train 200

Train FlowGMM on YAHOO Answers (800 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':.2}" --net_config "{'k':1024,'coupling_layers':7,'nperlayer':1}" --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 200 --dataset YAHOO --lr 3e-4 --train 800

UCI Data

Train FlowGMM on MINIBOONE (20 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':3.}"\
 --net_config "{'k':256,'coupling_layers':10,'nperlayer':1}" --network RealNVPTabularWPrior \
 --trainer SemiFlow --num_epochs 300 --dataset MINIBOONE --lr 3e-4

Train FlowGMM on HEPMASS (20 labeled examples):

python experiments/train_flows/flowgmm_tabular_new.py --trainer_config "{'unlab_weight':10}"\
 --net_config "{'k':256,'coupling_layers':10,'nperlayer':1}" \
 --network RealNVPTabularWPrior --trainer SemiFlow --num_epochs 15 --dataset HEPMASS

Note that for on the low dimensional tabular data the FlowGMM models are quite sensitive to initialization. You may want to run the script a couple of times in case the model does not recover from a bad init.

The training script for the UCI dataset will automatically download the relevant MINIBOONE or HEPMASS datasets and unpack them into ~/datasets/UCI/., but for reference they come from here and here. We follow the preprocessing (where sensible) from Masked Autoregressive Flow for Density Estimation.

Baselines

Training the 3 Layer NN + Dropout on

YAHOO Answers: python experiments/train_flows/flowgmm_tabular_new.py --lr=1e-3 --dataset YAHOO --num_epochs 1000 --train 800

AG-NEWS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset AG_News --num_epochs 1000 --train 200

MINIBOONE: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset MINIBOONE --num_epochs 500

HEPMASS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-4 --dataset HEPMASS --num_epochs 500

Training the Pi Model on

YAHOO Answers: python flowgmm_tabular_new.py --lr=1e-3 --dataset YAHOO --num_epochs 300 --train 800 --trainer PiModel --trainer_config "{'cons_weight':.3}"

AG-NEWS: python experiments/train_flows/flowgmm_tabular_new.py --lr 1e-3 --dataset AG_News --num_epochs 100 --train 200 --trainer PiModel --trainer_config "{'cons_weight':30}"

MINIBOONE: python flowgmm_tabular_new.py --lr 3e-4 --dataset MINIBOONE --trainer PiModel --trainer_config "{'cons_weight':30}" --num_epochs 10

HEPMASS: python experiments/train_flows/flowgmm_tabular_new.py --trainer PiModel --num_epochs 10 --dataset MINIBOONE --trainer_config "{'cons_weight':3}" --lr 1e-4

The notebook here can be used to run the kNN, Logistic Regression, and Label Spreading baselines once the data has already been downloaded by the previous scripts or if it was downloaded manually.

Image Classification

To run experiments with FlowGMM on image classification problems you first need to download and prepare the data. To do so, run the following scripts:

./data/bin/prepare_cifar10.sh
./data/bin/prepare_mnist.sh
./data/bin/prepare_svhn.sh

To run FlowGMM, you can use the following script

python3 experiments/train_flows/train_semisup_cons.py \
  --dataset=<DATASET> \
  --data_path=<DATAPATH> \
  --label_path=<LABELPATH> \
  --logdir=<LOGDIR> \
  --ckptdir=<CKPTDIR> \
  --save_freq=<SAVEFREQ> \ 
  --num_epochs=<EPOCHS> \
  --label_weight=<LABELWEIGHT> \
  --consistency_weight=<CONSISTENCYWEIGHT> \
  --consistency_rampup=<CONSISTENCYRAMPUP> \
  --lr=<LR> \
  --eval_freq=<EVALFREQ> \

Parameters:

  • DATASET — dataset name [MNIST/CIFAR10/SVHN]
  • DATAPATH — path to the directory containing data; if you used the data preparation scripts, you can use e.g. data/images/mnist as DATAPATH
  • LABELPATH — path to the label split generated by the data preparation scripts; this can be e.g. data/labels/mnist/1000_balanced_labels/10.npz or data/labels/cifar10/1000_balanced_labels/10.txt.
  • LOGDIR — directory where tensorboard logs will be stored
  • CKPTDIR — directory where checkpoints will be stored
  • SAVEFREQ — frequency of saving checkpoints in epochs
  • EPOCHS — number of training epochs (passes through labeled data)
  • LABELWEIGHT — weight of cross-entropy loss term (default: 1.)
  • CONSISTENCYWEIGHT — weight of consistency loss term (default: 1.)
  • CONSISTENCYRAMPUP — length of consistency ramp-up period in epochs (default: 1); consistency weight is linearly increasing from 0. to CONSISTENCYWEIGHT in the first CONSISTENCYRAMPUP epochs of training
  • LR — learning rate (default: 1e-3)
  • EVALFREQ — number of epochs between evaluation (default: 1)

Examples:

# MNIST, 100 labeled datapoints
python3 experiments/train_flows/train_semisup_cons.py --dataset=MNIST --data_path=data/images/mnist/ \
  --label_path=data/labels/mnist/100_balanced_labels/10.npz --logdir=<LOGDIR> --ckptdir=<CKPTDIR> \
  --save_freq=5000 --num_epochs=30001 --label_weight=3 --consistency_weight=1. --consistency_rampup=1000 \
  --lr=1e-5 --eval_freq=100 
  
# CIFAR-10, 4000 labeled datapoints
python3 experiments/train_flows/train_semisup_cons.py --dataset=CIFAR10 --data_path=data/images/cifar/cifar10/by-image/ \
  --label_path=data/labels/cifar10/4000_balanced_labels/10.txt --logdir=<LOGDIR> --ckptdir=<CKPTDIR> \ 
  --save_freq=500 --num_epochs=1501 --label_weight=3 --consistency_weight=1. --consistency_rampup=100 \
  --lr=1e-4 --eval_freq=50

References

Owner
Pavel Izmailov
Pavel Izmailov
[NeurIPS'21] "AugMax: Adversarial Composition of Random Augmentations for Robust Training" by Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Animashree Anandkumar, and Zhangyang Wang.

[NeurIPS'21] "AugMax: Adversarial Composition of Random Augmentations for Robust Training" by Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Animashree Anandkumar, and Zhangyang Wang.

VITA 112 Nov 07, 2022
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives

Status: Under development (expect bug fixes and huge updates) ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectiv

37 Dec 28, 2022
A deep learning based semantic search platform that computes similarity scores between provided query and documents

semanticsearch This is a deep learning based semantic search platform that computes similarity scores between provided query and documents. Documents

1 Nov 30, 2021
Improving adversarial robustness by a coupling rejection strategy

Adversarial Training with Rectified Rejection The code for the paper Adversarial Training with Rectified Rejection. Environment settings and libraries

Tianyu Pang 29 Jan 06, 2023
Harmonic Memory Networks for Graph Completion

HMemNetworks Code and documentation for Harmonic Memory Networks, a series of models for compositionally assembling representations of graph elements

mlalisse 0 Oct 27, 2021
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
Generate images from texts. In Russian

ruDALL-E Generate images from texts pip install rudalle==1.1.0rc0 🤗 HF Models: ruDALL-E Malevich (XL) ruDALL-E Emojich (XL) (readme here) ruDALL-E S

AI Forever 1.6k Dec 31, 2022
Hl classification bc - A Network-Based High-Level Data Classification Algorithm Using Betweenness Centrality

A Network-Based High-Level Data Classification Algorithm Using Betweenness Centr

Esteban Vilca 3 Dec 01, 2022
Title: Heart-Failure-Classification

This Notebook is based off an open source dataset available on where I have created models to classify patients who can potentially witness heart failure on the basis of various parameters. The best

Akarsh Singh 2 Sep 13, 2022
NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

NVTabular is a feature engineering and preprocessing library for tabular data designed to quickly and easily manipulate terabyte scale datasets used to train deep learning based recommender systems.

880 Jan 07, 2023
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Expressive Power of Invariant and Equivaraint Graph Neural Networks (ICLR 2021)

Expressive Power of Invariant and Equivaraint Graph Neural Networks In this repository, we show how to use powerful GNN (2-FGNN) to solve a graph alig

Marc Lelarge 36 Dec 12, 2022
One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing".

Introduction One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing". Users

seq-to-mind 18 Dec 11, 2022
The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks

The Implicit Bias of Gradient Descent on Generalized Gated Linear Networks This folder contains the code to reproduce the data in "The Implicit Bias o

Samuel Lippl 0 Feb 05, 2022
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 05, 2023
Block Sparse movement pruning

Movement Pruning: Adaptive Sparsity by Fine-Tuning Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; ho

Hugging Face 54 Dec 20, 2022
Method for facial emotion recognition compitition of Xunfei and Datawhale .

人脸情绪识别挑战赛-第3名-W03KFgNOc-源代码、模型以及说明文档 队名:W03KFgNOc 排名:3 正确率: 0.75564 队员:yyMoming,xkwang,RichardoMu。 比赛链接:人脸情绪识别挑战赛 文章地址:link emotion 该项目分别训练八个模型并生成csv文

6 Oct 17, 2022
Prediction of MBA refinance Index (Mortgage prepayment)

Prediction of MBA refinance Index (Mortgage prepayment) Deep Neural Network based Model The ability to predict mortgage prepayment is of critical use

Ruchil Barya 1 Jan 16, 2022
subpixel: A subpixel convnet for super resolution with Tensorflow

subpixel: A subpixel convolutional neural network implementation with Tensorflow Left: input images / Right: output images with 4x super-resolution af

Atrium LTS 2.1k Dec 23, 2022