PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

Overview

StarEnhancer

StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral)

Abstract: Image enhancement is a subjective process whose targets vary with user preferences. In this paper, we propose a deep learning-based image enhancement method covering multiple tonal styles using only a single model dubbed StarEnhancer. It can transform an image from one tonal style to another, even if that style is unseen. With a simple one-time setting, users can customize the model to make the enhanced images more in line with their aesthetics. To make the method more practical, we propose a well-designed enhancer that can process a 4K-resolution image over 200 FPS but surpasses the contemporaneous single style image enhancement methods in terms of PSNR, SSIM, and LPIPS. Finally, our proposed enhancement method has good interactability, which allows the user to fine-tune the enhanced image using intuitive options.

StarEnhancer

Getting started

Install

We test the code on PyTorch 1.8.1 + CUDA 11.1 + cuDNN 8.0.5, and close versions also work fine.

pip install -r requirements.txt

We mainly train the model on RTX 2080Ti * 4, but a smaller mini batch size can also work.

Prepare

You can generate your own dataset, or download the one we generate.

The final file path should be the same as the following:

┬─ save_model
│   ├─ stylish.pth.tar
│   └─ ... (model & embedding)
└─ data
    ├─ train
    │   ├─ 01-Experts-A
    │   │   ├─ a0001.jpg
    │   │   └─ ... (id.jpg)
    │   └─ ... (style folder)
    ├─ valid
    │   └─ ... (style folder)
    └─ test
        └─ ... (style folder)

Download

Data and pretrained models are available on GoogleDrive.

Generate

  1. Download raw data from MIT-Adobe FiveK Dataset.
  2. Download the modified Lightroom database fivek.lrcat, and replace the original database with it.
  3. Generate dataset in JPEG format with quality 100, which can refer to this issue.
  4. Run generate_dataset.py in data folder to generate dataset.

Train

Firstly, train the style encoder:

python train_stylish.py

Secondly, fetch the style embedding for each sample in the train set:

python fetch_embedding.py

Lastly, train the curve encoder and mapping network:

python train_enhancer.py

Test

Just run:

python test.py

Testing LPIPS requires about 10 GB GPU memory, and if an OOM occurs, replace the following lines

lpips_val = loss_fn_alex(output * 2 - 1, target_img * 2 - 1).item()

with

lpips_val = 0

Notes

Due to agreements, we are unable to release part of the source code. This repository provides a pure python implementation for research use. There are some differences between the repository and the paper as follows:

  1. The repository uses a ResNet-18 w/o BN as the curve encoder's backbone, and the paper uses a more lightweight model.
  2. The paper uses CUDA to implement the color transform function, and the repository uses torch.gather to implement it.
  3. The repository removes some tricks used in training lightweight models.

Overall, this repository can achieve higher performance, but will be slightly slower.

Comments
  • Multi-style, unpaired setting

    Multi-style, unpaired setting

    您好,在多风格非配对图场景,能否交换source和target的位置,并将得到的output_A和output_B进一步经过enhancer,得到recover_A和recover_B。最后计算l1_loss(source, recover_A)和l1_loss(target, recover_B)及Triplet_loss(output_A,target, source) 和 Triplet_loss(output_B,source,target)

    def train(train_loader, mapping, enhancer, criterion, optimizer):
        losses = AverageMeter()
        criterionTriplet = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        FEModel = Feature_Extract_Model().cuda()
    
        mapping.train()
        enhancer.train()
    
        for (source_img, source_center, target_img, target_center) in train_loader:
            source_img = source_img.cuda(non_blocking=True)
            source_center = source_center.cuda(non_blocking=True)
            target_img = target_img.cuda(non_blocking=True)
            target_center = target_center.cuda(non_blocking=True)
    
            style_A = mapping(source_center)
            style_B = mapping(target_center)
    
            output_A = enhancer(source_img, style_A, style_B)
            output_B = enhancer(target_img, style_B, style_A)
            recoverA = enhancer(output_A, style_B, style_A)
            recoverB = enhancer(output_B, style_A, style_B)
    
            source_img_feature = FEModel(source_img)
            target_img_feature = FEModel(target_img)
            output_A_feature = FEModel(output_A)
            output_B_feature = FEModel(output_B)
    
            loss_l1 = criterion(recoverA, source_img) + criterion(recoverB, target_img)
            loss_triplet = criterionTriplet(output_B_feature, source_img_feature, target_img_feature) + \
                           criterionTriplet(output_A_feature, target_img_feature, source_img_feature)
            loss = loss_l1 + loss_triplet
    
            losses.update(loss.item(), args.t_batch_size)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        return losses.avg
    
    opened by jxust01 4
  • Questions about dataset preparation

    Questions about dataset preparation

    您好,我想用您的工程跑一下自己的数据,现在有输入,输出一组数据对,训练数据里面A-E剩下的4种效果是怎样生成的呢,这些目标效果数据能否是非成对的呢?如果只有一种风格,能否A-E目标效果都拷贝成一样的数据呢,在train_enhancer.py所训练的单风格脚本是需要embeddings.npy文件,这个文件在单风格训练时是必须的吗

    opened by zener90818 4
  • Dataset processing

    Dataset processing

    你好,我在您提供的fivek.lrcat没找到 DeepUPE issue里的"(default) input with ExpertC"。请问单风格实验的输入是下图中的“InputAsShotZeroed”还是“(Q)InputZeroed with ExpertC WhiteBalance” image

    opened by madfff 2
  • Configure Renovate

    Configure Renovate

    WhiteSource Renovate

    Welcome to Renovate! This is an onboarding PR to help you understand and configure settings before regular Pull Requests begin.

    🚦 To activate Renovate, merge this Pull Request. To disable Renovate, simply close this Pull Request unmerged.


    Detected Package Files

    • requirements.txt (pip_requirements)

    Configuration Summary

    Based on the default config's presets, Renovate will:

    • Start dependency updates only once this onboarding PR is merged
    • Enable Renovate Dependency Dashboard creation
    • If semantic commits detected, use semantic commit type fix for dependencies and chore for all others
    • Ignore node_modules, bower_components, vendor and various test/tests directories
    • Autodetect whether to pin dependencies or maintain ranges
    • Rate limit PR creation to a maximum of two per hour
    • Limit to maximum 20 open PRs at any time
    • Group known monorepo packages together
    • Use curated list of recommended non-monorepo package groupings
    • Fix some problems with very old Maven commons versions
    • Ignore spring cloud 1.x releases
    • Ignore http4s digest-based 1.x milestones
    • Use node versioning for @types/node
    • Limit concurrent requests to reduce load on Repology servers until we can fix this properly, see issue 10133

    🔡 Would you like to change the way Renovate is upgrading your dependencies? Simply edit the renovate.json in this branch with your custom config and the list of Pull Requests in the "What to Expect" section below will be updated the next time Renovate runs.


    What to Expect

    With your current configuration, Renovate will create 1 Pull Request:

    Pin dependency torch to ==1.10.0
    • Schedule: ["at any time"]
    • Branch name: renovate/pin-dependencies
    • Merge into: main
    • Pin torch to ==1.10.0

    ❓ Got questions? Check out Renovate's Docs, particularly the Getting Started section. If you need any further assistance then you can also request help here.


    This PR has been generated by WhiteSource Renovate. View repository job log here.

    opened by renovate[bot] 1
  • The results are not the same as the paper

    The results are not the same as the paper

    I am the author.

    Some peers have emailed me asking about the performance of the open source model that does not agree with the results in the paper. As stated in the README, the model is not the model of the paper, but the performance is similar. The exact result should be: PSNR: 25.41, SSIM: 0.942, LPIPS: 0.085

    If you find that your result is not this, then it may be that the JPEG codec is different, which is related to the version of opencv and how it is installed.

    You can uninstall your opencv (either with pip or conda) and reinstall it using pip (it must be pip, because conda installs a different JPEG codec):

    pip install opencv-python==4.5.5.62​
    
    opened by IDKiro 0
Owner
IDKiro
Stroll in the abyss
IDKiro
Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Learning Causal Semantic Representation for Out-of-Distribution Prediction This repository is the official implementation of "Learning Causal Semantic

Chang Liu 54 Dec 01, 2022
This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger Bands to create a projected active liquidity range.

Gamma's Strategy One This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger

Gamma Strategies 46 Dec 02, 2022
FastReID is a research platform that implements state-of-the-art re-identification algorithms.

FastReID is a research platform that implements state-of-the-art re-identification algorithms.

JDAI-CV 2.8k Jan 07, 2023
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
Automatic meme generation model using Tensorflow Keras.

Memefly You can find the project at MemeflyAI. Contributors Nick Buukhalter Harsh Desai Han Lee Project Overview Trello Board Product Canvas Automatic

BloomTech Labs 2 Jan 13, 2022
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
AsymmetricGAN - Dual Generator Generative Adversarial Networks for Multi-Domain Image-to-Image Translation

AsymmetricGAN for Image-to-Image Translation AsymmetricGAN Framework for Multi-Domain Image-to-Image Translation AsymmetricGAN Framework for Hand Gest

Hao Tang 42 Jan 15, 2022
g2o: A General Framework for Graph Optimization

g2o - General Graph Optimization Linux: Windows: g2o is an open-source C++ framework for optimizing graph-based nonlinear error functions. g2o has bee

Rainer Kümmerle 2.5k Dec 30, 2022
A dual benchmarking study of visual forgery and visual forensics techniques

A dual benchmarking study of facial forgery and facial forensics In recent years, visual forgery has reached a level of sophistication that humans can

8 Jul 06, 2022
Neural Cellular Automata + CLIP

🧠 Text-2-Cellular Automata Using Neural Cellular Automata + OpenAI CLIP (Work in progress) Examples Text Prompt: Cthulu is watching cthulu_is_watchin

Mainak Deb 21 Dec 19, 2022
An official repository for Paper "Uformer: A General U-Shaped Transformer for Image Restoration".

Uformer: A General U-Shaped Transformer for Image Restoration Zhendong Wang, Xiaodong Cun, Jianmin Bao and Jianzhuang Liu Paper: https://arxiv.org/abs

Zhendong Wang 497 Dec 22, 2022
Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

VIN: Value Iteration Networks A quick thank you A few others have released amazing related work which helped inspire and improve my own implementation

Kent Sommer 297 Dec 26, 2022
PerfFuzz: Automatically Generate Pathological Inputs for C/C++ programs

PerfFuzz Performance problems in software can arise unexpectedly when programs are provided with inputs that exhibit pathological behavior. But how ca

Caroline Lemieux 125 Nov 18, 2022
Functional TensorFlow Implementation of Singular Value Decomposition for paper Fast Graph Learning

tf-fsvd TensorFlow Implementation of Functional Singular Value Decomposition for paper Fast Graph Learning with Unique Optimal Solutions Cite If you f

Sami Abu-El-Haija 14 Nov 25, 2021
Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting

Autoformer (NeurIPS 2021) Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting Time series forecasting is a c

THUML @ Tsinghua University 847 Jan 08, 2023
Context-Aware Image Matting for Simultaneous Foreground and Alpha Estimation

Context-Aware Image Matting for Simultaneous Foreground and Alpha Estimation This is the inference codes of Context-Aware Image Matting for Simultaneo

Qiqi Hou 125 Oct 22, 2022
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022
Six - a Python 2 and 3 compatibility library

Six is a Python 2 and 3 compatibility library. It provides utility functions for smoothing over the differences between the Python versions with the g

Benjamin Peterson 919 Dec 28, 2022
C3DPO - Canonical 3D Pose Networks for Non-rigid Structure From Motion.

C3DPO: Canonical 3D Pose Networks for Non-Rigid Structure From Motion By: David Novotny, Nikhila Ravi, Benjamin Graham, Natalia Neverova, Andrea Vedal

Meta Research 309 Dec 16, 2022
DeepFaceEditing: Deep Face Generation and Editing with Disentangled Geometry and Appearance Control

DeepFaceEditing: Deep Face Generation and Editing with Disentangled Geometry and Appearance Control One version of our system is implemented using the

260 Nov 28, 2022