PConv-Keras - Unofficial implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions". Try at: www.fixmyphoto.ai

Overview

Partial Convolutions for Image Inpainting using Keras

Keras implementation of "Image Inpainting for Irregular Holes Using Partial Convolutions", https://arxiv.org/abs/1804.07723. A huge shoutout the authors Guilin Liu, Fitsum A. Reda, Kevin J. Shih, Ting-Chun Wang, Andrew Tao and Bryan Catanzaro from NVIDIA corporation for releasing this awesome paper, it's been a great learning experience for me to implement the architecture, the partial convolutional layer, and the loss functions.

Dependencies

  • Python 3.6
  • Keras 2.2.4
  • Tensorflow 1.12

How to use this repository

The easiest way to try a few predictions with this algorithm is to go to www.fixmyphoto.ai, where I've deployed it on a serverless React application with AWS lambda functions handling inference.

If you want to dig into the code, the primary implementations of the new PConv2D keras layer as well as the UNet-like architecture using these partial convolutional layers can be found in libs/pconv_layer.py and libs/pconv_model.py, respectively - this is where the bulk of the implementation can be found. Beyond this I've set up four jupyter notebooks, which details the several steps I went through while implementing the network, namely:

Step 1: Creating random irregular masks
Step 2: Implementing and testing the implementation of the PConv2D layer
Step 3: Implementing and testing the UNet architecture with PConv2D layers
Step 4: Training & testing the final architecture on ImageNet
Step 5: Simplistic attempt at predicting arbitrary image sizes through image chunking

Pre-trained weights

I've ported the VGG16 weights from PyTorch to keras; this means the 1/255. pixel scaling can be used for the VGG16 network similarly to PyTorch.

Training on your own dataset

You can either go directly to step 4 notebook, or alternatively use the CLI (make sure to download the converted VGG16 weights):

python main.py \
    --name MyDataset \
    --train TRAINING_PATH \
    --validation VALIDATION_PATH \
    --test TEST_PATH \
    --vgg_path './data/logs/pytorch_to_keras_vgg16.h5'

Implementation details

Details of the implementation are in the paper itself, however I'll try to summarize some details here.

Mask Creation

In the paper they use a technique based on occlusion/dis-occlusion between two consecutive frames in videos for creating random irregular masks - instead I've opted for simply creating a simple mask-generator function which uses OpenCV to draw some random irregular shapes which I then use for masks. Plugging in a new mask generation technique later should not be a problem though, and I think the end results are pretty decent using this method as well.

Partial Convolution Layer

A key element in this implementation is the partial convolutional layer. Basically, given the convolutional filter W and the corresponding bias b, the following partial convolution is applied instead of a normal convolution:

where ⊙ is element-wise multiplication and M is a binary mask of 0s and 1s. Importantly, after each partial convolution, the mask is also updated, so that if the convolution was able to condition its output on at least one valid input, then the mask is removed at that location, i.e.

The result of this is that with a sufficiently deep network, the mask will eventually be all ones (i.e. disappear)

UNet Architecture

Specific details of the architecture can be found in the paper, but essentially it's based on a UNet-like structure, where all normal convolutional layers are replace with partial convolutional layers, such that in all cases the image is passed through the network alongside the mask. The following provides an overview of the architecture.

Loss Function(s)

The loss function used in the paper is kinda intense, and can be reviewed in the paper. In short it includes:

  • Per-pixel losses both for maskes and un-masked regions
  • Perceptual loss based on ImageNet pre-trained VGG-16 (pool1, pool2 and pool3 layers)
  • Style loss on VGG-16 features both for predicted image and for computed image (non-hole pixel set to ground truth)
  • Total variation loss for a 1-pixel dilation of the hole region

The weighting of all these loss terms are as follows:

Training Procedure

Network was trained on ImageNet with a batch size of 1, and each epoch was specified to be 10,000 batches long. Training was furthermore performed using the Adam optimizer in two stages since batch normalization presents an issue for the masked convolutions (since mean and variance is calculated for hole pixels).

Stage 1 Learning rate of 0.0001 for 50 epochs with batch normalization enabled in all layers

Stage 2 Learning rate of 0.00005 for 50 epochs where batch normalization in all encoding layers is disabled.

Training time for shown images was absolutely crazy long, but that is likely because of my poor personal setup. The few tests I've tried on a 1080Ti (with batch size of 4) indicates that training time could be around 10 days, as specified in the paper.

Owner
Mathias Gruber
Chief Data Scientist
Mathias Gruber
Tensorforce: a TensorFlow library for applied reinforcement learning

Tensorforce: a TensorFlow library for applied reinforcement learning Introduction Tensorforce is an open-source deep reinforcement learning framework,

Tensorforce 3.2k Jan 02, 2023
DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference

DeeBERT This is the code base for the paper DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference. Code in this repository is also available

Castorini 132 Nov 14, 2022
Learning-based agent for Google Research Football

TiKick 1.Introduction Learning-based agent for Google Research Football Code accompanying the paper "TiKick: Towards Playing Multi-agent Football Full

Tsinghua AI Research Team for Reinforcement Learning 90 Dec 26, 2022
CVPR 2021: "The Spatially-Correlative Loss for Various Image Translation Tasks"

Spatially-Correlative Loss arXiv | website We provide the Pytorch implementation of "The Spatially-Correlative Loss for Various Image Translation Task

Chuanxia Zheng 89 Jan 04, 2023
Weakly Supervised End-to-End Learning (NeurIPS 2021)

WeaSEL: Weakly Supervised End-to-end Learning This is a PyTorch-Lightning-based framework, based on our End-to-End Weak Supervision paper (NeurIPS 202

Auton Lab, Carnegie Mellon University 131 Jan 06, 2023
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Sanja Fidler's Lab 52 Nov 22, 2022
KwaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%)

KuaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%) KuaiRec is a real-world dataset collected from the recommendation log

Chongming GAO (高崇铭) 70 Dec 28, 2022
A PyTorch toolkit for 2D Human Pose Estimation.

PyTorch-Pose PyTorch-Pose is a PyTorch implementation of the general pipeline for 2D single human pose estimation. The aim is to provide the interface

Wei Yang 1.1k Dec 30, 2022
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
DimReductionClustering - Dimensionality Reduction + Clustering + Unsupervised Score Metrics

Dimensionality Reduction + Clustering + Unsupervised Score Metrics Introduction

11 Nov 15, 2022
Python Implementation of algorithms in Graph Mining, e.g., Recommendation, Collaborative Filtering, Community Detection, Spectral Clustering, Modularity Maximization, co-authorship networks.

Graph Mining Author: Jiayi Chen Time: April 2021 Implemented Algorithms: Network: Scrabing Data, Network Construbtion and Network Measurement (e.g., P

Jiayi Chen 3 Mar 03, 2022
Code & Models for Temporal Segment Networks (TSN) in ECCV 2016

Temporal Segment Networks (TSN) We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation fo

1.4k Jan 01, 2023
PyTorch implementation of ICLR 2022 paper PiCO: Contrastive Label Disambiguation for Partial Label Learning

PiCO: Contrastive Label Disambiguation for Partial Label Learning This is a PyTorch implementation of ICLR 2022 Oral paper PiCO; also see our Project

王皓波 147 Jan 07, 2023
S2s2net - Sentinel-2 Super-Resolution Segmentation Network

S2S2Net Sentinel-2 Super-Resolution Segmentation Network Getting started Install

Wei Ji 10 Nov 10, 2022
Generic Foreground Segmentation in Images

Pixel Objectness The following repository contains pretrained model for pixel objectness. Please visit our project page for the paper and visual resul

Suyog Jain 157 Nov 21, 2022
Framework to build and train RL algorithms

RayLink RayLink is a RL framework used to build and train RL algorithms. RayLink was used to build a RL framework, and tested in a large-scale multi-a

Bytedance Inc. 32 Oct 07, 2022
This repo contains the code and data used in the paper "Wizard of Search Engine: Access to Information Through Conversations with Search Engines"

Wizard of Search Engine: Access to Information Through Conversations with Search Engines by Pengjie Ren, Zhongkun Liu, Xiaomeng Song, Hongtao Tian, Zh

19 Oct 27, 2022
MAME is a multi-purpose emulation framework.

MAME's purpose is to preserve decades of software history. As electronic technology continues to rush forward, MAME prevents this important "vintage" software from being lost and forgotten.

Michael Murray 6 Oct 25, 2020
Efficient Householder transformation in PyTorch

Efficient Householder Transformation in PyTorch This repository implements the Householder transformation algorithm for calculating orthogonal matrice

Anton Obukhov 49 Nov 20, 2022
State-Relabeling Adversarial Active Learning

State-Relabeling Adversarial Active Learning Code for SRAAL [2020 CVPR Oral] Requirements torch = 1.6.0 numpy = 1.19.1 tqdm = 4.31.1 AL Results The

10 Jul 14, 2022