Pointer networks Tensorflow2

Overview

Pointer networks Tensorflow2

原文:https://arxiv.org/abs/1506.03134
仅供参考与学习,内含代码备注

环境

tensorflow==2.6.0
tqdm
matplotlib
numpy

《pointer networks》阅读笔记

应用场景:

文本摘要,凸包问题,Roundelay 三角剖分,旅行商问题

其中包括一些Latex,github无法渲染,所以建议clone下来用Typora查看。

abstract

本文提出一种新的网络结构:输出序列的元素是与输入序列中的位置相对应的离散标记。

an output sequence with elements that are discrete tokens corresponding to positions in an input sequence.

这种问题目前可以被一些现有的方法解决:sequence-to-sequence, neural turing machines。但是这些方法不是特别适用。

本文解决的问题是sorting variable sized sequences,以及各种组合优化问题。本模型使用attention机制来解决变化尺寸的输出。

intro

RNN模型的输出维度是固定的,sequence-to-sequence模型移除了这一个限制,通过用一个RNN把输入映射为一个embedding,又用一个RNN把embedding映射到输出序列。

但是这些sequence-to-sequence 方法都是固定大小的词汇表。

例如词汇表中只存在A,B,C。那么输入

1,2,3 ----> A,B,C

1,2,3,4 ----> A,B,C,A

本文提出的框架适用于输出的词汇表大小取决于输入问题的大小

image-20211105133740833

image-20211105134312635

左图:seq-2-seq

蓝色RNN,输出一个向量。

紫色RNN,利用概率的链式法则,输出一个固定维度。

本文的贡献如下:

  1. 提出一种新的结构,称为指针网路。简单且高效
  2. 良好的泛化性能
  3. 一个TSP近似求解器

Models

sequence-to-sequence 模型

训练数据为: $$ (P,C^P) $$ 其中,$\mathcal{P}=\left{P_{1}, \ldots, P_{n}\right}$,是n个向量。$\mathcal{C}^{\mathcal{P}}=\left{C_{1}, \ldots, C_{m(\mathcal{P})}\right}$ ,n个对应的结果,$m(\mathcal{P})\in [1,n]$ 。传统的sequence-to-sequence的$\mathcal{C}^{\mathcal{P}}$是固定大小的,但是要提前给定。本文的$\mathcal{C}^{\mathcal{P}}$为n,根据输入改变。

如果模型的参数记为$\theta$,神经网络模型表达为: $$ p(C^P|P,\theta) $$ 使用链式法则,写为: $$ p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta\right)=\prod_{i=1}^{m(\mathcal{P})} p_{\theta}\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P} ; \theta\right) $$ 训练阶段,最大似然概率: $$ \theta^{*}=\underset{\theta}{\arg \max } \sum_{\mathcal{P}, \mathcal{C}^{\mathcal{P}}} \log p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta\right) $$ input sequence的末端加一个$\Rightarrow$,代表进入生成阶段,$\Leftarrow$代表结束生成阶段。

推断: $$ \hat{\mathcal{C}}^{\mathcal{P}}=\underset{\mathcal{C}^{\mathcal{P}}}{\arg \max } p\left(\mathcal{C}^{\mathcal{P}} \mid \mathcal{P} ; \theta^{*}\right) $$

content based input attention

对于attention机制,请查看《Neural Machine Translation By Jointly Learning To Align And Translate》阅读笔记。

对于LSTM RNN $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) & j \in(1, \ldots, n) \ a_{j}^{i} &=\operatorname{softmax}\left(u_{j}^{i}\right) & j \in(1, \ldots, n) \ d_{i}^{\prime} &=\sum_{j=1}^{n} a_{j}^{i} e_{j} & \end{aligned} $$ 对于这个传统的attention机制,可以看到$u^{i}$, 是一个长度为$n$的向量。

这样的话,在解码器的每一个时间步迭代都会得到一个 n 长度的向量,可以作为指针,用于指向之前的 n 长度的序列。

Ptr-Net

所以Ptr-Net计算公式写为: $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) &=\operatorname{softmax}\left(u^{i}\right) \end{aligned} $$ image-20211111103159924

image-20211111110334755

数据以 [Batch, time_steps, feature] 的形式进入编码器LSTM(绿色部分),在时间步上迭代$n$次以后,得到:

  • n 个 e [batch, units], 可以合并写为 [batch, n, units]

  • 最后一个时间步输出的 c [batch, units]

进入到解码器LSTM(蓝色部分),输入为:

  • 上次得到解码得到的的pointer,如果是第一次则为initial pointer
  • 上次的状态d,c

pointer 如何得到?计算公式如下: $$ \begin{aligned} u_{j}^{i} &=v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) &=\operatorname{softmax}\left(u^{i}\right) \end{aligned} $$

motivation and datasets structure

文章是为了解决三种问题,凸包,Delaunay Triangulation,旅行商问题。在此只对旅行商问题进行探讨。

travelling salesman problem

给定一个城市列表,我们希望找到一条最短的路线,每个城市只访问一次,然后返回起点。此外,假设两个城市之间的距离在正反方向上是相同的。这是一个NP难问题,测试模型的能力和局限性。

数据生成:

卡迪尔坐标系(二维),$[0,1] \times[0,1]$

使用 Held-Karp algorithm 得到准确解,n最多为20。

A1,A2,A3为三种其他算法。A1,A2时间复杂度为$O\left(n^{2}\right)$,A3时间复杂度为$O\left(n^{3}\right)$。A3,Christofides algorithm 算法保证在距离最佳长度1.5倍的范围内找到解,详细信息查看原文参考文献。生成1M个数据进行训练。

image-20211111111416012

分析表格:

  1. n=5的时候,性能都很好
  2. n=10,ptr-net的性能比A1好
  3. n=50的时候,无法超过数据集性能(因为ptr-net使用不准确的答案进行训练的)
  4. 只用n少的训练,推广到大n情况,性能不太好。

对于n=30的情况,Ptr-net算法复杂度为$O(n \log n)$,远低于A1,A2,A3。却有相似的性能,说明可发展空间还是很大的。

You might also like...
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks without the use of any outside machine learning libraries - all from scratch.

Tensors and Dynamic neural networks in Python with strong GPU acceleration
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

Lightweight library to build and train neural networks in Theano

Lasagne Lasagne is a lightweight library to build and train neural networks in Theano. Its main features are: Supports feed-forward networks such as C

A flexible framework of neural networks for deep learning
A flexible framework of neural networks for deep learning

Chainer: A deep learning framework Website | Docs | Install Guide | Tutorials (ja) | Examples (Official, External) | Concepts | ChainerX Forum (en, ja

Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

Code samples for my book "Neural Networks and Deep Learning"

Code samples for "Neural Networks and Deep Learning" This repository contains code samples for my book on "Neural Networks and Deep Learning". The cod

Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

Releases(v0)
Owner
HUANG HAO
Program = Algorithm + Data structure
HUANG HAO
The official homepage of the (outdated) COCO-Stuff 10K dataset.

COCO-Stuff 10K dataset v1.1 (outdated) Holger Caesar, Jasper Uijlings, Vittorio Ferrari Overview Welcome to official homepage of the COCO-Stuff [1] da

Holger Caesar 263 Dec 11, 2022
An introduction to satellite image analysis using Python + OpenCV and JavaScript + Google Earth Engine

A Gentle Introduction to Satellite Image Processing Welcome to this introductory course on Satellite Image Analysis! Satellite imagery has become a pr

Edward Oughton 32 Jan 03, 2023
A highly efficient, fast, powerful and light-weight anime downloader and streamer for your favorite anime.

AnimDL - Download & Stream Your Favorite Anime AnimDL is an incredibly powerful tool for downloading and streaming anime. Core features Abuses the dev

KR 759 Jan 08, 2023
Improving Query Representations for DenseRetrieval with Pseudo Relevance Feedback:A Reproducibility Study.

APR The repo for the paper Improving Query Representations for DenseRetrieval with Pseudo Relevance Feedback:A Reproducibility Study. Environment setu

ielab 8 Nov 26, 2022
Face recognize and crop them

Face Recognize Cropping Module Source 아이디어 Face Alignment with OpenCV and Python Requirement 필요 라이브러리 imutil dlib python-opence (cv2) Usage 사용 방법 open

Cho Moon Gi 1 Feb 15, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
Code for testing various M1 Chip benchmarks with TensorFlow.

M1, M1 Pro, M1 Max Machine Learning Speed Test Comparison This repo contains some sample code to benchmark the new M1 MacBooks (M1 Pro and M1 Max) aga

Daniel Bourke 348 Jan 04, 2023
[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation

MonoRUn MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation. CVPR 2021. [paper] Hansheng Chen, Yuyao Huang, Wei Tian*

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 96 Dec 10, 2022
This is an early in-development version of training CLIP models with hivemind.

A transformer that does not hog your GPU memory This is an early in-development codebase: if you want a stable and documented hivemind codebase, look

<a href=[email protected]"> 4 Nov 06, 2022
Image-to-Image Translation in PyTorch

CycleGAN and pix2pix in PyTorch New: Please check out contrastive-unpaired-translation (CUT), our new unpaired image-to-image translation model that e

Jun-Yan Zhu 19k Jan 07, 2023
The code for 'Deep Residual Fourier Transformation for Single Image Deblurring'

Deep Residual Fourier Transformation for Single Image Deblurring Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang code will be released soon

145 Dec 13, 2022
The official repo for CVPR2021——ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search.

ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search [paper] Introduction This is the official implementation of ViPNAS: Efficient V

Lumin 42 Sep 26, 2022
A tool to analyze leveraged liquidity mining and find optimal option combination for hedging.

LP-Option-Hedging Description A Python program to analyze leveraged liquidity farming/mining and find the optimal option combination for hedging imper

Aureliano 18 Dec 19, 2022
PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)

PSTR (CVPR2022) This code is an official implementation of "PSTR: End-to-End One-Step Person Search With Transformers (CVPR2022)". End-to-end one-step

Jiale Cao 28 Dec 13, 2022
Little Ball of Fur - A graph sampling extension library for NetworKit and NetworkX (CIKM 2020)

Little Ball of Fur is a graph sampling extension library for Python. Please look at the Documentation, relevant Paper, Promo video and External Resour

Benedek Rozemberczki 619 Dec 14, 2022
[ ICCV 2021 Oral ] Our method can estimate camera poses and neural radiance fields jointly when the cameras are initialized at random poses in complex scenarios (outside-in scenes, even with less texture or intense noise )

GNeRF This repository contains official code for the ICCV 2021 paper: GNeRF: GAN-based Neural Radiance Field without Posed Camera. This implementation

Quan Meng 191 Dec 26, 2022
Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabri

224 Dec 29, 2022
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

Prototypical Networks for Few shot Learning in PyTorch Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code)

Orobix 93 Aug 17, 2022
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022