Cross-Modal Contrastive Learning for Text-to-Image Generation

Overview

Cross-Modal Contrastive Learning for Text-to-Image Generation

This repository hosts the open source JAX implementation of XMC-GAN.

Setup instructions

Environment

Set up virtualenv, and install required libraries:

virtualenv venv
source venv/bin/activate

Add the XMC-GAN library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/xmcgan/root/

JAX Installation

Note: Please follow the official JAX instructions for installing a GPU compatible version of JAX.

Other Dependencies

After installing JAX, install the remaining dependencies with:

pip install -r requirements.txt

Preprocess COCO-2014

To create the training and eval data, first start a directory. By default, the training scripts expect to save results in data/ in the base directory.

mkdir data/

The TFRecords required for training and validation on COCO-2014 can be created by running a preprocessing script over the TFDS coco_captions dataset:

python preprocess_data.py

This may take a while to complete, as it runs a pretrained BERT model over the captions and stores the embeddings. With a GPU, it runs in about 2.5 hours for train, and 1 hour for validation. Once it is done, the train and validation tfrecords files will be saved in the data/ directory. The train files require around 58G of disk space, and the validation requires 29G.

Note: If you run into an error related to TensorFlow gfile, one workaround is to edit site-packages/bert/tokenization.py and change tf.gfile.GFile to tf.io.gfile.GFile. For more details, refer to the following link.

If you run into a tensorflow.python.framework.errors_impl.ResourceExhaustedError about having too many open files, you may have to increase the machine's open file limits. To do so, open the limit configuration file for editing:

vi /etc/security/limits.conf

and append the following lines to the end of the file:

*         hard    nofile      500000
*         soft    nofile      500000
root      hard    nofile      500000
root      soft    nofile      500000

You may have to adjust the limit values depending on your machine. You will need to logout and login to your machine for these values to take effect.

Download Pretrained ResNet

To train XMC-GAN, we need a network pretrained on ImageNet to extract features. For our purposes, we train a ResNet-50 network for this. To download the weights, run:

gsutil cp gs://gresearch/xmcgan/resnet_pretrained.npy data/

If you would like to pretrain your own network on ImageNet, please refer to the official Flax ImageNet example.

Training

Start a training run, by first editing train.sh to specify an appropriate work directory. By default, the script assumes that 8 GPUs are available, and runs training on the first 7 GPUs, while test.sh assumes testing will run on the last GPU. After configuring the training job, start an experiment by running it on bash:

mkdir exp
bash train.sh exp_name &> train.txt

Checkpoints and Tensorboard logs will be saved in /path/to/exp/exp_name. By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights. To reproduce the full results on 256px images in our paper, the full model needs to be run using a 32-core Pod slice of Google Cloud TPU v3 devices.

Evaluation

To run an evaluation job, update test.sh with the correct settings used in the training script. Then, execute

bash test.sh exp_name &> eval.txt

to start an evaluation job. All checkpoints in workdir will be evaluated for FID and Inception Score. If you can spare the GPUs, you can also run train.sh and test.sh in parallel, which will continuously evaluate new checkpoints saved into the work directory. Scores will be written to Tensorboard and output to eval.txt.

Tensorboard

To start a Tensorboard for monitoring training progress, run:

tensorboard --logdir /path/to/exp/exp_name

Citation

If you find this work useful, please consider citing:

@inproceedings{zhang2021cross,
  title={Cross-Modal Contrastive Learning for Text-to-Image Generation},
  author={Zhang, Han and Koh, Jing Yu and Baldridge, Jason and Lee, Honglak and Yang, Yinfei},
  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Disclaimer

Not an official Google product.

Owner
Google Research
Google Research
Keras Realtime Multi-Person Pose Estimation - Keras version of Realtime Multi-Person Pose Estimation project

This repository has become incompatible with the latest and recommended version of Tensorflow 2.0 Instead of refactoring this code painfully, I create

M Faber 769 Dec 08, 2022
Training Structured Neural Networks Through Manifold Identification and Variance Reduction

Training Structured Neural Networks Through Manifold Identification and Variance Reduction This repository is a pytorch implementation of the Regulari

0 Dec 23, 2021
Storage-optimizer - Identify potintial optimizations on the cloud storage accounts

Storage Optimizer Identify potintial optimizations on the cloud storage accounts

Zaher Mousa 1 Feb 13, 2022
A Unified Framework and Analysis for Structured Knowledge Grounding

UnifiedSKG 📚 : Unifying and Multi-Tasking Structured Knowledge Grounding with Text-to-Text Language Models Code for paper UnifiedSKG: Unifying and Mu

HKU NLP Group 370 Dec 21, 2022
Json2Xml tool will help you convert from json COCO format to VOC xml format in Object Detection Problem.

JSON 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Json2Xml t

Nguyễn Trường Lâu 6 Aug 22, 2022
A pyparsing-based library for parsing SOQL statements

CONTRIBUTORS WANTED!! Installation pip install python-soql-parser or, with poetry poetry add python-soql-parser Usage from python_soql_parser import p

Kicksaw 0 Jun 07, 2022
AI-generated-characters for Learning and Wellbeing

AI-generated-characters for Learning and Wellbeing Click here for the full project page. This repository contains the source code for the paper AI-gen

MIT Media Lab 214 Jan 01, 2023
Dynamica causal Bayesian optimisation

Dynamic Causal Bayesian Optimization This is a Python implementation of Dynamic Causal Bayesian Optimization as presented at NeurIPS 2021. Abstract Th

nd308 18 Nov 22, 2022
[arXiv] What-If Motion Prediction for Autonomous Driving ❓🚗💨

WIMP - What If Motion Predictor Reference PyTorch Implementation for What If Motion Prediction [PDF] [Dynamic Visualizations] Setup Requirements The W

William Qi 96 Dec 29, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
DziriBERT: a Pre-trained Language Model for the Algerian Dialect

DziriBERT DziriBERT is the first Transformer-based Language Model that has been pre-trained specifically for the Algerian Dialect. It handles Algerian

117 Jan 07, 2023
NeuralForecast is a Python library for time series forecasting with deep learning models

NeuralForecast is a Python library for time series forecasting with deep learning models. It includes benchmark datasets, data-loading utilities, evaluation functions, statistical tests, univariate m

Nixtla 1.1k Jan 03, 2023
Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

Vishal R 1 Nov 16, 2021
Simple reference implementation of GraphSAGE.

Reference PyTorch GraphSAGE Implementation Author: William L. Hamilton Basic reference PyTorch implementation of GraphSAGE. This reference implementat

William L Hamilton 861 Jan 06, 2023
Educational API for 3D Vision using pose to control carton.

Educational API for 3D Vision using pose to control carton.

41 Jul 10, 2022
A platform for intelligent agent learning based on a 3D open-world FPS game developed by Inspir.AI.

Wilderness Scavenger: 3D Open-World FPS Game AI Challenge This is a platform for intelligent agent learning based on a 3D open-world FPS game develope

46 Nov 24, 2022
Learning the Beauty in Songs: Neural Singing Voice Beautifier; ACL 2022 (Main conference); Official code

Learning the Beauty in Songs: Neural Singing Voice Beautifier Jinglin Liu, Chengxi Li, Yi Ren, Zhiying Zhu, Zhou Zhao Zhejiang University ACL 2022 Mai

Jinglin Liu 257 Dec 30, 2022
Keras Image Embeddings using Contrastive Loss

Image to Embedding projection in vector space. Implementation in keras and tensorflow of batch all triplet loss for one-shot/few-shot learning.

Shravan Anand K 5 Mar 21, 2022
FastReID is a research platform that implements state-of-the-art re-identification algorithms.

FastReID is a research platform that implements state-of-the-art re-identification algorithms.

JDAI-CV 2.8k Jan 07, 2023
PyTorch implementation of Constrained Policy Optimization

PyTorch implementation of Constrained Policy Optimization (CPO) This repository has a simple to understand and use implementation of CPO in PyTorch. A

Sapana Chaudhary 25 Dec 08, 2022