PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Overview

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis

This is a PyTorch implementation of the Deep Streaming Linear Discriminant Analysis (SLDA) algorithm from our CVPRW-2020 paper. An arXiv pre-print of our paper is available, as well as the published paper.

Deep SLDA combines a feature extractor with LDA to perform streaming image classification and can be thought of as a way to train the output layer of a neural network. Deep SLDA only requires the storage of a single shared covariance matrix beyond its feature extraction CNN, making its memory requirements very low, e.g., 0.001 GB for our experiments with ResNet-18. Further, once initialized, Deep SLDA is able to train incrementally on the ImageNet dataset in roughly 30 minutes on a Titan X GPU. This is remarkable as methods like iCaRL require 3.011 GB of storage beyond the CNN and require 62 hours to train on the same hardware.

An additional Deep SLDA implementation directly using the CORe50 dataset and scenarios defined in the original CORe50 paper is located here

Dependences

  • Tested with Python 3.6 and PyTorch 1.1.0, or Python 3.7 and PyTorch 1.3.1, NumPy, NVIDIA GPU
  • Dataset:
    • ImageNet-1K (ILSVRC2012) -- Download the ImageNet-1K dataset and move validation images to labeled sub-folders. See link.

Usage

To replicate the SLDA experiments on ImageNet-1K, change necessary paths and run from terminal:

  • slda_imagenet.sh

Alternatively, setup appropriate parameters and run directly in python:

  • python experiment.py

Implementation Notes

When run, the script will save out network probabilities (torch files), accuracies (json files), and the SLDA means and covariance weights (torch files) after every 100 classes in a directory called ./streaming_experiments/*expt_name*.

We have included all necessary files to replicate our ImageNet-1K experiments. Note that the checkpoint file provided in image_files has only been trained on the base 100 classes. However, for other datasets you may want a checkpoint trained on the entire ImageNet-1K dataset, e.g., our CORe50 experiments. Simply change line 196 of experiment.py to feature_extraction_model = get_feature_extraction_model(None, imagenet_pretrained=True).eval() to use ImageNet-1K pre-trained weights from PyTorch.

Other datasets can be used by implementing a PyTorch dataloader for them.

If you would like to start streaming from scratch without a base initialization phase, simply leave out the call to fit_base.

Results on ImageNet ILSVRC-2012

Deep_SLDA

Citation

If using this code, please cite our paper.

@InProceedings{Hayes_2020_CVPR_Workshops,
    author = {Hayes, Tyler L. and Kanan, Christopher},
    title = {Lifelong Machine Learning With Deep Streaming Linear Discriminant Analysis},
    booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
    month = {June},
    year = {2020}
}
Owner
Tyler Hayes
I am a PhD candidate at the Rochester Institute of Technology (RIT). My current research is on lifelong machine learning.
Tyler Hayes
BoxInst: High-Performance Instance Segmentation with Box Annotations

Introduction This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge, the paper is BoxInst: High-Performan

88 Dec 21, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
Official PyTorch implementation of PICCOLO: Point-Cloud Centric Omnidirectional Localization (ICCV 2021)

Official PyTorch implementation of PICCOLO: Point-Cloud Centric Omnidirectional Localization (ICCV 2021)

16 Nov 19, 2022
Recognize Handwritten Digits using Deep Learning on the browser itself.

MNIST on the Web An attempt to predict MNIST handwritten digits from my PyTorch model from the browser (client-side) and not from the server, with the

Harjyot Bagga 7 May 28, 2022
A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

Ayushman Dash 93 Aug 04, 2022
TextureGAN in Pytorch

TextureGAN This code is our PyTorch implementation of TextureGAN [Project] [Arxiv] TextureGAN is a generative adversarial network conditioned on sketc

Patsorn 147 Dec 14, 2022
Video Frame Interpolation without Temporal Priors (a general method for blurry video interpolation)

Video Frame Interpolation without Temporal Priors (NeurIPS2020) [Paper] [video] How to run Prerequisites NVIDIA GPU + CUDA 9.0 + CuDNN 7.6.5 Pytorch 1

YoujianZhang 31 Sep 04, 2022
Pytorch implementation of "Get To The Point: Summarization with Pointer-Generator Networks"

About this repository This repo contains an Pytorch implementation for the ACL 2017 paper Get To The Point: Summarization with Pointer-Generator Netwo

wxDai 7 Oct 14, 2022
Using a Seq2Seq RNN architecture via TensorFlow to predict future Bitcoin prices

Recurrent Bitcoin Network A Data Science Thesis Project About This repository contains the source code for implementing Bitcoin price prediciton using

Frizu 6 Sep 08, 2022
Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

Point Cloud Denoising input segmentation output raw point-cloud valid/clear fog rain de-noised Abstract Lidar sensors are frequently used in environme

75 Nov 24, 2022
This repository contains tutorials for the py4DSTEM Python package

py4DSTEM Tutorials This repository contains tutorials for the py4DSTEM Python package. For more information about py4DSTEM, including installation ins

11 Dec 23, 2022
Official implementation for "Image Quality Assessment using Contrastive Learning"

Image Quality Assessment using Contrastive Learning Pavan C. Madhusudana, Neil Birkbeck, Yilin Wang, Balu Adsumilli and Alan C. Bovik This is the offi

Pavan Chennagiri 67 Dec 30, 2022
Code for the paper Task Agnostic Morphology Evolution.

Task-Agnostic Morphology Optimization This repository contains code for the paper Task-Agnostic Morphology Evolution by Donald (Joey) Hejna, Pieter Ab

Joey Hejna 18 Aug 04, 2022
Official PyTorch implementation of Learning Intra-Batch Connections for Deep Metric Learning (ICML 2021) published at International Conference on Machine Learning

About This repository the official PyTorch implementation of Learning Intra-Batch Connections for Deep Metric Learning. The config files contain the s

Dynamic Vision and Learning Group 41 Dec 10, 2022
Fully Convlutional Neural Networks for state-of-the-art time series classification

Deep Learning for Time Series Classification As the simplest type of time series data, univariate time series provides a reasonably good starting poin

Stephen 572 Dec 23, 2022
A Pytorch loader for MVTecAD dataset.

MVTecAD A Pytorch loader for MVTecAD dataset. It strictly follows the code style of common Pytorch datasets, such as torchvision.datasets.CIFAR10. The

Jiyuan 1 Dec 27, 2021
Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT

CheXbert: Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT CheXbert is an accurate, automated dee

Stanford Machine Learning Group 51 Dec 08, 2022
Official implementation of "CrossPoint: Self-Supervised Cross-Modal Contrastive Learning for 3D Point Cloud Understanding" (CVPR, 2022)

CrossPoint: Self-Supervised Cross-Modal Contrastive Learning for 3D Point Cloud Understanding (CVPR'22) Paper Link | Project Page Abstract : Manual an

Mohamed Afham 152 Dec 23, 2022
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023