Official Pytorch implementation of MixMo framework

Overview

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks

Official PyTorch implementation of the MixMo framework | paper | docs

Alexandre Ramé, Rémy Sun, Matthieu Cord

Citation

If you find this code useful for your research, please cite:

@article{rame2021ixmo,
    title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
    author={Alexandre Rame and Remy Sun and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2103.06132}
}

Abstract

Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.

In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.

We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.

Overview

Most important code sections

This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:

  1. mixmo.loaders.dataset_wrapper.py and specifically MixMoDataset to create batches with multiple inputs and multiple outputs.
  2. mixmo.augmentations.mixing_blocks.py where we create the mixing masks, e.g. via linear summing (_mixup_mask) or via patch mixing (_cutmix_mask).
  3. mixmo.networks.resnet.py and mixmo.networks.wrn.py where we adapt the network structures to handle:
    • multiple inputs via multiple conv1s encoders (one for each input). The function mixmo.augmentations.mixing_blocks.mix_manifold is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset.
    • multiple outputs via multiple predictions.

This translates to additional tensor management in mixmo.learners.learner.py.

Pseudo code

Our MixMoDataset wraps a PyTorch Dataset. The batch_repetition_sampler repeats the same index b times in each batch. Moreover, we provide SoftCrossEntropyLoss which handles soft-labels required by mixed sample data augmentations such as CutMix.

from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion

...

# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
        dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
        num_members=2,  # we use M=2 subnetworks
        mixmo_mix_method="cutmix",  # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
        mixmo_alpha=2,  # mixing ratio sampled from Beta distribution with concentration 2
        mixmo_weight_root=3  # root for reweighting of loss components 3
        )
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)

...

# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
    for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
        for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
            outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
            loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Configuration files

Our code heavily relies on yaml config files. In the mixmo-pytorch/config folder, we provide the configs to reproduce the main paper results.

For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4 means that:

  • cifar100: dataset is CIFAR-100.
  • wrn2810-2: WideResNet-28-10 network architecture with M=2 subnetworks.
  • cutmixmo-p5: mixing block is patch mixing with probability p=0.5 else linear mixing.
  • msdacutmix: use CutMix mixed sample data augmentation.
  • bar4: batch repetition to b=4.

Results and available checkpoints

CIFAR-100 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar100
-- Vanilla 81.79 exp_cifar100_wrn2810_1net_standard_bar1.yaml
-- Mixup 83.43 exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 83.95 exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 82.92 exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 82.96 exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 85.52 - 85.59 exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 85.36 - 85.57 exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 85.77 - 85.92 exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

CIFAR-10 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar10
-- Vanilla 96.37 exp_cifar10_wrn2810_1net_standard_bar1.yaml
-- Mixup 97.07 exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 97.28 exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 96.71 exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 96.88 exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 97.52 exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 97.73 exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 97.83 exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

Tiny ImageNet-200 with PreActResNet-18-width

Method Width Top-1 Accuracy config file in mixmo-pytorch/config/tiny
Vanilla 1 62.75 exp_tinyimagenet_res18_1net_standard_bar1.yaml
Linear-MixMo 1 62.91 exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 1 64.32 exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 2 64.91 exp_tinyimagenet_res182_1net_standard_bar1.yaml
Linear-MixMo 2 67.03 exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 2 69.12 exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 3 65.84 exp_tinyimagenet_res183_1net_standard_bar1.yaml
Linear-MixMo 3 68.36 exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 3 70.23 exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml

Installation

Requirements overview

  • python >= 3.6
  • torch >= 1.4.0
  • torchsummary >= 1.5.1
  • torchvision >= 0.5.0
  • tensorboard >= 1.14.0

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
  1. Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt

With this, you can edit the MixMo code on the fly.

Datasets

We advise to first create a dedicated data folder dataplace, that will be provided as an argument in the subsequent scripts.

  • CIFAR

CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace.

  • Tiny-ImageNet

Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.

  1. Download the zipped data from https://tiny-imagenet.herokuapp.com/.
  2. Extract the zipped data in folder dataplace.
  3. Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace

Running the code

Training

Baseline

First, to train a baseline model, simply execute the following command:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace

It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1 located in parent folder saveplace. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch as an argument.

MixMo

When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace

Evaluation

To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:

$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
  • checkpoint can be either:
    • a path towards a checkpoint.
    • an int matching the training epoch you wish to evaluate. In that case, you need to provide --saveplace $saveplace.
    • the string best: we then automatically select the best training epoch. In that case, you need to provide --saveplace $saveplace.
  • --tempscal: indicates that you will apply temperature scaling

Results will be printed at the end of the script.

If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace. Then use --robustness at evaluation.

Create your own configuration files and learning strategies

You can create new configs automatically via:

$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset

Acknowledgements and references

A real-time speech emotion recognition application using Scikit-learn and gradio

Speech-Emotion-Recognition-App A real-time speech emotion recognition application using Scikit-learn and gradio. Requirements librosa==0.6.3 numpy sou

Son Tran 6 Oct 04, 2022
Official PyTorch Implementation of SSMix (Findings of ACL 2021)

SSMix: Saliency-based Span Mixup for Text Classification (Findings of ACL 2021) Official PyTorch Implementation of SSMix | Paper Abstract Data augment

Clova AI Research 52 Dec 27, 2022
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

pybaum provides tools to work with pytrees which is a concept burrowed from JAX.

Open Source Economics 9 May 11, 2022
WTTE-RNN a framework for churn and time to event prediction

WTTE-RNN Weibull Time To Event Recurrent Neural Network A less hacky machine-learning framework for churn- and time to event prediction. Forecasting p

Egil Martinsson 727 Dec 28, 2022
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 07, 2022
From the basics to slightly more interesting applications of Tensorflow

TensorFlow Tutorials You can find python source code under the python directory, and associated notebooks under notebooks. Source code Description 1 b

Parag K Mital 5.6k Jan 09, 2023
Boosted CVaR Classification (NeurIPS 2021)

Boosted CVaR Classification Runtian Zhai, Chen Dan, Arun Sai Suggala, Zico Kolter, Pradeep Ravikumar NeurIPS 2021 Table of Contents Quick Start Train

Runtian Zhai 4 Feb 15, 2022
Command-line tool for downloading and extending the RedCaps dataset.

RedCaps Downloader This repository provides the official command-line tool for downloading and extending the RedCaps dataset. Users can seamlessly dow

RedCaps dataset 33 Dec 14, 2022
RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation

Multipath RefineNet A MATLAB based framework for semantic image segmentation and general dense prediction tasks on images. This is the source code for

Guosheng Lin 575 Dec 06, 2022
MediaPipeのPythonパッケージのサンプルです。2020/12/11時点でPython実装のある4機能(Hands、Pose、Face Mesh、Holistic)について用意しています。

mediapipe-python-sample MediaPipeのPythonパッケージのサンプルです。 2020/12/11時点でPython実装のある以下4機能について用意しています。 Hands Pose Face Mesh Holistic Requirement mediapipe 0.

KazuhitoTakahashi 217 Dec 12, 2022
A package for "Procedural Content Generation via Reinforcement Learning" OpenAI Gym interface.

Readme: Illuminating Diverse Neural Cellular Automata for Level Generation This is the codebase used to generate the results presented in the paper av

Sam Earle 27 Jan 05, 2023
Context-Sensitive Misspelling Correction of Clinical Text via Conditional Independence, CHIL 2022

cim-misspelling Pytorch implementation of Context-Sensitive Spelling Correction of Clinical Text via Conditional Independence, CHIL 2022. This model (

Juyong Kim 11 Dec 19, 2022
Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling

TGraM Multi-Object Tracking in Satellite Videos with Graph-Based Multi-Task Modeling, Qibin He, Xian Sun, Zhiyuan Yan, Beibei Li, Kun Fu Abstract Rece

Qibin He 6 Nov 25, 2022
StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

3k Jan 08, 2023
Code for the paper "Adversarial Generator-Encoder Networks"

This repository contains code for the paper "Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky. Pr

Dmitry Ulyanov 279 Jun 26, 2022
Prototypical Networks for Few shot Learning in PyTorch

Prototypical Networks for Few shot Learning in PyTorch Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code)

Orobix 835 Jan 08, 2023
Official Implementation of Neural Splines

Neural Splines: Fitting 3D Surfaces with Inifinitely-Wide Neural Networks This repository contains the official implementation of the CVPR 2021 (Oral)

Francis Williams 56 Nov 29, 2022
MEND: Model Editing Networks using Gradient Decomposition

MEND: Model Editing Networks using Gradient Decomposition Setup Environment This codebase uses Python 3.7.9. Other versions may work as well. Create a

Eric Mitchell 141 Dec 02, 2022
Get the partition that a file belongs and the percentage of space that consumes

tinos_eisai_sy Get the partition that a file belongs and the percentage of space that consumes (works only with OSes that use the df command) tinos_ei

Konstantinos Patronas 6 Jan 24, 2022
A fuzzing framework for SMT solvers

yinyang A fuzzing framework for SMT solvers. Given a set of seed SMT formulas, yinyang generates mutant formulas to stress-test SMT solvers. yinyang c

Project Yin-Yang for SMT Solver Testing 145 Jan 04, 2023