Deep Image Matting implementation in PyTorch

Overview

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




Comments
  • the frozen model named BEST_checkpoint.tar cannot be uncompressed

    the frozen model named BEST_checkpoint.tar cannot be uncompressed

    when I try to uncompress the frozen model it shows

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    this means the .tar file is not complete

    opened by banrenmasanxing 6
  • my own datasets are all full human body images

    my own datasets are all full human body images

    Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i process these images,i meet a questions: When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into [email protected]

    opened by lfxx 5
  • run demo.py question!

    run demo.py question!

    File "demo.py", line 84, in new_bgs = random.sample(new_bgs, 10) File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample raise ValueError("Sample larger than population") ValueError: Sample larger than population

    opened by kxcg99 5
  • Invalid BEST_checkpoint.tar ?

    Invalid BEST_checkpoint.tar ?

    Hi, thank you for the code. I tried to download the pretrained model and extract it but it dosnt work.

    tar xvf BEST_checkpoint.tar BEST_checkpoint
    

    results in

    tar: Ceci ne ressemble pas à une archive de type « tar »
    tar: On saute à l'en-tête suivant
    tar: BEST_checkpoint : non trouvé dans l'archive
    tar: Arrêt avec code d'échec à cause des erreurs précédentes
    

    anything i'm doing the wrong way ? or the provided tar is not valid ? kind reards

    opened by flocreate 4
  • How can i get the Trimaps of my pictures?

    How can i get the Trimaps of my pictures?

    Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

    opened by huangjunxiong11 3
  • can not unpack the 'BEST_checkpoint.tar'

    can not unpack the 'BEST_checkpoint.tar'

    When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

    opened by huangjunxiong11 3
  • Demo error

    Demo error

    /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "demo.py", line 69, in checkpoint = torch.load(checkpoint) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load return _load(f, map_location, pickle_module) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load result = unpickler.load() File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load data_type(size), location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location result = fn(storage, location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize device = validate_cuda_device(location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device raise RuntimeError('Attempting to deserialize object on a CUDA ' RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

    opened by Mlt123 3
  • Deep-Image-Matting-v2 implemetation on Android

    Deep-Image-Matting-v2 implemetation on Android

    Hi, Thanks for you work! its looking awesome output. I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project? Can you please give some suggestions? Thanks in advance.

    opened by charlizesmith 3
  • unable to start training using pretrained weigths

    unable to start training using pretrained weigths

    whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

    python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

    /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "train.py", line 180, in main() File "train.py", line 176, in main train_net(args) File "train.py", line 71, in train_net logger=logger) File "train.py", line 112, in train alpha_out = model(img) # [N, 3, 320, 320] File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward if t.device != self.src_device_obj: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

    opened by dev-srikanth 3
  • v2 didn't performance well as v1?

    v2 didn't performance well as v1?

    Hi, thanks for your pretrained model! I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1. the image: WechatIMG226 the origin tri map: test7_tri the v1 output: WechatIMG225 the v2 output: test7_result

    do you know what's the problem?

    Thanks,

    opened by MarSaKi 3
  • Questions about the PyTorch version and an issue in training regarding to the batch size

    Questions about the PyTorch version and an issue in training regarding to the batch size

    Hi,

    Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

    I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

    Traceback (most recent call last): File "train.py", line 171, in main() File "train.py", line 167, in main train_net(args) File "train.py", line 64, in train_net logger=logger) File "train.py", line 103, in train alpha_out = model(img) # [N, 3, 320, 320] File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward up4 = self.up4(up5, indices_4, unpool_shape4) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward outputs = self.conv(outputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward outputs = self.cbr_unit(inputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward self.padding, self.dilation, self.groups) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    I have tested the DATA PARALLELISM using the example here and it works well.

    opened by wuyujack 3
Owner
Yang Liu
Algorithm engineer
Yang Liu
Pyramid addon for OpenAPI3 validation of requests and responses.

Validate Pyramid views against an OpenAPI 3.0 document Peace of Mind The reason this package exists is to give you peace of mind when providing a REST

Pylons Project 79 Dec 30, 2022
Progressive Growing of GANs for Improved Quality, Stability, and Variation

Progressive Growing of GANs for Improved Quality, Stability, and Variation — Official TensorFlow implementation of the ICLR 2018 paper Tero Karras (NV

Tero Karras 5.9k Jan 05, 2023
UnFlow: Unsupervised Learning of Optical Flow with a Bidirectional Census Loss

UnFlow: Unsupervised Learning of Optical Flow with a Bidirectional Census Loss This repository contains the TensorFlow implementation of the paper UnF

Simon Meister 270 Nov 06, 2022
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
Hands-On Machine Learning for Algorithmic Trading, published by Packt

Hands-On Machine Learning for Algorithmic Trading Hands-On Machine Learning for Algorithmic Trading, published by Packt This is the code repository fo

Packt 981 Dec 29, 2022
AdamW optimizer and cosine learning rate annealing with restarts

AdamW optimizer and cosine learning rate annealing with restarts This repository contains an implementation of AdamW optimization algorithm and cosine

Maksym Pyrozhok 133 Dec 20, 2022
KwaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%)

KuaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%) KuaiRec is a real-world dataset collected from the recommendation log

Chongming GAO (高崇铭) 70 Dec 28, 2022
Learning High-Speed Flight in the Wild

Learning High-Speed Flight in the Wild This repo contains the code associated to the paper Learning Agile Flight in the Wild. For more information, pl

Robotics and Perception Group 391 Dec 29, 2022
Fermi Problems: A New Reasoning Challenge for AI

Fermi Problems: A New Reasoning Challenge for AI Fermi Problems are questions whose answer is a number that can only be reasonably estimated as a prec

AI2 15 May 28, 2022
Official repository of PanoAVQA: Grounded Audio-Visual Question Answering in 360° Videos (ICCV 2021)

Pano-AVQA Official repository of PanoAVQA: Grounded Audio-Visual Question Answering in 360° Videos (ICCV 2021) [Paper] [Poster] [Video] Getting Starte

Heeseung Yun 9 Dec 23, 2022
My implementation of Image Inpainting - A deep learning Inpainting model

Image Inpainting What is Image Inpainting Image inpainting is a restorative process that allows for the fixing or removal of unwanted parts within ima

Joshua V Evans 1 Dec 12, 2021
Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning

Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning This is the code for implementing the MADDPG algorithm presented in

97 Dec 21, 2022
Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021

Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021 Global Pooling, More than Meets the Eye: Posi

Md Amirul Islam 32 Apr 24, 2022
Dyalog-apl-docset - Dyalog APL Dash Docset Generator

Dyalog APL Dash Docset Generator o alasa e kili sona kepeken tenpo lili a A Dash

Maciej Goszczycki 1 Jan 10, 2022
ParaGen is a PyTorch deep learning framework for parallel sequence generation

ParaGen is a PyTorch deep learning framework for parallel sequence generation. Apart from sequence generation, ParaGen also enhances various NLP tasks, including sequence-level classification, extrac

Bytedance Inc. 169 Dec 22, 2022
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

💳 MONIFY (EXPENSE TRACKER PRO) 💳 Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
Cognition-aware Cognate Detection

Cognition-aware Cognate Detection The repository which contains our code for our EACL 2021 paper titled, "Cognition-aware Cognate Detection". This wor

Prashant K. Sharma 1 Feb 01, 2022
Implementation of Bidirectional Recurrent Independent Mechanisms (Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neural Networks with Attention over Modules)

BRIMs Bidirectional Recurrent Independent Mechanisms Implementation of the paper Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neura

Sarthak Mittal 26 May 26, 2022
This provides the R code and data to replicate results in "The USS Trustee’s risky strategy"

USSBriefs2021 This provides the R code and data to replicate results in "The USS Trustee’s risky strategy" by Neil M Davies, Jackie Grant and Chin Yan

1 Oct 30, 2021
This repository implements WGAN_GP.

Image_WGAN_GP This repository implements WGAN_GP. Image_WGAN_GP This repository uses wgan to generate mnist and fashionmnist pictures. Firstly, you ca

Lieon 6 Dec 10, 2021