What can linearized neural networks actually say about generalization?

Overview

What can linearized neural networks actually say about generalization?

This is the source code to reproduce the experiments of the NeurIPS 2021 paper "What can linearized neural networks actually say about generalization?" by Guillermo Ortiz-Jimenez, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard.

Dependencies

To run the code, please install all its dependencies by running:

$ pip install -r requirements.txt

This assumes that you have access to a Linux machine with an NVIDIA GPU with CUDA>=11.1. Otherwise, please check the instructions to install JAX with your setup in the corresponding repository.

In general, all scripts are parameterized using hydra and their configuration files can be found in the config/ folder.

Experiments

The repository contains code to reproduce the following experiments:

Spectral decomposition of NTK

To generate our new benchmark, consisting on the eigenfunctions of the NTK at initialization, please run the python script compute_ntk.py selecting a desired model (e.g., mlp, lenet or resnet18) and supporting dataset (e.g., cifar10 or mnist). This can be done by running

$ python compute_ntk.py model=lenet data.dataset=cifar10

This script will save the eigenvalues, eigenfunctions and weights of the model under artifacts/eigenfunctions/{data.dataset}/{model}/.

For other configuration options, please consult the configuration file config/compute-ntk/config.yaml.

Warning

Take into account that, for large models, this computation can take very long. For example, it took us two days to compute the full eigenvalue decomposition of the NTK of one randomly initialized ResNet18 using 4 NVIDIA V100 GPUs. The estimation of eigenvectors for the MLP or the LeNet, on the other hand, can be done in a matter of minutes, depending on the number of GPUs available and the selected batch_size

Training on binary eigenfunctions

Once you have estimated the eigenfunctions of the NTK, you should be able to train on any of them. To that end, select the desired label_idx (i.e. eigenfunction index), model and dataset, and run

$ python train_ntk.py label_idx=100 model=lenet data.dataset=cifar10 linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize. For the non-linear models, this script also computes the final alignment of the end NTK with the target function, which it stores under artifacts/eigenfunctions/{data.dataset}/{model}/alignment_plots/

To see the different supported training options, please consult the configuration file config/train-ntk/config.yaml.

Estimation of NADs

We also provide code to compute the NADs of a CNN architecture (e.g., lenet or resnet18) using the alignment with the NTK at initialization. To do so, please run

$ python compute_nads.py model=lenet

This script will save the eigenvalues, NADs and weights of the model under artifacts/nads/{model}/.

For other configuration options, please consult the configuration file config/compute-nads/config.yaml.

Training on linearly separable datasets

Once you have estimated the NADs of a network, you should be able to train on linearly separable datasets with a single NAD as discriminative feature. To that end, select the desired label_idx (i.e. NAD index) and model, and run

$ python train_nads.py label_idx=100 model=lenet linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize.

To see the different supported training options, please consult the configuration file config/train-nads/config.yaml.

Comparison of training dynamics with pretrained NTK

We also provide code to compare the training dynamics of the linearize network at initialization, and after non-linear pretraining, to estimate a particular eigenfunction of the NTK at initialization. To do this, please run

$ python pretrained_ntk_comparison.py label_idx=100 model=lenet data.dataset=cifar10

To see the different supported training options, please consult the configuration file config/pretrained_ntk_comparison/config.yaml.

Training on CIFAR2

Finally, you can train a neural network and its linearize approximation on the binary version of CIFAR10, i.e., CIFAR2. To do this, please run

$ python train_cifar.py model=lenet linearize=False

To see the different supported training options, please consult the configuration file config/binary-cifar/config.yaml.

Reference

If you use this code, please cite the following paper:

@InCollection{Ortiz-JimenezNeurIPS2021,
  title = {What can linearized neural networks actually say about generalization?},
  author = {{Ortiz-Jimenez}, Guillermo and {Moosavi-Dezfooli}, Seyed-Mohsen and Frossard, Pascal},
  booktitle = {Advances in Neural Information Processing Systems 35},
  month = Dec,
  year = {2021}
}
Owner
gortizji
PhD student at EPFL
gortizji
《K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters》(2020)

K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters This repository is the implementation of the paper "K-Adapter: Infusing Knowledge

Microsoft 118 Dec 13, 2022
Training Cifar-10 Classifier Using VGG16

opevcvdl-hw3 This project uses pytorch and Qt to achieve the requirements. Version Python 3.6 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.

Kenny Cheng 3 Aug 17, 2022
Official implementation of the article "Unsupervised JPEG Domain Adaptation For Practical Digital Forensics"

Unsupervised JPEG Domain Adaptation for Practical Digital Image Forensics @WIFS2021 (Montpellier, France) Rony Abecidan, Vincent Itier, Jeremie Boulan

Rony Abecidan 6 Jan 06, 2023
Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Dense Deep Unfolding Network with 3D-CNN Prior for Snapshot Compressive Imaging, ICCV2021 [PyTorch Code]

Jian Zhang 20 Oct 24, 2022
[CVPR'21] Projecting Your View Attentively: Monocular Road Scene Layout Estimation via Cross-view Transformation

Projecting Your View Attentively: Monocular Road Scene Layout Estimation via Cross-view Transformation Weixiang Yang, Qi Li, Wenxi Liu, Yuanlong Yu, Y

118 Dec 26, 2022
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 119 Sep 28, 2022
Efficient 6-DoF Grasp Generation in Cluttered Scenes

Contact-GraspNet Contact-GraspNet: Efficient 6-DoF Grasp Generation in Cluttered Scenes Martin Sundermeyer, Arsalan Mousavian, Rudolph Triebel, Dieter

NVIDIA Research Projects 148 Dec 28, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022
This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transformers.

TransMix: Attend to Mix for Vision Transformers This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transf

Jie-Neng Chen 130 Jan 01, 2023
OpenMMLab Image and Video Editing Toolbox

Introduction MMEditing is an open source image and video editing toolbox based on PyTorch. It is a part of the OpenMMLab project. The master branch wo

OpenMMLab 3.9k Jan 04, 2023
Few-shot Neural Architecture Search

One-shot Neural Architecture Search uses a single supernet to approximate the performance each architecture. However, this performance estimation is super inaccurate because of co-adaption among oper

Yiyang Zhao 38 Oct 18, 2022
Repository of continual learning papers

Continual learning paper repository This repository contains an incomplete (but dynamically updated) list of papers exploring continual learning in ma

29 Jan 05, 2023
Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples"

KSTER Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples" [paper]. Usage Download the processed datas

jiangqn 23 Nov 24, 2022
a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LSTM layers

RNN-Playwrite a reccurrent neural netowrk that when trained on a peice of text and fed a starting prompt will write its on 250 character text using LS

Arno Barton 1 Oct 29, 2021
Keras-1D-ACGAN-Data-Augmentation

Keras-1D-ACGAN-Data-Augmentation What is the ACGAN(Auxiliary Classifier GANs) ? Related Paper : [Abstract : Synthesizing high resolution photorealisti

Jae-Hoon Shim 7 Dec 23, 2022
[BMVC 2021] Official PyTorch Implementation of Self-supervised learning of Image Scale and Orientation Estimation

Self-Supervised Learning of Image Scale and Orientation Estimation (BMVC 2021) This is the official implementation of the paper "Self-Supervised Learn

Jongmin Lee 17 Nov 10, 2022
An AFL implementation with UnTracer (our coverage-guided tracer)

UnTracer-AFL This repository contains an implementation of our prototype coverage-guided tracing framework UnTracer in the popular coverage-guided fuz

113 Dec 17, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Time Using Noisy Proxies

Deconfounding Temporal Autoencoder (DTA) This is a repository for the paper "Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Tim

Milan Kuzmanovic 3 Feb 04, 2022
Implementation of UNET architecture for Image Segmentation.

Semantic Segmentation using UNET This is the implementation of UNET on Carvana Image Masking Kaggle Challenge About the Dataset This dataset contains

Anushka agarwal 4 Dec 21, 2021