PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

Overview

StarEnhancer

StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral)

Abstract: Image enhancement is a subjective process whose targets vary with user preferences. In this paper, we propose a deep learning-based image enhancement method covering multiple tonal styles using only a single model dubbed StarEnhancer. It can transform an image from one tonal style to another, even if that style is unseen. With a simple one-time setting, users can customize the model to make the enhanced images more in line with their aesthetics. To make the method more practical, we propose a well-designed enhancer that can process a 4K-resolution image over 200 FPS but surpasses the contemporaneous single style image enhancement methods in terms of PSNR, SSIM, and LPIPS. Finally, our proposed enhancement method has good interactability, which allows the user to fine-tune the enhanced image using intuitive options.

StarEnhancer

Getting started

Install

We test the code on PyTorch 1.8.1 + CUDA 11.1 + cuDNN 8.0.5, and close versions also work fine.

pip install -r requirements.txt

We mainly train the model on RTX 2080Ti * 4, but a smaller mini batch size can also work.

Prepare

You can generate your own dataset, or download the one we generate.

The final file path should be the same as the following:

┬─ save_model
│   ├─ stylish.pth.tar
│   └─ ... (model & embedding)
└─ data
    ├─ train
    │   ├─ 01-Experts-A
    │   │   ├─ a0001.jpg
    │   │   └─ ... (id.jpg)
    │   └─ ... (style folder)
    ├─ valid
    │   └─ ... (style folder)
    └─ test
        └─ ... (style folder)

Download

Data and pretrained models are available on GoogleDrive.

Generate

  1. Download raw data from MIT-Adobe FiveK Dataset.
  2. Download the modified Lightroom database fivek.lrcat, and replace the original database with it.
  3. Generate dataset in JPEG format with quality 100, which can refer to this issue.
  4. Run generate_dataset.py in data folder to generate dataset.

Train

Firstly, train the style encoder:

python train_stylish.py

Secondly, fetch the style embedding for each sample in the train set:

python fetch_embedding.py

Lastly, train the curve encoder and mapping network:

python train_enhancer.py

Test

Just run:

python test.py

Testing LPIPS requires about 10 GB GPU memory, and if an OOM occurs, replace the following lines

lpips_val = loss_fn_alex(output * 2 - 1, target_img * 2 - 1).item()

with

lpips_val = 0

Notes

Due to agreements, we are unable to release part of the source code. This repository provides a pure python implementation for research use. There are some differences between the repository and the paper as follows:

  1. The repository uses a ResNet-18 w/o BN as the curve encoder's backbone, and the paper uses a more lightweight model.
  2. The paper uses CUDA to implement the color transform function, and the repository uses torch.gather to implement it.
  3. The repository removes some tricks used in training lightweight models.

Overall, this repository can achieve higher performance, but will be slightly slower.

Comments
  • Multi-style, unpaired setting

    Multi-style, unpaired setting

    您好,在多风格非配对图场景,能否交换source和target的位置,并将得到的output_A和output_B进一步经过enhancer,得到recover_A和recover_B。最后计算l1_loss(source, recover_A)和l1_loss(target, recover_B)及Triplet_loss(output_A,target, source) 和 Triplet_loss(output_B,source,target)

    def train(train_loader, mapping, enhancer, criterion, optimizer):
        losses = AverageMeter()
        criterionTriplet = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        FEModel = Feature_Extract_Model().cuda()
    
        mapping.train()
        enhancer.train()
    
        for (source_img, source_center, target_img, target_center) in train_loader:
            source_img = source_img.cuda(non_blocking=True)
            source_center = source_center.cuda(non_blocking=True)
            target_img = target_img.cuda(non_blocking=True)
            target_center = target_center.cuda(non_blocking=True)
    
            style_A = mapping(source_center)
            style_B = mapping(target_center)
    
            output_A = enhancer(source_img, style_A, style_B)
            output_B = enhancer(target_img, style_B, style_A)
            recoverA = enhancer(output_A, style_B, style_A)
            recoverB = enhancer(output_B, style_A, style_B)
    
            source_img_feature = FEModel(source_img)
            target_img_feature = FEModel(target_img)
            output_A_feature = FEModel(output_A)
            output_B_feature = FEModel(output_B)
    
            loss_l1 = criterion(recoverA, source_img) + criterion(recoverB, target_img)
            loss_triplet = criterionTriplet(output_B_feature, source_img_feature, target_img_feature) + \
                           criterionTriplet(output_A_feature, target_img_feature, source_img_feature)
            loss = loss_l1 + loss_triplet
    
            losses.update(loss.item(), args.t_batch_size)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        return losses.avg
    
    opened by jxust01 4
  • Questions about dataset preparation

    Questions about dataset preparation

    您好,我想用您的工程跑一下自己的数据,现在有输入,输出一组数据对,训练数据里面A-E剩下的4种效果是怎样生成的呢,这些目标效果数据能否是非成对的呢?如果只有一种风格,能否A-E目标效果都拷贝成一样的数据呢,在train_enhancer.py所训练的单风格脚本是需要embeddings.npy文件,这个文件在单风格训练时是必须的吗

    opened by zener90818 4
  • Dataset processing

    Dataset processing

    你好,我在您提供的fivek.lrcat没找到 DeepUPE issue里的"(default) input with ExpertC"。请问单风格实验的输入是下图中的“InputAsShotZeroed”还是“(Q)InputZeroed with ExpertC WhiteBalance” image

    opened by madfff 2
  • Configure Renovate

    Configure Renovate

    WhiteSource Renovate

    Welcome to Renovate! This is an onboarding PR to help you understand and configure settings before regular Pull Requests begin.

    🚦 To activate Renovate, merge this Pull Request. To disable Renovate, simply close this Pull Request unmerged.


    Detected Package Files

    • requirements.txt (pip_requirements)

    Configuration Summary

    Based on the default config's presets, Renovate will:

    • Start dependency updates only once this onboarding PR is merged
    • Enable Renovate Dependency Dashboard creation
    • If semantic commits detected, use semantic commit type fix for dependencies and chore for all others
    • Ignore node_modules, bower_components, vendor and various test/tests directories
    • Autodetect whether to pin dependencies or maintain ranges
    • Rate limit PR creation to a maximum of two per hour
    • Limit to maximum 20 open PRs at any time
    • Group known monorepo packages together
    • Use curated list of recommended non-monorepo package groupings
    • Fix some problems with very old Maven commons versions
    • Ignore spring cloud 1.x releases
    • Ignore http4s digest-based 1.x milestones
    • Use node versioning for @types/node
    • Limit concurrent requests to reduce load on Repology servers until we can fix this properly, see issue 10133

    🔡 Would you like to change the way Renovate is upgrading your dependencies? Simply edit the renovate.json in this branch with your custom config and the list of Pull Requests in the "What to Expect" section below will be updated the next time Renovate runs.


    What to Expect

    With your current configuration, Renovate will create 1 Pull Request:

    Pin dependency torch to ==1.10.0
    • Schedule: ["at any time"]
    • Branch name: renovate/pin-dependencies
    • Merge into: main
    • Pin torch to ==1.10.0

    ❓ Got questions? Check out Renovate's Docs, particularly the Getting Started section. If you need any further assistance then you can also request help here.


    This PR has been generated by WhiteSource Renovate. View repository job log here.

    opened by renovate[bot] 1
  • The results are not the same as the paper

    The results are not the same as the paper

    I am the author.

    Some peers have emailed me asking about the performance of the open source model that does not agree with the results in the paper. As stated in the README, the model is not the model of the paper, but the performance is similar. The exact result should be: PSNR: 25.41, SSIM: 0.942, LPIPS: 0.085

    If you find that your result is not this, then it may be that the JPEG codec is different, which is related to the version of opencv and how it is installed.

    You can uninstall your opencv (either with pip or conda) and reinstall it using pip (it must be pip, because conda installs a different JPEG codec):

    pip install opencv-python==4.5.5.62​
    
    opened by IDKiro 0
Owner
IDKiro
Stroll in the abyss
IDKiro
TagLab: an image segmentation tool oriented to marine data analysis

TagLab: an image segmentation tool oriented to marine data analysis TagLab was created to support the activity of annotation and extraction of statist

Visual Computing Lab - ISTI - CNR 49 Dec 29, 2022
Uncertain natural language inference

Uncertain Natural Language Inference This repository hosts the code for the following paper: Tongfei Chen*, Zhengping Jiang*, Adam Poliak, Keisuke Sak

Tongfei Chen 14 Sep 01, 2022
Code for Private Recommender Systems: How Can Users Build Their Own Fair Recommender Systems without Log Data? (SDM 2022)

Private Recommender Systems: How Can Users Build Their Own Fair Recommender Systems without Log Data? (SDM 2022) We consider how a user of a web servi

joisino 20 Aug 21, 2022
PyTorch code for JEREX: Joint Entity-Level Relation Extractor

JEREX: "Joint Entity-Level Relation Extractor" PyTorch code for JEREX: "Joint Entity-Level Relation Extractor". For a description of the model and exp

LAVIS - NLP Working Group 50 Dec 01, 2022
Official code of the paper "Expanding Low-Density Latent Regions for Open-Set Object Detection" (CVPR 2022)

OpenDet Expanding Low-Density Latent Regions for Open-Set Object Detection (CVPR2022) Jiaming Han, Yuqiang Ren, Jian Ding, Xingjia Pan, Ke Yan, Gui-So

csuhan 64 Jan 07, 2023
TorchDistiller - a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and instance segmentation.

This project is a collection of the open source pytorch code for knowledge distillation, especially for the perception tasks, including semantic segmentation, depth estimation, object detection and i

yifan liu 147 Dec 03, 2022
imbalanced-DL: Deep Imbalanced Learning in Python

imbalanced-DL: Deep Imbalanced Learning in Python Overview imbalanced-DL (imported as imbalanceddl) is a Python package designed to make deep imbalanc

NTUCSIE CLLab 19 Dec 28, 2022
Generative Models as a Data Source for Multiview Representation Learning

GenRep Project Page | Paper Generative Models as a Data Source for Multiview Representation Learning Ali Jahanian, Xavier Puig, Yonglong Tian, Phillip

Ali 81 Dec 03, 2022
Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021

DIFFNet This repo is for Self-Supervised Monocular Depth Estimation with Internal Feature Fusion(arXiv), BMVC2021 A new backbone for self-supervised d

Hang 94 Dec 25, 2022
ArtEmis: Affective Language for Art

ArtEmis: Affective Language for Art Created by Panos Achlioptas, Maks Ovsjanikov, Kilichbek Haydarov, Mohamed Elhoseiny, Leonidas J. Guibas Introducti

Panos 268 Dec 12, 2022
🦕 NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano

🦕 nanosaur NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano Website: nanosaur.ai Do you need an help? Discord For tech

NanoSaur 162 Dec 09, 2022
A Python Library for Graph Outlier Detection (Anomaly Detection)

PyGOD is a Python library for graph outlier detection (anomaly detection). This exciting yet challenging field has many key applications, e.g., detect

PyGOD Team 757 Jan 04, 2023
Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Start-to-finish tutorial for interactive music co-creation in PyTorch and Tensorflow.js

Chris Donahue 98 Dec 14, 2022
Romanian Automatic Speech Recognition from the ROBIN project

RobinASR This repository contains Robin's Automatic Speech Recognition (RobinASR) for the Romanian language based on the DeepSpeech2 architecture, tog

RACAI 10 Jan 01, 2023
Anchor-free Oriented Proposal Generator for Object Detection

Anchor-free Oriented Proposal Generator for Object Detection Gong Cheng, Jiabao Wang, Ke Li, Xingxing Xie, Chunbo Lang, Yanqing Yao, Junwei Han, Intro

jbwang1997 56 Nov 15, 2022
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023
Lightweight Salient Object Detection in Optical Remote Sensing Images via Feature Correlation

CorrNet This project provides the code and results for 'Lightweight Salient Object Detection in Optical Remote Sensing Images via Feature Correlation'

Gongyang Li 13 Nov 03, 2022
Deep Two-View Structure-from-Motion Revisited

Deep Two-View Structure-from-Motion Revisited This repository provides the code for our CVPR 2021 paper Deep Two-View Structure-from-Motion Revisited.

Jianyuan Wang 145 Jan 06, 2023
Code of TIP2021 Paper《SFace: Sigmoid-Constrained Hypersphere Loss for Robust Face Recognition》. We provide both MxNet and Pytorch versions.

SFace Code of TIP2021 Paper 《SFace: Sigmoid-Constrained Hypersphere Loss for Robust Face Recognition》. We provide both MxNet, PyTorch and Jittor versi

Zhong Yaoyao 47 Nov 25, 2022
An alarm clock coded in Python 3 with Tkinter

Tkinter-Alarm-Clock An alarm clock coded in Python 3 with Tkinter. Run python3 Tkinter Alarm Clock.py in a terminal if you have Python 3. NOTE: This p

CodeMaster7000 1 Dec 25, 2021