Skip to content

giovcandido/prototypical-networks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

51 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prototypical Networks

This repository has the implementation for both Prototypical Networks, proposed by Snell et al. in 2017, and Prototypical Networks with Random Weights, proposed by the owner of this repository in 2021. It's worth mentioning that Weighted Prototypical Networks are also available now and you can test it with your own weights.

The repository also has scripts to train these models for the task of few-shot image classification on Omniglot and mini-ImageNet.

It's made with Python3 and tested on Linux.

Installation

Clone the repository or download the compressed source code. If you opted for the latter, you need to extract the source code to a desired directory.

In both cases, open the project directory in your terminal.

Now, install the requirements. You can achieve that by running:

pip3 install -r requirements.txt

In case you can't install the requirements as a user, run the following instead:

sudo pip3 install -r requirements.txt

You also need to install the protonets package with:

pip3 install -e .

You may need to install it with sudo:

sudo pip3 install -e .

After installing the requirements and the package, you're ready to go.

Usage

You can train two models:

  • Prototypical Networks;
  • Prototypical Networks with Random Weights.

And there are two available datasets:

  • Omniglot;
  • mini-ImageNet.

First, you need to go to the scripts directory.

Once you're in this directory, you need to download the datasets.

The dataset_downloader.py script takes a -d/--dataset argument. If you try to execute it without passing the required argument, you should expect to see the following message:

usage: dataset_downloader.py [-h] -d {all,omniglot,mini_imagenet}
dataset_downloader.py: error: the following arguments are required: -d/--dataset

Reading the output above we know that there are three possible choices: all, omniglot and mini_imagenet.

As an example, let's suppose we only want to download omniglot:

python3 dataset_downloader.py -d omniglot

After the download is complete, we can train a model on omniglot.

The training.py script takes two arguments: -m/--model and -d/--dataset. If you run it without passing the required arguments, you should expect to see the following message:

usage: training.py [-h] -m {vanilla,random_weights} -d {omniglot,mini_imagenet}
training.py: error: the following arguments are required: -m/--model, -d/--dataset

Reading the output above we know that both arguments have two possible values. For the first one, these values are: vanilla and random_weights. As for the latter, the values are: omniglot and mini_imagenet.

Since we have downloaded omniglot, let's run:

python3 training.py -m vanilla -d omniglot

After the training is complete, we can retrain by running:

python3 retraining.py

And after retraining, we can evaluate our model with:

python3 evaluation.py

The results are be stored in a directory called results.

Bear in mind that you have to rename or delete the results directory before training another model.

The retraining and the evaluation scripts work with the model obtained when you first execute the training script.

Few-Shot Setup

You can find the few-shot setup and other parameters in the config directory.

The splits and the implementation follow the procedure of Prototypical Networks For Few-shot Learning.

Results

The results obtained with this implementation are comparable to those obtained with the original one.

You can check my execution logs and trained models here.

Acknowledgements

This project was based on:

The idea of PNs can be originally found in Prototypical Networks for Few-shot Learning.

It's worth mentioning that using weights in order to calculate the prototypes is an idea that can be found in the paper Improved Prototypical Networks for Few-Shot Learning.

About

Prototypical Networks for the task of few-shot image classification on Omniglot and mini-ImageNet.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages