Compare neural networks by their feature similarity

Overview

PyTorch Model Compare

A tiny package to compare two neural networks in PyTorch. There are many ways to compare two neural networks, but one robust and scalable way is using the Centered Kernel Alignment (CKA) metric, where the features of the networks are compared.

Centered Kernel Alignment

Centered Kernel Alignment (CKA) is a representation similarity metric that is widely used for understanding the representations learned by neural networks. Specifically, CKA takes two feature maps / representations X and Y as input and computes their normalized similarity (in terms of the Hilbert-Schmidt Independence Criterion (HSIC)) as

CKA original version

Where K and L are similarity matrices of X and Y respectively. However, the above formula is not scalable against deep architectures and large datasets. Therefore, a minibatch version can be constructed that uses an unbiased estimator of the HSIC as

alt text

alt text

The above form of CKA is from the 2021 ICLR paper by Nguyen T., Raghu M, Kornblith S.

Getting Started

Installation

pip install torch_cka

Usage

from torch_cka import CKA
model1 = resnet18(pretrained=True)  # Or any neural network of your choice
model2 = resnet34(pretrained=True)

dataloader = DataLoader(your_dataset, 
                        batch_size=batch_size, # according to your device memory
                        shuffle=False)  # Don't forget to seed your dataloader

cka = CKA(model1, model2,
          model1_name="ResNet18",   # good idea to provide names to avoid confusion
          model2_name="ResNet34",   
          model1_layers=layer_names_resnet18, # List of layers to extract features from
          model2_layers=layer_names_resnet34, # extracts all layer features by default
          device='cuda')

cka.compare(dataloader) # secondary dataloader is optional

results = cka.export()  # returns a dict that contains model names, layer names
                        # and the CKA matrix

Examples

torch_cka can be used with any pytorch model (subclass of nn.Module) and can be used with pretrained models available from popular sources like torchHub, timm, huggingface etc. Some examples of where this package can come in handy are illustrated below.

Comparing the effect of Depth

A simple experiment is to analyse the features learned by two architectures of the same family - ResNets but of different depths. Taking two ResNets - ResNet18 and ResNet34 - pre-trained on the Imagenet dataset, we can analyse how they produce their features on, say CIFAR10 for simplicity. This comparison is shown as a heatmap below.

alt text

We see high degree of similarity between the two models in lower layers as they both learn similar representations from the data. However at higher layers, the similarity reduces as the deeper model (ResNet34) learn higher order features which the is elusive to the shallower model (ResNet18). Yet, they do indeed have certain similarity in their last fc layer which acts as the feature classifier.

Comparing Two Similar Architectures

Another way of using CKA is in ablation studies. We can go further than those ablation studies that only focus on resultant performance and employ CKA to study the internal representations. Case in point - ResNet50 and WideResNet50 (k=2). WideResNet50 has the same architecture as ResNet50 except having wider residual bottleneck layers (by a factor of 2 in this case).

alt text

We clearly notice that the learned features are indeed different after the first few layers. The width has a more pronounced effect in deeper layers as compared to the earlier layers as both networks seem to learn similar features in the initial layers.

As a bonus, here is a comparison between ViT and the latest SOTA model Swin Transformer pretrained on ImageNet22k.

alt text

Comparing quite different architectures

CNNs have been analysed a lot over the past decade since AlexNet. We somewhat know what sort of features they learn across their layers (through visualizations) and we have put them to good use. One interesting approach is to compare these understandable features with newer models that don't permit easy visualizations (like recent vision transformer architectures) and study them. This has indeed been a hot research topic (see Raghu et.al 2021).

alt text

Comparing Datasets

Yet another application is to compare two datasets - preferably two versions of the data. This is especially useful in production where data drift is a known issue. If you have an updated version of a dataset, you can study how your model will perform on it by comparing the representations of the datasets. This can be more telling about actual performance than simply comparing the datasets directly.

This can also be quite useful in studying the performance of a model on downstream tasks and fine-tuning. For instance, if the CKA score is high for some features on different datasets, then those can be frozen during fine-tuning. As an example, the following figure compares the features of a pretrained Resnet50 on the Imagenet test data and the VOC dataset. Clearly, the pretrained features have little correlation with the VOC dataset. Therefore, we have to resort to fine-tuning to get at least satisfactory results.

alt text

Tips

  • If your model is large (lots of layers or large feature maps), try to extract from select layers. This is to avoid out of memory issues.
  • If you still want to compare the entire feature map, you can run it multiple times with few layers at each iteration and export your data using cka.export(). The exported data can then be concatenated to produce the full CKA matrix.
  • Give proper model names to avoid confusion when interpreting the results. The code automatically extracts the model name for you by default, but it is good practice to label the models according to your use case.
  • When providing your dataloader(s) to the compare() function, it is important that they are seeded properly for reproducibility.
  • When comparing datasets, be sure to set drop_last=True when building the dataloader. This resolves shape mismatch issues - especially in differently sized datasets.

Citation

If you use this repo in your project or research, please cite as -

@software{subramanian2021torch_cka,
    author={Anand Subramanian},
    title={torch_cka},
    url={https://github.com/AntixK/PyTorch-Model-Compare},
    year={2021}
}
Owner
Anand Krishnamoorthy
Research Engineer
Anand Krishnamoorthy
MHFormer: Multi-Hypothesis Transformer for 3D Human Pose Estimation

MHFormer: Multi-Hypothesis Transformer for 3D Human Pose Estimation This repo is the official implementation of "MHFormer: Multi-Hypothesis Transforme

Vegetabird 281 Jan 07, 2023
Clairvoyance: a Unified, End-to-End AutoML Pipeline for Medical Time Series

Clairvoyance: A Pipeline Toolkit for Medical Time Series Authors: van der Schaar Lab This repository contains implementations of Clairvoyance: A Pipel

van_der_Schaar \LAB 89 Dec 07, 2022
CVPR2022 (Oral) - Rethinking Semantic Segmentation: A Prototype View

Rethinking Semantic Segmentation: A Prototype View Rethinking Semantic Segmentation: A Prototype View, Tianfei Zhou, Wenguan Wang, Ender Konukoglu and

Tianfei Zhou 239 Dec 26, 2022
RRL: Resnet as representation for Reinforcement Learning

Resnet as representation for Reinforcement Learning (RRL) is a simple yet effective approach for training behaviors directly from visual inputs. We demonstrate that features learned by standard image

Meta Research 21 Dec 07, 2022
ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive Image Synthesis

ImageBART NeurIPS 2021 Patrick Esser*, Robin Rombach*, Andreas Blattmann*, Björn Ommer * equal contribution arXiv | BibTeX | Poster Requirements A sui

CompVis Heidelberg 110 Jan 01, 2023
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
Point Cloud Registration using Representative Overlapping Points.

Point Cloud Registration using Representative Overlapping Points (ROPNet) Abstract 3D point cloud registration is a fundamental task in robotics and c

ZhuLifa 36 Dec 16, 2022
This example implements the end-to-end MLOps process using Vertex AI platform and Smart Analytics technology capabilities

MLOps with Vertex AI This example implements the end-to-end MLOps process using Vertex AI platform and Smart Analytics technology capabilities. The ex

Google Cloud Platform 238 Dec 21, 2022
How Do Adam and Training Strategies Help BNNs Optimization? In ICML 2021.

AdamBNN This is the pytorch implementation of our paper "How Do Adam and Training Strategies Help BNNs Optimization?", published in ICML 2021. In this

Zechun Liu 47 Sep 20, 2022
Vit-ImageClassification - Pytorch ViT for Image classification on the CIFAR10 dataset

Vit-ImageClassification Introduction This project uses ViT to perform image clas

Kaicheng Yang 4 Jun 01, 2022
PyTorch implementation of Memory-based semantic segmentation for off-road unstructured natural environments.

MemSeg: Memory-based semantic segmentation for off-road unstructured natural environments Introduction This repository is a PyTorch implementation of

11 Nov 28, 2022
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
This is a repository for a semantic segmentation inference API using the OpenVINO toolkit

BMW-IntelOpenVINO-Segmentation-Inference-API This is a repository for a semantic segmentation inference API using the OpenVINO toolkit. It's supported

BMW TechOffice MUNICH 34 Nov 24, 2022
Contrastive Feature Loss for Image Prediction

Contrastive Feature Loss for Image Prediction We provide a PyTorch implementation of our contrastive feature loss presented in: Contrastive Feature Lo

Alex Andonian 44 Oct 05, 2022
This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivariant Continuous Convolution

Trajectory Prediction using Equivariant Continuous Convolution (ECCO) This is the codebase for the ICLR 2021 paper Trajectory Prediction using Equivar

Spatiotemporal Machine Learning 45 Jul 22, 2022
Detecting drunk people through thermal images using Deep Learning (CNN)

Drunk Detection CNN Detecting drunk people through thermal images using Deep Learning (CNN) Dataset We used thermal images provided by Electronics Lab

Giacomo Ferretti 3 Oct 27, 2022
Ejemplo Algoritmo Viterbi - Example of a Viterbi algorithm applied to a hidden Markov model on DNA sequence

Ejemplo Algoritmo Viterbi Ejemplo de un algoritmo Viterbi aplicado a modelo ocul

Mateo Velásquez Molina 1 Jan 10, 2022
Implementation of Gans

GAN Generative Adverserial Networks are an approach to generative data modelling using Deep learning methods. I have currently implemented : DCGAN on

Sibam Parida 5 Sep 07, 2021
CS5242_2021 - Neural Networks and Deep Learning, NUS CS5242, 2021

CS5242_2021 Neural Networks and Deep Learning, NUS CS5242, 2021 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : https:/

Xavier Bresson 165 Oct 25, 2022