Implementation of CaiT models in TensorFlow and ImageNet-1k checkpoints. Includes code for inference and fine-tuning.

Overview

CaiT-TF (Going deeper with Image Transformers)

TensorFlow 2.8 HugginFace badge Models on TF-Hub

This repository provides TensorFlow / Keras implementations of different CaiT [1] variants from Touvron et al. It also provides the TensorFlow / Keras models that have been populated with the original CaiT pre-trained params available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the CaiT models listed here are available in this repository.

Refer to the "Using the models" section to get started.

Table of contents

Conversion

TensorFlow / Keras implementations are available in cait/models.py. Conversion utilities are in convert.py.

Models

Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/cait/1. You can fully inspect the architecture of the TF-Hub models like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = model(dummy_inputs)
print(model.summary(expand_nested=True))

Results

Results are on ImageNet-1k validation set (top-1 and top-5 accuracies).

model_name top1_acc(%) top5_acc(%)
cait_s24_224 83.368 96.576
cait_xxs24_224 78.524 94.212
cait_xxs36_224 79.76 94.876
cait_xxs36_384 81.976 96.064
cait_xxs24_384 80.648 95.516
cait_xs24_384 83.738 96.756
cait_s24_384 84.944 97.212
cait_s36_384 85.192 97.372
cait_m36_384 85.924 97.598
cait_m48_448 86.066 97.590

Results can be verified with the code in i1k_eval. Results are in line with [1]. Slight differences in the results stemmed from the fact that I used a different set of augmentation transformations. Original transformations suggested by the authors can be found here.

Using the models

Pre-trained models:

These models also output attention weights from each of the Transformer blocks. Refer to this notebook for more details. Additionally, the notebook shows how to visualize the attention maps for a given image (following figures 6 and 7 of the original paper).

Original Image Class Attention Maps Class Saliency Map
cropped image cls attn saliency

For the best quality, refer to the assets directory. You can also generate these plots using the following interactive demos on Hugging Face Spaces:

Randomly initialized models:

from cait.model_configs import base_config
from cait.models import CaiT
import tensorflow as tf
 
config = base_config.get_config(
    model_name="cait_xxs24_224"
)
cait_xxs24_224 = CaiT(config)

dummy_inputs = tf.ones((2, 224, 224, 3))
_ = cait_xxs24_224(dummy_inputs)
print(cait_xxs24_224.summary(expand_nested=True))

To initialize a network with say, 5 classes, do:

config = base_config.get_config(
    model_name="cait_xxs24_224"
)
with config.unlocked():
    config.num_classes = 5
cait_xxs24_224 = CaiT(config)

To view different model configurations, refer to convert_all_models.py.

References

[1] CaiT paper: https://arxiv.org/abs/2103.17239

[2] Official CaiT code: https://github.com/facebookresearch/deit

Acknowledgements

Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
Benchmark spaces - Benchmarks of how well different two dimensional spaces work for clustering algorithms

benchmark_spaces Benchmarks of how well different two dimensional spaces work fo

Bram Cohen 6 May 07, 2022
This repository contains the exercises and its solution contained in the book "An Introduction to Statistical Learning" in python.

An-Introduction-to-Statistical-Learning This repository contains the exercises and its solution contained in the book An Introduction to Statistical L

2.1k Jan 02, 2023
Voice of Pajlada with model and weights.

Pajlada TTS Stripped down version of ForwardTacotron (https://github.com/as-ideas/ForwardTacotron) with pretrained weights for Pajlada's (https://gith

6 Sep 03, 2021
The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

The codes and related files to reproduce the results for Image Similarity Challenge Track 2.

Wenhao Wang 89 Jan 02, 2023
Plaything for Autistic Children (demo for PaddlePaddle/Wechaty/Mixlab project)

星星的孩子 - 一款为孤独症孩子设计的聊天机器人游戏 孤独症儿童是目前常常被忽视的一类群体。他们有着类似性格内向的特征,实际却受着广泛性发育障碍的折磨。 项目背景 这类儿童在与人交往时存在着沟通障碍,其特点表现在: 社交交流差,互动障碍明显 认知能力有限,被动认知 兴趣狭窄,重复刻板,缺乏变化和想象

Tianyi Pan 35 Nov 24, 2022
Statistical-Rethinking-with-Python-and-PyMC3 - Python/PyMC3 port of the examples in " Statistical Rethinking A Bayesian Course with Examples in R and Stan" by Richard McElreath

Statistical Rethinking with Python and PyMC3 This repository has been deprecated in favour of this one, please check that repository for updates, for

Osvaldo Martin 786 Dec 29, 2022
DuBE: Duple-balanced Ensemble Learning from Skewed Data

DuBE: Duple-balanced Ensemble Learning from Skewed Data "Towards Inter-class and Intra-class Imbalance in Class-imbalanced Learning" (IEEE ICDE 2022 S

6 Nov 12, 2022
Unsupervised Representation Learning by Invariance Propagation

Unsupervised Learning by Invariance Propagation This repository is the official implementation of Unsupervised Learning by Invariance Propagation. Pre

FengWang 15 Jul 06, 2022
Physics-informed convolutional-recurrent neural networks for solving spatiotemporal PDEs

PhyCRNet Physics-informed convolutional-recurrent neural networks for solving spatiotemporal PDEs Paper link: [ArXiv] By: Pu Ren, Chengping Rao, Yang

Pu Ren 11 Aug 23, 2022
Official code release for "GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis"

GRAF This repository contains official code for the paper GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis. You can find detailed usage i

349 Dec 29, 2022
Gems & Holiday Package Prediction

Predictive_Modelling Gems & Holiday Package Prediction This project is based on 2 cases studies : Gems Price Prediction and Holiday Package prediction

Avnika Mehta 1 Jan 27, 2022
Research Artifact of USENIX Security 2022 Paper: Automated Side Channel Analysis of Media Software with Manifold Learning

Automated Side Channel Analysis of Media Software with Manifold Learning Official implementation of USENIX Security 2022 paper: Automated Side Channel

Yuanyuan Yuan 175 Jan 07, 2023
Geometry-Free View Synthesis: Transformers and no 3D Priors

Geometry-Free View Synthesis: Transformers and no 3D Priors Geometry-Free View Synthesis: Transformers and no 3D Priors Robin Rombach*, Patrick Esser*

CompVis Heidelberg 293 Dec 22, 2022
190 Jan 03, 2023
Adversarial-autoencoders - Tensorflow implementation of Adversarial Autoencoders

Adversarial Autoencoders (AAE) Tensorflow implementation of Adversarial Autoencoders (ICLR 2016) Similar to variational autoencoder (VAE), AAE imposes

Qian Ge 236 Nov 13, 2022
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
Pytorch implementation of Decoupled Spatial-Temporal Transformer for Video Inpainting

Decoupled Spatial-Temporal Transformer for Video Inpainting By Rui Liu, Hanming Deng, Yangyi Huang, Xiaoyu Shi, Lewei Lu, Wenxiu Sun, Xiaogang Wang, J

51 Dec 13, 2022
This is the implementation of the paper LiST: Lite Self-training Makes Efficient Few-shot Learners.

LiST (Lite Self-Training) This is the implementation of the paper LiST: Lite Self-training Makes Efficient Few-shot Learners. LiST is short for Lite S

Microsoft 28 Dec 07, 2022
Source codes for the paper "Local Additivity Based Data Augmentation for Semi-supervised NER"

LADA This repo contains codes for the following paper: Jiaao Chen*, Zhenghui Wang*, Ran Tian, Zichao Yang, Diyi Yang: Local Additivity Based Data Augm

GT-SALT 36 Dec 02, 2022
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022