Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Overview

Simple-Image-Classification

Simple Image Classification Code (PyTorch)

Yechan Kim

This repository contains:

  • Python3 / Pytorch code for multi-class image classification

Prerequisites

  • See requirements.txt for details.
torch
torchvision
matplotlib
scikit-learn
tqdm            # not mandatory but recommended
tensorboard     # not mandatory but recommended

How to use

  1. The directory structure of your dataset should be as follows. (You can use our toy-examples: unzip cifar10_dummy.zip.)
|โ€”โ€” ๐Ÿ“ your_own_dataset
	|โ€”โ€” ๐Ÿ“ train
		|โ€”โ€” ๐Ÿ“ class_1
			|โ€”โ€” ๐Ÿ–ผ๏ธ 1.jpg
			|โ€”โ€” ...
		|โ€”โ€” ๐Ÿ“ class_2 
			|โ€”โ€” ๐Ÿ–ผ๏ธ ...
	|โ€”โ€” ๐Ÿ“ valid
		|โ€”โ€” ๐Ÿ“ class_1
		|โ€”โ€” ๐Ÿ“ ... 
	|โ€”โ€” ๐Ÿ“ test
		|โ€”โ€” ๐Ÿ“ class_1
		|โ€”โ€” ๐Ÿ“ ... 
  1. Check __init__.py. You might need to modify variables and add somethings (transformation, optimizer, lr_schduler ...). ๐Ÿ’ Tip You can add your own loss function as follows:
...
def get_loss_function(loss_function_name, device):
    ... 
    elif loss_function_name == 'your_own_function_name':  # add +
        return your_own_function()
    ...
...
  1. Run train.py for training. The below is an example. See src/my_utils/parser.py for details. ๐Ÿ’ Tip --loss_function='CE' means that you choose softmax-cross-entropy (default) for your loss.
python train.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--batch_size=256 --epochs=5  \
--lr=0.1 --lr_step='[60, 120, 160]' --lr_step_gamma=0.5 --lr_warmup_epochs=5 \
--auto_mean_std --store_weights --store_loss_acc_log --store_logits --store_confusion_matrix \
--loss_function='your_own_function_name' --transform_list_name='CIFAR' --tag='train-001'
  1. Run test.py for test. The below is an example. See src/my_utils/parser.py for details.
python test.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--auto_mean_std --store_logits --store_confusion_matrix \
--checkpoint='pretrained_model_weights.pt'

Trailer

  1. If you install tqdm, you can check the progress of training. readme1

  2. If you install tensorboard, you can check the acc/loss changes and confusion matrices during training. readme1

Contribution

๐Ÿ› If you find any bugs or have opinions for further improvements, feel free to contact me ([email protected]). All contributions are welcome.

Reference

  1. https://github.com/weiaicunzai/pytorch-cifar100
  2. https://medium.com/@djin31/how-to-plot-wholesome-confusion-matrix-40134fd402a8 (Confusion Matrix)
  3. https://pytorch.org/ignite/generated/ignite.handlers.param_scheduler.create_lr_scheduler_with_warmup.html
Owner
Yechan Kim
GIST, Machine Learning and Vision Lab.
Yechan Kim
Towards Implicit Text-Guided 3D Shape Generation (CVPR2022)

Towards Implicit Text-Guided 3D Shape Generation Towards Implicit Text-Guided 3D Shape Generation (CVPR2022) Code for the paper [Towards Implicit Text

55 Dec 16, 2022
GUI for a Vocal Remover that uses Deep Neural Networks.

GUI for a Vocal Remover that uses Deep Neural Networks.

4.4k Jan 07, 2023
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
S2s2net - Sentinel-2 Super-Resolution Segmentation Network

S2S2Net Sentinel-2 Super-Resolution Segmentation Network Getting started Install

Wei Ji 10 Nov 10, 2022
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022
The official PyTorch implementation for the paper "sMGC: A Complex-Valued Graph Convolutional Network via Magnetic Laplacian for Directed Graphs".

Magnetic Graph Convolutional Networks About The official PyTorch implementation for the paper sMGC: A Complex-Valued Graph Convolutional Network via M

3 Feb 25, 2022
Automated Evidence Collection for Fake News Detection

Automated Evidence Collection for Fake News Detection This is the code repo for the Automated Evidence Collection for Fake News Detection paper accept

Mrinal Rawat 2 Apr 12, 2022
SMPL-X: A new joint 3D model of the human body, face and hands together

SMPL-X: A new joint 3D model of the human body, face and hands together [Paper Page] [Paper] [Supp. Mat.] Table of Contents License Description News I

Vassilis Choutas 1k Jan 09, 2023
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
Contrastive Loss Gradient Attack (CLGA)

Contrastive Loss Gradient Attack (CLGA) Official implementation of Unsupervised Graph Poisoning Attack via Contrastive Loss Back-propagation, WWW22 Bu

12 Dec 23, 2022
All the code and files related to the MI-Lab of UE19CS305 course in sem 5

Machine-Intelligence-Lab-CS305 The compilation of all the code an drelated files from MI-Lab UE19CS305 (of batch 2019-2023) offered by PES University

Arvind Krishna 3 Nov 10, 2022
Framework that uses artificial intelligence applied to mathematical models to make predictions

LiconIA Framework that uses artificial intelligence applied to mathematical models to make predictions Interface Overview Table of contents [TOC] 1 Ar

4 Jun 20, 2021
Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

5 May 23, 2022
[ICCV 2021] Released code for Causal Attention for Unbiased Visual Recognition

CaaM This repo contains the codes of training our CaaM on NICO/ImageNet9 dataset. Due to my recent limited bandwidth, this codebase is still messy, wh

Wang Tan 66 Dec 31, 2022
Repo for the paper "DiLBERT: Cheap Embeddings for Disease Related Medical NLP"

DiLBERT Repo for the paper "DiLBERT: Cheap Embeddings for Disease Related Medical NLP" Pretrained Model The pretrained model presented in the paper is

Kevin Roitero 2 Dec 15, 2022
This is a template for the Non-autoregressive Deep Learning-Based TTS model (in PyTorch).

Non-autoregressive Deep Learning-Based TTS Template This is a template for the Non-autoregressive TTS model. It contains Data Preprocessing Pipeline D

Keon Lee 13 Dec 05, 2022
Springer Link Download Module for Python

โ™ž pupalink A simple Python module to search and download books from SpringerLink. ๐Ÿงช This project is still in an early stage of development. Expect br

Pupa Corp. 18 Nov 21, 2022
Vanilla and Prototypical Networks with Random Weights for image classification on Omniglot and mini-ImageNet. Made with Python3.

vanilla-rw-protonets-project Vanilla Prototypical Networks and PNs with Random Weights for image classification on Omniglot and mini-ImageNet. Made wi

Giovani Candido 8 Aug 31, 2022
Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation

Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation Our paper is accepted by ICCV2021. Picture: Overview of the proposed Plug-an

Yunfei Liu 32 Dec 10, 2022
Evaluation toolkit of the informative tracking benchmark comprising 9 scenarios, 180 diverse videos, and new challenges.

Informative-tracking-benchmark Informative tracking benchmark (ITB) higher diversity. It contains 9 representative scenarios and 180 diverse videos. m

Xin Li 15 Nov 26, 2022