Skip to content

EMalagoli92/PonderNet-TensorFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

77 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PonderNet - TensorFlow

TensorFlow 2.X reimplementation of PonderNet: Learning to Ponder, Andrea Banino, Jan Balaguer, Charles Blundell.

Table of contents

Abstract

In standard neural networks the amount of computation used grows with the size of the inputs, but not with the complexity of the problem being learnt. To overcome this limitation we introduce PonderNet, a new algorithm that learns to adapt the amount of computation based on the complexity of the problem at hand. PonderNet learns end-to-end the number of computational steps to achieve an effective compromise between training prediction accuracy, computational cost and generalization. On a complex synthetic problem, PonderNet dramatically improves performance over previous adaptive computation methods and additionally succeeds at extrapolation tests where traditional neural networks fail. Also, our method matched the current state of the art results on a real world question and answering dataset, but using less compute. Finally, PonderNet reached state of the art results on a complex task designed to test the reasoning capabilities of neural networks.

Experiment on Parity Task

The input of the parity task is a vector with 0's 1's and −1's. The output is the parity of 1's - one if there is an odd number of 1's and zero otherwise. The input is generated by making a random number of elements in the vector either 1 or −1's.

Alt text

Performance on the parity task. a) Interpolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. b) Extrapolation. Top: accuracy for both PonderNet (blue) and ACT (orange). Bottom: number of ponder steps at evaluation time. Error bars are calculated over 10 random seeds. c) Total number of compute steps calculated as the number of actual forward passes performed by each network. Blue is PonderNet, Green is ACT and Orange is an RNN without adaptive compute.

Installation

Clone the repo and install necessary packages

git clone https://github.com/EMalagoli92/PonderNet-TensorFlow.git
pip install -r requirements.txt

Tested on Ubuntu 20.04.4 LTS x86_64, python 3.9.7.

Usage

Train a PonderNet on Parity Task

python __main__.py    

Acknowledgement

PonderNet (Official PyTorch Implementation)

Citations

@misc{banino2021pondernet,
      title={PonderNet: Learning to Ponder}, 
      author={Andrea Banino and Jan Balaguer and Charles Blundell},
      year={2021},
      eprint={2107.05407},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

This work is made available under the MIT License

About

TensorFlow 2.X reimplementation of PonderNet: Learning to Ponder, Andrea Banino, Jan Balaguer, Charles Blundell.

Topics

Resources

License

Stars

Watchers

Forks

Languages