A best practice for tensorflow project template architecture.

Overview

Tensorflow Project Template

A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.

So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)

Table Of Contents

In a Nutshell

In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:

  • In models folder create a class named VGG that inherit the "base_model" class
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • In trainers folder create a VGG trainer that inherit from "base_train" class
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • Override these two functions "train_step", "train_epoch" where you write the logic of the training process
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.

In Details

Project architecture

Folder structure

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

Main Components

Models


  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model contains:

    • Save -This function to save a checkpoint to the desk.
    • Load -This function to load a checkpoint from the desk.
    • Cur_epoch, Global_step counters -These variables to keep track of the current epoch and global step.
    • Init_Saver An abstract function to initialize the saver used for saving and loading the checkpoint, Note: override this function in the model you want to implement.
    • Build_model Here's an abstract function to define the model, Note: override this function in the model you want to implement.
  • Your model

    Here's where you implement your model. So you should :

    • Create your model class and inherit the base_model class
    • override "build_model" where you write the tensorflow model you want
    • override "init_save" where you create a tensorflow saver to use it to save and load checkpoint
    • call the "build_model" and "init_saver" in the initializer.

Trainer


  • Base trainer

    Base trainer is an abstract class that just wrap the training process.

  • Your trainer

    Here's what you should implement in your trainer.

    1. Create your trainer class and inherit the base_trainer class.
    2. override these two functions "train_step", "train_epoch" where you implement the training process of each step and each epoch.

Data Loader

This class is responsible for all data handling and processing and provide an easy interface that can be used by the trainer.

Logger

This class is responsible for the tensorboard summary, in your trainer create a dictionary of all tensorflow variables you want to summarize then pass this dictionary to logger.summarize().

This class also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric. Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Comet.ml Integration

This template also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric.

Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Here's how it looks after you start training:

You can also link your Github repository to your comet.ml project for full version control. Here's a live page showing the example from this repo

Configuration

I use Json as configuration method and then parse it, so write all configs you want then parse it using "utils/config/process_config" and pass this configuration object to all other objects.

Main

Here's where you combine all previous part.

  1. Parse the config file.
  2. Create a tensorflow session.
  3. Create an instance of "Model", "Data_Generator" and "Logger" and parse the config to all of them.
  4. Create an instance of "Trainer" and pass all previous objects to it.
  5. Now you can train your model by calling "Trainer.train()"

Future Work

  • Replace the data loader part with new tensorflow dataset API.

Contributing

Any kind of enhancement or contribution is welcomed.

Acknowledgments

Thanks for my colleague Mo'men Abdelrazek for contributing in this work. and thanks for Mohamed Zahran for the review. Thanks for Jtoy for including the repo in Awesome Tensorflow.

Issues
  • Loaded variables reinitialization

    Loaded variables reinitialization

    Hello! I have a little concern about the model.load() position in main(). Suppose we are in main(). First, you load the model (model.load()) and then initialize variables in base train (via "self.sess.run(self.init)" in "trainer = ExampleTrainer(...)"). But in this case all the variables from the loaded model will be reinitialized with default values. If I'm not mistaken about my concern, the solution is simple - move model.load() right after the trainer = ExampleTrainer(...) in main(). It helped me while implementing your magnificent template for my project. Thanks.

    opened by marinadec 2
  • logger for weights

    logger for weights

    I'm trying to use the logger to get a histogram of weights on my tensorboard. Any chance you'd be interested in putting that in? I'll see if I can pull if off and do a pr, but if not... :)

    opened by mobbSF 1
  • why Trainers?

    why Trainers?

    My question is not about the implementation details. It is more of a strategic question. I want to know the advantage of using a trainer class instead of incorporating the training in the model class as a method. Thank you for this very clean template!

    opened by mouradyahia 1
  • validation cycles

    validation cycles

    Thank you for providing this beautiful template.

    I was wondering where I'd ideally put validation cycles, say every 1000 iterations, in this structure. Any recommendations?

    opened by ferreirafabio 1
  • What is the point

    What is the point "cur_epoch_tensor" and "global_step_tensor"?

    Why keeping track of epoch and step counts is implemented with tensorflow Variables, instead of plain python variables?

    It seems that it is not very efficient to store them like that, considering that they will probably be placed on GPU with other tf variables.

    Is there any use case I'm missing?

    opened by WrathOfGrapes 1
  • how to make import works

    how to make import works

    Hi, I have tried running example.py but the error is

    Traceback (most recent call last):
      File "example.py", line 3, in <module>
        from data_loader.data_generator import DataGenerator
    ModuleNotFoundError: No module named 'data_loader'
    

    How do you guys make the sibling package imports work? Are you using IDE to help you do that?

    opened by zhedongzheng 1
  • Estimator

    Estimator

    Hi,

    I was looking for a nice and clean structure for TF projects and your's comes as number 1. What I miss is the testing/evaluation/inference part. Why haven't you implemented it? Are you planning to put something up?

    Regards

    opened by mkravchik 1
  • Adding support for comet.ml

    Adding support for comet.ml

    Hi! This repo is great and have been super useful for us. This PR adds optional support to comet.ml in the Logger class. This is how it looks on Comet.ml once an experiment is running:

    https://www.comet.ml/gidim/example/83e8f37f484b450d805cdcaabe0a3e4e

    Let me know if you want me to make any changes. We can add a screenshot like they did in this repo: https://github.com/Ahmkel/Keras-Project-Template

    opened by gidim 1
  • About the `init_saver` of base_model?

    About the `init_saver` of base_model?

    Why don't direct use the implement in the annotation?

        def init_saver(self):
            # just copy the following line in your child class
            # self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
            raise NotImplementedError
    
    opened by imhuay 1
  • from bunch import Bunch

    from bunch import Bunch

    Hi Mr Gemy95, Thanks for your tensorflow-template. While, in the sample "utils/config.py" , that indicated as "from bunch import Bunch", the question is where is bunch and what's the function of bunch? Is there another folder named "bunch"? Waiting for you reply. My email address is [email protected], you may send to me directly.

    opened by facereg 1
  • About BaseModel.load

    About BaseModel.load

    To load the model properly, I think we need: def load(self, sess): latest_checkpoint = tf.train.latest_checkpoint(self.config.checkpoint_dir) ...

    Instead of the original latest_checkpoint = tf.train.latest_checkpoint(os.path.join(self.config.checkpoint_dir, self.config.exp_name))

    opened by ulamaca 1
  • how to use it?

    how to use it?

    I execute example_trainer, nothing shows. I execute mains/example.py, it shows : missing or invalid arguments

    Does it support tensorflow serving for other apps to use it?

    opened by gclsoft 0
  • Fixed initialisation issue erasing loaded model.

    Fixed initialisation issue erasing loaded model.

    I found that in your example code, you loaded the model, then ran the variable initialisation, which effectively erased the loaded model. I changed the code to load the model after initialisation to avoid this problem. I hope you accept this change.

    I was running this on tensorflow 1.7, so it is possible that this behaviour didn't exist in previous versions.

    opened by AustinT 0
Owner
Mahmoud Gamal Salem
MSc. in AI at university of Guelph and Vector Institute. AI intern @samsung
Mahmoud Gamal Salem
This project uses Template Matching technique for object detecting by detection of template image over base image.

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 7 May 29, 2022
This project uses Template Matching technique for object detecting by detection of template image over base image

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 4 Nov 16, 2021
code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? Code for paper: Does Unsupervised Architecture Representation

null 36 Apr 18, 2022
Using a Seq2Seq RNN architecture via TensorFlow to predict future Bitcoin prices

Recurrent Bitcoin Network A Data Science Thesis Project About This repository contains the source code for implementing Bitcoin price prediciton using

Frizu 4 Jun 8, 2022
PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

English | 简体中文 Welcome to the PaddlePaddle GitHub. PaddlePaddle, as the only independent R&D deep learning platform in China, has been officially open

null 18.4k Jun 19, 2022
StyleGAN2-ada for practice

This version of the newest PyTorch-based StyleGAN2-ada is intended mostly for fellow artists, who rarely look at scientific metrics, but rather need a working creative tool. Tested on Python 3.7 + PyTorch 1.7.1, requires FFMPEG for sequence-to-video conversions. For more explicit details refer to the original implementations.

vadim epstein 157 Jun 15, 2022
Automatic Attendance marker for LMS Practice School Division, BITS Pilani

LMS Attendance Marker Automatic script for lazy people to mark attendance on LMS for Practice School 1. Setup Add your LMS credentials and time slot t

Nihar Bansal 3 Jun 12, 2021
ML models implementation practice

Let's implement various ML algorithms with numpy/tf Vanilla Neural Network https://towardsdatascience.com/lets-code-a-neural-network-in-plain-numpy-ae

Jinsoo Heo 4 Jul 4, 2021
a baseline to practice

ccks2021_track3_baseline a baseline to practice 路径可能会有问题,自己改改 torch==1.7.1 pyhton==3.7.1 transformers==4.7.0 cuda==11.0 this is a baseline, you can fi

null 40 Jun 16, 2022
StarGAN2 for practice

StarGAN2 for practice This version of StarGAN2 (coined as 'Post-modern Style Transfer') is intended mostly for fellow artists, who rarely look at scie

vadim epstein 79 May 15, 2022
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 1, 2022
Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy.

Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy. Now with tensorflow 1.0 support. Evaluation usa

Marcel R. 348 May 15, 2022
TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

TensorFlow Ranking is a library for Learning-to-Rank (LTR) techniques on the TensorFlow platform

null 2.5k Jun 23, 2022
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Peter Lin 5.8k Jun 21, 2022
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting (RVM) English | 中文 Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specific

flow-dev 1 Jan 29, 2022
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 474 Jun 21, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 8 Mar 22, 2022