The code for Expectation-Maximization Attention Networks for Semantic Segmentation (ICCV'2019 Oral)

Related tags

Deep LearningEMANet
Overview

EMANet

News

  • The bug in loading the pretrained model is now fixed. I have updated the .pth. To use it, download it again.
  • EMANet-101 gets 80.99 on the PASCAL VOC dataset (Thanks for Sensetimes' server). So, with a classic backbone(ResNet) instead of some newest ones(WideResNet, HRNet), EMANet still achieves the top performance.
  • EMANet-101 (OHEM) gets 81.14 in mIoU on Cityscapes val using single-scale inference, and 81.9 on test server with multi-scale inference.

Background

This repository is for Expectation-Maximization Attention Networks for Semantic Segmentation (to appear in ICCV 2019, Oral presentation),

by Xia Li, Zhisheng Zhong, Jianlong Wu, Yibo Yang, Zhouchen Lin and Hong Liu from Peking University.

The source code is now available!

citation

If you find EMANet useful in your research, please consider citing:

@inproceedings{li19,
    author={Xia Li and Zhisheng Zhong and Jianlong Wu and Yibo Yang and Zhouchen Lin and Hong Liu},
    title={Expectation-Maximization Attention Networks for Semantic Segmentation},
    booktitle={International Conference on Computer Vision},   
    year={2019},   
}

table of contents

Introduction

Self-attention mechanism has been widely used for various tasks. It is designed to compute the representation of each position by a weighted sum of the features at all positions. Thus, it can capture long-range relations for computer vision tasks. However, it is computationally consuming. Since the attention maps are computed w.r.t all other positions. In this paper, we formulate the attention mechanism into an expectation-maximization manner and iteratively estimate a much more compact set of bases upon which the attention maps are computed. By a weighted summation upon these bases, the resulting representation is low-rank and deprecates noisy information from the input. The proposed Expectation-Maximization Attention (EMA) module is robust to the variance of input and is also friendly in memory and computation. Moreover, we set up the bases maintenance and normalization methods to stabilize its training procedure. We conduct extensive experiments on popular semantic segmentation benchmarks including PASCAL VOC, PASCAL Context, and COCO Stuff, on which we set new records. EMA Unit

Design

As so many peers have starred at this repo, I feel the great pressure, and try to release the code with high quality. That's why I didn't release it until today (Aug, 22, 2018). It's known that the design of the code structure is not an easy thing. Different designs are suitable for different usage. Here, I aim at making research on Semantic Segmentation, especially on PASCAL VOC, more easier. So, I delete necessary encapsulation as much as possible, and leave over less than 10 python files. To be honest, the global variables in settings are not a good design for large project. But for research, it offers great flexibility. So, hope you can understand that

For research, I recommand seperatting each experiment with a folder. Each folder contains the whole project, and should be named as the experiment settings, such as 'EMANet101.moving_avg.l2norm.3stages'. Through this, you can keep tracks of all the experiments, and find their differences just by the 'diff' command.

Usage

  1. Install the libraries listed in the 'requirements.txt'
  2. Downloads images and labels of PASCAL VOC and SBD, decompress them together.
  3. Downloads the pretrained ResNet50 and ResNet101, unzip them, and put into the 'models' folder.
  4. Change the 'DATA_ROOT' in settings.py to where you place the dataset.
  5. Run sh clean.sh to clear the models and logs from the last experiment.
  6. Run python train.py for training and sh tensorboard.sh for visualization on your browser.
  7. Or you can download the pretraind model, put into the 'models' folder, and skip step 6.
  8. Run python eval.py for validation

Ablation Studies

The following results are referred from the paper. For this repo, it's not strange to get even higer performance. If so, I'd like you share it in the issue. By now, this repo only provides the SS inference. I may release the code for MS and Flip latter.

Tab 1. Detailed comparisons with Deeplabs. All results are achieved with the backbone ResNet-101 and output stride 8. The FLOPs and memory are computed with the input size 513×513. SS: Single scale input during test. MS: Multi-scale input. Flip: Adding left-right flipped input. EMANet (256) and EMANet (512) represent EMANet withthe number of input channels for EMA as 256 and 512, respectively.

Method SS MS+Flip FLOPs Memory Params
ResNet-101 - - 190.6G 2.603G 42.6M
DeeplabV3 78.51 79.77 +63.4G +66.0M +15.5M
DeeplabV3+ 79.35 80.57 +84.1G +99.3M +16.3M
PSANet 78.51 79.77 +56.3G +59.4M +18.5M
EMANet(256) 79.73 80.94 +21.1G +12.3M +4.87M
EMANet(512) 80.05 81.32 +43.1G +22.1M +10.0M

To be note, the majority overheads of EMANets come from the 3x3 convs before and after the EMA Module. As for the EMA Module itself, its computation is only 1/3 of a 3x3 conv's, and its parameter number is even smaller than a 1x1 conv.

Comparisons with SOTAs

Note that, for validation on the 'val' set, you just have to train 30k on the 'trainaug' set. But for test on the evaluation server, you should first pretrain on COCO, and then 30k on 'trainaug', and another 30k on the 'trainval' set.

Tab 2. Comparisons on the PASCAL VOC test dataset.

Method Backbone mIoU(%)
GCN ResNet-152 83.6
RefineNet ResNet-152 84.2
Wide ResNet WideResNet-38 84.9
PSPNet ResNet-101 85.4
DeeplabV3 ResNet-101 85.7
PSANet ResNet-101 85.7
EncNet ResNet-101 85.9
DFN ResNet-101 86.2
Exfuse ResNet-101 86.2
IDW-CNN ResNet-101 86.3
SDN DenseNet-161 86.6
DIS ResNet-101 86.8
EMANet101 ResNet-101 87.7
DeeplabV3+ Xception-65 87.8
Exfuse ResNeXt-131 87.9
MSCI ResNet-152 88.0
EMANet152 ResNet-152 88.2

Code Borrowed From

RESCAN

Pytorch-Encoding

Synchronized-BN

Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Claims.

MTM This is the official repository of the paper: Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Cla

ICTMCG 13 Sep 17, 2022
shufflev2-yolov5:lighter, faster and easier to deploy

shufflev2-yolov5: lighter, faster and easier to deploy. Evolved from yolov5 and the size of model is only 1.7M (int8) and 3.3M (fp16). It can reach 10+ FPS on the Raspberry Pi 4B when the input size

pogg 1.5k Jan 05, 2023
PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML)

pytorch-maml This is a PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML): https://arxiv

Kate Rakelly 516 Jan 05, 2023
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 09, 2023
The devkit of the nuPlan dataset.

The devkit of the nuPlan dataset.

Motional 264 Jan 03, 2023
Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Codebase for "ProtoAttend: Attention-Based Prototypical Learning." Authors: Sercan O. Arik and Tomas Pfister Paper: Sercan O. Arik and Tomas Pfister,

47 2 May 17, 2022
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022
Process JSON files for neural recording sessions using Medtronic's BrainSense Percept PC neurostimulator

percept_processing This code processes JSON files for streamed neural data using Medtronic's Percept PC neurostimulator with BrainSense Technology for

Maria Olaru 3 Jun 06, 2022
NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

5 Nov 03, 2022
Python Multi-Agent Reinforcement Learning framework

- Please pay attention to the version of SC2 you are using for your experiments. - Performance is *not* always comparable between versions. - The re

whirl 1.3k Jan 05, 2023
An implementation of "Learning human behaviors from motion capture by adversarial imitation"

Merel-MoCap-GAIL An implementation of Merel et al.'s paper on generative adversarial imitation learning (GAIL) using motion capture (MoCap) data: Lear

Yu-Wei Chao 34 Nov 12, 2022
vit for few-shot classification

Few-Shot ViT Requirements PyTorch (= 1.9) TorchVision timm (latest) einops tqdm numpy scikit-learn scipy argparse tensorboardx Pretrained Checkpoints

Martin Dong 26 Nov 30, 2022
When in Doubt: Improving Classification Performance with Alternating Normalization

When in Doubt: Improving Classification Performance with Alternating Normalization Findings of EMNLP 2021 Menglin Jia, Austin Reiter, Ser-Nam Lim, Yoa

Menglin Jia 13 Nov 06, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
Image Data Augmentation in Keras

Image data augmentation is a technique that can be used to artificially expand the size of a training dataset by creating modified versions of images in the dataset.

Grace Ugochi Nneji 3 Feb 15, 2022
VID-Fusion: Robust Visual-Inertial-Dynamics Odometry for Accurate External Force Estimation

VID-Fusion VID-Fusion: Robust Visual-Inertial-Dynamics Odometry for Accurate External Force Estimation Authors: Ziming Ding , Tiankai Yang, Kunyi Zhan

ZJU FAST Lab 86 Nov 18, 2022
Barlow Twins and HSIC

Barlow Twins and HSIC Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet). Correspon

Yao-Hung Hubert Tsai 49 Nov 24, 2022
🔥RandLA-Net in Tensorflow (CVPR 2020, Oral & IEEE TPAMI 2021)

RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds (CVPR 2020) This is the official implementation of RandLA-Net (CVPR2020, Oral

Qingyong 1k Dec 30, 2022
Gym environment for FLIPIT: The Game of "Stealthy Takeover"

gym-flipit Gym environment for FLIPIT: The Game of "Stealthy Takeover" invented by Marten van Dijk, Ari Juels, Alina Oprea, and Ronald L. Rivest. Desi

Lisa Oakley 2 Dec 15, 2021
[CVPR 2021] Unsupervised 3D Shape Completion through GAN Inversion

ShapeInversion Paper Junzhe Zhang, Xinyi Chen, Zhongang Cai, Liang Pan, Haiyu Zhao, Shuai Yi, Chai Kiat Yeo, Bo Dai, Chen Change Loy "Unsupervised 3D

100 Dec 22, 2022