Hierarchical Few-Shot Generative Models

Overview

Hierarchical Few-Shot Generative Models

Giorgio Giannone, Ole Winther

This repo contains code and experiments for the paper Hierarchical Few-Shot Generative Models.


Settings

Clone the repo:

git clone https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models
cd hierarchical-few-shot-generative-models

Create and activate the conda env:

conda env create -f environment.yml
conda activate hfsgm

The code has been tested on Ubuntu 18.04, Python 3.6 and CUDA 11.3

We use wandb for visualization. The first time you run the code you will need to login.

Data

We provide preprocessed Omniglot dataset.

From the main folder, copy the data in data/omniglot_ns/:

wget https://github.com/georgosgeorgos/hierarchical-few-shot-generative-models/releases/download/Omniglot/omni_train_val_test.pkl

For CelebA you need to download the dataset from here.

Dataset

In dataset we provide utilities to process and augment datasets in the few-shot setting. Each dataset is a large collection of small sets. Sets can be created dynamically. The dataset/base.py file collects basic info about the datasets. For binary datasets (omniglot_ns.py) we augment using flipping and rotations. For RGB datasets (celeba.py) we use only flipping.

Experiment

In experiment we implement scripts for model evaluation, experiments and visualizations.

  • attention.py - visualize attention weights and heads for models with learnable aggregations (LAG).
  • cardinality.py - compute ELBOs for different input set size: [1, 2, 5, 10, 20].
  • classifier_mnist.py - few-shot classifiers on MNIST.
  • kl_layer.py - compute KL over z and c for each layer in latent space.
  • marginal.py - compute approximate log-marginal likelihood with 1K importance samples.
  • refine_vis.py - visualize refined samples.
  • sampling_rgb.py - reconstruction, conditional, refined, unconditional sampling for RGB datasets.
  • sampling_transfer.py - reconstruction, conditional, refined, unconditional sampling on transfer datasets.
  • sampling.py - reconstruction, conditional, refined, unconditional sampling for binary datasets.
  • transfer.py - compute ELBOs on MNIST, DoubleMNIST, TripleMNIST.

Model

In model we implement baselines and model variants.

  • base.py - base class for all the models.
  • vae.py - Variational Autoencoder (VAE).
  • ns.py - Neural Statistician (NS).
  • tns.py - NS with learnable aggregation (NS-LAG).
  • cns.py - NS with convolutional latent space (CNS).
  • ctns.py - CNS with learnable aggregation (CNS-LAG).
  • hfsgm.py - Hierarchical Few-Shot Generative Model (HFSGM).
  • thfsgm.py - HFSGM with learnable aggregation (HFSGM-LAG).
  • chfsgm.py - HFSGM with convolutional latent space (CHFSGM).
  • cthfsgm.py - CHFSGM with learnable aggregation (CHFSGM-LAG).

Script

Scripts used for training the models in the paper.

To run a CNS on Omniglot:

sh script/main_cns.sh GPU_NUMBER omniglot_ns

Train a model

To train a generic model run:

python main.py --name {VAE, NS, CNS, CTNS, CHFSGM, CTHFSGM} \
               --model {vae, ns, cns, ctns, chfsgm, cthfsgm} \
               --augment \
               --dataset omniglot_ns \
               --likelihood binary \
               --hidden-dim 128 \
               --c-dim 32 \
               --z-dim 32 \
               --output-dir /output \
               --alpha-step 0.98 \
               --alpha 2 \
               --adjust-lr \
               --scheduler plateau \
               --sample-size {2, 5, 10} \
               --sample-size-test {2, 5, 10} \
               --num-classes 1 \
               --learning-rate 1e-4 \
               --epochs 400 \
               --batch-size 100 \
               --tag (optional string)

If you do not want to save logs, use the flag --dry_run. This flag will call utils/trainer_dry.py instead of trainer.py.


Acknowledgments

A lot of code and ideas borrowed from:

You might also like...
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation
[CVPR 2021] Few-shot 3D Point Cloud Semantic Segmentation

Few-shot 3D Point Cloud Semantic Segmentation Created by Na Zhao from National University of Singapore Introduction This repository contains the PyTor

Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

The implementation of PEMP in paper
The implementation of PEMP in paper "Prior-Enhanced Few-Shot Segmentation with Meta-Prototypes"

Prior-Enhanced network with Meta-Prototypes (PEMP) This is the PyTorch implementation of PEMP. Overview of PEMP Meta-Prototypes & Adaptive Prototypes

Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Releases(Omniglot)
Owner
Giorgio Giannone
Science is built up with data, as a house is with stones. But a collection of data is no more a science than a heap of stones is a house. (J.H. Poincaré)
Giorgio Giannone
Simulated garment dataset for virtual try-on

Simulated garment dataset for virtual try-on This repository contains the dataset used in the following papers: Self-Supervised Collision Handling via

33 Dec 20, 2022
MEDS: Enhancing Memory Error Detection for Large-Scale Applications

MEDS: Enhancing Memory Error Detection for Large-Scale Applications Prerequisites cmake and clang Build MEDS supporting compiler $ make Build Using Do

Secomp Lab at Purdue University 34 Dec 14, 2022
Official repository for Jia, Raghunathan, Göksel, and Liang, "Certified Robustness to Adversarial Word Substitutions" (EMNLP 2019)

Certified Robustness to Adversarial Word Substitutions This is the official GitHub repository for the following paper: Certified Robustness to Adversa

Robin Jia 38 Oct 16, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

51 Dec 01, 2022
A small fun project using python OpenCV, mediapipe, and pydirectinput

Here I tried a small fun project using python OpenCV, mediapipe, and pydirectinput. Here we can control moves car game when yellow color come to right box (press key 'd') left box (press key 'a') lef

Sameh Elisha 3 Nov 17, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Synthetic LiDAR sequential point cloud dataset with point-wise annotations

SynLiDAR dataset: Learning From Synthetic LiDAR Sequential Point Cloud This is official repository of the SynLiDAR dataset. For technical details, ple

78 Dec 27, 2022
[SIGGRAPH Asia 2019] Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning

AGIS-Net Introduction This is the official PyTorch implementation of the Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning. paper | suppl

Yue Gao 102 Jan 02, 2023
Luminaire is a python package that provides ML driven solutions for monitoring time series data.

A hands-off Anomaly Detection Library Table of contents What is Luminaire Quick Start Time Series Outlier Detection Workflow Anomaly Detection for Hig

Zillow 670 Jan 02, 2023
Keras implementation of the GNM model in paper ’Graph-Based Semi-Supervised Learning with Nonignorable Nonresponses‘

Graph-based joint model with Nonignorable Missingness (GNM) This is a Keras implementation of the GNM model in paper ’Graph-Based Semi-Supervised Lear

Fan Zhou 2 Apr 17, 2022
General Vision Benchmark, a project from OpenGVLab

Introduction We build GV-B(General Vision Benchmark) on Classification, Detection, Segmentation and Depth Estimation including 26 datasets for model e

174 Dec 27, 2022
Extracts data from the database for a graph-node and stores it in parquet files

subgraph-extractor Extracts data from the database for a graph-node and stores it in parquet files Installation For developing, it's recommended to us

Cardstack 0 Jan 10, 2022
Towards Improving Embedding Based Models of Social Network Alignment via Pseudo Anchors

PSML paper: Towards Improving Embedding Based Models of Social Network Alignment via Pseudo Anchors PSML_IONE,PSML_ABNE,PSML_DEEPLINK,PSML_SNNA: numpy

13 Nov 27, 2022
IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling

IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling This is my code, data and approach for the IEEE-CIS Technical Challen

3 Sep 18, 2022
Node-level Graph Regression with Deep Gaussian Process Models

Node-level Graph Regression with Deep Gaussian Process Models Prerequests our implementation is mainly based on tensorflow 1.x and gpflow 1.x: python

1 Jan 16, 2022
Recursive Bayesian Networks

Recursive Bayesian Networks This repository contains the code to reproduce the results from the NeurIPS 2021 paper Lieck R, Rohrmeier M (2021) Recursi

Robert Lieck 11 Oct 18, 2022
A faster pytorch implementation of faster r-cnn

A Faster Pytorch Implementation of Faster R-CNN Write at the beginning [05/29/2020] This repo was initaited about two years ago, developed as the firs

Jianwei Yang 7.1k Jan 01, 2023
Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.

ConvNeXt-TF This repository provides TensorFlow / Keras implementations of different ConvNeXt [1] variants. It also provides the TensorFlow / Keras mo

Sayak Paul 87 Dec 06, 2022
Retrieve and analysis data from SDSS (Sloan Digital Sky Survey)

Author: Behrouz Safari License: MIT sdss A python package for retrieving and analysing data from SDSS (Sloan Digital Sky Survey) Installation Install

Behrouz 3 Oct 28, 2022
Boundary-aware Transformers for Skin Lesion Segmentation

Boundary-aware Transformers for Skin Lesion Segmentation Introduction This is an official release of the paper Boundary-aware Transformers for Skin Le

Jiacheng Wang 79 Dec 16, 2022