Cross-Modal Contrastive Learning for Text-to-Image Generation

Overview

Cross-Modal Contrastive Learning for Text-to-Image Generation

This repository hosts the open source JAX implementation of XMC-GAN.

Setup instructions

Environment

Set up virtualenv, and install required libraries:

virtualenv venv
source venv/bin/activate

Add the XMC-GAN library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/xmcgan/root/

JAX Installation

Note: Please follow the official JAX instructions for installing a GPU compatible version of JAX.

Other Dependencies

After installing JAX, install the remaining dependencies with:

pip install -r requirements.txt

Preprocess COCO-2014

To create the training and eval data, first start a directory. By default, the training scripts expect to save results in data/ in the base directory.

mkdir data/

The TFRecords required for training and validation on COCO-2014 can be created by running a preprocessing script over the TFDS coco_captions dataset:

python preprocess_data.py

This may take a while to complete, as it runs a pretrained BERT model over the captions and stores the embeddings. With a GPU, it runs in about 2.5 hours for train, and 1 hour for validation. Once it is done, the train and validation tfrecords files will be saved in the data/ directory. The train files require around 58G of disk space, and the validation requires 29G.

Note: If you run into an error related to TensorFlow gfile, one workaround is to edit site-packages/bert/tokenization.py and change tf.gfile.GFile to tf.io.gfile.GFile. For more details, refer to the following link.

If you run into a tensorflow.python.framework.errors_impl.ResourceExhaustedError about having too many open files, you may have to increase the machine's open file limits. To do so, open the limit configuration file for editing:

vi /etc/security/limits.conf

and append the following lines to the end of the file:

*         hard    nofile      500000
*         soft    nofile      500000
root      hard    nofile      500000
root      soft    nofile      500000

You may have to adjust the limit values depending on your machine. You will need to logout and login to your machine for these values to take effect.

Download Pretrained ResNet

To train XMC-GAN, we need a network pretrained on ImageNet to extract features. For our purposes, we train a ResNet-50 network for this. To download the weights, run:

gsutil cp gs://gresearch/xmcgan/resnet_pretrained.npy data/

If you would like to pretrain your own network on ImageNet, please refer to the official Flax ImageNet example.

Training

Start a training run, by first editing train.sh to specify an appropriate work directory. By default, the script assumes that 8 GPUs are available, and runs training on the first 7 GPUs, while test.sh assumes testing will run on the last GPU. After configuring the training job, start an experiment by running it on bash:

mkdir exp
bash train.sh exp_name &> train.txt

Checkpoints and Tensorboard logs will be saved in /path/to/exp/exp_name. By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights. To reproduce the full results on 256px images in our paper, the full model needs to be run using a 32-core Pod slice of Google Cloud TPU v3 devices.

Evaluation

To run an evaluation job, update test.sh with the correct settings used in the training script. Then, execute

bash test.sh exp_name &> eval.txt

to start an evaluation job. All checkpoints in workdir will be evaluated for FID and Inception Score. If you can spare the GPUs, you can also run train.sh and test.sh in parallel, which will continuously evaluate new checkpoints saved into the work directory. Scores will be written to Tensorboard and output to eval.txt.

Tensorboard

To start a Tensorboard for monitoring training progress, run:

tensorboard --logdir /path/to/exp/exp_name

Citation

If you find this work useful, please consider citing:

@inproceedings{zhang2021cross,
  title={Cross-Modal Contrastive Learning for Text-to-Image Generation},
  author={Zhang, Han and Koh, Jing Yu and Baldridge, Jason and Lee, Honglak and Yang, Yinfei},
  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Disclaimer

Not an official Google product.

Owner
Google Research
Google Research
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
Python interface for SmartRF Sniffer 2 Firmware

#TI SmartRF Packet Sniffer 2 Python Interface TI Makes available a nice packet sniffer firmware, which interfaces to Wireshark. You can see this proje

Colin O'Flynn 3 May 18, 2021
Official implementation of "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers"

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Figure 1: Performance of SegFormer-B0 to SegFormer-B5. Project page

NVIDIA Research Projects 1.4k Dec 31, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Project page of the paper 'Analyzing Perception-Distortion Tradeoff using Enhanced Perceptual Super-resolution Network' (ECCVW 2018)

EPSR (Enhanced Perceptual Super-resolution Network) paper This repo provides the test code, pretrained models, and results on benchmark datasets of ou

Subeesh Vasu 78 Nov 19, 2022
PyTorch-Multi-Style-Transfer - Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 906 Jan 04, 2023
Recurrent Variational Autoencoder that generates sequential data implemented with pytorch

Pytorch Recurrent Variational Autoencoder Model: This is the implementation of Samuel Bowman's Generating Sentences from a Continuous Space with Kim's

Daniil Gavrilov 347 Nov 14, 2022
Explainer for black box models that predict molecule properties

Explaining why that molecule exmol is a package to explain black-box predictions of molecules. The package uses model agnostic explanations to help us

White Laboratory 172 Dec 19, 2022
Resources for the Ki testnet challenge

Ki Testnet Challenge This repository hosts ki-testnet-challenge. A set of scripts and resources to be used for the Ki Testnet Challenge What is the te

Ki Foundation 23 Aug 08, 2022
The official repository for "Score Transformer: Generating Musical Scores from Note-level Representation" (MMAsia '21)

Score Transformer This is the official repository for "Score Transformer": Score Transformer: Generating Musical Scores from Note-level Representation

22 Dec 22, 2022
IJCAI2020 & IJCV 2020 :city_sunrise: Unsupervised Scene Adaptation with Memory Regularization in vivo

Seg_Uncertainty In this repo, we provide the code for the two papers, i.e., MRNet:Unsupervised Scene Adaptation with Memory Regularization in vivo, IJ

Zhedong Zheng 348 Jan 05, 2023
Answering Open-Domain Questions of Varying Reasoning Steps from Text

This repository contains the authors' implementation of the Iterative Retriever, Reader, and Reranker (IRRR) model in the EMNLP 2021 paper "Answering Open-Domain Questions of Varying Reasoning Steps

26 Dec 22, 2022
NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

100 Sep 28, 2022
Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )

Differential Privacy (DP) Based Federated Learning (FL) Everything about DP-based FL you need is here. (所有你需要的DP-based FL的信息都在这里) Code Tip: the code o

wenzhu 83 Dec 24, 2022
Deep Compression for Dense Point Cloud Maps.

DEPOCO This repository implements the algorithms described in our paper Deep Compression for Dense Point Cloud Maps. How to get started (using Docker)

Photogrammetry & Robotics Bonn 67 Dec 06, 2022
Easy to use Python camera interface for NVIDIA Jetson

JetCam JetCam is an easy to use Python camera interface for NVIDIA Jetson. Works with various USB and CSI cameras using Jetson's Accelerated GStreamer

NVIDIA AI IOT 358 Jan 02, 2023
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
GNN-based Recommendation Benchma

GRecX A Fair Benchmark for GNN-based Recommendation Preliminary Comparison DiffNet-Yelp dataset (featureless) Algo 73 Oct 17, 2022

An OpenAI-Gym Package for Training and Testing Reinforcement Learning algorithms with OpenSim Models

Authors: Utkarsh A. Mishra and Dr. Dimitar Stanev Advisors: Dr. Dimitar Stanev and Prof. Auke Ijspeert, Biorobotics Laboratory (BioRob), EPFL Video Pl

Utkarsh Mishra 16 Dec 13, 2022
Reproduced Code for Image Forgery Detection papers.

Image Forgery Detection With over 4.5 billion active internet users, the amount of multimedia content being shared every day has surpassed everyone’s

Umar Masud 15 Dec 06, 2022