Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Overview

Robust Video Matting (RVM)

Teaser

English | 中文

Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specifically designed for robust human video matting. Unlike existing neural models that process frames as independent images, RVM uses a recurrent neural network to process videos with temporal memory. RVM can perform matting in real-time on any videos without additional inputs. It achieves 4K 76FPS and HD 104FPS on an Nvidia GTX 1080 Ti GPU. The project was developed at ByteDance Inc.


News

  • [Aug 25 2021] Source code and pretrained models are published.
  • [Jul 27 2021] Paper is accepted by WACV 2022.

Showreel

Watch the showreel video (YouTube, Bilibili) to see the model's performance.

All footage in the video are available in Google Drive and Baidu Pan (code: tb3w).


Demo

  • Webcam Demo: Run the model live in your browser. Visualize recurrent states.
  • Colab Demo: Test our model on your own videos with free GPU.

Download

We recommend MobileNetv3 models for most use cases. ResNet50 models are the larger variant with small performance improvements. Our model is available on various inference frameworks. See inference documentation for more instructions.

Framework Download Notes
PyTorch rvm_mobilenetv3.pth
rvm_resnet50.pth
Official weights for PyTorch. Doc
TorchHub Nothing to Download. Easiest way to use our model in your PyTorch project. Doc
TorchScript rvm_mobilenetv3_fp32.torchscript
rvm_mobilenetv3_fp16.torchscript
rvm_resnet50_fp32.torchscript
rvm_resnet50_fp16.torchscript
If inference on mobile, consider export int8 quantized models yourself. Doc
ONNX rvm_mobilenetv3_fp32.onnx
rvm_mobilenetv3_fp16.onnx
rvm_resnet50_fp32.onnx
rvm_resnet50_fp16.onnx
Tested on ONNX Runtime with CPU and CUDA backends. Provided models use opset 12. Doc, Exporter.
TensorFlow rvm_mobilenetv3_tf.zip
rvm_resnet50_tf.zip
TensorFlow 2 SavedModel. Doc
TensorFlow.js rvm_mobilenetv3_tfjs_int8.zip
Run the model on the web. Demo, Starter Code
CoreML rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel
rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodel
rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel
CoreML does not support dynamic resolution. Other resolutions can be exported yourself. Models require iOS 13+. s denotes downsample_ratio. Doc, Exporter

All models are available in Google Drive and Baidu Pan (code: gym7).


PyTorch Example

  1. Install dependencies:
pip install -r requirements_inference.txt
  1. Load the model:
import torch
from model import MattingNetwork

model = MattingNetwork('mobilenetv3').eval().cuda()  # or "resnet50"
model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
  1. To convert videos, we provide a simple conversion API:
from inference import convert_video

convert_video(
    model,                           # The model, can be on any device (cpu or cuda).
    input_source='input.mp4',        # A video file or an image sequence directory.
    output_type='video',             # Choose "video" or "png_sequence"
    output_composition='output.mp4', # File path if video; directory path if png sequence.
    output_video_mbps=4,             # Output video mbps. Not needed for png sequence.
    downsample_ratio=None,           # A hyperparameter to adjust or use None for auto.
    seq_chunk=12,                    # Process n frames at once for better parallelism.
)
  1. Or write your own inference code:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from inference_utils import VideoReader, VideoWriter

reader = VideoReader('input.mp4', transform=ToTensor())
writer = VideoWriter('output.mp4', frame_rate=30)

bgr = torch.tensor([.47, 1, .6]).view(3, 1, 1).cuda()  # Green background.
rec = [None] * 4                                       # Initial recurrent states.
downsample_ratio = 0.25                                # Adjust based on your video.

with torch.no_grad():
    for src in DataLoader(reader):                     # RGB tensor normalized to 0 ~ 1.
        fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio)  # Cycle the recurrent states.
        com = fgr * pha + bgr * (1 - pha)              # Composite to green background. 
        writer.write(com)                              # Write frame.
  1. The models and converter API are also available through TorchHub.
# Load the model.
model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # or "resnet50"

# Converter API.
convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")

Please see inference documentation for details on downsample_ratio hyperparameter, more converter arguments, and more advanced usage.


Training and Evaluation

Please refer to the training documentation to train and evaluate your own model.


Speed

Speed is measured with inference_speed_test.py for reference.

GPU dType HD (1920x1080) 4K (3840x2160)
RTX 3090 FP16 172 FPS 154 FPS
RTX 2060 Super FP16 134 FPS 108 FPS
GTX 1080 Ti FP32 104 FPS 74 FPS
  • Note 1: HD uses downsample_ratio=0.25, 4K uses downsample_ratio=0.125. All tests use batch size 1 and frame chunk 1.
  • Note 2: GPUs before Turing architecture does not support FP16 inference, so GTX 1080 Ti uses FP32.
  • Note 3: We only measure tensor throughput. The provided video conversion script in this repo is expected to be much slower, because it does not utilize hardware video encoding/decoding and does not have the tensor transfer done on parallel threads. If you are interested in implementing hardware video encoding/decoding in Python, please refer to PyNvCodec.

Project Members

Comments
  • [Questions] - Training Procedure

    [Questions] - Training Procedure

    Hi,

    I have some questions about the training procedure:

    1. In the paper, you've mentioned training Stage 1, for 15 epochs, while in the code you've set the instructions to 20 epochs. Is there a reason for such change? Will the results be similar?
    2. I could not get access to Distinctions-646, I had no reply from the authors/maintainers of the dataset. Based on your indicated file structure, I've built a similar dataset, which adds uncertainty to the quality of my training, but it is a risk I am willing to take. To have a comparison parameter (stages 1-3) do not depend on this dataset, would you mind sharing your partial training weights on pytorch (stage1/epoch19.pth, stage2/epoch21.pth, and stage3/epoch22.pth)?
    3. What is the min resolution you've used for the background images while training?

    For the 3rd time, thank you very much for your contribution to the field. It was a brilliant work. Looking forward to your future work.

    opened by SamHSlva 39
  • hardsigmoid replacement

    hardsigmoid replacement

    I've been trying to export an onnx model replacing the hardsigmoid operator.

    I have modified the site-packages/torch/onnx/symbolic_opset9.py file this way:

    @parse_args("v") def hardswish(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    @parse_args("v") def hardsigmoid(g, self): hardsigmoid = g.op('HardSigmoid', self, alpha_f=1 / 6) return g.op("Mul", self, hardsigmoid)

    But I am not sure at all if this is the way to replace them with primitive ops

    When I export the onnx with this change I still get and error "OnnxImportException: Unknown type HardSigmoid encountered while parsing layer 396" with the inference engine I am trying to use.

    opened by livingbeams 15
  • VideoMatte240K-HD

    VideoMatte240K-HD

    if I'm going to train stage3 and stage4, the VideoMatte-HD data will be used. And is it right to modify the following path?VideoMatte240K_JPEG_SD to VideoMatte240K_JPEG_HD

    'videomatte': { 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', },

    opened by FengMu1995 7
  • Add Unity example to README?

    Add Unity example to README?

    Hey there, I just ported RVM to Unity using NatML, an open-source machine learning runtime. I have questions:

    1. Can I make a PR to add a link into the README to a Unity example project demonstrating using RVM?
    2. I published the model under my account on NatML Hub. Would you be interested in signing up on Hub, so that I can transfer the model to you?

    Here's the model on NatML Hub:

    @natsuite/robust-video-matting

    opened by olokobayusuf 7
  •  Some questions about training

    Some questions about training

    1.How to eliminate or reduce edge flickering problem,can i set --seq-length-lr Is it possible to increase the sequence length improvement,Does it work? 2.Only the composite image has no foreground image,Is it possible to remove foreground training,and foreground loss?or is there a better way? 3.How important is foreground prediction for matting

    Looking forward to your reply

    opened by zhanghongyong123456 5
  • Not Issue 👉 Few questions

    Not Issue 👉 Few questions

    First of all thank you for working on this project! it looks much stronger than the BMV2 !

    1. Will it work on Anaconda and Windows 10 just like BMV2 works? (not more complicated?)

    2. Will it support same hardware, or need a much more powerful CPU / GPU compare to BMV2 ?

    3. Can you please tell when will you release it again, I missed it first so I can't test it because it's still offline. It will be very nice to have it this week if possible of course.

    Thanks ahead for the answers, please keep up the good work! ❤

    opened by AlonDan 5
  • Synchronization issues between inferred mask and original video

    Synchronization issues between inferred mask and original video

    Hello!. Thanks for the code. I have had some timing issues between the inferred output in the mask compared to the original video. I made this comparison by transforming my original video and the output video from masks to frames. I have obtained the same amount of frames in both processes, so the difference can be caused by a bad configuration of mine. My original video is 30fps and 1080x1920. If you have a suggestion I would appreciate it.

    opened by italosalgado14 4
  • Weird results when use Segmentation Pass for inference

    Weird results when use Segmentation Pass for inference

    https://github.com/PeterL1n/RobustVideoMatting/blob/f8a26e27198a93a94bfd06e96b8d5a34d0660f81/inference.py#L127

    I changed this line to use Segmentation Pass. (use the pretrained weights rvm_mobilenetv3.pth)

    pha, *rec = model(src, *rec, segmentation_pass=True)
    fgr = src * pha
    

    But I got weird mask results, something like this, why?

    seg_pass_alpha

    opened by luuil 4
  • 新手问题的关于模型结果

    新手问题的关于模型结果

    大神辛苦,两个问题请教....... 1.除了更改downsample_ratio的参数值来修正抠图的精度,还可以更改那些参数来更改实现效果? 2.此项目对显卡的要求是否更高?显卡的型号会影响最后结果么? 目前,有执行model的项目,但是效果并不是很理想,再次感谢!

    Hello, two questions to consult.......

    1. In addition to changing the parameter value of downsample_ratio to correct the accuracy of matting, which other parameters can be changed to change the implementation effect?

    2. Does this project have higher requirements for graphics cards? Does the type of graphics card affect the final results?

    I have my own project to implement model, but the effect is not very ideal, thank you!(Translation from Youdao Translation)

    opened by yinjia823 4
  • FP16 is slower than FP32

    FP16 is slower than FP32

    I use pre-trained ONNX model parameters for inference tests (in Python not C++), only onnxruntime, cv2 and numpy libraries, nothing extra. Parameters downloaded from https://github.com/PeterL1n/RobustVideoMatting/releases/: rvm_mobilenetv3_fp32.onnx and rvm_mobilenetv3_fp16.onnx

    Inference on 1080x1920 video,downsample_ratio=0.25. As a result, the speed of FP32 is about 170ms (1 frame), and the speed of FP16 is about 240ms. Why is FP16 so slow?

    I have adjusted the input correctly, for src, r1i, r2i, r3i, r4i it is np.array([[[[]]]], dtype=np.float32 or 16) and for downsample_ratio it is always np.array([0.25], dtype= float32)

    I use CPU (Intel i5) for inference, Is it so slow because the CPU does not support FP16 operations?

    opened by ZachL1 3
  • foreground prediction details

    foreground prediction details

    你好,请教一下,关于前景预测,从官方提供的web demo中,我看到模型预测的前景图片中除了前景(人像)外,还存在输入图片的背景细节(非人像像素),但是我自己训练得到的模型(我的模型没有修改官方的任何细节,唯一的不同仅仅是采用我采集而来的背景图片),预测的前景图片只含有人像而不会存在输入图片的背景细节,一开始我怀疑可能是前景loss包含了所有像素(alpha可以是任何值而不仅仅是像论文中所说的大于0)的loss, 但是我查看代码后没有任何问题,和论文一致,请问这是什么原因造成? 谢谢。

    opened by li-wenquan 3
  • Add Replicate demo and

    Add Replicate demo and

    Hey @PeterL1n ! 👋

    Thanks for this wonderful project!

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can run your model! View it here: https://replicate.com/arielreplicate/robust_video_matting Replicate also have an API, so people can easily run your model from their code:

    import replicate
    model = replicate.models.get("[arielreplicate/robust_video_matting]")
    model.predict(...)
    

    If you'd like to modify the Replicate page, let me know and I can transfer ownership to your account.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ArielReplicate 0
  • Support of NCNN version

    Support of NCNN version

    Hello How are you? Thanks for contributing to this project. I tried to use NCNN version of this model on Windows C++. Of course, there are several github repos using NCNN model in C++ but I can NOT run them because there are some issue when extracting the output data from model. First, the model are old so it is impossible to run them. Could u support NCNN version here?

    opened by rose-jinyang 0
  • Problem with exporting alpha-mask on the Replicate/COG version

    Problem with exporting alpha-mask on the Replicate/COG version

    I tried both the replicate page and local COG variants. When predicting with alpha-mask, this error is a constant:

    FileNotFoundError: [Errno 2] No such file or directory: 'alpha-mask.mp4'

    Exporting with green-screen and foreground-mask works however. Maybe this is a issue with mp4 not supporting alpha transparent video, so it fails?

    opened by SomeOrdinaryDude 0
  • A question about src_sm in the model.py

    A question about src_sm in the model.py

    I see "x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])" in mobilenetv3.py and ''' f1, f2, f3, f4 = self.backbone(src_sm) ... hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) ''' in model.py. This means the input src_sm of the decoder has not been normalized. Is that your intention?

    opened by FengMu1995 0
  • Performance using grayscale images

    Performance using grayscale images

    Hey there,

    thanks for this amazing tool! Does anybody know how the performance for grayscale images is? I want to use it in the dark and I have an infrared camera.

    If retraining is required how much GPU hours do you think are necessary?

    :)

    opened by bytosaur 1
  • Slow inference and low GPU use.

    Slow inference and low GPU use.

    The inference.py and its running at ~4.2it/s. It barely loads my RTX2060 (0-13% use) The inference_speed_test script gives me ~33.2it/s on the same model and video settings. Changing the --workers on the convert_video() function did nothing. Am I missing something? How can I run inferences faster using the full hardware potential?

    Thanks.

    opened by sharp-trickster 0
Ensembling Off-the-shelf Models for GAN Training

Vision-aided GAN video (3m) | website | paper Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN t

345 Dec 28, 2022
Marine debris detection with commercial satellite imagery and deep learning.

Marine debris detection with commercial satellite imagery and deep learning. Floating marine debris is a global pollution problem which threatens mari

Inter Agency Implementation and Advanced Concepts 56 Dec 16, 2022
The King is Naked: on the Notion of Robustness for Natural Language Processing

the-king-is-naked: on the notion of robustness for natural language processing AAAI2022 DISCLAIMER:This repo will be updated soon with instructions on

Iperboreo_ 1 Nov 24, 2022
Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Ponder(ing) Transformer Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of

Phil Wang 65 Oct 04, 2022
ProFuzzBench - A Benchmark for Stateful Protocol Fuzzing

ProFuzzBench - A Benchmark for Stateful Protocol Fuzzing ProFuzzBench is a benchmark for stateful fuzzing of network protocols. It includes a suite of

155 Jan 08, 2023
Deep Compression for Dense Point Cloud Maps.

DEPOCO This repository implements the algorithms described in our paper Deep Compression for Dense Point Cloud Maps. How to get started (using Docker)

Photogrammetry & Robotics Bonn 67 Dec 06, 2022
Dynamic Attentive Graph Learning for Image Restoration, ICCV2021 [PyTorch Code]

Dynamic Attentive Graph Learning for Image Restoration This repository is for GATIR introduced in the following paper: Chong Mou, Jian Zhang, Zhuoyuan

Jian Zhang 84 Dec 09, 2022
🤗 Push your spaCy pipelines to the Hugging Face Hub

spacy-huggingface-hub: Push your spaCy pipelines to the Hugging Face Hub This package provides a CLI command for uploading any trained spaCy pipeline

Explosion 30 Oct 09, 2022
CVPR 2021 Official Pytorch Code for UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training

UC2 UC2: Universal Cross-lingual Cross-modal Vision-and-Language Pre-training Mingyang Zhou, Luowei Zhou, Shuohang Wang, Yu Cheng, Linjie Li, Zhou Yu,

Mingyang Zhou 28 Dec 30, 2022
1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

1st place solution to the Satellite Image Change Detection Challenge hosted by SenseTime

Lihe Yang 209 Jan 01, 2023
Neural Motion Learner With Python

Neural Motion Learner Introduction This work is to extract skeletal structure from volumetric observations and to learn motion dynamics from the detec

Jinseok Bae 14 Nov 28, 2022
Simple ray intersection library similar to coldet - succedeed by libacc

Ray Intersection This project offers a header only acceleration structure library including implementations for a BVH- and KD-Tree. Applications may i

Nils Moehrle 29 Jun 23, 2022
Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021)

TDEER (WIP) Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021) Overview TDEER is an e

Alipay 6 Dec 17, 2022
Generative Models as a Data Source for Multiview Representation Learning

GenRep Project Page | Paper Generative Models as a Data Source for Multiview Representation Learning Ali Jahanian, Xavier Puig, Yonglong Tian, Phillip

Ali 81 Dec 03, 2022
automatic color-grading

color-matcher Description color-matcher enables color transfer across images which comes in handy for automatic color-grading of photographs, painting

hahnec 168 Jan 05, 2023
Dynamic Token Normalization Improves Vision Transformers

Dynamic Token Normalization Improves Vision Transformers This is the PyTorch implementation of the paper Dynamic Token Normalization Improves Vision T

Wenqi Shao 20 Oct 09, 2022
Stitch it in Time: GAN-Based Facial Editing of Real Videos

STIT - Stitch it in Time [Project Page] Stitch it in Time: GAN-Based Facial Edit

1.1k Jan 04, 2023
Tensorflow Implementation for "Pre-trained Deep Convolution Neural Network Model With Attention for Speech Emotion Recognition"

Tensorflow Implementation for "Pre-trained Deep Convolution Neural Network Model With Attention for Speech Emotion Recognition" Pre-trained Deep Convo

Ankush Malaker 5 Nov 11, 2022
Official code for "InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization" (ICLR 2020, spotlight)

InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization Authors: Fan-yun Sun, Jordan Hoffm

Fan-Yun Sun 232 Dec 28, 2022
DGN pymarl - Implementation of DGN on Pymarl, which could be trained by VDN or QMIX

This is the implementation of DGN on Pymarl, which could be trained by VDN or QM

4 Nov 23, 2022