Image Captioning using CNN and Transformers

Overview

Image-Captioning

Keras/Tensorflow Image Captioning application using CNN and Transformer as encoder/decoder.
In particulary, the architecture consists of three models:

  1. A CNN: used to extract the image features. In this application, it used EfficientNetB0 pre-trained on imagenet.
  2. A TransformerEncoder: the extracted image features are then passed to a Transformer based encoder that generates a new representation of the inputs.
  3. A TransformerDecoder: this model takes the encoder output and the text data sequence as inputs and tries to learn to generate the caption.

Dataset

The model has been trained on 2014 Train/Val COCO dataset. You can download the dataset here. Note that test images are not required for this code to work.

Original dataset has 82783 train images and 40504 validation images; for each image there is a number of captions between 1 and 6. I have preprocessing the dataset per to keep only images that have exactly 5 captions. In fact, the model has been trained to ensure that 5 captions are assigned for each image. After this filtering, the final dataset has 68363 train images and 33432 validation images.
Finally, I serialized the dataset into two json files which you can find in:

COCO_dataset/captions_mapping_train.json
COCO_dataset/captions_mapping_valid.json

Each element in the captions_mapping_train.json file has such a structure :
"COCO_dataset/train2014/COCO_train2014_000000318556.jpg": ["caption1", "caption2", "caption3", "caption4", "caption5"], ...

In same way in the captions_mapping_valid.json :
"COCO_dataset/val2014/COCO_val2014_000000203564.jpg": ["caption1", "caption2", "caption3", "caption4", "caption5"], ...

Dependencies

I have used the following versions for code work:

  • python==3.8.8
  • tensorflow==2.4.1
  • tensorflow-gpu==2.4.1
  • numpy==1.19.1
  • h5py==2.10.0

Training

To train the model you need to follow the following steps :

  1. you have to make sure that the training set images are in the folder COCO_dataset/train2014/ and that validation set images are in COCO_dataset/val2014/.
  2. you have to enter all the parameters necessary for the training in the settings.py file.
  3. start the model training with python3 training.py

My settings

For my training session, I have get best results with this settings.py file :

# Desired image dimensions
IMAGE_SIZE = (299, 299)
# Max vocabulary size
MAX_VOCAB_SIZE = 2000000
# Fixed length allowed for any sequence
SEQ_LENGTH = 25
# Dimension for the image embeddings and token embeddings
EMBED_DIM = 512
# Number of self-attention heads
NUM_HEADS = 6
# Per-layer units in the feed-forward network
FF_DIM = 1024
# Shuffle dataset dim on tf.data.Dataset
SHUFFLE_DIM = 512
# Batch size
BATCH_SIZE = 64
# Numbers of training epochs
EPOCHS = 14

# Reduce Dataset
# If you want reduce number of train/valid images dataset, set 'REDUCE_DATASET=True'
# and set number of train/valid images that you want.
#### COCO dataset
# Max number train dataset images : 68363
# Max number valid dataset images : 33432
REDUCE_DATASET = False
# Number of train images -> it must be a value between [1, 68363]
NUM_TRAIN_IMG = None
# Number of valid images -> it must be a value between [1, 33432]
NUM_VALID_IMG = None
# Data augumention on train set
TRAIN_SET_AUG = True
# Data augmention on valid set
VALID_SET_AUG = False

# Load train_data.json pathfile
train_data_json_path = "COCO_dataset/captions_mapping_train.json"
# Load valid_data.json pathfile
valid_data_json_path = "COCO_dataset/captions_mapping_valid.json"
# Load text_data.json pathfile
text_data_json_path  = "COCO_dataset/text_data.json"

# Save training files directory
SAVE_DIR = "save_train_dir/"

I have training model on full dataset (68363 train images and 33432 valid images) but you can train the model on a smaller number of images by changing the NUM_TRAIN_IMG / NUM_VALID_IMG parameters to reduce the training time and hardware resources required.

Data augmention

I applied data augmentation on the training set during the training to reduce the generalization error, with this transformations (this code is write in dataset.py) :

trainAug = tf.keras.Sequential([
    	tf.keras.layers.experimental.preprocessing.RandomContrast(factor=(0.05, 0.15)),
    	tf.keras.layers.experimental.preprocessing.RandomTranslation(height_factor=(-0.10, 0.10), width_factor=(-0.10, 0.10)),
	tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=(-0.10, 0.10), width_factor=(-0.10, 0.10)),
	tf.keras.layers.experimental.preprocessing.RandomRotation(factor=(-0.10, 0.10))
])

You can customize your data augmentation by changing this code or disable data augmentation setting TRAIN_SET_AUG = False in setting.py.

My results

This is results of my best training :

Epoch 1/13
1069/1069 [==============================] - 1450s 1s/step - loss: 17.3777 - acc: 0.3511 - val_loss: 13.9711 - val_acc: 0.4819
Epoch 2/13
1069/1069 [==============================] - 1453s 1s/step - loss: 13.7338 - acc: 0.4850 - val_loss: 12.7821 - val_acc: 0.5133
Epoch 3/13
1069/1069 [==============================] - 1457s 1s/step - loss: 12.9772 - acc: 0.5069 - val_loss: 12.3980 - val_acc: 0.5229
Epoch 4/13
1069/1069 [==============================] - 1452s 1s/step - loss: 12.5683 - acc: 0.5179 - val_loss: 12.2659 - val_acc: 0.5284
Epoch 5/13
1069/1069 [==============================] - 1450s 1s/step - loss: 12.3292 - acc: 0.5247 - val_loss: 12.1828 - val_acc: 0.5316
Epoch 6/13
1069/1069 [==============================] - 1443s 1s/step - loss: 12.1614 - acc: 0.5307 - val_loss: 12.1410 - val_acc: 0.5341
Epoch 7/13
1069/1069 [==============================] - 1453s 1s/step - loss: 12.0461 - acc: 0.5355 - val_loss: 12.1234 - val_acc: 0.5354
Epoch 8/13
1069/1069 [==============================] - 1440s 1s/step - loss: 11.9533 - acc: 0.5407 - val_loss: 12.1086 - val_acc: 0.5367
Epoch 9/13
1069/1069 [==============================] - 1444s 1s/step - loss: 11.8838 - acc: 0.5427 - val_loss: 12.1235 - val_acc: 0.5373
Epoch 10/13
1069/1069 [==============================] - 1443s 1s/step - loss: 11.8114 - acc: 0.5460 - val_loss: 12.1574 - val_acc: 0.5367
Epoch 11/13
1069/1069 [==============================] - 1444s 1s/step - loss: 11.7543 - acc: 0.5486 - val_loss: 12.1518 - val_acc: 0.5371

These are good results considering that for each image given as input to the model during training, the error and the accuracy are averaged over 5 captions. However, I spent little time doing model selection and you can improve the results by trying better settings.
For example, you could :

  1. change CNN architecture.
  2. change SEQ_LENGTH, EMBED_DIM, NUM_HEADS, FF_DIM, BATCH_SIZE (etc...) parameters.
  3. change data augmentation transformations/parameters.
  4. etc...

N.B. I have saved my best training results files in the directory save_train_dir/.

Inference

After training and saving the model, you can restore it in a new session to inference captions on new images.
To generate a caption from a new image, you must :

  1. insert the parameters in the file settings_inference.py
  2. run python3 inference.py --image={image_path_file}

Results example

Examples of image output taken from the validation set.

a large passenger jet flying through the sky
a man in a white shirt and black shorts playing tennis
a person on a snowboard in the snow
a boy on a skateboard in the street
a black bear is walking through the grass
a train is on the tracks near a station
Owner
I love computer vision and NLP. I love artificial intelligence. Machine Learning and Big Data master's degree student.
Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch.

Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLPv2, RaftMLP, ConvMLP, ConvMixer in Jittor and PyTorch! Now, Rearrange and Reduce in einops.layers.jittor are support!!

130 Jan 08, 2023
The source code of CVPR17 'Generative Face Completion'.

GenerativeFaceCompletion Matcaffe implementation of our CVPR17 paper on face completion. In each panel from left to right: original face, masked input

Yijun Li 313 Oct 18, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance

Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance Project Page | Paper | Data This repository contains an implementatio

Lior Yariv 521 Dec 30, 2022
In generative deep geometry learning, we often get many obj files remain to be rendered

a python prompt cli script for blender batch render In deep generative geometry learning, we always get many .obj files to be rendered. Our rendered i

Tian-yi Liang 1 Mar 20, 2022
Code for "OctField: Hierarchical Implicit Functions for 3D Modeling (NeurIPS 2021)"

OctField(Jittor): Hierarchical Implicit Functions for 3D Modeling Introduction This repository is code release for OctField: Hierarchical Implicit Fun

55 Dec 08, 2022
Code for the Lovász-Softmax loss (CVPR 2018)

The Lovász-Softmax loss: A tractable surrogate for the optimization of the intersection-over-union measure in neural networks Maxim Berman, Amal Ranne

Maxim Berman 1.3k Jan 04, 2023
SCALoss: Side and Corner Aligned Loss for Bounding Box Regression (AAAI2022).

SCALoss PyTorch implementation of the paper "SCALoss: Side and Corner Aligned Loss for Bounding Box Regression" (AAAI 2022). Introduction IoU-based lo

TuZheng 20 Sep 07, 2022
Make your AirPlay devices as TTS speakers

Apple AirPlayer Home Assistant integration component, make your AirPlay devices as TTS speakers. Before Use 2021.6.X or earlier Apple Airplayer compon

George Zhao 117 Dec 15, 2022
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

neural-combinatorial-rl-pytorch PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning. I have implemented the basic

Patrick E. 454 Jan 06, 2023
This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Models used for prediction Diabetes and further the basic theory and working of Gold nanoparticles.

GoldNanoparticles This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Mode

1 Jan 30, 2022
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 04, 2022
Transfer Learning Remote Sensing

Transfer_Learning_Remote_Sensing Simulation R codes for data generation and visualizations are in the folder simulation. Experiment: California Housin

2 Jun 21, 2022
Repo 4 basic seminar §How to make human machine readable"

WORK IN PROGRESS... Notebooks from the Seminar: Human Machine Readable WS21/22 Introduction into programming Georg Trogemann, Christian Heck, Mattis

experimental-informatics 3 May 29, 2022
a basic code repository for basic task in CV(classification,detection,segmentation)

basic_cv a basic code repository for basic task in CV(classification,detection,segmentation,tracking) classification generate dataset train predict de

1 Oct 15, 2021
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
Official repository of IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSUMPTION.

IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSUMPTION This is the official repository of IMPROVING DEEP IMAGE MATTING VIA LOCAL SMOOTHNESS ASSU

电线杆 14 Dec 15, 2022
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 01, 2023
An auto discord account and token generator. Automatically verifies the phone number. Works without proxy. Bypasses captcha.

JOIN DISCORD SERVER https://discord.gg/uAc3agBY FREE HCAPTCHA SOLVING API Discord-Token-Gen An auto discord token generator. Auto verifies phone numbe

3kp 271 Jan 01, 2023