PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

Overview

neural-combinatorial-rl-pytorch

PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

I have implemented the basic RL pretraining model with greedy decoding from the paper. An implementation of the supervised learning baseline model is available here. Instead of a critic network, I got my results below on TSP from using an exponential moving average critic. The critic network is simply commented out in my code right now. From correspondence with a few others, it was determined that the exponential moving average critic significantly helped improve results.

My implementation uses a stochastic decoding policy in the pointer network, realized via PyTorch's torch.multinomial(), during training, and beam search (not yet finished, only supports 1 beam a.k.a. greedy) for decoding when testing the model.

Currently, there is support for a sorting task and the planar symmetric Euclidean TSP.

See main.sh for an example of how to run the code.

Use the --load_path $LOAD_PATH and --is_train False flags to load a saved model.

To load a saved model and view the pointer network's attention layer, also use the --plot_attention True flag.

Please, feel free to notify me if you encounter any errors, or if you'd like to submit a pull request to improve this implementation.

Adding other tasks

This implementation can be extended to support other combinatorial optimization problems. See sorting_task.py and tsp_task.py for examples on how to add. The key thing is to provide a dataset class and a reward function that takes in a sample solution, selected by the pointer network from the input, and returns a scalar reward. For the sorting task, the agent received a reward proportional to the length of the longest strictly increasing subsequence in the decoded output (e.g., [1, 3, 5, 2, 4] -> 3/5 = 0.6).

Dependencies

  • Python=3.6 (should be OK with v >= 3.4)
  • PyTorch=0.2 and 0.3
  • tqdm
  • matplotlib
  • tensorboard_logger

PyTorch 0.4 compatibility is available on branch pytorch-0.4.

TSP Results

Results for 1 random seed over 50 epochs (each epoch is 10,000 batches of size 128). After each epoch, I validated performance on 1000 held out graphs. I used the same hyperparameters from the paper, as can be seen in main.sh. The dashed line shows the value indicated in Table 2 of Bello, et. al for comparison. The log scale x axis for the training reward is used to show how the tour length drops early on.

TSP 20 Train TSP 20 Val TSP 50 Train TSP 50 Val

Sort Results

I trained a model on sort10 for 4 epochs of 1,000,000 randomly generated samples. I tested it on a dataset of size 10,000. Then, I tested the same model on sort15 and sort20 to test the generalization capabilities.

Test results on 10,000 samples (A reward of 1.0 means the network perfectly sorted the input):

task average reward variance
sort10 0.9966 0.0005
sort15 0.7484 0.0177
sort20 0.5586 0.0060

Example prediction on sort10:

input: [4, 7, 5, 0, 3, 2, 6, 8, 9, 1]
output: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Attention visualization

Plot the pointer network's attention layer with the argument --plot_attention True

TODO

  • Add RL pretraining-Sampling
  • Add RL pretraining-Active Search
  • Active Search
  • Asynchronous training a la A3C
  • Refactor USE_CUDA variable
  • Finish implementing beam search decoding to support > 1 beam
  • Add support for variable length inputs

Acknowledgements

Special thanks to the repos devsisters/neural-combinatorial-rl-tensorflow and MaximumEntropy/Seq2Seq-PyTorch for getting me started, and @ricgama for figuring out that weird bug with clone()

Owner
Patrick E.
Machine Learning PhD Candidate at Univ. of Florida. Deep generative models | object-centric representation learning | RL | transportation
Patrick E.
Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach

Digital Twin Mobility Profiling: A Spatio-Temporal Graph Learning Approach This is the implementation of traffic prediction code in DTMP based on PyTo

chenxin 1 Dec 19, 2021
Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators This is our Pytorch implementation for t

RUCAIBox 12 Jul 22, 2022
A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

Jayson Reis 94 Nov 21, 2022
PixelPyramids: Exact Inference Models from Lossless Image Pyramids (ICCV 2021)

PixelPyramids: Exact Inference Models from Lossless Image Pyramids This repository contains the PyTorch implementation of the paper PixelPyramids: Exa

Visual Inference Lab @TU Darmstadt 8 Dec 11, 2022
Gray Zone Assessment

Gray Zone Assessment Get started Clone github repository git clone https://github.com/andreanne-lemay/gray_zone_assessment.git Build docker image dock

1 Jan 08, 2022
Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks

Stochastic Downsampling for Cost-Adjustable Inference and Improved Regularization in Convolutional Networks (SDPoint) This repository contains the cod

Jason Kuen 17 Jul 04, 2022
基于DouZero定制AI实战欢乐斗地主

DouZero_For_Happy_DouDiZhu: 将DouZero用于欢乐斗地主实战 本项目基于DouZero 环境配置请移步项目DouZero 模型默认为WP,更换模型请修改start.py中的模型路径 运行main.py即可 SL (baselines/sl/): 基于人类数据进行深度学习

1.5k Jan 08, 2023
DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs

DeepI2I: Enabling Deep Hierarchical Image-to-Image Translation by Transferring from GANs Abstract: Image-to-image translation has recently achieved re

yaxingwang 23 Apr 14, 2022
Tensorflow 2 implementation of our high quality frame interpolation neural network

FILM: Frame Interpolation for Large Scene Motion Project | Paper | YouTube | Benchmark Scores Tensorflow 2 implementation of our high quality frame in

Google Research 1.6k Dec 28, 2022
A series of convenience functions to make basic image processing operations such as translation, rotation, resizing, skeletonization, and displaying Matplotlib images easier with OpenCV and Python.

imutils A series of convenience functions to make basic image processing functions such as translation, rotation, resizing, skeletonization, and displ

Adrian Rosebrock 4.3k Jan 08, 2023
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
PyTorch implementation of the ExORL: Exploratory Data for Offline Reinforcement Learning

ExORL: Exploratory Data for Offline Reinforcement Learning This is an original PyTorch implementation of the ExORL framework from Don't Change the Alg

Denis Yarats 52 Jan 01, 2023
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
[AAAI-2021] Visual Boundary Knowledge Translation for Foreground Segmentation

Trans-Net Code for (Visual Boundary Knowledge Translation for Foreground Segmentation, AAAI2021). [https://ojs.aaai.org/index.php/AAAI/article/view/16

ZJU-VIPA 2 Mar 04, 2022
Official page of Struct-MDC (RA-L'22 with IROS'22 option); Depth completion from Visual-SLAM using point & line features

Struct-MDC (click the above buttons for redirection!) Official page of "Struct-MDC: Mesh-Refined Unsupervised Depth Completion Leveraging Structural R

Urban Robotics Lab. @ KAIST 37 Dec 22, 2022
Interpretable and Generalizable Person Re-Identification with Query-Adaptive Convolution and Temporal Lifting

QAConv Interpretable and Generalizable Person Re-Identification with Query-Adaptive Convolution and Temporal Lifting This PyTorch code is proposed in

Shengcai Liao 166 Dec 28, 2022
SSD: A Unified Framework for Self-Supervised Outlier Detection [ICLR 2021]

SSD: A Unified Framework for Self-Supervised Outlier Detection [ICLR 2021] Pdf: https://openreview.net/forum?id=v5gjXpmR8J Code for our ICLR 2021 pape

Princeton INSPIRE Research Group 113 Nov 27, 2022
The 3rd place solution for competition

The 3rd place solution for competition "Lyft Motion Prediction for Autonomous Vehicles" at Kaggle Team behind this solution: Artsiom Sanakoyeu [Homepa

Artsiom 104 Nov 22, 2022
An Implicit Function Theorem (IFT) optimizer for bi-level optimizations

iftopt An Implicit Function Theorem (IFT) optimizer for bi-level optimizations. Requirements Python 3.7+ PyTorch 1.x Installation $ pip install git+ht

The Money Shredder Lab 2 Dec 02, 2021
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

248 Dec 04, 2022