Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Overview

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Black areas have the highest and white areas the lowest cost. In the centre, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
UCL Natural Language Processing
UCL Natural Language Processing
A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised Learning

LABES This is the code for EMNLP 2020 paper "A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised L

17 Sep 28, 2022
AI Summer's complete catalog of articles

Learn Deep Learning with AI Summer A collection of all articles (almost 100) written for the AI Summer blog organized by topic. Deep Learning Theory M

AI Summer 95 Dec 29, 2022
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
A Kitti Road Segmentation model implemented in tensorflow.

KittiSeg KittiSeg performs segmentation of roads by utilizing an FCN based model. The model achieved first place on the Kitti Road Detection Benchmark

Marvin Teichmann 890 Jan 04, 2023
A higher performance pytorch implementation of DeepLab V3 Plus(DeepLab v3+)

A Higher Performance Pytorch Implementation of DeepLab V3 Plus Introduction This repo is an (re-)implementation of Encoder-Decoder with Atrous Separab

linhua 326 Nov 22, 2022
TensorFlow implementation of ENet, trained on the Cityscapes dataset.

segmentation TensorFlow implementation of ENet (https://arxiv.org/pdf/1606.02147.pdf) based on the official Torch implementation (https://github.com/e

Fredrik Gustafsson 248 Dec 16, 2022
A pyparsing-based library for parsing SOQL statements

CONTRIBUTORS WANTED!! Installation pip install python-soql-parser or, with poetry poetry add python-soql-parser Usage from python_soql_parser import p

Kicksaw 0 Jun 07, 2022
MSG-Transformer: Exchanging Local Spatial Information by Manipulating Messenger Tokens

MSG-Transformer Official implementation of the paper MSG-Transformer: Exchanging Local Spatial Information by Manipulating Messenger Tokens, by Jiemin

Hust Visual Learning Team 68 Nov 16, 2022
Blender Python - Node-based multi-line text and image flowchart

MindMapper v0.8 Node-based text and image flowchart for Blender Mindmap with shortcuts visible: Mindmap with shortcuts hidden: Notes This was requeste

SpectralVectors 58 Oct 08, 2022
VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations

VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations 3D-aware Image Synthesis via Learning Structural and Textura

GenForce: May Generative Force Be with You 116 Dec 26, 2022
Bib-parser - Convenient script to parse .bib files with the ACM Digital Library like metadata

Bib Parser Convenient script to parse .bib files with the ACM Digital Library li

Mehtab Iqbal (Shahan) 1 Jan 26, 2022
LIMEcraft: Handcrafted superpixel selectionand inspection for Visual eXplanations

LIMEcraft LIMEcraft: Handcrafted superpixel selectionand inspection for Visual eXplanations The LIMEcraft algorithm is an explanatory method based on

MI^2 DataLab 4 Aug 01, 2022
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

102 Dec 06, 2022
Covid-19 Test AI (Deep Learning - NNs) Software. Accuracy is the %96.5, loss is the 0.09 :)

Covid-19 Test AI (Deep Learning - NNs) Software I developed a segmentation algorithm to understand whether Covid-19 Test Photos are positive or negati

Emirhan BULUT 28 Dec 04, 2021
Script for getting information in discord

User-info.py Script for getting information in https://discord.com/ Instalação: apt-get update -y apt-get upgrade -y apt-get install git pkg install

Moleey 1 Dec 18, 2021
Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning"

Code used for the results in the paper "ClassMix: Segmentation-Based Data Augmentation for Semi-Supervised Learning" Getting started Prerequisites CUD

70 Dec 02, 2022
A setup script to generate ITK Python Wheels

ITK Python Package This project provides a setup.py script to build ITK Python binary packages and infrastructure to build ITK external module Python

Insight Software Consortium 59 Dec 14, 2022
DeFMO: Deblurring and Shape Recovery of Fast Moving Objects (CVPR 2021)

Evaluation, Training, Demo, and Inference of DeFMO DeFMO: Deblurring and Shape Recovery of Fast Moving Objects (CVPR 2021) Denys Rozumnyi, Martin R. O

Denys Rozumnyi 139 Dec 26, 2022
Rainbow DQN implementation that outperforms the paper's results on 40% of games using 20x less data 🌈

Rainbow 🌈 An implementation of Rainbow DQN which reaches a median HNS of 205.7 after only 10M frames (the original Rainbow from Hessel et al. 2017 re

Dominik Schmidt 31 Dec 21, 2022
A curated list of awesome game datasets, and tools to artificial intelligence in games

🎮 Awesome Game Datasets In computer science, Artificial Intelligence (AI) is intelligence demonstrated by machines. Its definition, AI research as th

Leonardo Mauro 454 Jan 03, 2023