Optimized code based on M2 for faster image captioning training

Overview

Transformer Captioning

This repository contains the code for Transformer-based image captioning. Based on meshed-memory-transformer, we further optimize the code for FASTER training without any accuracy decline.

Specifically, we optimize following aspects:

  • vocab: we pre-tokenize the dataset so there are no ' '(space token) in vocab or generated sentences.
  • Dataloader: we optimize speed of dataloader and achieve 2x~6x speed-up.
  • BeamSearch:
    • Make ops parallel in beam_search.py (e.g. loop gather -> parallel gather)
    • Use cheaper ops (e.g. torch.sort -> torch.topk)
    • Use faster and specialized functions instead of general ones
  • Self-critical Training
    • Compute Cider by index instead of raw text
    • Cache tf-idf vector of gts instead of computing it again and again
    • drop on-the-fly tokenization since it is too SLOW.
  • contiguous model parameter
  • other details...

speed-up result (1 GeForce 1080Ti GPU, num_workers=8, batch_size=50(XE)/100(SCST))

Training its/s Original Optimized Accelerate
XE 7.5 10.3 138%
SCST 0.6 1.3 204%
Dataloader its/s Original XE Optimized XE Accelerate Original SCST Optimized SCST Accelerate
batch size=50 12.5 52.5 320% 29.3 90.7 209%
batch size=100 5.5 33.5 510% 22.3 88.5 297%
batch size=150 3.7 25.4 580% 13.4 71.8 435%
batch size=200 2.7 20.1 650% 11.4 54.1 376%

Things I have tried but not useful

  • TorchText n-gram counter: slower than the original one.
  • nn.Module.MultiHeadAttention: slightly faster than original one.
  • GPU cider: very slow
  • BeamableMM: slower than the original

Environment setup

Clone the repository and create the m2release conda environment using the environment.yml file:

conda env create -f environment.yml
conda activate m2release

Then download spacy data by executing the following command:

python -m spacy download en

Note: Python 3.6 is required to run our code.

Data preparation

To run the code, annotations and detection features for the COCO dataset are needed. Please download the annotations file annotations.zip and extract it.

Detection features are computed with the code provided by [1]. To reproduce our result, please download the COCO features file coco_detections.hdf5 (~53.5 GB), in which detections of each image are stored under the <image_id>_features key. <image_id> is the id of each COCO image, without leading zeros (e.g. the <image_id> for COCO_val2014_000000037209.jpg is 37209), and each value should be a (N, 2048) tensor, where N is the number of detections.

REMEMBER to do pre-tokenize

python pre_tokenize.py

Evaluation

Run python test.py using the following arguments:

Argument Possible values
--batch_size Batch size (default: 10)
--workers Number of workers (default: 0)
--features_path Path to detection features file
--annotation_folder Path to folder with COCO annotations

Training procedure

Run python train.py using the following arguments:

Argument Possible values
--exp_name Experiment name
--batch_size Batch size (default: 10)
--workers Number of workers (default: 0)
--head Number of heads (default: 8)
--resume_last If used, the training will be resumed from the last checkpoint.
--resume_best If used, the training will be resumed from the best checkpoint.
--features_path Path to detection features file
--annotation_folder Path to folder with COCO annotations
--logs_folder Path folder for tensorboard logs (default: "tensorboard_logs")

For example, to train our model with the parameters used in our experiments, use

We recommend to use batch size=100 during SCST stage. Since it will accelerate convergence without obvious accuracy decline

python train.py --exp_name test --batch_size 50 --head 8 --features_path ~/datassd/coco_detections.hdf5 --annotation_folder annotation --workers 8 --rl_batch_size 100 --image_field FasterImageDetectionsField --model transformer --seed 118

References

Owner
lyricpoem
lyricpoem
Code for Subgraph Federated Learning with Missing Neighbor Generation (NeurIPS 2021)

To run the code Unzip the package to your local directory; Run 'pip install -r requirements.txt' to download required packages; Open file ~/nips_code/

32 Dec 26, 2022
Plover-tapey-tape: an alternative to Plover’s built-in paper tape

plover-tapey-tape plover-tapey-tape is an alternative to Plover’s built-in paper

7 May 29, 2022
Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021

Global Pooling, More than Meets the Eye: Position Information is Encoded Channel-Wise in CNNs, ICCV 2021 Global Pooling, More than Meets the Eye: Posi

Md Amirul Islam 32 Apr 24, 2022
An implementation of Deep Graph Infomax (DGI) in PyTorch

DGI Deep Graph Infomax (Veličković et al., ICLR 2019): https://arxiv.org/abs/1809.10341 Overview Here we provide an implementation of Deep Graph Infom

Petar Veličković 491 Jan 03, 2023
A Broad Study on the Transferability of Visual Representations with Contrastive Learning

A Broad Study on the Transferability of Visual Representations with Contrastive Learning This repository contains code for the paper: A Broad Study on

Ashraful Islam 29 Nov 09, 2022
Improved Fitness Optimization Landscapes for Sequence Design

ReLSO Improved Fitness Optimization Landscapes for Sequence Design Description Citation How to run Training models Original data source Description In

Krishnaswamy Lab 44 Dec 20, 2022
CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhancement

CBREN This is the Pytorch implementation for our IEEE TCSVT paper : CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhanceme

Zhao Hengrun 3 Nov 04, 2022
BMW TechOffice MUNICH 148 Dec 21, 2022
Pretraining on Dynamic Graph Neural Networks

Pretraining on Dynamic Graph Neural Networks Our article is PT-DGNN and the code is modified based on GPT-GNN Requirements python 3.6 Ubuntu 18.04.5 L

7 Dec 17, 2022
Implement slightly different caffe-segnet in tensorflow

Tensorflow-SegNet Implement slightly different (see below for detail) SegNet in tensorflow, successfully trained segnet-basic in CamVid dataset. Due t

Tseng Kuan Lun 364 Oct 27, 2022
Code for "The Box Size Confidence Bias Harms Your Object Detector"

The Box Size Confidence Bias Harms Your Object Detector - Code Disclaimer: This repository is for research purposes only. It is designed to maintain r

Johannes G. 24 Dec 07, 2022
Systematic generalisation with group invariant predictions

Requirements are Python 3, TensorFlow v1.14, Numpy, Scipy, Scikit-Learn, Matplotlib, Pillow, Scikit-Image, h5py, tqdm. Experiments were run on V100 GPUs (16 and 32GB).

Faruk Ahmed 30 Dec 01, 2022
A quick recipe to learn all about Transformers

Transformers have accelerated the development of new techniques and models for natural language processing (NLP) tasks.

DAIR.AI 772 Dec 31, 2022
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
Split your patch similarly to `git add -p` but supporting multiple buckets

split-patch.py This is git add -p on steroids for patches. Given a my.patch you can run ./split-patch.py my.patch You can choose in which bucket to p

102 Oct 06, 2022
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 07, 2022
Detectron2-FC a fast construction platform of neural network algorithm based on detectron2

What is Detectron2-FC Detectron2-FC a fast construction platform of neural network algorithm based on detectron2. We have been working hard in two dir

董晋宗 9 Jun 06, 2022
Semantic Segmentation with SegFormer on Drone Dataset.

SegFormer_Segmentation Semantic Segmentation with SegFormer on Drone Dataset. You can check out the blog on Medium You can also try out the model with

Praneet 8 Oct 20, 2022
Source codes of CenterTrack++ in 2021 ICME Workshop on Big Surveillance Data Processing and Analysis

MOT Tracked object bounding box association (CenterTrack++) New association method based on CenterTrack. Two new branches (Tracked Size and IOU) are a

36 Oct 04, 2022
Implements VQGAN+CLIP for image and video generation, and style transfers, based on text and image prompts. Emphasis on ease-of-use, documentation, and smooth video creation.

VQGAN-CLIP-GENERATOR Overview This is a package (with available notebook) for running VQGAN+CLIP locally, with a focus on ease of use, good documentat

Ryan Hamilton 98 Dec 30, 2022