Implementation of our paper "DMT: Dynamic Mutual Training for Semi-Supervised Learning"

Overview

PWC

PWC

PWC

PWC

DMT: Dynamic Mutual Training for Semi-Supervised Learning

This repository contains the code for our paper DMT: Dynamic Mutual Training for Semi-Supervised Learning, a concise and effective method for semi-supervised semantic segmentation & image classification.

Some might know it as the previous version DST-CBC, or Semi-Supervised Semantic Segmentation via Dynamic Self-Training and Class-Balanced Curriculum, if you want the old code, you can check out the dst-cbc branch.

Also, for older PyTorch version (<1.6.0) users, or the exact same environment that produced the paper's results, refer to 53853f6.

News

2021.6.7

Multi-GPU training support (based on Accelerate) is added, and the whole project is upgraded to PyTorch 1.6. Thanks to the codes & testing by @jinhuan-hit, and discussions from @lorenmt, @TiankaiHang.

2021.2.10

A slight backbone architecture difference in the segmentation task has just been identified and described in Acknowledgement.

2021.1.1

DMT is released. Happy new year! 😉

2020.12.7

The bug fix for DST-CBC (not fully tested) is released at the scale branch.

2020.11.9

Stay tuned for Dynamic Mutual Training (DMT), an updated version of DST-CBC, which has overall better and stabler performance and will be released early November. A new version Dynamic Mutual Training (DMT) will be released later, which has overall better and stabler performance.

Also, thanks to @lorenmt, a data augmentation bug fix will be released along with the next version, where PASCAL VOC performance is overall boosted by 1~2%, Cityscapes could also have better performance. But probably the gap to oracle will remain similar.

Setup

First, you'll need a CUDA 10, Python3 environment (best on Linux).

1. Setup PyTorch & TorchVision:

pip install torch==1.6.0 torchvision==0.7.0

2. Install other python packages you may require:

pip install packaging accelerate future matplotlib tensorboard tqdm
pip install git+https://github.com/ildoonet/pytorch-randaugment

3. Download the code and prepare the scripts:

git clone https://github.com/voldemortX/DST-CBC.git
cd DST-CBC
chmod 777 segmentation/*.sh
chmod 777 classification/*.sh

Getting started

Get started with SEGMENTATION.md for semantic segmentation.

Get started with CLASSIFICATION.md for image classification.

Understand the code

We refer interested readers to this repository's wiki. It is not updated for DMT yet.

Notes

It's best to use a Turing or Volta architecture GPU when running our code, since they have tensor cores and the computation speed is much faster with mixed precision. For instance, RTX 2080 Ti (which is what we used) or Tesla V100, RTX 20/30 series.

Our implementation is fast and memory efficient. A whole run (train 2 models by DMT on PASCAL VOC 2012) takes about 8 hours on a single RTX 2080 Ti using up to 6GB graphic memory, including on-the-fly evaluations and training baselines. The Cityscapes experiments are even faster.

Contact

Issues and PRs are most welcomed.

If you have any questions that are not answerable with Google, feel free to contact us through [email protected].

Citation

@article{feng2020dmt,
  title={DMT: Dynamic Mutual Training for Semi-Supervised Learning},
  author={Feng, Zhengyang and Zhou, Qianyu and Gu, Qiqi and Tan, Xin and Cheng, Guangliang and Lu, Xuequan and Shi, Jianping and Ma, Lizhuang},
  journal={arXiv preprint arXiv:2004.08514},
  year={2020}
}

Acknowledgements

The DeepLabV2 network architecture and coco pre-trained weights are faithfully re-implemented from AdvSemiSeg. The only difference is we use the so-called ResNetV1.5 implementation for ResNet-101 backbone (same as torchvision), for difference between ResNetV1 and V1.5, refer to this issue. However, the difference is reported to only bring 0-0.5% gain in ImageNet, considering we use the V1 COCO pre-trained weights that mismatch with V1.5, the overall performance should remain similar to V1. The better fully-supervised performance mainly comes from better training schedule. Besides, we base comparisons on relative performance to Oracle, not absolute performance.

The CBC part of the older version DST-CBC is adapted from CRST.

The overall implementation is based on TorchVision and PyTorch.

The people who've helped to make the method & code better: lorenmt, jinhuan-hit, TiankaiHang, etc.

Comments
  • miou problem in segmentation

    miou problem in segmentation

    Thanks for sharing a good job! I have a question. When I train cityscapes using 1/8 labeled data, two models(init from coco and imagenet) can reach nearly 59 mIOU in val set, close to 59.65 presented in the paper. However, after 5 iterations, the metric descends to 53(coco) and 22(imagenet). I check the pseudo label using the model of 59 mIOU and it is not particularly good. I don't know if that affected the results.

    question possible bug 
    opened by jinhuan-hit 26
  • Visualize the final experimental results

    Visualize the final experimental results

    Hello, your paper and code are very good, thank you for your efforts. Now I have a question to ask you, the details are as follows: First of all, I conducted experiments on my own data, and the results have been obtained. How can I use these weights to test test sets? In addition, I used dmT-VOC-20-1__p5 -- I , and use the training model to test, the effect is very poor, I do not know when the test method is correct.

    您好,您的论文和代码非常棒,感谢您的付出。现在我有个问题想请教您,具体如下:首先我是在我自己的数据上进行实验,且已经跑出结果。我如何能够用这些权值来测试测试集?此外,我使用了dmt-voc-20-1__p5--i的权重,并利用训练的模型来进行测试,效果很差,我不知道测试方法时候正确。

    question 
    opened by JayeShen1996 16
  • Nan values in confusion matrix

    Nan values in confusion matrix

    Hello, i'm using a segmentation dataset with two classes and grayscale images. I'm duplicating the channels of the image with elif pic.mode == 'L': img = torch.from_numpy(np.array(pic, np.uint8, copy=False)).expand([3, 224, 224]).reshape(-1) While training the baseline without DMT i get only accuracy values for one of the classes with nan values for the other: average row correct: ['99.52', 'nan'].

    Do you have any idea what i'd might have done wrong/missed? Thanks in advance!

    question 
    opened by dervirvel 10
  • Question about label mapping for cityscapes dataset

    Question about label mapping for cityscapes dataset

    When I was using part of your code about cityscapes benchmark, I met up with the error that

    IndexError: Caught IndexError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "../utils/datasets.py", line 155, in getitem img1, target1 = self.transforms(img, target) File "../utils/transforms.py", line 27, in call image, target = t(image, target) File "../utils/transforms.py", line 216, in call target = target if type(target) == str else self.label_id_map[target] IndexError: index 255 is out of bounds for dimension 0 with size 34

    It seems that the LabelMap(label_id_map_city), didn't work correctly. It's the first time to using this benchmark, so I dont know how to deal with this problem, could you plz give me some hints?

    question 
    opened by revaeb 6
  • How

    How

    First of all, thank you very much for your previous help. Now I can train on my own data set, but now I have another problem. I want to convert the output of the network into a mask file like the given label. I want to know how this should be How to do it, can you help me? Can the output use softmax and then set the threshold to generate the final mask?

    question 
    opened by userhr2333 4
  • About using a better model

    About using a better model

    I would like to ask if you have used a better model for experimentation, such as deeplab V3+. Will it bring better accuracy if you use a better model?

    opened by wing212 3
  • What's the meaning of splits?

    What's the meaning of splits?

    Thanks for your hard work!

    I am new to this question. Can you explain the meaning of splits in generate_splits.py, like setting [2, 4, 8, 20, 29.75] for cityscapes? I only know that it means the ratio of labeld data and unlabeled data and really don't know why you set those values. Furthermore, if I want to train it on my own data, how can I set this variable according to the ratio of my labeled data and unlabeled data?

    Thank you for your help.

    question 
    opened by czb2133 3
  • Sudden drop in accuracy

    Sudden drop in accuracy

    Hello, I want to ask why the accuracy has suddenly dropped, and the accuracy of my reproduced article is much lower than that of the original text. I use a single 3090ti graphics card for training. image

    question 
    opened by wing212 18
  • When I run segmentation code with my own dataset, it occurs the error...

    When I run segmentation code with my own dataset, it occurs the error...

    Hello ! When I match my dataset to the cityscapes, it does not work in the model initialization phase. RuntimeError: Error(s) in loading state_dict for DeepLab: size mismatch for classifier.0.convs.0.weight: copying a param with shape torch.Size([19, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([4, 2048, 3, 3]).

    My dataset contains only 5% labeled images. The size is 2048*1024,which is the same as the cityscapes. Could you help me find the probelm?

    Thank you very much!

    question 
    opened by grbcwq123 4
  • A warning appears during the running of the program, will this affect the accuracy?

    A warning appears during the running of the program, will this affect the accuracy?

    Warning: multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback. Original ImportError was: ModuleNotFoundError("No module named 'amp_C'",)

    The version of pytorch I installed is 1.2.0 and the version of torchvision is 0.4.0,and the version of apex is 0.1

    question 
    opened by userhr2333 2
  • [Kept for Feedback] Multi-GPU & New models

    [Kept for Feedback] Multi-GPU & New models

    Thanks for your nice work and congratulations on your good results!

    I have several questions.

    • Will your model extended to Parallel (distributed data-parallel) in the future.
    • Why don't you try to use deeplabv3+, will it lead to a better result?

    Best.

    question fixed 
    opened by TiankaiHang 21
Releases(v1.2)
Owner
Zhengyang Feng
Coder? Researcher? Artist?
Zhengyang Feng
Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline

vqvae_dwt_distiller.pytorch Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline. It allows to generate 512x512 ima

Sergei Belousov 25 Jul 19, 2022
Norm-based Analysis of Transformer

Norm-based Analysis of Transformer Implementations for 2 papers introducing to analyze Transformers using vector norms: Kobayashi+'20 Attention is Not

Goro Kobayashi 52 Dec 05, 2022
Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweeper.

Minesweeper-AI Created as part of CS50 AI's coursework. This AI makes use of knowledge entailment to calculate the best probabilities to win Minesweep

Beckham 0 Jul 20, 2022
Perception-aware multi-sensor fusion for 3D LiDAR semantic segmentation (ICCV 2021)

Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation (ICCV 2021) [中文|EN] 概述 本工作主要探索一种高效的多传感器(激光雷达和摄像头)融合点云语义分割方法。现有的多传感器融合方法主要将点云投影

ICE 126 Dec 30, 2022
This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transformer"

FlatTN This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transfor

THUHCSI 74 Nov 28, 2022
PyTorch implementation of 1712.06087 "Zero-Shot" Super-Resolution using Deep Internal Learning

Unofficial PyTorch implementation of "Zero-Shot" Super-Resolution using Deep Internal Learning Unofficial Implementation of 1712.06087 "Zero-Shot" Sup

Jacob Gildenblat 196 Nov 27, 2022
Image restoration with neural networks but without learning.

Warning! The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make s

Dmitry Ulyanov 7.4k Jan 01, 2023
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
ICCV2021 - A New Journey from SDRTV to HDRTV.

ICCV2021 - A New Journey from SDRTV to HDRTV.

XyChen 82 Dec 27, 2022
Official repository of "Investigating Tradeoffs in Real-World Video Super-Resolution"

RealBasicVSR [Paper] This is the official repository of "Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv". This repository contain

Kelvin C.K. Chan 566 Dec 28, 2022
Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective

Towards Calibrated Model for Long-Tailed Visual Recognition from Prior Perspective Zhengzhuo Xu, Zenghao Chai, Chun Yuan This is the PyTorch implement

Sincere 16 Dec 15, 2022
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
Official Code Implementation of the paper : XAI for Transformers: Better Explanations through Conservative Propagation

Official Code Implementation of The Paper : XAI for Transformers: Better Explanations through Conservative Propagation For the SST-2 and IMDB expermin

Ameen Ali 23 Dec 30, 2022
Project page for End-to-end Recovery of Human Shape and Pose

End-to-end Recovery of Human Shape and Pose Angjoo Kanazawa, Michael J. Black, David W. Jacobs, Jitendra Malik CVPR 2018 Project Page Requirements Pyt

1.4k Dec 29, 2022
Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning.

Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning Installation

Pytorch Lightning 1.6k Jan 08, 2023
Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training Consistency Shift (ICCV 2021)

Π-NAS This repository provides the evaluation code of our submitted paper: Pi-NAS: Improving Neural Architecture Search by Reducing Supernet Training

Jiqi Zhang 18 Aug 18, 2022
A small library for creating and manipulating custom JAX Pytree classes

Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree

Cristian Garcia 58 Nov 23, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

52 Nov 30, 2022
FedCV: A Federated Learning Framework for Diverse Computer Vision Tasks

FedCV: A Federated Learning Framework for Diverse Computer Vision Tasks Image Classification Dataset: Google Landmark, COCO, ImageNet Model: Efficient

FedML-AI 62 Dec 10, 2022
A lightweight python AUTOmatic-arRAY library.

A lightweight python AUTOmatic-arRAY library. Write numeric code that works for: numpy cupy dask autograd jax mars tensorflow pytorch ... and indeed a

Johnnie Gray 62 Dec 27, 2022