Deep Image Matting implementation in PyTorch

Overview

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




Comments
  • the frozen model named BEST_checkpoint.tar cannot be uncompressed

    the frozen model named BEST_checkpoint.tar cannot be uncompressed

    when I try to uncompress the frozen model it shows

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    this means the .tar file is not complete

    opened by banrenmasanxing 6
  • my own datasets are all full human body images

    my own datasets are all full human body images

    Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i process these images,i meet a questions: When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into [email protected]

    opened by lfxx 5
  • run demo.py question!

    run demo.py question!

    File "demo.py", line 84, in new_bgs = random.sample(new_bgs, 10) File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample raise ValueError("Sample larger than population") ValueError: Sample larger than population

    opened by kxcg99 5
  • Invalid BEST_checkpoint.tar ?

    Invalid BEST_checkpoint.tar ?

    Hi, thank you for the code. I tried to download the pretrained model and extract it but it dosnt work.

    tar xvf BEST_checkpoint.tar BEST_checkpoint
    

    results in

    tar: Ceci ne ressemble pas à une archive de type « tar »
    tar: On saute à l'en-tête suivant
    tar: BEST_checkpoint : non trouvé dans l'archive
    tar: Arrêt avec code d'échec à cause des erreurs précédentes
    

    anything i'm doing the wrong way ? or the provided tar is not valid ? kind reards

    opened by flocreate 4
  • How can i get the Trimaps of my pictures?

    How can i get the Trimaps of my pictures?

    Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

    opened by huangjunxiong11 3
  • can not unpack the 'BEST_checkpoint.tar'

    can not unpack the 'BEST_checkpoint.tar'

    When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

    opened by huangjunxiong11 3
  • Demo error

    Demo error

    /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "demo.py", line 69, in checkpoint = torch.load(checkpoint) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load return _load(f, map_location, pickle_module) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load result = unpickler.load() File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load data_type(size), location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location result = fn(storage, location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize device = validate_cuda_device(location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device raise RuntimeError('Attempting to deserialize object on a CUDA ' RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

    opened by Mlt123 3
  • Deep-Image-Matting-v2 implemetation on Android

    Deep-Image-Matting-v2 implemetation on Android

    Hi, Thanks for you work! its looking awesome output. I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project? Can you please give some suggestions? Thanks in advance.

    opened by charlizesmith 3
  • unable to start training using pretrained weigths

    unable to start training using pretrained weigths

    whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

    python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

    /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "train.py", line 180, in main() File "train.py", line 176, in main train_net(args) File "train.py", line 71, in train_net logger=logger) File "train.py", line 112, in train alpha_out = model(img) # [N, 3, 320, 320] File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward if t.device != self.src_device_obj: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

    opened by dev-srikanth 3
  • v2 didn't performance well as v1?

    v2 didn't performance well as v1?

    Hi, thanks for your pretrained model! I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1. the image: WechatIMG226 the origin tri map: test7_tri the v1 output: WechatIMG225 the v2 output: test7_result

    do you know what's the problem?

    Thanks,

    opened by MarSaKi 3
  • Questions about the PyTorch version and an issue in training regarding to the batch size

    Questions about the PyTorch version and an issue in training regarding to the batch size

    Hi,

    Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

    I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

    Traceback (most recent call last): File "train.py", line 171, in main() File "train.py", line 167, in main train_net(args) File "train.py", line 64, in train_net logger=logger) File "train.py", line 103, in train alpha_out = model(img) # [N, 3, 320, 320] File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward up4 = self.up4(up5, indices_4, unpool_shape4) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward outputs = self.conv(outputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward outputs = self.cbr_unit(inputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward self.padding, self.dilation, self.groups) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    I have tested the DATA PARALLELISM using the example here and it works well.

    opened by wuyujack 3
Owner
Yang Liu
Algorithm engineer
Yang Liu
PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network"

HAN PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network" This repository is for HAN introduced in the

五维空间 140 Nov 23, 2022
An Extendible (General) Continual Learning Framework based on Pytorch - official codebase of Dark Experience for General Continual Learning

Mammoth - An Extendible (General) Continual Learning Framework for Pytorch NEWS STAY TUNED: We are working on an update of this repository to include

AImageLab 277 Dec 28, 2022
Using modified BiSeNet for face parsing in PyTorch

face-parsing.PyTorch Contents Training Demo References Training Prepare training data: -- download CelebAMask-HQ dataset -- change file path in the pr

zll 1.6k Jan 08, 2023
Syed Waqas Zamir 906 Dec 30, 2022
Balancing Principle for Unsupervised Domain Adaptation

Blancing Principle for Domain Adaptation NeurIPS 2021 Paper Abstract We address the unsolved algorithm design problem of choosing a justified regulari

Marius-Constantin Dinu 4 Dec 15, 2022
A Comprehensive Study on Learning-Based PE Malware Family Classification Methods

A Comprehensive Study on Learning-Based PE Malware Family Classification Methods Datasets Because of copyright issues, both the MalwareBazaar dataset

8 Oct 21, 2022
ANEA: Automated (Named) Entity Annotation for German Domain-Specific Texts

ANEA The goal of Automatic (Named) Entity Annotation is to create a small annotated dataset for NER extracted from German domain-specific texts. Insta

Anastasia Zhukova 2 Oct 07, 2022
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers

SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Gu

Chen Liang 23 Nov 07, 2022
A PyTorch implementation of "DGC-Net: Dense Geometric Correspondence Network"

DGC-Net: Dense Geometric Correspondence Network This is a PyTorch implementation of our work "DGC-Net: Dense Geometric Correspondence Network" TL;DR A

191 Dec 16, 2022
Generative Adversarial Text-to-Image Synthesis

###Generative Adversarial Text-to-Image Synthesis Scott Reed, Zeynep Akata, Xinchen Yan, Lajanugen Logeswaran, Bernt Schiele, Honglak Lee This is the

Scott Ellison Reed 883 Dec 31, 2022
A Dying Light 2 (DL2) PAKFile Utility for Modders and Mod Makers.

Dying Light 2 PAKFile Utility A Dying Light 2 (DL2) PAKFile Utility for Modders and Mod Makers. This tool aims to make PAKFile (.pak files) modding a

RHQ Online 12 Aug 26, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022
Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm.

REDQ source code Author's PyTorch implementation of Randomized Ensembled Double Q-Learning (REDQ) algorithm. Paper link: https://arxiv.org/abs/2101.05

109 Dec 16, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Dictionary Learning with Uniform Sparse Representations for Anomaly Detection

Dictionary Learning with Uniform Sparse Representations for Anomaly Detection Implementation of the Uniform DL Representation for AD algorithm describ

Paul Irofti 1 Nov 23, 2022
A Context-aware Visual Attention-based training pipeline for Object Detection from a Webpage screenshot!

CoVA: Context-aware Visual Attention for Webpage Information Extraction Abstract Webpage information extraction (WIE) is an important step to create k

Keval Morabia 41 Jan 01, 2023
3DV 2021: Synergy between 3DMM and 3D Landmarks for Accurate 3D Facial Geometry

SynergyNet 3DV 2021: Synergy between 3DMM and 3D Landmarks for Accurate 3D Facial Geometry Cho-Ying Wu, Qiangeng Xu, Ulrich Neumann, CGIT Lab at Unive

Cho-Ying Wu 239 Jan 06, 2023
Evaluating AlexNet features at various depths

Linear Separability Evaluation This repo provides the scripts to test a learned AlexNet's feature representation performance at the five different con

Yuki M. Asano 32 Dec 30, 2022
Official Pytorch implementation of 'RoI Tanh-polar Transformer Network for Face Parsing in the Wild.'

Official Pytorch implementation of 'RoI Tanh-polar Transformer Network for Face Parsing in the Wild.'

Jie Shen 125 Jan 08, 2023