- Python3 / Pytorch code for multi-class image classification
- You can obtain logits for all images (
--store_logits
) and confusion matrix (--store_confusion_matrix
) of your current model. - You can manage your experiments with this code. This code helps to store loss/accuracy changes during training (
--store_loss_acc_log
). Also, you can access the best weight files during validation (--store_weights
). - You can easily change the loss function for training your model. (See below
tip
s!) - You can check the progress of training with tensorboard.
- You do not need to calculate mean/std of training examples: use
--auto_mean_std
.
- You can obtain logits for all images (
- See
requirements.txt
for details.
torch
torchvision
matplotlib
scikit-learn
tqdm # not mandatory but recommended
tensorboard # not mandatory but recommended
- The directory structure of your dataset should be as follows. (You can use our toy-examples: unzip
cifar10_dummy.zip
.)
|—— 📁 your_own_dataset
|—— 📁 train
|—— 📁 class_1
|—— 🖼️ 1.jpg (Available file extensions: *.jpeg, *.jpg, *.png, *.bmp)
|—— ...
|—— 📁 class_2
|—— 🖼️ ...
|—— 📁 valid
|—— 📁 class_1
|—— 📁 ...
|—— 📁 test
|—— 📁 class_1
|—— 📁 ...
- Check
__init__.py
. You might need to modify variables and add somethings (transformation, optimizer, lr_schduler ...). 💁Tip
: You can add your own loss function as follows:
...
def get_loss_function(loss_function_name, device):
...
elif loss_function_name == 'your_own_function_name': # add +
return your_own_function()
...
...
- Run
train.py
for training. The below is an example. Seesrc/my_utils/parser.py
for details. 💁Tip
:--loss_function='CE'
means that you choose softmax-cross-entropy (default) for your loss.
python train.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--batch_size=256 --epochs=5 \
--lr=0.1 --lr_step='[60, 120, 160]' --lr_step_gamma=0.5 --lr_warmup_epochs=5 \
--auto_mean_std --store_weights --store_loss_acc_log --store_logits --store_confusion_matrix \
--loss_function='your_own_function_name' --transform_list_name='CIFAR' --tag='train-001'
- Run
test.py
for test. The below is an example. Seesrc/my_utils/parser.py
for details.
python test.py --network_name='resnet34_for_tiny' --dataset_dir='./cifar10_dummy' \
--batch_size=256 --auto_mean_std --store_logits --store_confusion_matrix \
--checkpoint='your_pretrained_model_weights.pt'
- You can add your own classification models. See
src/my_models
andsrc/model.py
for details.
-
If you install tqdm, you can check the progress of training.
-
If you install tensorboard, you can see the plots of acc/loss changes and confusion matrices during training. (Type
tensorboard --logdir='./runs'
in your command shell.)
If you use this code for your research, please cite the following papers:
@article{kim2021imbalanced,
title={Imbalanced image classification with complement cross entropy},
author={Kim, Yechan and Lee, Younkwan and Jeon, Moongu},
journal={Pattern Recognition Letters},
volume={151},
pages={33--40},
year={2021},
publisher={Elsevier}
}
🐛 If you find any bugs or have opinions for further improvements, feel free to contact me (yechankim@gm.gist.ac.kr). All contributions are welcome.