A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Overview

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(知乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration 1 3 4
val error 0.36 0.36 0.41
Paper 0.29 0.25 -

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference

Owner
Huadong Liao
Explore Nature from an Omics Perspective
Huadong Liao
This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm and CNN.

Vietnamese sign lagnuage recognition using MHI and CNN This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm

Phat Pham 3 Feb 24, 2022
Medical image analysis framework merging ANTsPy and deep learning

ANTsPyNet A collection of deep learning architectures and applications ported to the python language and tools for basic medical image processing. Bas

Advanced Normalization Tools Ecosystem 118 Dec 24, 2022
Sdf sparse conv - Deep Learning on SDF for Classifying Brain Biomarkers

Deep Learning on SDF for Classifying Brain Biomarkers To reproduce the results f

1 Jan 25, 2022
CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum

CO-PILOT CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum, NeurIPS 2021, Shuang Ao, Tianyi Zhou, Guodong Long, Qingh

Shuang Ao 1 Feb 18, 2022
Computer Vision and Pattern Recognition, NUS CS4243, 2022

CS4243_2022 Computer Vision and Pattern Recognition, NUS CS4243, 2022 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : h

Xavier Bresson 142 Dec 15, 2022
Tutorial page of the Climate Hack, the greatest hackathon ever

Tutorial page of the Climate Hack, the greatest hackathon ever

UCL Artificial Intelligence Society 12 Jul 02, 2022
Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

[AAAI2022] UCTransNet This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspectiv

Haonan Wang 199 Jan 03, 2023
Official PyTorch implementation of "Evolving Search Space for Neural Architecture Search"

Evolving Search Space for Neural Architecture Search Usage Install all required dependencies in requirements.txt and replace all ..path/..to in the co

Yuanzheng Ci 10 Oct 24, 2022
Using python and scikit-learn to make stock predictions

MachineLearningStocks in python: a starter project and guide EDIT as of Feb 2021: MachineLearningStocks is no longer actively maintained MachineLearni

Robert Martin 1.3k Dec 29, 2022
Fast Soft Color Segmentation

Fast Soft Color Segmentation

3 Oct 29, 2022
NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem

NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem Liang Xin, Wen Song, Zhiguang

xinliangedu 33 Dec 27, 2022
A machine learning project which can detect and predict the skin disease through image recognition.

ML-Project-2021 A machine learning project which can detect and predict the skin disease through image recognition. The dataset used for this is the H

Debshishu Ghosh 1 Jan 13, 2022
This project aims to explore the deployment of Swin-Transformer based on TensorRT, including the test results of FP16 and INT8.

Swin Transformer This project aims to explore the deployment of SwinTransformer based on TensorRT, including the test results of FP16 and INT8. Introd

maggiez 87 Dec 21, 2022
Implementation of Auto-Conditioned Recurrent Networks for Extended Complex Human Motion Synthesis

acLSTM_motion This folder contains an implementation of acRNN for the CMU motion database written in Pytorch. See the following links for more backgro

Yi_Zhou 61 Sep 07, 2022
PyTorch-lightning implementation of the ESFW module proposed in our paper Edge-Selective Feature Weaving for Point Cloud Matching

Edge-Selective Feature Weaving for Point Cloud Matching This repository contains a PyTorch-lightning implementation of the ESFW module proposed in our

5 Feb 14, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

35 Dec 06, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
A curated list of long-tailed recognition resources.

Awesome Long-tailed Recognition A curated list of long-tailed recognition and related resources. Please feel free to pull requests or open an issue to

Zhiwei ZHANG 542 Jan 01, 2023
Semi-supervised Learning for Sentiment Analysis

Neural-Semi-supervised-Learning-for-Text-Classification-Under-Large-Scale-Pretraining Code, models and Datasets for《Neural Semi-supervised Learning fo

47 Jan 01, 2023
A clean implementation based on AlphaZero for any game in any framework + tutorial + Othello/Gobang/TicTacToe/Connect4 and more

Alpha Zero General (any game, any framework!) A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play

Surag Nair 3.1k Jan 05, 2023