Neural network pruning for finding a sparse computational model for controlling a biological motor task.

Overview

MothPruning

Scientific Overview

Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for modeling complex systems. DNNs are used in a diversity of domains and have helped solve some of the most intractable problems in physics, biology, and computer science. Despite their prevalence, the use of DNNs as a modeling tool comes with some major downsides. DNNs are highly overparameterized, which often results in them being difficult to generalize and interpret, as well as being incredibly computationally expensive. Unlike DNNs, which are often trained until they reach the highest accuracy possible, biological networks have to balance performance with robustness to a noisy and dynamic environment. Biological neural systems use a variety of mechanisms to promote specialized and efficient pathways capable of performing complex tasks in the presence of noise. One such mechanism, synaptic pruning, plays a significant role in refining task-specific behaviors. Synaptic pruning results in a more sparsely connected network that can still perform complex cognitive and motor tasks. Here, we draw inspiration from biology and use DNNs and the method of neural network pruning to find a sparse computational model for controlling a biological motor task.

In this work, we use the inertial dynamics model in [2] to simulate examples of M. sexta hovering flight. These data are used to train a DNN to learn the controllers for hovering. Drawing inspiration from pruning in biological neural systems, we sparsify the network using neural network pruning. Here, we prune weights based simply on their magnitudes, removing those weights closest to zero. Insects must maneuver through high noise environments to accomplish controlled flight. It is often assumed that there is a trade-off between perfect flight control and robustness to noise and that the sensory data may be limited by the signal-to-noise ratio. Thus the network need not train for the most accurate model since in practice noise prevents high-fidelity models from exhibiting their underlying accuracy. Rather, we seek to find the sparsest model capable of performing the task given the noisy environment. We employed two methods for neural network pruning: either through manually setting weights to zero or by utilizing binary masking layers. Furthermore, the DNN is pruned sequentially, meaning groups of weights are removed slowly from the network, with retraining in-between successive prunes, until a target sparsity is reached. Monte Carlo simulations are also used to quantify the statistical distribution of network weights during pruning given random initialization of network weights.

For more information, please see our paper [1].

This is an image!

Project Description

The deep, fully-connected neural network was constructed with ten input variables and seven output variables. The initial and final state space conditions are the inputs to the network: i, i, i, i, i, i, f, f, f, and f. The network predicts the control variables and the final derivatives of the state space in its output layer: x, y, , f, f, f, and f.

After the fully-connected network is trained to a minimum error, we used the method of neural network pruning to promote sparsity between the network layers. In this work, a target sparsity (percentage of pruned network weights) is specified and the smallest magnitude weights are forced to zero. The network is then retrained until a minimum error is reached. This process is repeated until most of the weights have been pruned from the network.

The training and pruning protocols were developed using Keras with the TensorFlow backend. To scale up training for the statistical analysis of many networks, the training and pruning protocols were parallelized using the Jax framework.

To ensure weights remain pruned during retraining, we implemented the pruning functionality of a TensorFlow built toolkit called the Model Optimization Toolkit. The toolkit contains functions for pruning deep neural networks. In the Model Optimization Toolkit, pruning is achieved through the use of binary masking layers that are multiplied element-wise to each weight matrix in the network.

To be able to train and analyze many neural networks, the training and pruning protocols were parallelized in the Jax framework. Jax however does not come with a toolkit for pruning, therefore pruning by way of the binary masking matrices was coded into the training loop.

Installation

Create new conda environment with tools for generating data and training network (Note that this environment requires a GPU and the correct NVIDIA drivers).

conda env create -f environment_ODE_DL.yml

Create kernelspec (so you can see this kernel in JupyterLab).

conda activate [environment name]
python -m ipykernel install --user --name [environment name]
conda deactivate

To install Jax and Flax please follow the instructions on the Jax Github.

Data

To use the TensorFlow version of this code, you need to gerenate simulations of moth hovering for the data. The Jax version (multi-network train and prune) has data provided in this repository.

cd MothMachineLearning/Underactuated/GenerateData

and use 010_OneTorqueParallelSims.ipynb to generate the simulations.

How to use

The following guide walks through the process of training and pruning many networks in parallel using the Jax framework. However, the TensorFlow code is also provided for experimentation and visualization.

Step 1: Train networks

cd MothMachineLearning/Underactuated/TrainNetwork/multiNetPrune/

First we train and prune the desired number of networks in parallel using the Jax framework. Choose the number of networks you wish to train/prune in parallel by adjusting the numParallel parameter. You can also define the number of layers, units, and other hyperparameters. Use the command

python3 step1_train.py

to train and prune the networks in parallel.

Step 2: Evaluate at prunes

Next, the networks need to be evaulated at each prune. Use the command

python3 step2_pruneEval.py

to evaluate the networks at each prune.

Step 3: Pre-process networks

This code prepares the networks for sparse network identification (explained in the next step). It essentially just reorganizes the data. Open and run step3_preprocess.ipynb to preprocess, making sure to change modeltimestamp and the file names to the correct ones for your run.

Step 4: Find sparse networks

This codes finds the optimally sparse networks. For each network, the most pruned version whose loss is below a specified threshold (here 0.001) is kept. For example, the image below is a single network that has gone through the sequential pruning process and the red line specifies the defined threshold. For this example, the optimally sparse network is the one pruned by 94% (i.e. 6% of the original weights remain).

This is an image!

The sparse networks are collected and saved to a file called sparseNetworks.pkl. Open and run step4_findSparse.ipynb, making sure to change modeltimestamp and the file names to the correct ones for your run.

Note that if a network does not have a single prune that is below the loss threshold, it will be skipped and not included in the list of sparseNetworks. For example, if you trained and pruned 10 networks and 3 did not have a prune below a loss of 0.001, the list sparseNetworks will be length 7.

References

[1] Zahn, O., Bustamante, Jr J., Switzer, C., Daniel, T., and Kutz, J. N. (2022). Pruning deep neural networks generates a sparse, bio-inspired nonlinear controller for insect flight.

[2] Bustamante, Jr J., Ahmed, M., Deora, T., Fabien, B., and Daniel, T. (2021). Abdominal movements in insect flight reshape the role of non-aerodynamic structures for flight maneuverability. J. Integrative and Comparative Biology. In revision.

Owner
Olivia Thomas
Physics graduate student at the University of Washington
Olivia Thomas
Efficient Deep Learning Systems course

Efficient Deep Learning Systems This repository contains materials for the Efficient Deep Learning Systems course taught at the Faculty of Computer Sc

Max Ryabinin 173 Dec 29, 2022
Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation"

EgoNet Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation". This repo inclu

Shichao Li 138 Dec 09, 2022
SpineAI Bilsky Grading With Python

SpineAI-Bilsky-Grading SpineAI Paper with Code 📫 Contact Address correspondence to J.T.P.D.H. (e-mail: james_hallinan AT nuhs.edu.sg) Disclaimer This

<a href=[email protected]"> 2 Dec 16, 2021
Sparse-dense operators implementation for Paddle

Sparse-dense operators implementation for Paddle This module implements coo, csc and csr matrix formats and their inter-ops with dense matrices. Feel

北海若 3 Dec 17, 2022
Accelerating BERT Inference for Sequence Labeling via Early-Exit

Sequence-Labeling-Early-Exit Code for ACL 2021 paper: Accelerating BERT Inference for Sequence Labeling via Early-Exit Requirement: Please refer to re

李孝男 23 Oct 14, 2022
University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN

Music-Sentiment-Transfer University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN Poster: Music Sentiment Transfer

Miles Sigel 2 Jan 24, 2022
Sound Event Detection with FilterAugment

Sound Event Detection with FilterAugment Official implementation of Heavily Augmented Sound Event Detection utilizing Weak Predictions (DCASE2021 Chal

43 Aug 28, 2022
Code for project: "Learning to Minimize Remainder in Supervised Learning".

Learning to Minimize Remainder in Supervised Learning Code for project: "Learning to Minimize Remainder in Supervised Learning". Requirements and Envi

Yan Luo 0 Jul 18, 2021
This is the winning solution of the Endocv-2021 grand challange.

Endocv2021-winner [Paper] This is the winning solution of the Endocv-2021 grand challange. Dependencies pytorch # tested with 1.7 and 1.8 torchvision

Vajira Thambawita 14 Dec 03, 2022
BERTMap: A BERT-Based Ontology Alignment System

BERTMap: A BERT-based Ontology Alignment System Important Notices The relevant paper was accepted in AAAI-2022. Arxiv version is available at: https:/

KRR 36 Dec 24, 2022
Pytorch code for "State-only Imitation with Transition Dynamics Mismatch" (ICLR 2020)

This repo contains code for our paper State-only Imitation with Transition Dynamics Mismatch published at ICLR 2020. The code heavily uses the RL mach

20 Sep 08, 2022
Autonomous Perception: 3D Object Detection with Complex-YOLO

Autonomous Perception: 3D Object Detection with Complex-YOLO LiDAR object detect

Thomas Dunlap 2 Feb 18, 2022
PyTorch implementation for "HyperSPNs: Compact and Expressive Probabilistic Circuits", NeurIPS 2021

HyperSPN This repository contains code for the paper: HyperSPNs: Compact and Expressive Probabilistic Circuits "HyperSPNs: Compact and Expressive Prob

8 Nov 08, 2022
Official code for our EMNLP2021 Outstanding Paper MindCraft: Theory of Mind Modeling for Situated Dialogue in Collaborative Tasks

MindCraft Authors: Cristian-Paul Bara*, Sky CH-Wang*, Joyce Chai This is the official code repository for the paper (arXiv link): Cristian-Paul Bara,

Situated Language and Embodied Dialogue (SLED) Research Group 14 Dec 29, 2022
SPTAG: A library for fast approximate nearest neighbor search

SPTAG: A library for fast approximate nearest neighbor search SPTAG SPTAG (Space Partition Tree And Graph) is a library for large scale vector approxi

Microsoft 4.3k Jan 01, 2023
Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Phil Wang 383 Jan 02, 2023
Virtual Dance Reality Stage is a feature that offers you to share a stage with another user virtually.

Virtual Dance Reality Stage is a feature that offers you to share a stage with another user virtually. It uses the concept of Image Background Removal using DeepLab Architecture (based on Semantic Se

Devashi Choudhary 5 Aug 24, 2022
HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton

HandFoldingNet ✌️ : A 3D Hand Pose Estimation Network Using Multiscale-Feature Guided Folding of a 2D Hand Skeleton Wencan Cheng, Jae Hyun Park, Jong

cwc1260 23 Oct 21, 2022
A Lightweight Hyperparameter Optimization Tool 🚀

Lightweight Hyperparameter Optimization 🚀 The mle-hyperopt package provides a simple and intuitive API for hyperparameter optimization of your Machin

136 Jan 08, 2023