TensorFlow Similarity is a python package focused on making similarity learning quick and easy.


TensorFlow Similarity: Metric Learning for Humans

TensorFlow Similarity is a TensorFLow library for similarity learning also known as metric learning and contrastive learning.

TensorFlow Similarity is still in beta.


Tensorflow Similarity offers state-of-the-art algorithms for metric learning and all the necessary components to research, train, evaluate, and serve similarity-based models.

Example of nearest neighbors search performed on the embedding generated by a similarity model trained on the Oxford IIIT Pet Dataset.

With TensorFlow Similarity you can train and serve models that find similar items (such as images) in a large corpus of examples. For example, as visible above, you can train a similarity model to find and cluster similar looking images of cats and dogs from the Oxford IIIT Pet Dataset by only training on a few classes. To train your own similarity model see this notebook.

Metric learning is different from traditional classification as it's objective is different. The model learns to minimize the distance between similar examples and maximize the distance between dissimilar examples, in a supervised or self-supervised fashion. Either way, TensorFlow Similarity provides the necessary losses, metrics, samplers, visualizers, and indexing sub-system to make this quick and easy.

Currently, TensorFlow Similarity supports supervised training. In future releases, it will support semi-supervised and self-supervised training.

To learn more about the benefits of using similarity training, you can check out the blog post.

What's new

For previous changes - see the release changelog

Getting Started


Use pip to install the library

pip install tensorflow_similarity


The detailed and narrated notebooks are a good way to get started with TensorFlow Similarity. There is likely to be one that is similar to your data or your problem (if not, let us know). You can start working with the examples immediately in Google Colab by clicking the Google Colab icon.

For more information about specific functions, you can check the API documentation

For contributing to the project please check out the contribution guidelines

Minimal Example: MNIST similarity

Here is a bare bones example demonstrating how to train a TensorFlow Similarity model on the MNIST data. This example illustrates some of the main components provided by TensorFlow Similarity and how they fit together. Please refer to the hello_world notebook for a more detailed introduction.

Preparing data

TensorFlow Similarity provides data samplers, for various dataset types, that balance the batches to ensure smoother training. In this example, we are using the multi-shot sampler that integrate directly from the TensorFlow dataset catalog.

from tensorflow_similarity.samplers import TFDatasetMultiShotMemorySampler

# Data sampler that generates balanced batches from MNIST dataset
sampler = TFDatasetMultiShotMemorySampler(dataset_name='mnist', classes_per_batch=10)

Building a Similarity model

Building a TensorFlow Similarity model is similar to building a standard Keras model, except the output layer is usually a MetricEmbedding() layer that enforces L2 normalization and the model is instantiated as a specialized subclass SimilarityModel() that supports additional functionality.

from tensorflow.keras import layers
from tensorflow_similarity.layers import MetricEmbedding
from tensorflow_similarity.models import SimilarityModel

# Build a Similarity model using standard Keras layers
inputs = layers.Input(shape=(28, 28, 1))
x = layers.experimental.preprocessing.Rescaling(1/255)(inputs)
x = layers.Conv2D(64, 3, activation='relu')(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation='relu')(x)
outputs = MetricEmbedding(64)(x)

# Build a specialized Similarity model
model = SimilarityModel(inputs, outputs)

Training model via contrastive learning

To output a metric embedding, that are searchable via approximate nearest neighbor search, the model needs to be trained using a similarity loss. Here we are using the MultiSimilarityLoss(), which is one of the most efficient loss functions.

from tensorflow_similarity.losses import MultiSimilarityLoss

# Train Similarity model using contrastive loss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)

Building images index and querying it

Once the model is trained, reference examples must indexed via the model index API to be searchable. After indexing, you can use the model lookup API to search the index for the K most similar items.

from tensorflow_similarity.visualization import viz_neigbors_imgs

# Index 100 embedded MNIST examples to make them searchable
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)

# Find the top 5 most similar indexed MNIST examples for a given example
qx, qy = sampler.get_slice(3713, 1)
nns = model.single_lookup(qx[0])

# Visualize the query example and its top 5 neighbors
viz_neigbors_imgs(qx[0], qy[0], nns)

Supported Algorithms

Supervised Losses

  • Triplet Loss
  • PN Loss
  • Multi Sim Loss
  • Circle Loss


Tensorflow Similarity offers many of the most common metrics used for classification and retrieval evaluation. Including:

Name Type Description
Precision Classification
Recall Classification
F1 Score Classification
[email protected] Retrieval
Binary NDCG Retrieval


Please cite this reference if you use any part of TensorFlow similarity in your research:

  title={TensorFlow Similarity: A Usuable, High-Performance Metric Learning Library},
  author={Elie Bursztein, James Long, Shun Lin, Owen Vallis, Francois Chollet},


This is not an official Google product.

  • v0.16(May 27, 2022)


    • Cross-batch memory (XBM). Thanks @chjort
    • VicReg Loss - Improvement of Barlow Twins. Thanks @dewball345
    • Add augmenter function for Barlow Twins. Thanks @dewball345


    • Simplified MetricEmbedding layer. Function tracing and serialization are better supported now.
    • Refactor image augmentation modules into separate utils modules to help make them more composable. Thanks @dewball345
    • GeneralizedMeanPooling layers default value for P is now 3.0. This better aligns with the value in the paper.
    • EvalCallback now supports split validation callback. Thanks @abhisharsinha
    • Distance and losses refactor. Refactor distances call signature to now accept query and key inputs instead of a single set of embeddings or labels. Thanks @chjort


    • Fix TFRecordDatasetSampler to ensure the correct number of examples per class per batch. Deterministic is now set to True and we have updated the docstring to call out the requirements for the tf record files.
    • Removed unneeded tf.function and registar_keras_serializable decorators.
    • Refactored the model index attribute to raise a more informative AttributeError if the index does not exist.
    • Freeze all BatchNormalization layers in architectures when loading weights.
    • Fix bug in losses.utils.LogSumExp(). tf.math.log(1 + x) should be tf.math.log(tf.math.exp(-my_max) + x). This is needed to properly account for removing the row wise max before computing the logsumexp.
    • Fix multisim loss offsets. The tfsim version of multisim uses distances instead of the inner product. However, multisim requires that we "center" the pairwise distances around 0. Here we add a new center param, which we set to 1.0 for cosine distance. Additionally, we also flip the lambda (lmda) param to add the threshold to the values instead of subtracting it. These changes will help improve the pos and neg weighting in the log1psumexp.
    • Fix nmslib save and load. nmslib requires a string path and will only read and write to local files. In order to support writing to a remote file path, we first write to a local tmp dir and then write that to the user provided path using tf.io.gfile.GFile.
    • Fix serialization of Simclr params in get_config()
    • Other fixes and improvements...
    Source code(tar.gz)
    Source code(zip)
  • v0.15(Jan 21, 2022)

    This release add self-supervised training support


    • Refactored Augmenters to be a class
    • Added SimCLRAugmenter
    • Added SiamSiamLoss()
    • Added SimCLR Loss
    • Added ContrastiveModel() for self-supervised training
    • Added encoder_dev() metrics as suggested in SiamSiam to detect collapsing
    • Added visualize_views() to see view side by side
    • Added self-supervised hello world narrated notebook
    • Refactored Augmenter() as class.


    • Remove augmentation argument from architectures as the augmentation arg could lead to issues when saving the model or training on TPU.
    • Removed RandAugment which is not used directly by the package and causes issues with TF 2.8+
    Source code(tar.gz)
    Source code(zip)
  • v0.14(Oct 9, 2021)

  • v0.13(Sep 13, 2021)

