Pretty Tensor - Fluent Neural Networks in TensorFlow

Overview

Pretty Tensor - Fluent Neural Networks in TensorFlow

Pretty Tensor provides a high level builder API for TensorFlow. It provides thin wrappers on Tensors so that you can easily build multi-layer neural networks.

Pretty Tensor provides a set of objects that behave likes Tensors, but also support a chainable object syntax to quickly define neural networks and other layered architectures in TensorFlow.

result = (pretty_tensor.wrap(input_data, m)
          .flatten()
          .fully_connected(200, activation_fn=tf.nn.relu)
          .fully_connected(10, activation_fn=None)
          .softmax(labels, name=softmax_name))

Please look here for full documentation of the PrettyTensor object for all available operations: Available Operations or you can check out the complete documentation

See the tutorial directory for samples: tutorial/

Installation

The easiest installation is just to use pip:

  1. Follow the instructions at tensorflow.org
  2. pip install prettytensor

Note: Head is tested against the TensorFlow nightly builds and pip is tested against TensorFlow release.

Quick start

Imports

import prettytensor as pt
import tensorflow as tf

Setup your input

my_inputs = # numpy array of shape (BATCHES, BATCH_SIZE, DATA_SIZE)
my_labels = # numpy array of shape (BATCHES, BATCH_SIZE, CLASSES)
input_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, DATA_SIZE))
label_tensor = tf.placeholder(np.float32, shape=(BATCH_SIZE, CLASSES))
pretty_input = pt.wrap(input_tensor)

Define your model

softmax, loss = (pretty_input.
                 fully_connected(100).
                 softmax_classifier(CLASSES, labels=label_tensor))

Train and evaluate

accuracy = softmax.evaluate_classifier(label_tensor)

optimizer = tf.train.GradientDescentOptimizer(0.1)  # learning rate
train_op = pt.apply_optimizer(optimizer, losses=[loss])

init_op = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init_op)
    for inp, label in zip(my_inputs, my_labels):
        unused_loss_value, accuracy_value = sess.run([loss, accuracy],
                                 {input_tensor: inp, label_tensor: label})
        print 'Accuracy: %g' % accuracy_value

Features

Thin

Full power of TensorFlow is easy to use

Pretty Tensors can be used (almost) everywhere that a tensor can. Just call pt.wrap to make a tensor pretty.

You can also add any existing TensorFlow function to the chain using apply. apply applies the current Tensor as the first argument and takes all the other arguments as normal.

Note: because apply is so generic, Pretty Tensor doesn't try to wrap the world.

Plays well with other libraries

It also uses standard TensorFlow idioms so that it plays well with other libraries, this means that you can use it a little bit in a model or throughout. Just make sure to run the update_ops on each training set (see with_update_ops).

Terse

You've already seen how a Pretty Tensor is chainable and you may have noticed that it takes care of handling the input shape. One other feature worth noting are defaults. Using defaults you can specify reused values in a single place without having to repeat yourself.

with pt.defaults_scope(activation_fn=tf.nn.relu):
  hidden_output2 = (pretty_images.flatten()
                   .fully_connected(100)
                   .fully_connected(100))

Check out the documentation to see all supported defaults.

Code matches model

Sequential mode lets you break model construction across lines and provides the subdivide syntactic sugar that makes it easy to define and understand complex structures like an inception module:

with pretty_tensor.defaults_scope(activation_fn=tf.nn.relu):
  seq = pretty_input.sequential()
  with seq.subdivide(4) as towers:
    towers[0].conv2d(1, 64)
    towers[1].conv2d(1, 112).conv2d(3, 224)
    towers[2].conv2d(1, 32).conv2d(5, 64)
    towers[3].max_pool(2, 3).conv2d(1, 32)

Inception module showing branch and rejoin

Templates provide guaranteed parameter reuse and make unrolling recurrent networks easy:

output = [], s = tf.zeros([BATCH, 256 * 2])

A = (pretty_tensor.template('x')
     .lstm_cell(num_units=256, state=UnboundVariable('state'))

for x in pretty_input_array:
  h, s = A.construct(x=x, state=s)
  output.append(h)

There are also some convenient shorthands for LSTMs and GRUs:

pretty_input_array.sequence_lstm(num_units=256)

Unrolled RNN

Extensible

You can call any existing operation by using apply and it will simply subsitute the current tensor for the first argument.

pretty_input.apply(tf.mul, 5)

You can also create a new operation There are two supported registration mechanisms to add your own functions. @Register() allows you to create a method on PrettyTensor that operates on the Tensors and returns either a loss or a new value. Name scoping and variable scoping are handled by the framework.

The following method adds the leaky_relu method to every Pretty Tensor:

@pt.Register
def leaky_relu(input_pt):
  return tf.select(tf.greater(input_pt, 0.0), input_pt, 0.01 * input_pt)

@RegisterCompoundOp() is like adding a macro, it is designed to group together common sets of operations.

Safe variable reuse

Within a graph, you can reuse variables by using templates. A template is just like a regular graph except that some variables are left unbound.

See more details in PrettyTensor class.

Accessing Variables

Pretty Tensor uses the standard graph collections from TensorFlow to store variables. These can be accessed using tf.get_collection(key) with the following keys:

  • tf.GraphKeys.VARIABLES: all variables that should be saved (including some statistics).
  • tf.GraphKeys.TRAINABLE_VARIABLES: all variables that can be trained (including those before a stop_gradients` call). These are what would typically be called parameters of the model in ML parlance.
  • pt.GraphKeys.TEST_VARIABLES: variables used to evaluate a model. These are typically not saved and are reset by the LocalRunner.evaluate method to get a fresh evaluation.

Authors

Eider Moore (eiderman)

with key contributions from:

  • Hubert Eichner
  • Oliver Lange
  • Sagar Jain (sagarjn)
Owner
Google
Google ❤️ Open Source
Google
Official implementation of the paper Do pedestrians pay attention? Eye contact detection for autonomous driving

Do pedestrians pay attention? Eye contact detection for autonomous driving Official implementation of the paper Do pedestrians pay attention? Eye cont

VITA lab at EPFL 26 Nov 02, 2022
code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology"

GIANT Code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology" https://arxiv.org/pdf/2004.02118.pdf Please cite our paper if this pr

Excalibur 39 Dec 29, 2022
Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Official implementation of NeurIPS 2021 paper "One Loss for All: Deep Hashing with a Single Cosine Similarity based Learning Objective"

Ng Kam Woh 71 Dec 22, 2022
CT Based COVID 19 Diagnose by Image Processing and Deep Learning

This project proposed the deep learning and image processing method to undertake the diagnosis on 2D CT image and 3D CT volume.

1 Feb 08, 2022
Solutions of Reinforcement Learning 2nd Edition

Solutions of Reinforcement Learning, An Introduction

YIFAN WANG 1.4k Dec 30, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Dec 30, 2022
OpenPCDet Toolbox for LiDAR-based 3D Object Detection.

OpenPCDet OpenPCDet is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. It is also the official code release o

OpenMMLab 3.2k Dec 31, 2022
Create and implement a deep learning library from scratch.

In this project, we create and implement a deep learning library from scratch. Table of Contents Deep Leaning Library Table of Contents About The Proj

Rishabh Bali 22 Aug 23, 2022
A collection of awesome resources image-to-image translation.

awesome image-to-image translation A collection of resources on image-to-image translation. Contributing If you think I have missed out on something (

876 Dec 28, 2022
[PAMI 2020] Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation

Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation This repository contains the source code for

Yun-Chun Chen 60 Nov 25, 2022
Deep Sea Treasure Environment for Multi-Objective Optimization Research

DeepSeaTreasure Environment Installation In order to get started with this environment, you can install it using the following command: python3 -m pip

imec IDLab 6 Nov 14, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

SMPLify-XMC This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright Lic

Lea Müller 83 Dec 14, 2022
PyTorch Implementation of Fully Convolutional Networks. (Training code to reproduce the original result is available.)

pytorch-fcn PyTorch implementation of Fully Convolutional Networks. Requirements pytorch = 0.2.0 torchvision = 0.1.8 fcn = 6.1.5 Pillow scipy tqdm

Kentaro Wada 1.6k Jan 07, 2023
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

IIGROUP 6 Sep 21, 2022
PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Pre

Facebook Research 437 Dec 23, 2022
Official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch.

Multi-speaker DGP This repository provides official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch. O

sarulab-speech 24 Sep 07, 2022
TAP: Text-Aware Pre-training for Text-VQA and Text-Caption, CVPR 2021 (Oral)

TAP: Text-Aware Pre-training TAP: Text-Aware Pre-training for Text-VQA and Text-Caption by Zhengyuan Yang, Yijuan Lu, Jianfeng Wang, Xi Yin, Dinei Flo

Microsoft 61 Nov 14, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
Kaggle Lyft Motion Prediction for Autonomous Vehicles 4th place solution

Lyft Motion Prediction for Autonomous Vehicles Code for the 4th place solution of Lyft Motion Prediction for Autonomous Vehicles on Kaggle. Discussion

44 Jun 27, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 01, 2023