Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Overview

Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Paper, Project Page

This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistecy Training, which adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised learning and learning on multiple domains.

Highlights

(1) Consistency Training for semantic segmentation.
We observe that for semantic segmentation, due to the dense nature of the task, the cluster assumption is more easily enforced over the hidden representations rather than the inputs.

(2) Cross-Consistecy Training.
We propose CCT (Cross-Consistecy Training) for semi-supervised semantic segmentation, where we define a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs rather than the inputs.

(3) Using weak-labels and pixel-level labels from multiple domains.
The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and pixel-level labels from multiple-domains.

Requirements

This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.

The required packages are pytorch and torchvision, together with PIL and opencv for data-preprocessing and tqdm for showing the training progress. With some additional modules like dominate to save the results in the form of HTML files. To setup the necessary modules, simply run:

pip install -r requirements.txt

Dataset

In this repo, we use Pascal VOC, to obtain it, first download the original dataset, after extracting the files we'll end up with VOCtrainval_11-May-2012/VOCdevkit/VOC2012 containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.
The second step is to augment the dataset using the additionnal annotations provided by Semantic Contours from Inverse Detectors. Download the rest of the annotations SegmentationClassAug and add them to the path VOCtrainval_11-May-2012/VOCdevkit/VOC2012, now we're set, for training use the path to VOCtrainval_11-May-2012.

Training

To train a model, first download PASCAL VOC as detailed above, then set data_dir to the dataset path in the config file in configs/config.json and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:

python train.py --config configs/config.json

The log files and the .pth checkpoints will be saved in saved\EXP_NAME, to monitor the training using tensorboard, please run:

tensorboard --logdir saved

To resume training using a saved .pth model:

python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth

Results: The results will be saved in saved as an html file, containing the validation results, and the name it will take is experim_name specified in configs/config.json.

Pseudo-labels

If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels using the code in pseudo_labels:

cd pseudo_labels
python run.py --voc12_root DATA_PATH

DATA_PATH must point to the folder containing JPEGImages in Pascal Voc dataset. The results will be saved in pseudo_labels/result/pseudo_labels as PNG files, the flag use_weak_labels needs to be set to True in the config file, and then we can train the model as detailed above.

Inference

For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),

python inference.py --config config.json --model best_model.pth --images images_folder

The predictions will be saved as .png images in outputs\ is used, for Pacal VOC the default palette is:

Here are the flags available for inference:

--images       Folder containing the jpg images to segment.
--model        Path to the trained pth model.
--config       The config file used for training the model.

Pre-trained models

Pre-trained models can be downloaded here.

Citation ✏️ 📄

If you find this repo useful for your research, please consider citing the paper as follows:

@InProceedings{Ouali_2020_CVPR,
  author = {Ouali, Yassine and Hudelot, Celine and Tami, Myriam},
  title = {Semi-Supervised Semantic Segmentation With Cross-Consistency Training},
  booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2020}
}

For any questions, please contact Yassine Ouali.

Config file details ⚙️

Bellow we detail the CCT parameters that can be controlled in the config file configs/config.json, the rest of the parameters are self-explanatory.

{
    "name": "CCT",                              
    "experim_name": "CCT",                             // The name the results will take (html and the folder in /saved)
    "n_gpu": 1,                                             // Number of GPUs
    "n_labeled_examples": 1000,                             // Number of labeled examples (choices are 60, 100, 200, 
                                                            // 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
    "diff_lrs": true,
    "ramp_up": 0.1,                                         // The unsupervised loss will be slowly scaled up in the first 10% of Training time
    "unsupervised_w": 30,                                   // Weighting of the unsupervised loss
    "ignore_index": 255,
    "lr_scheduler": "Poly",
    "use_weak_labels": false,                               // If the pseudo-labels were generated, we can use them to train the aux. decoders
    "weakly_loss_w": 0.4,                                   // Weighting of the weakly-supervised loss
    "pretrained": true,

    "model":{
        "supervised": true,                                  // Supervised setting (training only on the labeled examples)
        "semi": false,                                       // Semi-supervised setting
        "supervised_w": 1,                                   // Weighting of the supervised loss

        "sup_loss": "CE",                                    // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
        "un_loss": "MSE",                                    // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]

        "softmax_temp": 1,
        "aux_constraint": false,                             // Pair-wise loss (sup. mat.)
        "aux_constraint_w": 1,
        "confidence_masking": false,                         // Confidence masking (sup. mat.)
        "confidence_th": 0.5,

        "drop": 6,                                           // Number of DropOut decoders
        "drop_rate": 0.5,                                    // Dropout probability
        "spatial": true,
    
        "cutout": 6,                                         // Number of G-Cutout decoders
        "erase": 0.4,                                        // We drop 40% of the area
    
        "vat": 2,                                            // Number of I-VAT decoders
        "xi": 1e-6,                                          // VAT parameters
        "eps": 2.0,

        "context_masking": 2,                               // Number of Con-Msk decoders
        "object_masking": 2,                                // Number of Obj-Msk decoders
        "feature_drop": 6,                                  // Number of F-Drop decoders

        "feature_noise": 6,                                 // Number of F-Noise decoders
        "uniform_range": 0.3                                // The range of the noise
    },

Acknowledgements

  • Pseudo-labels generation is based on Jiwoon Ahn's implementation irn.
  • Code structure was based on Pytorch-Template
  • ResNet backbone was downloaded from torchcv
Comments
  • custom dataset with 4 classes

    custom dataset with 4 classes

    Thank you so far for all your great help. I have an issue that I also found in the closed issues, but for me it isn't solved. I have my own custom data set with 4 classes (background and 3 objects, labeled 0-3), so I changed num_classes = 4 in voc.py The results with training fully supervised are as shown in the image below. There is one class with an IoU of 0.0. image I ran multiple tests, using semi and weakly supervised settings, the results are unpredictable and often show 0.0 for the object classes. Only the background has good results. Is there something I need to adjust in the code?

    opened by SuzannaLin 22
  • Training error!

    Training error!

    I want to train VOC2012, but get the error below:

    Traceback (most recent call last):
      File "train.py", line 98, in <module>
        main(config, args.resume)
      File "train.py", line 82, in main
        trainer.train()
      File "/home/byronnar/bigfile/projects/CCT/base/base_trainer.py", line 91, in train
        results = self._train_epoch(epoch)
      File "/home/byronnar/bigfile/projects/CCT/trainer.py", line 76, in _train_epoch
        total_loss, cur_losses, outputs = self.model(x_l=input_l, target_l=target_l, x_ul=input_ul, curr_iter=batch_idx, target_ul=target_ul, epoch=epoch-1)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/model.py", line 93, in forward
        output_l = self.main_decoder(self.encoder(x_l))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 61, in forward
        x = self.psp(x)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in forward
        align_corners=False) for stage in self.stages])
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in <listcomp>
        align_corners=False) for stage in self.stages])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
        input = module(input)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
        exponential_average_factor, self.eps)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1652, in batch_norm
        raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
    ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
      0%|                                                                                                         | 0/9118 [00:02<?, ?it/s]
    

    How should I do? Thank you

    opened by Byronnar 12
  • Poor mIoU when training on 1464 images in a supervised manner

    Poor mIoU when training on 1464 images in a supervised manner

    Hi, I trained the model on the 1464 images in a supervised manner. The highest mIoU on Val is 67.00%, but the reported number in your paper is 69.4%. Here is my config.json file. Can you have a look at which part is wrong?

     {  "name": "CCT",
        "experim_name": "CCT",
        "n_gpu": 1,
        "n_labeled_examples": 1464,
        "diff_lrs": true,
        "ramp_up": 0.1,
        "unsupervised_w": 30,
        "ignore_index": 255,
        "lr_scheduler": "Poly",
        "use_weak_lables":false,
        "weakly_loss_w": 0.4,
        "pretrained": true,
        "model":{
            "supervised": true,
            "semi": false,
            "supervised_w": 1,
    
            "sup_loss": "CE",
            "un_loss": "MSE",
    
            "softmax_temp": 1,
            "aux_constraint": false,
            "aux_constraint_w": 1,
            "confidence_masking": false,
            "confidence_th": 0.5,
    
            "drop": 6,
            "drop_rate": 0.5,
            "spatial": true,
        
            "cutout": 6,
            "erase": 0.4,
        
            "vat": 2,
            "xi": 1e-6,
            "eps": 2.0,
    
            "context_masking": 2,
            "object_masking": 2,
            "feature_drop": 6,
    
            "feature_noise": 6,
            "uniform_range": 0.3
        },
    
    
        "optimizer": {
            "type": "SGD",
            "args":{
                "lr": 1e-2,
                "weight_decay": 1e-4,
                "momentum": 0.9
            }
        },
    
    
        "train_supervised": {
            "data_dir": "../data/VOC2012",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_supervised",
            "num_workers": 8
        },
    
        "train_unsupervised": {
            "data_dir": "VOCtrainval_11-May-2012",
            "weak_labels_output": "pseudo_labels/result/pseudo_labels",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_unsupervised",
            "num_workers": 8
        },
    
        "val_loader": {
            "data_dir": "../data/VOC2012",
            "batch_size": 1,
            "val": true,
            "split": "val",
            "shuffle": false,
            "num_workers": 4
        },
    
        "trainer": {
            "epochs": 80,
            "save_dir": "saved/",
            "save_period": 5,
      
            "monitor": "max Mean_IoU",
            "early_stop": 10,
            
            "tensorboardX": true,
            "log_dir": "saved/",
            "log_per_iter": 20,
    
            "val": true,
            "val_per_epochs": 5
        }
    }
    
    opened by xiaomengyc 11
  • Fail to reimplement your paper's result for semi-supervised.

    Fail to reimplement your paper's result for semi-supervised.

    I use the default config file to conduct experiments, but I only got 68.9mIoU for not adopting weak label and got 70.09mIoU for adopting weak label following your readme. These results are far lower than yours. My env is pytorch 1.7.0 and python 3.8.5. Could provide some advice?

    opened by TyroneLi 8
  • checkerboard

    checkerboard

    Hi Yassine, I am using the CCT model to train on a satellite dataset. The images are size 128x128. For some reason the predictions show a clear checkerboard pattern as shown in this example. Left: prediction, Right: ground truth. image Do you have any idea what causes this and how to avoid it?

    opened by SuzannaLin 7
  • inference with 4-channel model

    inference with 4-channel model

    Hi Yassine! I have managed to train a model with 4 channels, but the inference is not working. I get this error message:

    !python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'

    Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth Traceback (most recent call last): File "inference.py", line 155, in main() File "inference.py", line 102, in main conf=config['model'], testing=True, pretrained = True) File "/home/scuypers/CCT_4/models/model.py", line 55, in init self.encoder = Encoder(pretrained=pretrained) File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone orig_resnet = deepbase_resnet50(pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50 model = ModuleHelper.load_model(model, pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model model.load_state_dict(load_dict) File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

    opened by SuzannaLin 7
  • Loss function

    Loss function

    Thank you for your contribution. I want to know what the |Dl| and |Du| in your cross-entropy loss function formula and semi-supervised loss function formula represent, thank you for your answer.

    opened by yrcrcy 7
  • Reproducing Cross-domain Experiments

    Reproducing Cross-domain Experiments

    @yassouali @SuzannaLin Firstly, thanks a lot for your great work! But I still have problems in reproducing the cross-domain experiments.

    I have seen the implementation from https://github.com/tarun005/USSS_ICCV19, however, I notice that there are still many alternatives for the training schemes.

    During training, in each iteration (i.e. each execution of 'optimizer.step()'), do you train the network by forwarding inputs from both two datasets? Or by forwarding inputs from only one dataset in the current iteration, and then executing "optimizer.step()", and then forwarding inputs from the other dataset in the next iteration, and so on?

    Also, are there any tricks to deal with the data imbalance situation, e.g. the CamVid dataset only contains 367 images while the Cityscapes dataset has 2975 training images? (Just like constructing a training batch with different ratios for two datasets or other sorts of things)

    Besides, can you give some hints on hyperparameters, e.g. the number of training iterations, batch size, learning rate, weight decay?

    Looking forward to your reply! Thanks a lot!

    opened by X-Lai 6
  • low performance for full supervised setting

    low performance for full supervised setting

    I modified the config file to set the code to 'supervised' mode, but the result seems to be very low: Epoch : 40 | Mean_IoU : 0.699999988079071 | PixelAcc : 0.933 | Val Loss : 0.26163 compared with 'semi' mode:

    Epoch : 40 | Mean_IoU : 0.7120000123977661 | PixelAcc : 0.931 | Val Loss : 0.31637

    Note that I have changed the supervised list to the 10k+ augmented list in the 'supervised' setting. Did I miss something here?

    opened by zhangyuygss 6
  • How to obtain figure 2(d)

    How to obtain figure 2(d)

    Hi, thank you for your nice work!

    I want to know how to produce figure2(d)? There are 2048 channels for hidden representation, how to visualize? Thanks for your help!

    opened by reluuu 5
  • low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    Thank you for your nice work!

    I tried to training the model with 1464 labeled samples in semi-supervised mode, and I used 2 gpus. I set the epoch as 80, and stop it after 50 epoch. But the performance is poor, e.g., miou at epoch 5 is 34.70% while at epoch 10 is 11.40%. image

    I set the 'use_weak_labels' as true, the 'drop_last' as false, and the rest are default.

    Have you ever met this situation?

    opened by wqhIris 5
Alphabetical Letter Recognition

DecisionTrees-Image-Classification Alphabetical Letter Recognition In these demo we are using "Decision Trees" Our database is composed by Learning Im

Mohammed Firass 4 Nov 30, 2021
Code for "Learning the Best Pooling Strategy for Visual Semantic Embedding", CVPR 2021

Learning the Best Pooling Strategy for Visual Semantic Embedding Official PyTorch implementation of the paper Learning the Best Pooling Strategy for V

Jiacheng Chen 106 Jan 06, 2023
Python Implementation of algorithms in Graph Mining, e.g., Recommendation, Collaborative Filtering, Community Detection, Spectral Clustering, Modularity Maximization, co-authorship networks.

Graph Mining Author: Jiayi Chen Time: April 2021 Implemented Algorithms: Network: Scrabing Data, Network Construbtion and Network Measurement (e.g., P

Jiayi Chen 3 Mar 03, 2022
Train CPPNs as a Generative Model, using Generative Adversarial Networks and Variational Autoencoder techniques to produce high resolution images.

cppn-gan-vae tensorflow Train Compositional Pattern Producing Network as a Generative Model, using Generative Adversarial Networks and Variational Aut

hardmaru 343 Dec 29, 2022
Pytorch implementation of the paper Improving Text-to-Image Synthesis Using Contrastive Learning

T2I_CL This is the official Pytorch implementation of the paper Improving Text-to-Image Synthesis Using Contrastive Learning Requirements Linux Python

42 Dec 31, 2022
Codecov coverage standard for Python

Python-Standard Last Updated: 01/07/22 00:09:25 What is this? This is a Python application, with basic unit tests, for which coverage is uploaded to C

Codecov 10 Nov 04, 2022
Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021)

Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021) The implementation of Reducing Infromation Bottleneck for W

Jungbeom Lee 81 Dec 16, 2022
MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets)

MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets) Using mixup data augmentation as reguliraztion and tuning the hyper par

Bhanu 2 Jan 16, 2022
The Deep Learning with Julia book, using Flux.jl.

Deep Learning with Julia DL with Julia is a book about how to do various deep learning tasks using the Julia programming language and specifically the

Logan Kilpatrick 67 Dec 25, 2022
Adversarial Attacks are Reversible via Natural Supervision

Adversarial Attacks are Reversible via Natural Supervision ICCV2021 Citation @InProceedings{Mao_2021_ICCV, author = {Mao, Chengzhi and Chiquier

Computer Vision Lab at Columbia University 20 May 22, 2022
Rank1 Conversation Emotion Detection Task

Rank1-Conversation_Emotion_Detection_Task accuracy macro-f1 recall 0.826 0.7544 0.719 基于预训练模型和时序预测模型的对话情感探测任务 1 摘要 针对对话情感探测任务,本文将其分为文本分类和时间序列预测两个子任务,分

Yuchen Han 2 Nov 28, 2021
End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021)

PDVC Official implementation for End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021) [paper] [valse论文速递(Chinese)] This repo supports:

Teng Wang 118 Dec 16, 2022
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
MNIST, but with Bezier curves instead of pixels

bezier-mnist This is a work-in-progress vector version of the MNIST dataset. Samples Here are some samples from the training set. Note that, while the

Alex Nichol 15 Jan 16, 2022
💃 VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena

💃 VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena.

Heidelberg-NLP 17 Nov 07, 2022
Neural style in TensorFlow! 🎨

neural-style An implementation of neural style in TensorFlow. This implementation is a lot simpler than a lot of the other ones out there, thanks to T

Anish Athalye 5.5k Dec 29, 2022
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Facebook Research 75 Dec 19, 2022
Revisting Open World Object Detection

Revisting Open World Object Detection Installation See INSTALL.md. Dataset Our n

58 Dec 23, 2022
tf2-keras implement yolov5

YOLOv5 in tesnorflow2.x-keras yolov5数据增强jupyter示例 Bilibili视频讲解地址: 《yolov5 解读,训练,复现》 Bilibili视频讲解PPT文件: yolov5_bilibili_talk_ppt.pdf Bilibili视频讲解PPT文件:

yangcheng 254 Jan 08, 2023
[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning

Rethinking the Value of Labels for Improving Class-Imbalanced Learning This repository contains the implementation code for paper: Rethinking the Valu

Yuzhe Yang 656 Dec 28, 2022