Official code for HH-VAEM

Overview

HH-VAEM

This repository contains the official Pytorch implementation of the Hierarchical Hamiltonian VAE for Mixed-type Data (HH-VAEM) model and the sampling-based feature acquisition technique presented in the paper Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo. HH-VAEM is a Hierarchical VAE model for mixed-type incomplete data that uses Hamiltonian Monte Carlo with automatic hyper-parameter tuning for improved approximate inference. The repository contains the implementation and the experiments provided in the paper.

Please, if you use this code, cite the preprint using:

@article{peis2022missing,
  title={Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo},
  author={Peis, Ignacio and Ma, Chao and Hern{\'a}ndez-Lobato, Jos{\'e} Miguel},
  journal={arXiv preprint arXiv:2202.04599},
  year={2022}
}

Instalation

The installation is straightforward using the following instruction, that creates a conda virtual environment named HH-VAEM using the provided file environment.yml:

conda env create -f environment.yml

Usage

Training

The project is developed in the recent research framework PyTorch Lightning. The HH-VAEM model is implemented as a LightningModule that is trained by means of a Trainer. A model can be trained by using:

# Example for training HH-VAEM on Boston dataset
python train.py --model HHVAEM --dataset boston --split 0

This will automatically download the boston dataset, split in 10 train/test splits and train HH-VAEM on the training split 0. Two folders will be created: data/ for storing the datasets and logs/ for model checkpoints and TensorBoard logs. The variable LOGDIR can be modified in src/configs.py to change the directory where these folders will be created (this might be useful for avoiding overloads in network file systems).

The following datasets are available:

  • A total of 10 UCI datasets: avocado, boston, energy, wine, diabetes, concrete, naval, yatch, bank or insurance.
  • The MNIST datasets: mnist or fashion_mnist.
  • More datasets can be easily added to src/datasets.py.

For each dataset, the corresponding parameter configuration must be added to src/configs.py.

The following models are also available (implemented in src/models/):

  • HHVAEM: the proposed model in the paper.
  • VAEM: the VAEM strategy presented in (Ma et al., 2020) with Gaussian encoder (without including the Partial VAE).
  • HVAEM: A Hierarchical VAEM with two layers of latent variables and a Gaussian encoder.
  • HMCVAEM: A VAEM that includes a tuned HMC sampler for the true posterior.
  • For MNIST datasets (non heterogeneous data), use HHVAE, VAE, HVAE and HMCVAE.

By default, the test stage will be executed at the end of the training stage. This can be cancelled with --test 0 for manually running the test using:

# Example for testing HH-VAEM on Boston dataset
python test.py --model HHVAEM --dataset boston --split 0

which will load the trained model to be tested on the boston test split number 0. Once all the splits are tested, the average results can be obtained using the script in the run/ folder:

# Example for obtaining the average test results with HH-VAEM on Boston dataset
python test_splits.py --model HHVAEM --dataset boston

Experiments

The experiments in the paper can be executed using:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning.py --model HHVAEM --dataset boston --method mi --split 0

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist --split 0

Once this is executed on all the splits, you can plot the SAIA error curves or obtain the average OoD metrics using the scripts in the run/ folder:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning_plots.py --models VAEM HHVAEM --dataset boston

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood_splits.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist


Help

Use the --help option for documentation on the usage of any of the mentioned scripts.

Contributors

Ignacio Peis
Chao Ma
José Miguel Hernández-Lobato

Contact

For further information: [email protected]

Owner
Ignacio Peis
PhD student at UC3M \\ Visitor at the Machine Learning Group, CBL, University of Cambridge
Ignacio Peis
A Python step-by-step primer for Machine Learning and Optimization

early-ML Presentation General Machine Learning tutorials A Python step-by-step primer for Machine Learning and Optimization This github repository gat

Dimitri Bettebghor 8 Dec 01, 2022
Tribuo - A Java machine learning library

Tribuo - A Java prediction library (v4.1) Tribuo is a machine learning library in Java that provides multi-class classification, regression, clusterin

Oracle 1.1k Dec 28, 2022
50% faster, 50% less RAM Machine Learning. Numba rewritten Sklearn. SVD, NNMF, PCA, LinearReg, RidgeReg, Randomized, Truncated SVD/PCA, CSR Matrices all 50+% faster

[Due to the time taken @ uni, work + hell breaking loose in my life, since things have calmed down a bit, will continue commiting!!!] [By the way, I'm

Daniel Han-Chen 1.4k Jan 01, 2023
High performance Python GLMs with all the features!

High performance Python GLMs with all the features!

QuantCo 200 Dec 14, 2022
Applied Machine Learning for Graduate Program in Computer Science (PPGCC)

Applied Machine Learning for Graduate Program in Computer Science (PPGCC) - Federal University of Santa Catarina

Jônatas Negri Grandini 1 Dec 22, 2021
🚪✊Knock Knock: Get notified when your training ends with only two additional lines of code

Knock Knock A small library to get a notification when your training is complete or when it crashes during the process with two additional lines of co

Hugging Face 2.5k Jan 07, 2023
A Powerful Serverless Analysis Toolkit That Takes Trial And Error Out of Machine Learning Projects

KXY: A Seemless API to 10x The Productivity of Machine Learning Engineers Documentation https://www.kxy.ai/reference/ Installation From PyPi: pip inst

KXY Technologies, Inc. 35 Jan 02, 2023
A Streamlit demo to interactively visualize Uber pickups in New York City

Streamlit Demo: Uber Pickups in New York City A Streamlit demo written in pure Python to interactively visualize Uber pickups in New York City. View t

Streamlit 230 Dec 28, 2022
mlpack: a scalable C++ machine learning library --

a fast, flexible machine learning library Home | Documentation | Doxygen | Community | Help | IRC Chat Download: current stable version (3.4.2) mlpack

mlpack 4.2k Jan 01, 2023
Python Extreme Learning Machine (ELM) is a machine learning technique used for classification/regression tasks.

Python Extreme Learning Machine (ELM) Python Extreme Learning Machine (ELM) is a machine learning technique used for classification/regression tasks.

Augusto Almeida 84 Nov 25, 2022
Falken provides developers with a service that allows them to train AI that can play their games

Falken provides developers with a service that allows them to train AI that can play their games. Unlike traditional RL frameworks that learn through rewards or batches of offline training, Falken is

Google Research 223 Jan 03, 2023
Python package for machine learning for healthcare using a OMOP common data model

This library was developed in order to facilitate rapid prototyping in Python of predictive machine-learning models using longitudinal medical data from an OMOP CDM-standard database.

Sontag Lab 75 Jan 03, 2023
Code for the TCAV ML interpretability project

Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV) Been Kim, Martin Wattenberg, Justin Gilmer, C

552 Dec 27, 2022
Iterative stochastic gradient descent (SGD) linear regressor with regularization

SGD-Linear-Regressor Iterative stochastic gradient descent (SGD) linear regressor with regularization Dataset: Kaggle “Graduate Admission 2” https://w

Zechen Ma 1 Oct 29, 2021
Merlion: A Machine Learning Framework for Time Series Intelligence

Merlion is a Python library for time series intelligence. It provides an end-to-end machine learning framework that includes loading and transforming data, building and training models, post-processi

Salesforce 2.8k Jan 05, 2023
Implementation of the Object Relation Transformer for Image Captioning

Object Relation Transformer This is a PyTorch implementation of the Object Relation Transformer published in NeurIPS 2019. You can find the paper here

Yahoo 158 Dec 24, 2022
Library of Stan Models for Survival Analysis

survivalstan: Survival Models in Stan author: Jacki Novik Overview Library of Stan Models for Survival Analysis Features: Variety of standard survival

Hammer Lab 122 Jan 06, 2023
2021 Machine Learning Security Evasion Competition

2021 Machine Learning Security Evasion Competition This repository contains code samples for the 2021 Machine Learning Security Evasion Competition. P

Fabrício Ceschin 8 May 01, 2022
Predicting India’s COVID-19 Third Wave with LSTM

Predicting India’s COVID-19 Third Wave with LSTM Complete project of predicting new COVID-19 cases in the next 90 days with LSTM India is seeing a ste

Samrat Dutta 4 Jan 27, 2022
Pydantic based mock data generation

This library offers powerful mock data generation capabilities for pydantic based models. It can also be used with other libraries that use pydantic as a foundation, for example SQLModel, Beanie and

Na'aman Hirschfeld 396 Dec 28, 2022