PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Overview

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models

This repository will reproduce the main results from our paper:

On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models
Erik Nijkamp*, Mitch Hill*, Tian Han, Song-Chun Zhu, and Ying Nian Wu (*equal contributions)
https://arxiv.org/abs/1903.12370
AAAI 2020.

The files train_data.py and train_toy.py are PyTorch-based implementations of Algorithm 1 for image datasets and toy 2D distributions respectively. Both files will measure and plot the diagnostic values $d_{s_t}$ and $r_t$ described in Section 3 during training. The file eval.py will sample from a saved checkpoint using either unadjusted Langevin dynamics or Metropolis-Hastings adjusted Langevin dynamics. We provide an appendix ebm-anatomy-appendix.pdf that contains further practical considerations and empirical observations.

Config Files

The folder config_locker has several JSON files that reproduce different convergent and non-convergent learning outcomes for image datasets and toy distributions. Config files for evaluation of pre-trained networks are also included. The files data_config.json, toy_config.json, and eval_config.json fully explain the parameters for train_data.py, train_toy.py, and eval.py respectively.

Executable Files

To run an experiment with train_data.py, train_toy.py, or eval.py, just specify a name for the experiment folder and the location of the JSON config file:

# directory for experiment results
EXP_DIR = './name_of/new_folder/'
# json file with experiment config
CONFIG_FILE = './path_to/config.json'

before execution.

Other Files

Network structures are located in nets.py. A download function for Oxford Flowers 102 data, plotting functions, and a toy dataset class can be found in utils.py.

Diagnostics

Energy Difference and Langevin Gradient Magnitude: Both image and toy experiments will plot $d_{s_t}$ and $r_t$ (see Section 3) over training along with correlation plots as in Figure 4 (with ACF rather than PACF).

Landscape Plots: Toy experiments will plot the density and log-density (negative energy) for ground-truth, learned energy, and short-run models. Kernel density estimation is used to obtain the short-run density.

Short-Run MCMC Samples: Image data experiments will periodically visualize the short-run MCMC samples. A batch of persistent MCMC samples will also be saved for implementations that use persistent initialization for short-run sampling.

Long-Run MCMC Samples: Image data experiments have the option to obtain long-run MCMC samples during training. When log_longrun is set to true in a data config file, the training implementation will generate long-run MCMC samples at a frequency determined by log_longrun_freq. The appearance of long-run MCMC samples indicates whether the energy function assigns probability mass in realistic regions of the image space.

Pre-trained Networks

A convergent pre-trained network and non-convergent pre-trained network for the Oxford Flowers 102 dataset are available in the Releases section of the repository. The config files eval_flowers_convergent.json and eval_flowers_convergent_mh.json are set up to evaluate flowers_convergent_net.pth. The config file eval_flowers_nonconvergent.json is set up to evaluate flowers_nonconvergent_net.pth.

Contact

Please contact Mitch Hill ([email protected]) or Erik Nijkamp ([email protected]) for any questions.

You might also like...
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

PyTorch implementation of the implicit Q-learning algorithm (IQL)
PyTorch implementation of the implicit Q-learning algorithm (IQL)

Implicit-Q-Learning (IQL) PyTorch implementation of the implicit Q-learning algorithm IQL (Paper) Currently only implemented for online learning. Offl

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

A pytorch reprelication of the model-based reinforcement learning algorithm MBPO
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).
PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).

PyGAD: Genetic Algorithm in Python PyGAD is an open-source easy-to-use Python 3 library for building the genetic algorithm and optimizing machine lear

Comments
  • Step size in Langevin Dynamics

    Step size in Langevin Dynamics

    Hi, in your code, when you do the langevin dynamics, you run x_s_t.data += - f_prime + config['epsilon'] * t.randn_like(x_s_t) However, does this mean that the step size for the gradient f_prim is 1? Should we run x_s_t.data += - 0.5*config['epsilon']**2*f_prime + config['epsilon'] * t.randn_like(x_s_t) instead?

    opened by XavierXiao 1
Releases(v1.0)
Owner
Mitch Hill
Assistant Professor of Statistics and Data Science at UCF
Mitch Hill
More than a hundred strange attractors

dysts Analyze more than a hundred chaotic systems. Basic Usage Import a model and run a simulation with default initial conditions and parameter value

William Gilpin 185 Dec 23, 2022
GoodNews Everyone! Context driven entity aware captioning for news images

This is the code for a CVPR 2019 paper, called GoodNews Everyone! Context driven entity aware captioning for news images. Enjoy! Model preview: Huge T

117 Dec 19, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 03, 2023
Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing

Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing Paper Introduction Multi-task indoor scene understanding is widely considered a

62 Dec 05, 2022
Text Extraction Formulation + Feedback Loop for state-of-the-art WSD (EMNLP 2021)

ConSeC is a novel approach to Word Sense Disambiguation (WSD), accepted at EMNLP 2021. It frames WSD as a text extraction task and features a feedback loop strategy that allows the disambiguation of

Sapienza NLP group 36 Dec 13, 2022
Keras implementation of Deeplab v3+ with pretrained weights

Keras implementation of Deeplabv3+ This repo is not longer maintained. I won't respond to issues but will merge PR DeepLab is a state-of-art deep lear

1.3k Dec 07, 2022
Revisiting Global Statistics Aggregation for Improving Image Restoration

Revisiting Global Statistics Aggregation for Improving Image Restoration Xiaojie Chu, Liangyu Chen, Chengpeng Chen, Xin Lu Paper: https://arxiv.org/pd

MEGVII Research 128 Dec 24, 2022
GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning

GradAttack is a Python library for easy evaluation of privacy risks in public gradients in Federated Learning, as well as corresponding mitigation strategies.

129 Dec 30, 2022
Vision Transformer for 3D medical image registration (Pytorch).

ViT-V-Net: Vision Transformer for Volumetric Medical Image Registration keywords: vision transformer, convolutional neural networks, image registratio

Junyu Chen 192 Dec 20, 2022
Official code for "EagerMOT: 3D Multi-Object Tracking via Sensor Fusion" [ICRA 2021]

EagerMOT: 3D Multi-Object Tracking via Sensor Fusion Read our ICRA 2021 paper here. Check out the 3 minute video for the quick intro or the full prese

Aleksandr Kim 276 Dec 30, 2022
🐦 Quickly annotate data from the comfort of your Jupyter notebook

🐦 pigeon - Quickly annotate data on Jupyter Pigeon is a simple widget that lets you quickly annotate a dataset of unlabeled examples from the comfort

Anastasis Germanidis 647 Jan 05, 2023
Official PyTorch Implementation of Learning Architectures for Binary Networks

Learning Architectures for Binary Networks An Pytorch Implementation of the paper Learning Architectures for Binary Networks (BNAS) (ECCV 2020) If you

Computer Vision Lab. @ GIST 25 Jun 09, 2022
Python scripts for performing lane detection using the LSTR model in ONNX

ONNX LSTR Lane Detection Python scripts for performing lane detection using the Lane Shape Prediction with Transformers (LSTR) model in ONNX. Requirem

Ibai Gorordo 29 Aug 30, 2022
A Python Reconnection Tool for alt:V

altv-reconnect What? It invokes a reconnect in the altV Client Dev Console. You get to determine when your local client should reconnect when developi

8 Jun 30, 2022
FlowTorch is a PyTorch library for learning and sampling from complex probability distributions using a class of methods called Normalizing Flows

FlowTorch is a PyTorch library for learning and sampling from complex probability distributions using a class of methods called Normalizing Flows.

Meta Incubator 272 Jan 02, 2023
DataCLUE: 国内首个以数据为中心的AI测评(含模型分析报告)

DataCLUE: A Benchmark Suite for Data-centric NLP You can get the english version of README. 以数据为中心的AI测评(DataCLUE) 内容导引 章节 描述 简介 介绍以数据为中心的AI测评(DataCLUE

CLUE benchmark 135 Dec 22, 2022
Towards Boosting the Accuracy of Non-Latin Scene Text Recognition

Convolutional Recurrent Neural Network + CTCLoss | STAR-Net Code for paper "Towards Boosting the Accuracy of Non-Latin Scene Text Recognition" Depende

Sanjana Gunna 7 Aug 07, 2022
Flask101 - FullStack Web Development with Python & JS - From TAQWA

Task: Create a CLI Calculator Step 0: Creating Virtual Environment $ python -m

Hossain Foysal 1 May 31, 2022
This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper

DeepShift This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper, that aims to replace multiplicati

Mostafa Elhoushi 88 Dec 23, 2022
A TensorFlow implementation of FCN-8s

FCN-8s implementation in TensorFlow Contents Overview Examples and demo video Dependencies How to use it Download pre-trained VGG-16 Overview This is

Pierluigi Ferrari 50 Aug 08, 2022