Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Overview

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • What are your thoughts on using latents for additional classification task

    What are your thoughts on using latents for additional classification task

    Hi! I was wondering if you have thought about aggregating seasonal and growth latents for additional tasks (for example classification)? What are the possible ways to bring latents into single feature vector in your opinion? The easiest one would be just get the mean along layers and time dimensions but that seams to be too naive. Another idea I had it to use Cross Attention mechanic with single time query key to aggregate latents:

    all_latents = torch.cat([latent_growths, latent_seasonals], dim=-1)
    all_latents = rearrange(all_latents, 'b n l d -> (b l) n d')
    # q = nn.Parameter(torch.randn(all_latents_dim))
    q = repeat(q, 'd -> b 1 d', b = all_latents.shape[0])
    agg_latent = cross_attention(query=q, context=all_latents)
    agg_latent = rearrange(all_latents, '(b l) n d -> b (l n) d')
    agg_latent = agg_latent.mean(dim=1) # may be we should have done it before cross attention?
    

    Would be great to hear your thoughts

    opened by inspirit 15
  • Pre LayerNorm might be required for k,v?

    Pre LayerNorm might be required for k,v?

    https://github.com/lucidrains/ETSformer-pytorch/blob/2561053007e919409b3255eb1d0852c68799d24f/etsformer_pytorch/etsformer_pytorch.py#L440

    In my early tests I see some instability in training results, I was wondering if it might be good idea to LayerNorm latents before constructing key and values?

    opened by inspirit 5
  • growth_term calculation error

    growth_term calculation error

    https://github.com/lucidrains/ETSformer-pytorch/blob/e1d8514b44d113ead523aa6307986833e68eecc5/etsformer_pytorch/etsformer_pytorch.py#L233-L235

    It looks like you are not using growth and growth_smoothing_weightsto calculate growth_term

    opened by inspirit 4
  • Backward gradient error

    Backward gradient error

    Hello,

    i was trying to run the provided class and see following error: Function ScatterBackward0 returned an invalid gradient at index 1 - got [64, 4, 128] but expected shape compatible with [64, 33, 128]

    model = ETSFormer(
                time_features = 9,
                model_dim = 128,
                embed_kernel_size = 3,
                layers = 2,
                heads = 4,
                K = 4,
                dropout = 0.2
            )
    

    input = torch.rand(64, 64, 9) x = model(input, num_steps_forecast = 16)

    opened by inspirit 3
  • Does ETS-Former allow adding features

    Does ETS-Former allow adding features

    @lucidrains Thanks for making the code of the model available!

    In your paper, you state that the model infers seasonal patterns itself, so that there is no need to add time features like week, month, etc.

    Still, to increase the applicability of your approach, does the current implementation allow to add any (time-invariant and time-varying) features, e.g., categorical or numeric?

    opened by StatMixedML 2
  • wrong order of arguments

    wrong order of arguments

    https://github.com/lucidrains/ETSformer-pytorch/blob/2e0d465576c15fc8d84c4673f93fdd71d45b799c/etsformer_pytorch/etsformer_pytorch.py#L327

    you pass latents on wrong order to Level module: according to forward method first should be growth and then seasonal

    opened by inspirit 1
  • Clarification regarding data pre-processing

    Clarification regarding data pre-processing

    Hello,

    I was trying to run the ETSformer for ETT dataset. The paper mentions that the dataset is split as 60/20/20 for train, validation and test. Could you give some insight as to how the dataset split is happening in the code.

    Thank you.

    opened by vageeshmaiya 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Hyperopt for solving CIFAR-100 with a convolutional neural network (CNN) built with Keras and TensorFlow, GPU backend

Hyperopt for solving CIFAR-100 with a convolutional neural network (CNN) built with Keras and TensorFlow, GPU backend This project acts as both a tuto

Guillaume Chevalier 103 Jul 22, 2022
Rl-quickstart - Reinforcement Learning Quickstart

Reinforcement Learning Quickstart To get setup with the repository, git clone ht

UCLA DataRes 3 Jun 16, 2022
Implementation of ReSeg using PyTorch

Implementation of ReSeg using PyTorch ReSeg: A Recurrent Neural Network-based Model for Semantic Segmentation Pascal-Part Annotations Pascal VOC 2010

Onur Kaplan 46 Nov 23, 2022
Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Gated-Attention Architectures for Task-Oriented Language Grounding This is a PyTorch implementation of the AAAI-18 paper: Gated-Attention Architecture

Devendra Chaplot 234 Nov 05, 2022
Official implementation of "Watermarking Images in Self-Supervised Latent-Spaces"

🔍 Watermarking Images in Self-Supervised Latent-Spaces PyTorch implementation and pretrained models for the paper. For details, see Watermarking Imag

Meta Research 32 Dec 13, 2022
Awesome AI Learning with +100 AI Cheat-Sheets, Free online Books, Top Courses, Best Videos and Lectures, Papers, Tutorials, +99 Researchers, Premium Websites, +121 Datasets, Conferences, Frameworks, Tools

All about AI with Cheat-Sheets(+100 Cheat-sheets), Free Online Books, Courses, Videos and Lectures, Papers, Tutorials, Researchers, Websites, Datasets

Niraj Lunavat 1.2k Jan 01, 2023
Towards Interpretable Deep Metric Learning with Structural Matching

DIML Created by Wenliang Zhao*, Yongming Rao*, Ziyi Wang, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for paper Towards Interpr

Wenliang Zhao 75 Nov 11, 2022
Data manipulation and transformation for audio signal processing, powered by PyTorch

torchaudio: an audio library for PyTorch The aim of torchaudio is to apply PyTorch to the audio domain. By supporting PyTorch, torchaudio follows the

1.9k Dec 28, 2022
This repository contains the reference implementation for our proposed Convolutional CRFs.

ConvCRF This repository contains the reference implementation for our proposed Convolutional CRFs in PyTorch (Tensorflow planned). The two main entry-

Marvin Teichmann 553 Dec 07, 2022
A python comtrade load library accelerated by go

Comtrade-GRPC Code for python used is mainly from dparrini/python-comtrade. Just patch the code in BinaryDatReader.parse for parsing a little more eff

Bo 1 Dec 27, 2021
Neuron Merging: Compensating for Pruned Neurons (NeurIPS 2020)

Neuron Merging: Compensating for Pruned Neurons Pytorch implementation of Neuron Merging: Compensating for Pruned Neurons, accepted at 34th Conference

Woojeong Kim 33 Dec 30, 2022
A sample pytorch Implementation of ACL 2021 research paper "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction".

Span-ASTE-Pytorch This repository is a pytorch version that implements Ali's ACL 2021 research paper Learning Span-Level Interactions for Aspect Senti

来自丹麦的天籁 10 Dec 06, 2022
Breaking the Curse of Space Explosion: Towards Efficient NAS with Curriculum Search

Breaking the Curse of Space Explosion: Towards Effcient NAS with Curriculum Search Pytorch implementation for "Breaking the Curse of Space Explosion:

guoyong 17 Jan 03, 2023
git git《Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking》(CVPR 2021) GitHub:git2] 《Masksembles for Uncertainty Estimation》(CVPR 2021) GitHub:git3]

Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking Ning Wang, Wengang Zhou, Jie Wang, and Houqiang Li Accepted by CVPR

NingWang 236 Dec 22, 2022
Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation

TVT Code of TVT: Transferable Vision Transformer for Unsupervised Domain Adaptation Datasets: Digit: MNIST, SVHN, USPS Object: Office, Office-Home, Vi

37 Dec 15, 2022
Creating Artificial Life with Reinforcement Learning

Although Evolutionary Algorithms have shown to result in interesting behavior, they focus on learning across generations whereas behavior could also be learned during ones lifetime.

Maarten Grootendorst 49 Dec 21, 2022
Tensor-Based Quantum Machine Learning

TensorLy_Quantum TensorLy-Quantum is a Python library for Tensor-Based Quantum Machine Learning that builds on top of TensorLy and PyTorch. Website: h

TensorLy 85 Dec 03, 2022
Towards Understanding Quality Challenges of the Federated Learning: A First Look from the Lens of Robustness

FL Analysis This repository contains the code and results for the paper "Towards Understanding Quality Challenges of the Federated Learning: A First L

3 Oct 17, 2022
Multi-Target Adversarial Frameworks for Domain Adaptation in Semantic Segmentation

Multi-Target Adversarial Frameworks for Domain Adaptation in Semantic Segmentation Paper Multi-Target Adversarial Frameworks for Domain Adaptation in

Valeo.ai 20 Jun 21, 2022
This is the official PyTorch implementation of the CVPR 2020 paper "TransMoMo: Invariance-Driven Unsupervised Video Motion Retargeting".

TransMoMo: Invariance-Driven Unsupervised Video Motion Retargeting Project Page | YouTube | Paper This is the official PyTorch implementation of the C

Zhuoqian Yang 330 Dec 11, 2022