Second-Order Neural ODE Optimizer, NeurIPS 2021 spotlight

Related tags

Deep Learningsnopt
Overview

Second-order Neural ODE Optimizer
(NeurIPS 2021 Spotlight) [arXiv]

✔️ faster convergence in wall-clock time | ✔️ O(1) memory cost |
✔️ better test-time performance | ✔️ architecture co-optimization

This repo provides PyTorch code of Second-order Neural ODE Optimizer (SNOpt), a second-order optimizer for training Neural ODEs that retains O(1) memory cost with superior convergence and test-time performance.

SNOpt result

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1) and torchdiffeq >= 0.2.0 are required.

  1. Install the dependencies with Anaconda and activate the environment snopt with
    conda env create --file requirements.yaml python=3
    conda activate snopt
  2. [Optional] This repo provides a modification (with 15 lines!) of torchdiffeq that allows SNOpt to collect 2nd-order information during adjoint-based training. If you wish to run torchdiffeq on other commit, simply copy-and-paste the folder to this directory then apply the provided snopt_integration.patch.
    cp -r <path_to_your_torchdiffeq_folder> .
    git apply snopt_integration.patch

Run the code

We provide example code for 8 datasets across image classification (main_img_clf.py), time-series prediction (main_time_series.py), and continuous normalizing flow (main_cnf.py). The command lines to generate similar results shown in our paper are detailed in scripts folder. Datasets will be automatically downloaded to data folder at the first call, and all results will be saved to result folder.

bash scripts/run_img_clf.sh     <dataset> # dataset can be {mnist, svhn, cifar10}
bash scripts/run_time_series.sh <dataset> # dataset can be {char-traj, art-wr, spo-ad}
bash scripts/run_cnf.sh         <dataset> # dataset can be {miniboone, gas}

For architecture (specifically integration time) co-optimization, run

bash scripts/run_img_clf.sh cifar10-t1-optimize

Integration with your workflow

snopt can be integrated flawlessly with existing training work flow. Below we provide a handy checklist and pseudo-code to help your integration. For more complex examples, please refer to main_*.py in this repo.

  • Import torchdiffeq that is patched with snopt integration; otherwise simply use torchdiffeq in this repo.
  • Inherit snopt.ODEFuncBase as your vector field; implement the forward pass in F rather than forward.
  • Create Neural ODE with ode layer(s) using snopt.ODEBlock; implement properties odes and ode_mods.
  • Initialize snopt.SNOpt as preconditioner; call train_itr_setup() and step() before standard optim.zero_grad() and optim.step() (see the code below).
  • That's it 🤓 ! Enjoy your second-order training 🚂 🚅 !
import torch
from torchdiffeq import odeint_adjoint as odesolve
from snopt import SNOpt, ODEFuncBase, ODEBlock
from easydict import EasyDict as dict

class ODEFunc(ODEFuncBase):
    def __init__(self, opt):
        super(ODEFunc, self).__init__(opt)
        self.linear = torch.nn.Linear(input_dim, input_dim)

    def F(self, t, z):
        return self.linear(z)

class NeuralODE(torch.nn.Module):
    def __init__(self, ode):
        super(NeuralODE, self).__init__()
        self.ode = ode

    def forward(self, z):
        return self.ode(z)

    @property
    def odes(self): # in case we have multiple odes, collect them in a list
        return [self.ode]

    @property
    def ode_mods(self): # modules of all ode(s)
        return [mod for mod in self.ode.odefunc.modules()]

# Create Neural ODE
opt = dict(
    optimizer='SNOpt',tol=1e-3,ode_solver='dopri5',use_adaptive_t1=False,snopt_step_size=0.01)
odefunc = ODEFunc(opt)
integration_time = torch.tensor([0.0, 1.0]).float()
ode = ODEBlock(opt, odefunc, odesolve, integration_time)
net = NeuralODE(ode)

# Create SNOpt optimizer
precond = SNOpt(net, eps=0.05, update_freq=100)
optim = torch.optim.SGD(net.parameters(), lr=0.001)

# Training loop
for (x,y) in training_loader:
    precond.train_itr_setup() # <--- additional step for precond
    optim.zero_grad()

    loss = loss_function(net(x), y)
    loss.backward()

    # Run SNOpt optimizer
    precond.step()            # <--- additional step for precond
    optim.step()

What the library actually contains

This snopt library implements the following objects for efficient 2nd-order adjoint-based training of Neural ODEs.

  • ODEFuncBase: Defines the vector field (inherits torch.nn.Module) of Neural ODE.
  • CNFFuncBase: Serves the same purposes as ODEFuncBase except for CNF applications.
  • ODEBlock: A Neural-ODE module (torch.nn.Module) that solves the initial value problem (given the vector field, integration time, and a ODE solver) and handles integration time co-optimization with feedback policy.
  • SNOpt: Our primary 2nd-order optimizer (torch.optim.Optimizer), implemented as a "preconditioner" (see example code above). It takes the following arguments.
    • net is the Neural ODE. Note that the entire network (rather than net.parameters()) is required.
    • eps is the the regularization that stabilizes preconditioning. We recommend the value in [0.05, 0.1].
    • update_freq is the frequency to refresh the 2nd-order information. We recommend the value 100~200.
    • alpha decides the running averages of eigenvalues. We recommend fixing the value to 0.75.
    • full_precond decides whether we wish to precondition layers aside from those in Neural ODEs.
  • SNOptAdjointCollector: A helper to collect information from torchdiffeq to construct 2nd-order matrices.
  • IntegrationTimeOptimizer: Our 2nd-order method that co-optimizes the integration time (i.e., t1). This is done by calling t1_train_itr_setup(train_it) and update_t1() together with optim.zero_grad() and optim.step() (see trainer.py).

The options are passed in as opt and contains the following fields (see options.py for full descriptions.)

  • optimizer is the training method. Use "SNOpt" to enable our method.
  • ode_solver specifies the ODE solver (default is "dopri5") with the absolute/relative tolerance tol.
  • For CNF applications, use divergence_type to specify how divergence should be computed.
  • snopt_step_size determines the step sizes SNOpt will sample along the integration to compute 2nd-order matrices. We recommend the value 0.01 for integration time [0,1], which yield around 100 sampled points.
  • For integration time (t1) co-optimization, enable the flag use_adaptive_t1 and setup the following options.
    • adaptive_t1 specifies t1 optimization method. Choices are "baseline" and "feedback"(ours).
    • t1_lr is the learning rate. We recommend the value in [0.05, 0.1].
    • t1_reg is the coefficient of the quadratic penalty imposed on t1. The performance is quite sensitive to this value. We recommend the value in [1e-4, 1e-3].
    • t1_update_freq is the frequency to update t1. We recommend the value 50~100.

Remarks & Citation

The current library only supports adjoint-based training, yet it can be extended to normal odeint method (stay tuned!). The pre-processing of tabular and uea datasets are adopted from ffjord and NeuralCDE, and the eigenvalue-regularized preconditioning is adopted from EKFAC-pytorch.

If you find this library useful, please cite ⬇️ . Contact me ([email protected]) if you have any questions!

@inproceedings{liu2021second,
  title={Second-order Neural ODE Optimizer},
  author={Liu, Guan-Horng and Chen, Tianrong and Theodorou, Evangelos A},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021},
}
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Official pytorch code for "APP: Anytime Progressive Pruning"

APP: Anytime Progressive Pruning Diganta Misra1,2,3, Bharat Runwal2,4, Tianlong Chen5, Zhangyang Wang5, Irina Rish1,3 1 Mila - Quebec AI Institute,2 L

Landskape AI 12 Nov 22, 2022
A curated list of long-tailed recognition resources.

Awesome Long-tailed Recognition A curated list of long-tailed recognition and related resources. Please feel free to pull requests or open an issue to

Zhiwei ZHANG 542 Jan 01, 2023
Semantic similarity computation with different state-of-the-art metrics

Semantic similarity computation with different state-of-the-art metrics Description • Installation • Usage • License Description TaxoSS is a semantic

6 Jun 22, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022
Source code for "Interactive All-Hex Meshing via Cuboid Decomposition [SIGGRAPH Asia 2021]".

Interactive All-Hex Meshing via Cuboid Decomposition Video demonstration This repository contains an interactive software to the PolyCube-based hex-me

Lingxiao Li 131 Dec 05, 2022
Predict halo masses from simulations via graph neural networks

HaloGraphNet Predict halo masses from simulations via Graph Neural Networks. Given a dark matter halo and its galaxies, creates a graph with informati

Pablo Villanueva Domingo 20 Nov 15, 2022
A PaddlePaddle version of Neural Renderer, refer to its PyTorch version

Neural 3D Mesh Renderer in PadddlePaddle A PaddlePaddle version of Neural Renderer, refer to its PyTorch version Install Run: pip install neural-rende

AgentMaker 13 Jul 12, 2022
Code for NeurIPS 2021 paper: Invariant Causal Imitation Learning for Generalizable Policies

Invariant Causal Imitation Learning for Generalizable Policies Ioana Bica, Daniel Jarrett, Mihaela van der Schaar Neural Information Processing System

Ioana Bica 17 Dec 01, 2022
MT-GAN-PyTorch - PyTorch Implementation of Learning to Transfer: Unsupervised Domain Translation via Meta-Learning

MT-GAN-PyTorch PyTorch Implementation of AAAI-2020 Paper "Learning to Transfer: Unsupervised Domain Translation via Meta-Learning" Dependency: Python

29 Oct 19, 2022
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

HiT-GAN Official TensorFlow Implementation HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GA

Google Research 78 Oct 31, 2022
Codes to calculate solar-sensor zenith and azimuth angles directly from hyperspectral images collected by UAV. Works only for UAVs that have high resolution GNSS/IMU unit.

UAV Solar-Sensor Angle Calculation Table of Contents About The Project Built With Getting Started Prerequisites Installation Datasets Contributing Lic

Sourav Bhadra 1 Jan 15, 2022
[CVPR 2021] Counterfactual VQA: A Cause-Effect Look at Language Bias

Counterfactual VQA (CF-VQA) This repository is the Pytorch implementation of our paper "Counterfactual VQA: A Cause-Effect Look at Language Bias" in C

Yulei Niu 94 Dec 03, 2022
Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

3 Aug 03, 2022
Scripts for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation and a convolutional neural network (CNN) for image classification

About subwAI subwAI - a project for training an AI to play the endless runner Subway Surfers using a supervised machine learning approach by imitation

82 Jan 01, 2023
Iranian Cars Detection using Yolov5s, PyTorch

Iranian Cars Detection using Yolov5 Train 1- git clone https://github.com/ultralytics/yolov5 cd yolov5 pip install -r requirements.txt 2- Dataset ../

Nahid Ebrahimian 22 Dec 05, 2022
💊 A 3D Generative Model for Structure-Based Drug Design (NeurIPS 2021)

A 3D Generative Model for Structure-Based Drug Design Coming soon... Citation @inproceedings{luo2021sbdd, title={A 3D Generative Model for Structu

Shitong Luo 118 Jan 05, 2023
OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021)

OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021) Video demo We here provide a video demo from co

20 Nov 25, 2022
The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection .

GCoNet The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection . Trained model Download final_gconet.pth

Qi Fan 46 Nov 17, 2022
HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton

HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton Wencan Cheng, Jae Hyun Park, Jong

cwc1260 23 Oct 21, 2022
Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs

Few-shot Relation Extraction via Bayesian Meta-learning on Relation Graphs This is an implemetation of the paper Few-shot Relation Extraction via Baye

MilaGraph 36 Nov 22, 2022