Skip to content

hamidriasat/BASNet

Repository files navigation

PWC PWC PWC

Boundary-Aware Segmentation Network for Mobile and Web Applications

This repository contain implementation of BASNet in Tensorflow/Keras.

Note: We are looking for collaborators with good compute power to help us train full model.

If you interested please contact us at hamidriasat@gmail.com

Code Structure

  • images: Model architecture images
  • sample_data: Samples images for visualization
  • weights: Model checkpoint directory
  • basnet.py: Contains model implementation
  • basnet_prediction.ipynb: Notebook to visualize model output
  • basnet_training.ipynb: Notebook to train model
  • dataloader.py: Dataloader to efficiently load data into memory
  • loss.py: Implementation of BASNet hybrid loss function
  • utils.py: Generic utility functions

Training Data

Like paper, we have also used DUTS-TR dataset for training. It has 10,553 images. Commands to download data is written inside training notebook.

Training Settings

We have trained model on Google Colab Pro plus plan using A100 (40 GB) GPU. We have trained model for 100 epochs (~120k iterations) with a batch size of 8. It took almost 24 hours to train. In paper author trained model for 400k iterations. That's why our results are not as good as author results but are enough to demonstrate model learning abilities.

Training code can be found in basnet_training.ipynb Open In Colab

Inference Demo

Model output can be visualized using basnet_prediction.ipynb . Open In Colab

Pretrained weights are available at this Google Drive link. Commands to download weights are present inside prediction notebook.

Model Architecture

alt text

Dependies:
KerasCV 0.5.0
Tensorflow 2.12.0 (Can be any KerasCV compatible version)

Licensed under the MIT License

About

Code for Boundary-Aware Segmentation Network for Mobile and Web Applications

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published