PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Overview

Soft DTW Loss Function for PyTorch in CUDA

This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch supported computation, CUDA-friendly, and feasible to use as a final loss. I can confirm that you can train a (sequential) model with this as a final loss! The following image shows training logs of a TTS model using the Soft-DTW Loss Function.

There are some previous implementations:

  1. mblondel's soft-dtw
  2. lyprince's sdtw_pytorch
  3. Maghoumi's pytorch-softdtw-cuda

But they are either not supported by CUDA-friendly batch computation or not considering the jacobean w.r.t input matrix, which is necessary to be used as a final loss in recent deep learning frameworks. In the current implementation, all conditions are satisfied.

Usage

Same as Maghoumi's pytorch-softdtw-cuda:

from sdtw_cuda_loss import SoftDTW

# Create the sequences
batch_size, len_x, len_y, dims = 8, 15, 12, 5
x = torch.rand((batch_size, len_x, dims), requires_grad=True)
y = torch.rand((batch_size, len_y, dims))

# Create the "criterion" object
sdtw = SoftDTW(use_cuda=True, gamma=0.1)

# Compute the loss value
loss = sdtw(x, y)  # Just like any torch.nn.xyzLoss()

# Aggregate and call backward()
loss.mean().backward()

But the backward will compute the gradient w.r.t input target sequence x (which is not considered in the previous work).

Note

In the current implementation, only use_cuda=True is supported. But you can easily implement the CPU version as in Maghoumi's pytorch-softdtw-cuda.

Citation

@misc{lee2021soft_dtw_loss,
  author = {Lee, Keon},
  title = {Soft-DTW-Loss},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/keonlee9420/Soft-DTW-Loss}}
}
You might also like...
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.

CSE-Autoloss Designing proper loss functions for vision tasks has been a long-standing research direction to advance the capability of existing models

Multi-scale discriminator feature-wise loss function

Multi-Scale Discriminative Feature Loss This repository provides code for Multi-Scale Discriminative Feature (MDF) loss for image reconstruction algor

clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images Histological Image Segmentation This

Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

[CVPR 2022] Official code for the paper:
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation"

DSP Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation". Accepted by ACM Multimedia 2021. Authors

Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Decorators for maximizing memory utilization with PyTorch & CUDA

torch-max-mem This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and

Comments
  • Does this supports multi-gpu training?

    Does this supports multi-gpu training?

    Thanks for sharing impl of soft-dtw, I can use it in single-gpu env,but can't use it in multi-gpu envs.Currently, it doesn't support multi-gpu training?

    opened by mayfool 2
  • how to use dtw-loss to fit a curve?

    how to use dtw-loss to fit a curve?

    hello, I tried to fit a curve (discrete points) using Soft-DTW-Loss as a loss function. But the loss does not converge to the exact result in the end. Is there something wrong with the way I am using it? The code is as follows:

    if name == "main":

    batch_size = 1
    len_x = 15
    len_predict = 10
    dims = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    x = torch.unsqueeze(torch.linspace(1, 4, steps=len_x, requires_grad=True), dim=0)
    y = x ** 2
    y = y.view(1, len_x, 1)
    x = x.view(1, len_x, 1)
    
    #(batch,length,dims)---->(1,15,2)
    truth_points = torch.cat((y, x), dim=2).cuda()
    
    #(1,20)
    input = torch.unsqueeze(torch.linspace(1, 4, steps=len_predict*2, requires_grad=True), dim=0).cuda()
    
    
    class testNN(torch.nn.Module):
        def __init__(self):
            super(testNN, self).__init__()
            self.layer = nn.Sequential(
                nn.Linear(20, 50),
                nn.ReLU(),
                nn.Linear(50, 200),
                nn.ReLU(),
                nn.Linear(200, 50),
                nn.ReLU(),
                nn.Linear(50, 20),
                nn.ReLU(),
            )
        def forward(self, x):
            x = self.layer(x)
            return x
    
    
    test = testNN()
    test = test.to(device)
    
    loss_function = SoftDTW(use_cuda=True, gamma=0.01, normalize=False)
    optimizer = torch.optim.Adam(test.parameters(), lr=0.01)
    
    
    for epoch in range(1000):
    
    
        predict = test(input)
        #(1,20) reshape to (1,10,2)
        predict = predict.reshape(1, len_predict, 2)
        loss = loss_function(predict, truth_points)
        optimizer.zero_grad()
        loss.mean().backward(retain_graph=True)
        optimizer.step()
    
    
        if epoch % 10 == 0:
            print("epoch : %d | loss : %f" % (epoch, loss))
            plt_predict = predict.cpu().detach().numpy()
            # print(plt_predict)
            plt_predict = plt_predict.reshape(1, len_predict, 2)
            print(plt_predict[0, :, 0])
            print(plt_predict[0, :, 1])
    
    opened by visionlyx 0
Releases(v1.0.0)
Owner
Keon Lee
Expressive Speech Synthesis | Conversational AI | Open-domain Dialog | NLP | Generative Models | Empathic Computing | HCI
Keon Lee
render sprites into your desktop environment as shaped windows using GTK

spritegtk render static or animated sprites into your desktop environment as dynamic shaped windows using GTK requires pycairo and PYGobject: pip inst

hermit 20 Oct 27, 2022
Playable Video Generation

Playable Video Generation Playable Video Generation Willi Menapace, Stéphane Lathuilière, Sergey Tulyakov, Aliaksandr Siarohin, Elisa Ricci Paper: ArX

Willi Menapace 136 Dec 31, 2022
For storing the complete exploration of Visual Question Answering for our B.Tech Project

Multi-Image vqa @authors: Akhilesh, Janhavi, Harsh Paper summary, Ideas tried and their corresponding results: on wiki Other discussions: on discussio

Harsh Raj 3 Jun 16, 2022
🍅🍅🍅YOLOv5-Lite: lighter, faster and easier to deploy. Evolved from yolov5 and the size of model is only 1.7M (int8) and 3.3M (fp16). It can reach 10+ FPS on the Raspberry Pi 4B when the input size is 320×320~

YOLOv5-Lite:lighter, faster and easier to deploy Perform a series of ablation experiments on yolov5 to make it lighter (smaller Flops, lower memory, a

pogg 1.5k Jan 05, 2023
Official PyTorch implementation of the NeurIPS 2021 paper StyleGAN3

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation of the NeurIPS 2021 paper Alias-Free Generative Adversarial Net

Eugenio Herrera 92 Nov 18, 2022
Simple ray intersection library similar to coldet - succedeed by libacc

Ray Intersection This project offers a header only acceleration structure library including implementations for a BVH- and KD-Tree. Applications may i

Nils Moehrle 29 Jun 23, 2022
An educational resource to help anyone learn deep reinforcement learning.

Status: Maintenance (expect bug fixes and minor updates) Welcome to Spinning Up in Deep RL! This is an educational resource produced by OpenAI that ma

OpenAI 7.6k Jan 09, 2023
Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods”

Uncertainty Estimation Methods Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods” Reference If you use this code,

EPFL Machine Learning and Optimization Laboratory 4 Apr 05, 2022
Learning Saliency Propagation for Semi-supervised Instance Segmentation

Learning Saliency Propagation for Semi-supervised Instance Segmentation PyTorch Implementation This repository contains: the PyTorch implementation of

Berkeley DeepDrive 68 Oct 18, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
A PyTorch Toolbox for Face Recognition

FaceX-Zoo FaceX-Zoo is a PyTorch toolbox for face recognition. It provides a training module with various supervisory heads and backbones towards stat

JDAI-CV 1.6k Jan 06, 2023
Training vision models with full-batch gradient descent and regularization

Stochastic Training is Not Necessary for Generalization -- Training competitive vision models without stochasticity This repository implements trainin

Jonas Geiping 32 Jan 06, 2023
The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

Will Thompson 166 Jan 04, 2023
商品推荐系统

商品top50推荐系统 问题建模 本项目的数据集给出了15万左右的用户以及12万左右的商品, 以及对应的经过脱敏处理的用户特征和经过预处理的商品特征,旨在为用户推荐50个其可能购买的商品。 推荐系统架构方案 本项目采用传统的召回+排序的方案。

107 Dec 29, 2022
Differential Privacy for Heterogeneous Federated Learning : Utility & Privacy tradeoffs

Differential Privacy for Heterogeneous Federated Learning : Utility & Privacy tradeoffs In this work, we propose an algorithm DP-SCAFFOLD(-warm), whic

19 Nov 10, 2022
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
NeurIPS workshop paper 'Counter-Strike Deathmatch with Large-Scale Behavioural Cloning'

Counter-Strike Deathmatch with Large-Scale Behavioural Cloning Tim Pearce, Jun Zhu Offline RL workshop, NeurIPS 2021 Paper: https://arxiv.org/abs/2104

Tim Pearce 169 Dec 26, 2022
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 06, 2023
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023