Official Pytorch implementation of MixMo framework

Overview

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks

Official PyTorch implementation of the MixMo framework | paper | docs

Alexandre Ramé, Rémy Sun, Matthieu Cord

Citation

If you find this code useful for your research, please cite:

@article{rame2021ixmo,
    title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
    author={Alexandre Rame and Remy Sun and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2103.06132}
}

Abstract

Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.

In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.

We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.

Overview

Most important code sections

This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:

  1. mixmo.loaders.dataset_wrapper.py and specifically MixMoDataset to create batches with multiple inputs and multiple outputs.
  2. mixmo.augmentations.mixing_blocks.py where we create the mixing masks, e.g. via linear summing (_mixup_mask) or via patch mixing (_cutmix_mask).
  3. mixmo.networks.resnet.py and mixmo.networks.wrn.py where we adapt the network structures to handle:
    • multiple inputs via multiple conv1s encoders (one for each input). The function mixmo.augmentations.mixing_blocks.mix_manifold is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset.
    • multiple outputs via multiple predictions.

This translates to additional tensor management in mixmo.learners.learner.py.

Pseudo code

Our MixMoDataset wraps a PyTorch Dataset. The batch_repetition_sampler repeats the same index b times in each batch. Moreover, we provide SoftCrossEntropyLoss which handles soft-labels required by mixed sample data augmentations such as CutMix.

from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion

...

# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
        dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
        num_members=2,  # we use M=2 subnetworks
        mixmo_mix_method="cutmix",  # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
        mixmo_alpha=2,  # mixing ratio sampled from Beta distribution with concentration 2
        mixmo_weight_root=3  # root for reweighting of loss components 3
        )
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)

...

# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
    for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
        for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
            outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
            loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Configuration files

Our code heavily relies on yaml config files. In the mixmo-pytorch/config folder, we provide the configs to reproduce the main paper results.

For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4 means that:

  • cifar100: dataset is CIFAR-100.
  • wrn2810-2: WideResNet-28-10 network architecture with M=2 subnetworks.
  • cutmixmo-p5: mixing block is patch mixing with probability p=0.5 else linear mixing.
  • msdacutmix: use CutMix mixed sample data augmentation.
  • bar4: batch repetition to b=4.

Results and available checkpoints

CIFAR-100 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar100
-- Vanilla 81.79 exp_cifar100_wrn2810_1net_standard_bar1.yaml
-- Mixup 83.43 exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 83.95 exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 82.92 exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 82.96 exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 85.52 - 85.59 exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 85.36 - 85.57 exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 85.77 - 85.92 exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

CIFAR-10 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar10
-- Vanilla 96.37 exp_cifar10_wrn2810_1net_standard_bar1.yaml
-- Mixup 97.07 exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 97.28 exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 96.71 exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 96.88 exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 97.52 exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 97.73 exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 97.83 exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

Tiny ImageNet-200 with PreActResNet-18-width

Method Width Top-1 Accuracy config file in mixmo-pytorch/config/tiny
Vanilla 1 62.75 exp_tinyimagenet_res18_1net_standard_bar1.yaml
Linear-MixMo 1 62.91 exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 1 64.32 exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 2 64.91 exp_tinyimagenet_res182_1net_standard_bar1.yaml
Linear-MixMo 2 67.03 exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 2 69.12 exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 3 65.84 exp_tinyimagenet_res183_1net_standard_bar1.yaml
Linear-MixMo 3 68.36 exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 3 70.23 exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml

Installation

Requirements overview

  • python >= 3.6
  • torch >= 1.4.0
  • torchsummary >= 1.5.1
  • torchvision >= 0.5.0
  • tensorboard >= 1.14.0

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
  1. Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt

With this, you can edit the MixMo code on the fly.

Datasets

We advise to first create a dedicated data folder dataplace, that will be provided as an argument in the subsequent scripts.

  • CIFAR

CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace.

  • Tiny-ImageNet

Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.

  1. Download the zipped data from https://tiny-imagenet.herokuapp.com/.
  2. Extract the zipped data in folder dataplace.
  3. Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace

Running the code

Training

Baseline

First, to train a baseline model, simply execute the following command:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace

It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1 located in parent folder saveplace. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch as an argument.

MixMo

When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace

Evaluation

To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:

$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
  • checkpoint can be either:
    • a path towards a checkpoint.
    • an int matching the training epoch you wish to evaluate. In that case, you need to provide --saveplace $saveplace.
    • the string best: we then automatically select the best training epoch. In that case, you need to provide --saveplace $saveplace.
  • --tempscal: indicates that you will apply temperature scaling

Results will be printed at the end of the script.

If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace. Then use --robustness at evaluation.

Create your own configuration files and learning strategies

You can create new configs automatically via:

$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset

Acknowledgements and references

This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
Pun Detection and Location

Pun Detection and Location “The Boating Store Had Its Best Sail Ever”: Pronunciation-attentive Contextualized Pun Recognition Yichao Zhou, Jyun-yu Jia

lawson 3 May 13, 2022
FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI

FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI 声明: 本项目仅限于学习交流,不可用于非法用途,包括但不限于:用于游戏外挂等,使用本项目产生的任何后果与本人无关! 简介 本项目基于yolov5,实现了一款FPS类游戏(CF、CSGO等)的自瞄AI,本项目旨在使用现

Fabian 246 Dec 28, 2022
Autotype on websites that have copy-paste disabled like Moodle, HackerEarth contest etc.

Autotype A quick and small python script that helps you autotype on websites that have copy paste disabled like Moodle, HackerEarth contests etc as it

Tushar 32 Nov 03, 2022
Hunt down social media accounts by username across social networks

Hunt down social media accounts by username across social networks Installation | Usage | Docker Notes | Contributing Installation # clone the repo $

1 Dec 14, 2021
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Ludwig 8.7k Jan 05, 2023
RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

RMNet: Equivalently Removing Residual Connection from Networks This repository is the official implementation of "RMNet: Equivalently Removing Residua

184 Jan 04, 2023
Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide.

SARS-CoV-2 processing requests Request execution of Galaxy SARS-CoV-2 variation analysis workflows on input data you provide. Prerequisites This autom

useGalaxy.eu 17 Aug 13, 2022
An investigation project for SISR.

SISR-Survey An investigation project for SISR. This repository is an official project of the paper "From Beginner to Master: A Survey for Deep Learnin

Juncheng Li 79 Oct 20, 2022
基于YoloX目标检测+DeepSort算法实现多目标追踪Baseline

项目简介: 使用YOLOX+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。 代码地址(欢迎star): https://github.com/Sharpiless/yolox-deepsort/ 最终效果: 运行demo: python demo

114 Dec 30, 2022
GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️

GAT - Graph Attention Network (PyTorch) 💻 + graphs + 📣 = ❤️ This repo contains a PyTorch implementation of the original GAT paper ( 🔗 Veličković et

Aleksa Gordić 1.9k Jan 09, 2023
Code implementation from my Medium blog post: [Transformers from Scratch in PyTorch]

transformer-from-scratch Code for my Medium blog post: Transformers from Scratch in PyTorch Note: This Transformer code does not include masked attent

Frank Odom 27 Dec 21, 2022
The official codes of "Semi-supervised Models are Strong Unsupervised Domain Adaptation Learners".

SSL models are Strong UDA learners Introduction This is the official code of paper "Semi-supervised Models are Strong Unsupervised Domain Adaptation L

Yabin Zhang 26 Dec 26, 2022
The implementation code for "DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction"

DAGAN This is the official implementation code for DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruct

TensorLayer Community 159 Nov 22, 2022
Code for STFT Transformer used in BirdCLEF 2021 competition.

STFT_Transformer Code for STFT Transformer used in BirdCLEF 2021 competition. The STFT Transformer is a new way to use Transformers similar to Vision

Jean-François Puget 69 Sep 29, 2022
Efficient Lottery Ticket Finding: Less Data is More

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match

VITA 20 Sep 04, 2022
Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Wenxuan Zhou 146 Nov 29, 2022
Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21.

Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21. We optimized wind turbine placement in a wind farm, subject to wake effects, using Q-learni

Manasi Sharma 2 Sep 27, 2022
NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

Göktuğ Karakaşlı 16 Dec 05, 2022
VGGFace2-HQ - A high resolution face dataset for face editing purpose

The first open source high resolution dataset for face swapping!!! A high resolution version of VGGFace2 for academic face editing purpose

Naiyuan Liu 232 Dec 29, 2022