DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. (2, 2) and during training its strides will be optimized for the task at hand.

We describe DiffStride in our ICLR 2022 paper Learning Strides in Convolutional Neural Network. Compared to the experiments described in the paper, this implementation uses a Pre-Act Resnet and uses Mixup in training.

Installation

To install the diffstride library, run the following pip git clone this repo:

git clone https://github.com/google-research/diffstride.git

The cd into the root and run the command:

pip install -e .

Example training

To run an example training on CIFAR10 and save the result in TensorBoard:

python3 -m diffstride.examples.main \
  --gin_config=cifar10.gin \
  --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"

Using custom parameters

This implementation uses Gin to parametrize the model, data processing and training loop. To use custom parameters, one should edit examples/cifar10.gin.

For example, to train with SpectralPooling on cifar100:

data.load_datasets:
  name = 'cifar100'

resnet.Resnet:
  pooling_cls = @pooling.FixedSpectralPooling

Or to train with strided convolutions and without Mixup:

data.load_datasets:
  mixup_alpha = 0.0

resnet.Resnet:
  pooling_cls = None

Results

This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (resnet.resnet18.strides = '1, 1, 2, 2, 2') and with a 'poor' choice of strides (resnet.resnet18.strides = '1, 1, 3, 2, 3'). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.

CIFAR-10

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 91.06 ± 0.04 89.21 ± 0.27
Spectral Pooling 93.49 ± 0.05 92.00 ± 0.08
DiffStride 94.20 ± 0.06 94.19 ± 0.15

CIFAR-100

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 65.75 ± 0.39 60.82 ± 0.42
Spectral Pooling 72.86 ± 0.23 67.74 ± 0.43
DiffStride 76.08 ± 0.23 76.09 ± 0.06

CPU/GPU Warning

We rely on the tensorflow FFT implementation which requires the input data to be in the channels_first format. This is usually not the regular data format of most datasets (including CIFAR) and running with channels_first also prevents from using of convolutions on CPU. Therefore even if we do support channels_last data format for CPU compatibility , we do encourage the user to run with channels_first data format on GPU.

Reference

If you use this repository, please consider citing:

@article{riad2022diffstride,
  title={Learning Strides in Convolutional Neural Networks},
  author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
  journal={ICLR},
  year={2022}
}

Disclainer

This is not an official Google product.

Owner
Google Research
Google Research
AFLFast (extends AFL with Power Schedules)

AFLFast Power schedules implemented by Marcel Böhme [email protected]

Marcel Böhme 380 Jan 03, 2023
HyperaPy: An automatic hyperparameter optimization framework ⚡🚀

hyperpy HyperPy: An automatic hyperparameter optimization framework Description HyperPy: Library for automatic hyperparameter optimization. Build on t

Sergio Mora 7 Sep 06, 2022
Code for CVPR 2021 oral paper "Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts"

Exploring Data-Efficient 3D Scene Understanding with Contrastive Scene Contexts The rapid progress in 3D scene understanding has come with growing dem

Facebook Research 182 Dec 30, 2022
Auto HMM: Automatic Discrete and Continous HMM including Model selection

Auto HMM: Automatic Discrete and Continous HMM including Model selection

Chess_champion 29 Dec 07, 2022
Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

Python Rapid Artificial Intelligence Ab Initio Molecular Dynamics

14 Nov 06, 2022
ROS Basics and TurtleSim

Waypoint Follower Anna Garverick This package draws given waypoints, then waits for a service call with a start position to send the turtle to each wa

Anna Garverick 1 Dec 13, 2021
Title: Graduate-Admissions-Predictor

The purpose of this project is create a predictive model capable of identifying the probability of a person securing an admit based on their personal profile parameters. Simplified visualisations hav

Akarsh Singh 1 Jan 26, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
PyTorch implementation of paper “Unbiased Scene Graph Generation from Biased Training”

A new codebase for popular Scene Graph Generation methods (2020). Visualization & Scene Graph Extraction on custom images/datasets are provided. It's also a PyTorch implementation of paper “Unbiased

Kaihua Tang 824 Jan 03, 2023
Repo for code associated with Modeling the Mitral Valve.

Project Title Mitral Valve Getting Started Repo for code associated with Modeling the Mitral Valve. See https://arxiv.org/abs/1902.00018 for preprint,

Alex Kaiser 1 May 17, 2022
Repo for "Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions" https://arxiv.org/abs/2201.12296

Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions This repo contains the dataset and code for the paper Benchmarking Ro

Jiachen Sun 168 Dec 29, 2022
Examples of how to create colorful, annotated equations in Latex using Tikz.

The file "eqn_annotate.tex" is the main latex file. This repository provides four examples of annotated equations: [example_prob.tex] A simple one ins

SyNeRCyS Research Lab 3.2k Jan 05, 2023
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022
Rotation-Only Bundle Adjustment

ROBA: Rotation-Only Bundle Adjustment Paper, Video, Poster, Presentation, Supplementary Material In this repository, we provide the implementation of

Seong 51 Nov 29, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
prior-based-losses-for-medical-image-segmentation

Repository for papers: Benchmark: Effect of Prior-based Losses on Segmentation Performance: A Benchmark Midl: A Surprisingly Effective Perimeter-based

Rosana EL JURDI 9 Sep 07, 2022
FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation.

FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation [Project] [Paper] [arXiv] [Home] Official implementation of FastFCN:

Wu Huikai 815 Dec 29, 2022
Poisson Surface Reconstruction for LiDAR Odometry and Mapping

Poisson Surface Reconstruction for LiDAR Odometry and Mapping Surfels TSDF Our Approach Table: Qualitative comparison between the different mapping te

Photogrammetry & Robotics Bonn 305 Dec 21, 2022
Implements a fake news detection program using classifiers.

Fake news detection Implements a fake news detection program using classifiers for Data Mining course at UoA. Description The project is the categoriz

Apostolos Karvelas 1 Jan 09, 2022