Tutorial for surrogate gradient learning in spiking neural networks

Overview

SpyTorch

A tutorial on surrogate gradient learning in spiking neural networks

Version: 0.4

DOI

This repository contains tutorial files to get you started with the basic ideas of surrogate gradient learning in spiking neural networks using PyTorch.

You find a brief introductory video accompanying these notebooks here https://youtu.be/xPYiAjceAqU

Feedback and contributions are welcome.

For more information on surrogate gradient learning please refer to:

Neftci, E.O., Mostafa, H., and Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine 36, 51–63. https://ieeexplore.ieee.org/document/8891809 preprint: https://arxiv.org/abs/1901.09948

Also see https://github.com/surrogate-gradient-learning

Copyright and license

Copyright 2019-2020 Friedemann Zenke, https://fzenke.net

This work is licensed under a Creative Commons Attribution 4.0 International License. http://creativecommons.org/licenses/by/4.0/

Comments
  • resetting with

    resetting with "out" instead of "rst"?

    • This is a comment, not an issue *

    Hi Friedemann, First of thanks a lot for these great tutorials, I've enjoyed a lot playing with them, and I've learned a lot :-) One question: in the run_snn function, why do you bother constructing the "rst" tensor? Why don't you subtract the "out" tensor, which also contains the output spikes? I've tried, and it seems to work. Just curious. Best,

    Tim

    question 
    opened by tmasquelier 8
  • Problem in SpyTorchTutorial2

    Problem in SpyTorchTutorial2

    Hello,

    It was a very nice and interesting tutorial, thank you for preparing it...

    tutorial1 haven't any problem, but in tutorial 2, some dtype problems occurred... after their fixation, training process was very slow on GTX 980 (I've run on this config some very deep model)... could you please explain your config, and also training time and response time?

    opened by ghost 6
  • Spike times shifted

    Spike times shifted

    I have the impression that the spike recordings are shifted one time step in all tutorials. Could you maybe check if this is indeed the case?

    From my understanding, time step 0 is recorded twice for the spikes, once during initialisation

      mem = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
      spk_rec = [mem]
    

    and once within the simulation of time step 0:

      for t in range(nb_steps):
          mthr = mem-1.0
          out = spike_fn(mthr)
          ...
          spk_rec.append(out)
    

    As a result the indeces appear shifted when comparing

    print(torch.nonzero((mem_rec-1.0) > 0.0))
    print(torch.nonzero(spk_rec))
    

    Thanks, Simon

    opened by smonsays 4
  • Software/Machine description available?

    Software/Machine description available?

    Hey Friedemann,

    thanks for making the examples available, they look very helpful. However, to make them fully reproducible I think that some additional information regarding the "technical dependencies" is needed.

    In particular, the list of used software packages (incl. version and build variant information) plus some specification about the machine hardware (CPU arch, GPUs).

    Preferably, the former could be expressed as a recipe for constructing a container (Dockerfile, or for better HPC-compatibility, a Singularity recipe), maybe even using an explicitly versioning package manager like spack.

    Cheers, Eric

    opened by muffgaga 3
  • Dataset never decompressed

    Dataset never decompressed

    Hello,

    I belive I ran into a possible issue here. Due to line 37 the evaluation in line 38 will always be false if one hasnt already got the uncompressed dataset.

    https://github.com/fzenke/spytorch/blob/9e91eceaf53f17be9e95a3743164224bdbb086bb/notebooks/utils.py#L35-L42

    If I change line 37 to: hdf5_file_path = gz_file_path[:-3] This works for me.

    Best, Aaron

    opened by AaronSpieler 1
  • propagation delay

    propagation delay

    Hi zenke, I have a question about the snn model. If I feed a spike image to a snn with L layers at time step n, the output of the last layer will be affected by the input at time step n + L - 1. In deep networks, the delay should be considered, because it will increase the whole time steps. Screen Shot 2021-12-15 at 4 50 45 PM

    opened by yizx6 1
  • Compute recurrent contribution from spikes

    Compute recurrent contribution from spikes

    Hey Friedemann,

    thank you for the very comprehensive tutorial! I have a question on the way the recurrence is computed in tutorial 4. If I understand the equation for the dynamics of the current correctly, the recurrence should be computed with the spiking neuron state:

    mthr = mem-1.0
    out = spike_fn(mthr)
    h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (out, v1))
    

    Instead in tutorial 4, a separate hidden state is kept, that ignores the spike function:

    h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (h1, v1))
    

    Is this done deliberately? Judging from simulating a few epochs, the two versions seem to perform similarly.

    Thank you,

    Simon

    opened by smonsays 1
  • maybe simplification

    maybe simplification

    I don't understand why the 'rst' variable exists. It seems to always be == 'out'. Changing to rst = out yields same results...

    def spike_fn(x):
        out = torch.zeros_like(x)
        out[x > 0] = 1.0
        return out
    ...
    # Here we loop over time
    for t in range(nb_steps):
        mthr = mem-1.0
        out = spike_fn(mthr) 
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c] 
    
    opened by colinator 1
  • Issue in running Tutorial-4

    Issue in running Tutorial-4

    When I am running the following piece of code in Tutorial-4:

    loss_hist = train(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs)

    I am getting the following error: pic3

    Can you please suggest me how to resolve this issue?

    opened by paglabhola 0
Releases(v0.3)
Owner
Friedemann Zenke
Friedemann Zenke
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 01, 2022
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022