Neural Turing Machines (NTM) - PyTorch Implementation

Overview

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having larger capacity, without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

Features

  • Batch learning support
  • Numerically stable
  • Flexible head configuration - use X read heads and Y write heads and specify the order of operation
  • copy and repeat-copy experiments agree with the paper

Copy Task

The Copy task tests the NTM's ability to store and recall a long sequence of arbitrary information. The input to the network is a random sequence of bits, ending with a delimiter. The sequence lengths are randomised between 1 to 20.

Training

Training convergence for the copy task using 4 different seeds (see the notebook for details)

NTM Convergence

The following plot shows the cost per sequence length during training. The network was trained with seed=10 and shows fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Evaluation

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

Copy Task

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

Copy Task


Repeat Copy Task

The Repeat Copy task tests whether the NTM can learn a simple nested function, and invoke it by learning to execute a for loop. The input to the network is a random sequence of bits, followed by a delimiter and a scalar value that represents the number of repetitions to output. The number of repetitions, was normalized to have zero mean and variance of one (as in the paper). Both the length of the sequence and the number of repetitions are randomised between 1 to 10.

Training

Training convergence for the repeat-copy task using 4 different seeds (see the notebook for details)

NTM Convergence

Evaluation

The following image shows the input presented to the network, a sequence of bits + delimiter + num-reps scalar. Specifically the sequence length here is eight and the number of repetitions is five.

Repeat Copy Task

And here's the output the network had predicted:

Repeat Copy Task

Here's an animated GIF that shows how the network learns to predict the targets. Specifically, the network was evaluated in each checkpoint saved during training with the same input sequence.

Repeat Copy Task

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

Execute ./train.py

usage: train.py [-h] [--seed SEED] [--task {copy,repeat-copy}] [-p PARAM]
                [--checkpoint-interval CHECKPOINT_INTERVAL]
                [--checkpoint-path CHECKPOINT_PATH]
                [--report-interval REPORT_INTERVAL]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Seed value for RNGs
  --task {copy,repeat-copy}
                        Choose the task to train (default: copy)
  -p PARAM, --param PARAM
                        Override model params. Example: "-pbatch_size=4
                        -pnum_heads=2"
  --checkpoint-interval CHECKPOINT_INTERVAL
                        Checkpoint interval (default: 1000). Use 0 to disable
                        checkpointing
  --checkpoint-path CHECKPOINT_PATH
                        Path for saving checkpoint data (default: './')
  --report-interval REPORT_INTERVAL
                        Reporting interval
Owner
Guy Zana
I make things, author of Curated Papers
Guy Zana
NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows

NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows This repo contains the code for the paper Tractable Densit

Layer6 Labs 4 Dec 12, 2022
A High-Performance Distributed Library for Large-Scale Bundle Adjustment

MegBA: A High-Performance and Distributed Library for Large-Scale Bundle Adjustment This repo contains an official implementation of MegBA. MegBA is a

旷视研究院 3D 组 336 Dec 27, 2022
Network Pruning That Matters: A Case Study on Retraining Variants (ICLR 2021)

Network Pruning That Matters: A Case Study on Retraining Variants (ICLR 2021)

Duong H. Le 18 Jun 13, 2022
Backend code to use MCPI's python API to make infinite worlds with custom generation

inf-mcpi Backend code to use MCPI's python API to make infinite worlds with custom generation Does not save player-placed blocks! Generation is still

5 Oct 04, 2022
A python comtrade load library accelerated by go

Comtrade-GRPC Code for python used is mainly from dparrini/python-comtrade. Just patch the code in BinaryDatReader.parse for parsing a little more eff

Bo 1 Dec 27, 2021
load .txt to train YOLOX, same as Yolo others

YOLOX train your data you need generate data.txt like follow format (per line- one image). prepare one data.txt like this: img_path1 x1,y1,x2,y2,clas

LiMingf 18 Aug 18, 2022
An efficient and easy-to-use deep learning model compression framework

TinyNeuralNetwork 简体中文 TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework, which contains features like neura

Alibaba 441 Dec 25, 2022
YOLOV4运行在嵌入式设备上

在嵌入式设备上实现YOLO V4 tiny 在嵌入式设备上实现YOLO V4 tiny 目录结构 目录结构 |-- YOLO V4 tiny |-- .gitignore |-- LICENSE |-- README.md |-- test.txt |-- t

Liu-Wei 6 Sep 09, 2021
Medical Image Segmentation using Squeeze-and-Expansion Transformers

Medical Image Segmentation using Squeeze-and-Expansion Transformers Introduction This repository contains the code of the IJCAI'2021 paper 'Medical Im

askerlee 172 Dec 20, 2022
Code for "Long-tailed Distribution Adaptation"

Long-tailed Distribution Adaptation (Accepted in ACM MM2021) This project is built upon BBN. Installation pip install -r requirements.txt Usage Traini

Zhiliang Peng 10 May 18, 2022
Buffon’s needle: one of the oldest problems in geometric probability

Buffon-s-Needle Buffon’s needle is one of the oldest problems in geometric proba

3 Feb 18, 2022
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN ⠀ A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Benedek Rozemberczki 393 Dec 13, 2022
Code for the paper "Balancing Training for Multilingual Neural Machine Translation, ACL 2020"

Balancing Training for Multilingual Neural Machine Translation Implementation of the paper Balancing Training for Multilingual Neural Machine Translat

Xinyi Wang 21 May 18, 2022
A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

18 Dec 30, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
On the Limits of Pseudo Ground Truth in Visual Camera Re-Localization

On the Limits of Pseudo Ground Truth in Visual Camera Re-Localization This repository contains the evaluation code and alternative pseudo ground truth

Torsten Sattler 36 Dec 22, 2022
Official repository for the paper F, B, Alpha Matting

FBA Matting Official repository for the paper F, B, Alpha Matting. This paper and project is under heavy revision for peer reviewed publication, and s

Marco Forte 404 Jan 05, 2023
Code for paper Adaptively Aligned Image Captioning via Adaptive Attention Time

Adaptively Aligned Image Captioning via Adaptive Attention Time This repository includes the implementation for Adaptively Aligned Image Captioning vi

Lun Huang 45 Aug 27, 2022
A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

Facebook Research 536 Jan 06, 2023
Main repository for the HackBio'2021 Virtual Internship Experience for #Team-Greider ❤️

Hello 🤟 #Team-Greider The team of 20 people for HackBio'2021 Virtual Bioinformatics Internship 💝 🖨️ 👨‍💻 HackBio: https://thehackbio.com 💬 Ask us

Siddhant Sharma 7 Oct 20, 2022