Transformer model implemented with Pytorch

Overview

transformer-pytorch

Transformer model implemented with Pytorch

Attention is all you need-[Paper]

Architecture

Transformer


Self-Attention

Attention

self_attention.py

[N, len, heads, head_dim] values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) queries = queries.reshape(N, query_len, self.heads, self.head_dim) # Einsum does matrix mult. for query*keys for each training example # with every other training example, don't be confused by einsum # it's just how I like doing matrix multiplication & bmm energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # queries shape: (N, query_len, heads, heads_dim), # keys shape: (N, key_len, heads, heads_dim) # energy: (N, heads, query_len, key_len) # Mask padded indices so their weights become 0 if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) # Normalize energy values similarly to seq2seq + attention # so that they sum to 1. Also divide by scaling factor for # better stability attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # attention shape: (N, heads, query_len, key_len) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) # attention shape: (N, heads, query_len, key_len) # values shape: (N, value_len, heads, heads_dim) # out after matrix multiply: (N, query_len, heads, head_dim), then # we reshape and flatten the last two dimensions. out = self.fc_out(out) # Linear layer doesn't modify the shape, final shape will be # (N, query_len, embed_size) return out ">
 class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads      = heads
        self.head_dim   = embed_size // heads

        assert (
                self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values  = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys    = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out  = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values  = self.values(values)
        keys    = self.keys(keys)
        queries = self.queries(query)
        
        # Split the embedding into self.heads different pieces
        # Multi head
        # [N, len, embed_size] --> [N, len, heads, head_dim]
        values    = values.reshape(N, value_len, self.heads, self.head_dim)
        keys      = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries   = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out

Encoder Block

Encoder

encoder_block.py

class EncoderBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(EncoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1     = nn.LayerNorm(embed_size)
        self.norm2     = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x       = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out     = self.dropout(self.norm2(forward + x))
        return out

Encoder

Encoder

encoder.py

class Encoder(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size         = embed_size
        self.device             = device
        self.word_embedding     = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                EncoderBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

Decoder Block

DecoderBlock

docoder_block.py

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm              = nn.LayerNorm(embed_size)
        self.attention         = SelfAttention(embed_size, heads=heads)
        self.transformer_block = EncoderBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout           = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query     = self.dropout(self.norm(attention + x))
        out       = self.transformer_block(value, key, query, src_mask)
        return out

Decoder

Decoder

decoder.py

class Decoder(nn.Module):
    def __init__(
            self,
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
    ):
        super(Decoder, self).__init__()
        self.device             = device
        self.word_embedding     = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc_out  = nn.Linear(embed_size, trg_vocab_size)


    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions     = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x             = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

Transformer

transformer.py

class Transformer(nn.Module):
    def __init__(
            self,
            src_vocab_size,
            trg_vocab_size,
            src_pad_idx,
            trg_pad_idx,
            embed_size=512,
            num_layers=6,
            forward_expansion=4,
            heads=8,
            dropout=0,
            device="cpu",
            max_length=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device      = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out

Authors

Owner
Mingu Kang
SW Engineering / ML / DL / Blockchain Dept. of Software Engineering, Jeonbuk National University
Mingu Kang
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Zhiqiang Shen 129 Dec 24, 2022
Diverse graph algorithms implemented using JGraphT library.

# 1. Installing Maven & Pandas First, please install Java (JDK11) and Python 3 if they are not already. Next, make sure that Maven (for importing J

See Woo Lee 3 Dec 17, 2022
The story of Chicken for Club Bing

Chicken Story tl;dr: The time when Microsoft banned my entire country for cheating at Club Bing. (A lot of the details are from memory so I've recreat

Eyal 142 May 16, 2022
Code for our paper "Interactive Analysis of CNN Robustness"

Perturber Code for our paper "Interactive Analysis of CNN Robustness" Datasets Feature visualizations: Google Drive Fine-tuning checkpoints as saved m

Stefan Sietzen 0 Aug 17, 2021
Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021)

T2Net Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021) [Paper][Code] Dependencies numpy==1.18.5 scikit_image==

64 Nov 23, 2022
Codes to calculate solar-sensor zenith and azimuth angles directly from hyperspectral images collected by UAV. Works only for UAVs that have high resolution GNSS/IMU unit.

UAV Solar-Sensor Angle Calculation Table of Contents About The Project Built With Getting Started Prerequisites Installation Datasets Contributing Lic

Sourav Bhadra 1 Jan 15, 2022
Vikrant Deshpande 1 Nov 17, 2022
A 10000+ hours dataset for Chinese speech recognition

WenetSpeech Official website | Paper A 10000+ Hours Multi-domain Chinese Corpus for Speech Recognition Download Please visit the official website, rea

310 Jan 03, 2023
PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.

PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. Warning: the master branch might collapse. To ob

559 Dec 14, 2022
How to use TensorLayer

How to use TensorLayer While research in Deep Learning continues to improve the world, we use a bunch of tricks to implement algorithms with TensorLay

zhangrui 349 Dec 07, 2022
[CVPR 2022 Oral] TubeDETR: Spatio-Temporal Video Grounding with Transformers

TubeDETR: Spatio-Temporal Video Grounding with Transformers Website • STVG Demo • Paper This repository provides the code for our paper. This includes

Antoine Yang 108 Dec 27, 2022
Code for the paper "Zero-shot Natural Language Video Localization" (ICCV2021, Oral).

Zero-shot Natural Language Video Localization (ZSNLVL) by Pseudo-Supervised Video Localization (PSVL) This repository is for Zero-shot Natural Languag

Computer Vision Lab. @ GIST 37 Dec 27, 2022
code for our ECCV 2020 paper "A Balanced and Uncertainty-aware Approach for Partial Domain Adaptation"

Code for our ECCV (2020) paper A Balanced and Uncertainty-aware Approach for Partial Domain Adaptation. Prerequisites: python == 3.6.8 pytorch ==1.1.0

32 Nov 27, 2022
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation.

Unified-EPT Code for the ICCV 2021 Workshop paper: A Unified Efficient Pyramid Transformer for Semantic Segmentation. Installation Linux, CUDA=10.0,

29 Aug 23, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 01, 2023
PHOTONAI is a high level python API for designing and optimizing machine learning pipelines.

PHOTONAI is a high level python API for designing and optimizing machine learning pipelines. We've created a system in which you can easily select and

Medical Machine Learning Lab - University of Münster 57 Nov 12, 2022
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

Salesforce 334 Jan 06, 2023
Python codes for Lite Audio-Visual Speech Enhancement.

Lite Audio-Visual Speech Enhancement (Interspeech 2020) Introduction This is the PyTorch implementation of Lite Audio-Visual Speech Enhancement (LAVSE

Shang-Yi Chuang 85 Dec 01, 2022
[CIKM 2021] Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. This repo contains the PyTorch code and implementation for the paper E

Akuchi 18 Dec 22, 2022