Differentiable Optimizers with Perturbations in Pytorch

Overview

Differentiable Optimizers with Perturbations in PyTorch

This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tensorflow. All credit belongs to the original authors which can be found below. The source code, tests, and examples given below are a one-to-one copy of the original work, but with pure PyTorch implementations.

Overview

We propose in this work a universal method to transform any optimizer in a differentiable approximation. We provide a PyTorch implementation, illustrated here on some examples.

Perturbed argmax

We start from an original optimizer, an argmax function, computed on an example input theta.

import torch
import torch.nn.functional as F
import perturbations

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def argmax(x, axis=-1):
    return F.one_hot(torch.argmax(x, dim=axis), list(x.shape)[axis]).float()

This function returns a one-hot corresponding to the largest input entry.

>>> argmax(torch.tensor([-0.6, 1.9, -0.2, 1.1, -1.0]))
tensor([0., 1., 0., 0., 0.])

It is possible to modify the function by creating a perturbed optimizer, using Gumbel noise.

pert_argmax = perturbations.perturbed(argmax,
                                      num_samples=1000000,
                                      sigma=0.5,
                                      noise='gumbel',
                                      batched=False,
                                      device=device)
>>> theta = torch.tensor([-0.6, 1.9, -0.2, 1.1, -1.0], device=device)
>>> pert_argmax(theta)
tensor([0.0055, 0.8150, 0.0122, 0.1648, 0.0025], device='cuda:0')

In this particular case, it is equal to the usual softmax with exponential weights.

>>> sigma = 0.5
>>> F.softmax(theta/sigma, dim=-1)
tensor([0.0055, 0.8152, 0.0122, 0.1646, 0.0025], device='cuda:0')

Batched version

The original function can accept a batch dimension, and is applied to every element of the batch.

theta_batch = torch.tensor([[-0.6, 1.9, -0.2, 1.1, -1.0],
                            [-0.6, 1.0, -0.2, 1.8, -1.0]], device=device, requires_grad=True)
>>> argmax(theta_batch)
tensor([[0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.]], device='cuda:0')

Likewise, if the argument batched is set to True (its default value), the perturbed optimizer can handle a batch of inputs.

pert_argmax = perturbations.perturbed(argmax,
                                      num_samples=1000000,
                                      sigma=0.5,
                                      noise='gumbel',
                                      batched=True,
                                      device=device)
>>> pert_argmax(theta_batch)
tensor([[0.0055, 0.8158, 0.0122, 0.1640, 0.0025],
        [0.0066, 0.1637, 0.0147, 0.8121, 0.0030]], device='cuda:0')

It can be compared to its deterministic version, the softmax.

>>> F.softmax(theta_batch/sigma, dim=-1)
tensor([[0.0055, 0.8152, 0.0122, 0.1646, 0.0025],
        [0.0067, 0.1639, 0.0149, 0.8116, 0.0030]], device='cuda:0')

Decorator version

It is also possible to use the perturbed function as a decorator.

@perturbations.perturbed(num_samples=1000000, sigma=0.5, noise='gumbel', batched=True, device=device)
def argmax(x, axis=-1):
  	return F.one_hot(torch.argmax(x, dim=axis), list(x.shape)[axis]).float()
>>> argmax(theta_batch)
tensor([[0.0054, 0.8148, 0.0121, 0.1652, 0.0024],
        [0.0067, 0.1639, 0.0148, 0.8116, 0.0029]], device='cuda:0')

Gradient computation

The Perturbed optimizers are differentiable, and the gradients can be computed with stochastic estimation automatically. In this case, it can be compared directly to the gradient of softmax.

output = pert_argmax(theta_batch)
square_norm = torch.linalg.norm(output)
square_norm.backward(torch.ones_like(square_norm))
grad_pert = theta_batch.grad
>>> grad_pert
tensor([[-0.0072,  0.1708, -0.0132, -0.1476, -0.0033],
        [-0.0068, -0.1464, -0.0173,  0.1748, -0.0046]], device='cuda:0')

Compared to the same computations with a softmax.

output = F.softmax(theta_batch/sigma, dim=-1)
square_norm = torch.linalg.norm(output)
square_norm.backward(torch.ones_like(square_norm))
grad_soft = theta_batch.grad
>>> grad_soft
tensor([[-0.0064,  0.1714, -0.0142, -0.1479, -0.0029],
        [-0.0077, -0.1457, -0.0170,  0.1739, -0.0035]], device='cuda:0')

Perturbed OR

The OR function over the signs of inputs, that is an example of optimizer, offers a well-interpretable visualization.

def hard_or(x):
    s = ((torch.sign(x) + 1) / 2.0).type(torch.bool)
    result = torch.any(s, dim=-1)
    return result.type(torch.float) * 2.0 - 1

In the following batch of two inputs, both instances are evaluated as True (value 1).

theta = torch.tensor([[-5., 0.2],
                      [-5., 0.1]], device=device)
>>> hard_or(theta)
tensor([1., 1.])

Computing a perturbed OR operator over 1000 samples shows the difference in value for these two inputs.

pert_or = perturbations.perturbed(hard_or,
                                  num_samples=1000,
                                  sigma=0.1,
                                  noise='gumbel',
                                  batched=True,
                                  device=device)
>>> pert_or(theta)
tensor([1.0000, 0.8540], device='cuda:0')

This can be vizualized more broadly, for values between -1 and 1, as well as the evaluated values of the gradient.

Perturbed shortest path

This framework can also be easily applied to more complex optimizers, such as a blackbox shortest paths solver (here the function shortest_path). We consider a small example on 9 nodes, illustrated here with the shortest path between 0 and 8 in bold, and edge costs labels.

We also consider a function of the perturbed solution: the weight of this solution on the edgebetween nodes 6 and 8.

A gradient of this function with respect to a vector of four edge costs (top-rightmost, between nodes 4, 5, 6, and 8) is automatically computed. This can be used to increase the weight on this edge of the solution by changing these four costs. This is challenging to do with first-order methods using only an original optimizer, as its gradient would be zero almost everywhere.

final_edges_costs = torch.tensor([0.4, 0.1, 0.1, 0.1], device=device, requires_grad=True)
weights = edge_costs_to_weights(final_edges_costs)

@perturbations.perturbed(num_samples=100000, sigma=0.05, batched=False, device=device)
def perturbed_shortest_path(weights):
    return shortest_path(weights, symmetric=False)

We obtain a perturbed solution to the shortest path problem on this graph, an average of solutions under perturbations on the weights.

>>> perturbed_shortest_path(weights)
tensor([[0.    0.    0.001 0.025 0.    0.    0.    0.    0.   ]
        [0.    0.    0.    0.    0.023 0.    0.    0.    0.   ]
        [0.679 0.    0.    0.119 0.    0.    0.    0.    0.   ]
        [0.304 0.    0.    0.    0.    0.    0.    0.    0.   ]
        [0.    0.023 0.    0.    0.    0.898 0.    0.    0.   ]
        [0.    0.    0.001 0.    0.    0.    0.896 0.    0.   ]
        [0.    0.    0.    0.    0.    0.001 0.    0.974 0.   ]
        [0.    0.    0.797 0.178 0.    0.    0.    0.    0.   ]
        [0.    0.    0.    0.    0.921 0.    0.079 0.    0.   ]])

For illustration, this solution can be represented with edge width proportional to the weight of the solution.

We consider an example of scalar function on this solution, here the weight of the perturbed solution on the edge from node 6 to 8 (of current value 0.079).

def i_to_j_weight_fn(i, j, paths):
    return paths[..., i, j]

weights = edge_costs_to_weights(final_edges_costs)
pert_paths = perturbed_shortest_path(weights)
i_to_j_weight = pert_paths[..., 8, 6]
i_to_j_weight.backward(torch.ones_like(i_to_j_weight))
grad = final_edges_costs.grad

This provides a direction in which to modify the vector of four edge costs, to increase the weight on this solution, obtained thanks to our perturbed version of the optimizer.

>>> grad
tensor([-2.0993764,  2.076386 ,  2.042395 ,  2.0411625], device='cuda:0')

Running gradient ascent for 30 steps on this vector of four edge costs to increase the weight of the edge from 6 to 8 modifies the problem. Its new perturbed solution has a corresponding edge weight of 0.989. The new problem and its perturbed solution can be vizualized as follows.

References

Berthet Q., Blondel M., Teboul O., Cuturi M., Vert J.-P., Bach F., Learning with Differentiable Perturbed Optimizers, NeurIPS 2020

License

Please see the original repository for proper details.

Owner
Jake Tuero
PhD student at University of Alberta
Jake Tuero
SOLO and SOLOv2 for instance segmentation, ECCV 2020 & NeurIPS 2020.

SOLO: Segmenting Objects by Locations This project hosts the code for implementing the SOLO algorithms for instance segmentation. SOLO: Segmenting Obj

Xinlong Wang 1.5k Dec 31, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma 🔥 News 2021-10

Jingtao Zhan 99 Dec 27, 2022
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
Multi-label Co-regularization for Semi-supervised Facial Action Unit Recognition (NeurIPS 2019)

MLCR This is the source code for paper Multi-label Co-regularization for Semi-supervised Facial Action Unit Recognition. Xuesong Niu, Hu Han, Shiguang

Edson-Niu 60 Nov 29, 2022
BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation

BEAMetrics: Benchmark to Evaluate Automatic Metrics in Natural Language Generation Installing The Dependencies $ conda create --name beametrics python

7 Jul 04, 2022
A PyTorch implementation for PyramidNets (Deep Pyramidal Residual Networks)

A PyTorch implementation for PyramidNets (Deep Pyramidal Residual Networks) This repository contains a PyTorch implementation for the paper: Deep Pyra

Greg Dongyoon Han 262 Jan 03, 2023
Official implementation of "Robust channel-wise illumination estimation"

This repository provides the official implementation of "Robust channel-wise illumination estimation." accepted in BMVC (2021).

Firas Laakom 4 Nov 08, 2022
Deep Face Recognition in PyTorch

Face Recognition in PyTorch By Alexey Gruzdev and Vladislav Sovrasov Introduction A repository for different experimental Face Recognition models such

Alexey Gruzdev 141 Sep 11, 2022
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

Salesforce 334 Jan 06, 2023
Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Self-supervised learning Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive loss

Arijit Das 2 Mar 26, 2022
【ACMMM 2021】DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning

DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning (ACMMM 2021) Overview We release the code of the DSANet (Dynamic S

Wenhao Wu 46 Dec 27, 2022
Some experiments with tennis player aging curves using Hilbert space GPs in PyMC. Only experimental for now.

NOTE: This is still being developed! Setup notes This document uses Jeff Sackmann's tennis data. You can obtain it as follows: git clone https://githu

Martin Ingram 1 Jan 20, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
This porject is intented to build the most accurate model for predicting the porbability of loan default

Estimating-Loan-Default-Probability IBA ML2 Mid-project / Kaggle Competition This porject is intented to build the most accurate model for predicting

Adil Gahramanov 1 Jan 24, 2022
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
Speedy Implementation of Instance-based Learning (IBL) agents in Python

A Python library to create single or multi Instance-based Learning (IBL) agents that are built based on Instance Based Learning Theory (IBLT) 1 Instal

0 Nov 18, 2021
PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids

PSML: A Multi-scale Time-series Dataset for Machine Learning in Decarbonized Energy Grids The electric grid is a key enabling infrastructure for the a

Texas A&M Engineering Research 19 Jan 07, 2023
SNE-RoadSeg in PyTorch, ECCV 2020

SNE-RoadSeg Introduction This is the official PyTorch implementation of SNE-RoadSeg: Incorporating Surface Normal Information into Semantic Segmentati

242 Dec 20, 2022
ruptures: change point detection in Python

Welcome to ruptures ruptures is a Python library for off-line change point detection. This package provides methods for the analysis and segmentation

Charles T. 1.1k Jan 03, 2023