Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Overview

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Reference

  • Paper URL

  • Author: Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng

  • Google Research

Method

model

1. Dense Synthesizer

2. Fixed Random Synthesizer

3. Random Synthesizer

4. Factorized Dense Synthesizer

5. Factorized Random Synthesizer

6. Mixture of Synthesizers

Usage

import torch

from synthesizer import Transformer, SynthesizerDense, SynthesizerRandom, FactorizedSynthesizerDense, FactorizedSynthesizerRandom, MixtureSynthesizers, get_n_params, calculate_flops


def main():
    batch_size, channel_dim, sentence_length = 2, 1024, 32
    x = torch.randn([batch_size, sentence_length, channel_dim])

    vanilla = Transformer(channel_dim)
    out, attention_map = vanilla(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(vanilla), calculate_flops(vanilla.children())
    print('vanilla, n_params: {}, flops: {}'.format(n_params, flops))

    dense_synthesizer = SynthesizerDense(channel_dim, sentence_length)
    out, attention_map = dense_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(dense_synthesizer), calculate_flops(dense_synthesizer.children())
    print('dense_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer = SynthesizerRandom(channel_dim, sentence_length)
    out, attention_map = random_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer), calculate_flops(random_synthesizer.children())
    print('random_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))

    random_synthesizer_fix = SynthesizerRandom(channel_dim, sentence_length, fixed=True)
    out, attention_map = random_synthesizer_fix(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(random_synthesizer_fix), calculate_flops(random_synthesizer_fix.children())
    print('random_synthesizer_fix, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_random = FactorizedSynthesizerRandom(channel_dim)
    out, attention_map = factorized_synthesizer_random(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_random), calculate_flops(
        factorized_synthesizer_random.children())
    print('factorized_synthesizer_random, n_params: {}, flops: {}'.format(n_params, flops))

    factorized_synthesizer_dense = FactorizedSynthesizerDense(channel_dim, sentence_length)
    out, attention_map = factorized_synthesizer_dense(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(factorized_synthesizer_dense), calculate_flops(
        factorized_synthesizer_dense.children())
    print('factorized_synthesizer_dense, n_params: {}, flops: {}'.format(n_params, flops))

    mixture_synthesizer = MixtureSynthesizers(channel_dim, sentence_length)
    out, attention_map = mixture_synthesizer(x)
    print(out.size(), attention_map.size())
    n_params, flops = get_n_params(mixture_synthesizer), calculate_flops(mixture_synthesizer.children())
    print('mixture_synthesizer, n_params: {}, flops: {}'.format(n_params, flops))


if __name__ == '__main__':
    main()

Output

torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
vanilla, n_params: 3148800, flops: 3145729
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
dense_synthesizer, n_params: 1083456, flops: 1082370
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([1, 32, 32])
random_synthesizer_fix, n_params: 1050624, flops: 1048577
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_random, n_params: 1066000, flops: 1064961
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
factorized_synthesizer_dense, n_params: 1061900, flops: 1060865
torch.Size([2, 32, 1024]) torch.Size([2, 32, 32])
mixture_synthesizer, n_params: 3149824, flops: 3145729

Paper Performance

eval

Owner
Myeongjun Kim
Computer Vision Research using Deep Learning
Myeongjun Kim
[CVPR 2021] Exemplar-Based Open-Set Panoptic Segmentation Network (EOPSN)

EOPSN: Exemplar-Based Open-Set Panoptic Segmentation Network (CVPR 2021) PyTorch implementation for EOPSN. We propose open-set panoptic segmentation t

Jaedong Hwang 49 Dec 30, 2022
PyTorch implementation of Neural Dual Contouring.

NDC PyTorch implementation of Neural Dual Contouring. Citation We are still writing the paper while adding more improvements and applications. If you

Zhiqin Chen 140 Dec 26, 2022
All-in-one Docker container that allows a user to explore Nautobot in a lab environment.

Nautobot Lab This container is not for production use! Nautobot Lab is an all-in-one Docker container that allows a user to quickly get an instance of

Nautobot 29 Sep 16, 2022
This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports"

Introduction: X-Ray Report Generation This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports". O

no name 36 Dec 16, 2022
Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Semi-Supervised Learning with Ladder Networks in Keras This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-superv

Divam Gupta 101 Sep 07, 2022
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022
TGS Salt Identification Challenge

TGS Salt Identification Challenge This is an open solution to the TGS Salt Identification Challenge. Note Unfortunately, we can no longer provide supp

neptune.ai 123 Nov 04, 2022
This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network"

TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network This repository contains code from the paper "TTS-GAN: A Transformer-based Tim

Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab) - Texas State University 108 Dec 29, 2022
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

Sayak Paul 43 Jan 08, 2023
Learning Calibrated-Guidance for Object Detection in Aerial Images

Learning Calibrated-Guidance for Object Detection in Aerial Images arxiv We propose a simple yet effective Calibrated-Guidance (CG) scheme to enhance

51 Sep 22, 2022
Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models Abstract Many applications of generative models rely on the marginali

Stanford Intelligent Systems Laboratory 9 Jun 06, 2022
A project to build an AI voice assistant using Python . The Voice assistant interacts with the humans to perform basic tasks.

AI_Personal_Voice_Assistant_Using_Python A project to build an AI voice assistant using Python . The Voice assistant interacts with the humans to perf

Chumui Tripura 1 Oct 30, 2021
Language Used: Python . Made in Jupyter(Anaconda) notebook.

FACE-DETECTION-ATTENDENCE-SYSTEM Made in Jupyter(Anaconda) notebook. Language Used: Python Steps to perform before running the program : Install Anaco

1 Jan 12, 2022
Code for technical report "An Improved Baseline for Sentence-level Relation Extraction".

RE_improved_baseline Code for technical report "An Improved Baseline for Sentence-level Relation Extraction". Requirements torch = 1.8.1 transformers

Wenxuan Zhou 74 Nov 29, 2022
Invert and perturb GAN images for test-time ensembling

GAN Ensembling Project Page | Paper | Bibtex Ensembling with Deep Generative Views. Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhan

Lucy Chai 93 Dec 08, 2022
An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance"

Lidar-Segementation An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance" from

Wangxu1996 135 Jan 06, 2023
Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning

Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning This is the official repository of "Camera Distortion-

Hanbyel Cho 12 Oct 06, 2022
Motion and Shape Capture from Sparse Markers

MoSh++ This repository contains the official chumpy implementation of mocap body solver used for AMASS: AMASS: Archive of Motion Capture as Surface Sh

Nima Ghorbani 135 Dec 23, 2022
Near-Duplicate Video Retrieval with Deep Metric Learning

Near-Duplicate Video Retrieval with Deep Metric Learning This repository contains the Tensorflow implementation of the paper Near-Duplicate Video Retr

2 Jan 24, 2022