A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

Overview

ssnt-loss

ℹ️ This is a WIP project. the implementation is still being tested.

A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction" https://arxiv.org/abs/1609.08194.

Usage

There are two versions, a normal version and a memory efficient version. They should give the same output, please inform me if they don't.

>> target_mask = targets.ne(pad) # (B, T) >>> targets = targets[target_mask] # (T_flat,) >>> log_probs = log_probs[target_mask] # (T_flat, S, V) Args: log_probs (Tensor): Word prediction log-probs, should be output of log_softmax. tensor with shape (T_flat, S, V) where T_flat is the summation of all target lengths, S is the maximum number of input frames and V is the vocabulary of labels. targets (Tensor): Tensor with shape (T_flat,) representing the reference target labels for all samples in the minibatch. log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid. tensor with shape (T_flat, S) where T_flat is the summation of all target lengths, S is the maximum number of input frames. source_lengths (Tensor): Tensor with shape (N,) representing the number of frames for each sample in the minibatch. target_lengths (Tensor): Tensor with shape (N,) representing the length of the transcription for each sample in the minibatch. neg_inf (float, optional): The constant representing -inf used for masking. Default: -1e4 reduction (string, optional): Specifies reduction. suppoerts mean / sum. Default: None. """">
def ssnt_loss_mem(
    log_probs: Tensor,
    targets: Tensor,
    log_p_choose: Tensor,
    source_lengths: Tensor,
    target_lengths: Tensor,
    neg_inf: float = -1e4,
    reduction="mean",
):
    """The memory efficient implementation concatenates along the targets
    dimension to reduce wasted computation on padding positions.

    Assuming the summation of all targets in the batch is T_flat, then
    the original B x T x ... tensor is reduced to T_flat x ...

    The input tensors can be obtained by using target mask:
    Example:
        >>> target_mask = targets.ne(pad)   # (B, T)
        >>> targets = targets[target_mask]  # (T_flat,)
        >>> log_probs = log_probs[target_mask]  # (T_flat, S, V)

    Args:
        log_probs (Tensor): Word prediction log-probs, should be output of log_softmax.
            tensor with shape (T_flat, S, V)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames and V is
            the vocabulary of labels.
        targets (Tensor): Tensor with shape (T_flat,) representing the
            reference target labels for all samples in the minibatch.
        log_p_choose (Tensor): emission log-probs, should be output of F.logsigmoid.
            tensor with shape (T_flat, S)
            where T_flat is the summation of all target lengths,
            S is the maximum number of input frames.
        source_lengths (Tensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        target_lengths (Tensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        neg_inf (float, optional): The constant representing -inf used for masking.
            Default: -1e4
        reduction (string, optional): Specifies reduction. suppoerts mean / sum.
            Default: None.
    """

Minimal example

import torch
import torch.nn as nn
import torch.nn.functional as F
from ssnt_loss import ssnt_loss_mem, lengths_to_padding_mask
B, S, H, T, V = 2, 100, 256, 10, 2000

# model
transcriber = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
predictor = nn.LSTM(input_size=H, hidden_size=H, num_layers=1).cuda()
joiner_trans = nn.Linear(H, V, bias=False).cuda()
joiner_alpha = nn.Sequential(
    nn.Linear(H, 1, bias=True),
    nn.Tanh()
).cuda()

# inputs
src_embed = torch.rand(B, S, H).cuda().requires_grad_()
tgt_embed = torch.rand(B, T, H).cuda().requires_grad_()
targets = torch.randint(0, V, (B, T)).cuda()
adjust = lambda x, goal: x * goal // x.max()
source_lengths = adjust(torch.randint(1, S+1, (B,)).cuda(), S)
target_lengths = adjust(torch.randint(1, T+1, (B,)).cuda(), T)

# forward
src_feats, (h1, c1) = transcriber(src_embed.transpose(1, 0))
tgt_feats, (h2, c2) = predictor(tgt_embed.transpose(1, 0))

# memory efficient joint
mask = ~lengths_to_padding_mask(target_lengths)
lattice = F.relu(
    src_feats.transpose(0, 1).unsqueeze(1) + tgt_feats.transpose(0, 1).unsqueeze(2)
)[mask]
log_alpha = F.logsigmoid(joiner_alpha(lattice)).squeeze(-1)
lattice = joiner_trans(lattice).log_softmax(-1)

# normal ssnt loss
loss = ssnt_loss_mem(
    lattice,
    targets[mask],
    log_alpha,
    source_lengths=source_lengths,
    target_lengths=target_lengths,
    reduction="sum"
) / (B*T)
loss.backward()
print(loss.item())

Note

This implementation is based on the simplifying derivation proposed for monotonic attention, where they use parallelized cumsum and cumprod to compute the alignment. Based on the similarity of SSNT and monotonic attention, we can infer that the forward variable alpha(i,j) can be computed similarly.

Feel free to contact me if there are bugs in the code.

Reference

Owner
張致強
張致強
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022
Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and Tracking of Object Poses in 3D Space"

Sparse Steerable Convolution (SS-Conv) Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and

25 Dec 21, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides

Predicting Axillary Lymph Node Metastasis in Early Breast Cancer Using Deep Learning on Primary Tumor Biopsy Slides Project | This repo is the officia

CVSM Group - email: <a href=[email protected]"> 33 Dec 28, 2022
ObsPy: A Python Toolbox for seismology/seismological observatories.

ObsPy is an open-source project dedicated to provide a Python framework for processing seismological data. It provides parsers for common file formats

ObsPy 979 Jan 07, 2023
A simple software for capturing human body movements using the Kinect camera.

KinectMotionCapture A simple software for capturing human body movements using the Kinect camera. The software can seamlessly save joints and bones po

Aleksander Palkowski 5 Aug 13, 2022
These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations"

Few-shot-NLEs These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations". You can find the smal

Yordan Yordanov 0 Oct 21, 2022
[ICCV 2021] Relaxed Transformer Decoders for Direct Action Proposal Generation

RTD-Net (ICCV 2021) This repo holds the codes of paper: "Relaxed Transformer Decoders for Direct Action Proposal Generation", accepted in ICCV 2021. N

Multimedia Computing Group, Nanjing University 80 Nov 30, 2022
An API-first distributed deployment system of deep learning models using timeseries data to analyze and predict systems behaviour

Gordo Building thousands of models with timeseries data to monitor systems. Table of content About Examples Install Uninstall Developer manual How to

Equinor 26 Dec 27, 2022
Semantic Segmentation Suite in TensorFlow

Semantic Segmentation Suite in TensorFlow. Implement, train, and test new Semantic Segmentation models easily!

George Seif 2.5k Jan 06, 2023
Contenido del curso Bases de datos del DCC PUC versión 2021-2

IIC2413 - Bases de Datos Tabla de contenidos Equipo Profesores Ayudantes Contenidos Calendario Evaluaciones Resumen de notas Foro Política de integrid

54 Nov 23, 2022
[제 13회 투빅스 컨퍼런스] OK Mugle! - 장르부터 멜로디까지, Content-based Music Recommendation

Ok Mugle! 🎵 장르부터 멜로디까지, Content-based Music Recommendation 'Ok Mugle!'은 제13회 투빅스 컨퍼런스(2022.01.15)에서 진행한 음악 추천 프로젝트입니다. Description 📖 본 프로젝트에서는 Kakao

SeongBeomLEE 5 Oct 09, 2022
Fortuitous Forgetting in Connectionist Networks

Fortuitous Forgetting in Connectionist Networks Introduction This repository includes reference code for the paper Fortuitous Forgetting in Connection

Hattie Zhou 14 Nov 26, 2022
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
A map update dataset and benchmark

MUNO21 MUNO21 is a dataset and benchmark for machine learning methods that automatically update and maintain digital street map datasets. Previous dat

16 Nov 30, 2022
Official PyTorch repo for JoJoGAN: One Shot Face Stylization

JoJoGAN: One Shot Face Stylization This is the PyTorch implementation of JoJoGAN: One Shot Face Stylization. Abstract: While there have been recent ad

1.3k Dec 29, 2022
Anonymous implementation of KSL

k-Step Latent (KSL) Implementation of k-Step Latent (KSL) in PyTorch. Representation Learning for Data-Efficient Reinforcement Learning [Paper] Code i

1 Nov 10, 2021
SigOpt wrappers for scikit-learn methods

SigOpt + scikit-learn Interfacing This package implements useful interfaces and wrappers for using SigOpt and scikit-learn together Getting Started In

SigOpt 73 Sep 30, 2022
基于YoloX目标检测+DeepSort算法实现多目标追踪Baseline

项目简介: 使用YOLOX+Deepsort实现车辆行人追踪和计数,代码封装成一个Detector类,更容易嵌入到自己的项目中。 代码地址(欢迎star): https://github.com/Sharpiless/yolox-deepsort/ 最终效果: 运行demo: python demo

114 Dec 30, 2022
Open Source Differentiable Computer Vision Library for PyTorch

Kornia is a differentiable computer vision library for PyTorch. It consists of a set of routines and differentiable modules to solve generic computer

kornia 7.6k Jan 04, 2023