Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Overview

Semi-Supervised Learning with Ladder Networks in Keras

This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-supervised learning. Refer to the paper titled Semi-Supervised Learning with Ladder Networks by A Rasmus, H Valpola, M Honkala,M Berglund, and T Raiko

This implementation was used in the official code of our paper Unsupervised Clustering using Pseudo-semi-supervised Learning . The code can be found here and the blog post can be found here

The model achives 98% test accuracy on MNIST with just 100 labeled examples.

The code only works with Tensorflow backend.

Requirements

  • Python 2.7+/3.6+
  • Tensorflow (1.4.0)
  • numpy
  • keras (2.1.4)

Note that other versions of tensorflow/keras should also work.

How to use

Load the dataset

from keras.datasets import mnist
import keras
import random

# get the dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 28*28).astype('float32')/255.0
x_test = x_test.reshape(10000, 28*28).astype('float32')/255.0

y_train = keras.utils.to_categorical( y_train )
y_test = keras.utils.to_categorical( y_test )

# only select 100 training samples 
idxs_annot = range( x_train.shape[0])
random.seed(0)
random.shuffle( idxs_annot )
idxs_annot = idxs_annot[ :100 ]

x_train_unlabeled = x_train
x_train_labeled = x_train[ idxs_annot ]
y_train_labeled = y_train[ idxs_annot  ]

Repeat the labeled dataset to match the shapes

n_rep = x_train_unlabeled.shape[0] / x_train_labeled.shape[0]
x_train_labeled_rep = np.concatenate([x_train_labeled]*n_rep)
y_train_labeled_rep = np.concatenate([y_train_labeled]*n_rep)

Initialize the model

from ladder_net import get_ladder_network_fc
inp_size = 28*28 # size of mnist dataset 
n_classes = 10
model = get_ladder_network_fc( layer_sizes = [ inp_size , 1000, 500, 250, 250, 250, n_classes ]  )

Train the model

model.fit([ x_train_labeled_rep , x_train_unlabeled   ] , y_train_labeled_rep , epochs=100)

Get the test accuracy

from sklearn.metrics import accuracy_score
y_test_pr = model.test_model.predict(x_test , batch_size=100 )

print "test accuracy" , accuracy_score(y_test.argmax(-1) , y_test_pr.argmax(-1)  )
Owner
Divam Gupta
Graduate student at Carnegie Mellon University | Former Research Fellow at Microsoft Research
Divam Gupta
Jremesh-tools - Blender addon for quad remeshing

JRemesh Tools Blender 2.8 - 3.x addon for quad remeshing. Currently it is a wrap

Jayanam 89 Dec 30, 2022
Info and sample codes for "NTU RGB+D Action Recognition Dataset"

"NTU RGB+D" Action Recognition Dataset "NTU RGB+D 120" Action Recognition Dataset "NTU RGB+D" is a large-scale dataset for human action recognition. I

Amir Shahroudy 578 Dec 30, 2022
A 3D sparse LBM solver implemented using Taichi

taichi_LBM3D Background Taichi_LBM3D is a 3D lattice Boltzmann solver with Multi-Relaxation-Time collision scheme and sparse storage structure impleme

Jianhui Yang 121 Jan 06, 2023
3rd place solution for the Weather4cast 2021 Stage 1 Challenge

weather4cast2021_Stage1 3rd place solution for the Weather4cast 2021 Stage 1 Challenge Dependencies The code can be executed from a fresh environment

5 Aug 14, 2022
Semi-supervised Stance Detection of Tweets Via Distant Network Supervision

SANDS This is an annonymous repository containing code and data necessary to reproduce the results published in "Semi-supervised Stance Detection of T

2 Sep 22, 2022
Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021]

Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021] Paper: https://arxiv.org/abs/2104.11208 Introduction Despite the significa

76 Dec 07, 2022
GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration

Geometric Transformer for Fast and Robust Point Cloud Registration PyTorch imple

Zheng Qin 220 Jan 05, 2023
Repo for "Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions" https://arxiv.org/abs/2201.12296

Benchmarking Robustness of 3D Point Cloud Recognition against Common Corruptions This repo contains the dataset and code for the paper Benchmarking Ro

Jiachen Sun 168 Dec 29, 2022
Code for KDD'20 "An Efficient Neighborhood-based Interaction Model for Recommendation on Heterogeneous Graph"

Heterogeneous INteract and aggreGatE (GraphHINGE) This is a pytorch implementation of GraphHINGE model. This is the experiment code in the following w

Jinjiarui 69 Nov 24, 2022
A more easy-to-use implementation of KPConv based on PyTorch.

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 36 Dec 29, 2022
Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach

This repository holds the implementation for paper Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach Download our preproc

Qitian Wu 42 Dec 27, 2022
Hypersim: A Photorealistic Synthetic Dataset for Holistic Indoor Scene Understanding

The Hypersim Dataset For many fundamental scene understanding tasks, it is difficult or impossible to obtain per-pixel ground truth labels from real i

Apple 1.3k Jan 04, 2023
OpenMMLab Image and Video Editing Toolbox

Introduction MMEditing is an open source image and video editing toolbox based on PyTorch. It is a part of the OpenMMLab project. The master branch wo

OpenMMLab 3.9k Jan 04, 2023
A more easy-to-use implementation of KPConv

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 35 Dec 14, 2022
Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective

Unofficial pytorch implementation of the paper "Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective"

16 Nov 21, 2022
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
GAN-STEM-Conv2MultiSlice - Exploring Generative Adversarial Networks for Image-to-Image Translation in STEM Simulation

GAN-STEM-Conv2MultiSlice GAN method to help covert lower resolution STEM images generated by convolution methods to higher resolution STEM images gene

UW-Madison Computational Materials Group 2 Feb 10, 2021
AWS provides a Python SDK, "Boto3" ,which can be used to access the AWS-account from the local.

Boto3 - The AWS SDK for Python Boto3 is the Amazon Web Services (AWS) Software Development Kit (SDK) for Python, which allows Python developers to wri

Shreyas Srivastava 1 Oct 25, 2021
Codebase for BMVC 2021 paper "Text Based Person Search with Limited Data"

Text Based Person Search with Limited Data This is the codebase for our BMVC 2021 paper. Please bear with me refactoring this codebase after CVPR dead

Xiao Han 33 Nov 24, 2022