Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Overview

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • What are your thoughts on using latents for additional classification task

    What are your thoughts on using latents for additional classification task

    Hi! I was wondering if you have thought about aggregating seasonal and growth latents for additional tasks (for example classification)? What are the possible ways to bring latents into single feature vector in your opinion? The easiest one would be just get the mean along layers and time dimensions but that seams to be too naive. Another idea I had it to use Cross Attention mechanic with single time query key to aggregate latents:

    all_latents = torch.cat([latent_growths, latent_seasonals], dim=-1)
    all_latents = rearrange(all_latents, 'b n l d -> (b l) n d')
    # q = nn.Parameter(torch.randn(all_latents_dim))
    q = repeat(q, 'd -> b 1 d', b = all_latents.shape[0])
    agg_latent = cross_attention(query=q, context=all_latents)
    agg_latent = rearrange(all_latents, '(b l) n d -> b (l n) d')
    agg_latent = agg_latent.mean(dim=1) # may be we should have done it before cross attention?
    

    Would be great to hear your thoughts

    opened by inspirit 15
  • Pre LayerNorm might be required for k,v?

    Pre LayerNorm might be required for k,v?

    https://github.com/lucidrains/ETSformer-pytorch/blob/2561053007e919409b3255eb1d0852c68799d24f/etsformer_pytorch/etsformer_pytorch.py#L440

    In my early tests I see some instability in training results, I was wondering if it might be good idea to LayerNorm latents before constructing key and values?

    opened by inspirit 5
  • growth_term calculation error

    growth_term calculation error

    https://github.com/lucidrains/ETSformer-pytorch/blob/e1d8514b44d113ead523aa6307986833e68eecc5/etsformer_pytorch/etsformer_pytorch.py#L233-L235

    It looks like you are not using growth and growth_smoothing_weightsto calculate growth_term

    opened by inspirit 4
  • Backward gradient error

    Backward gradient error

    Hello,

    i was trying to run the provided class and see following error: Function ScatterBackward0 returned an invalid gradient at index 1 - got [64, 4, 128] but expected shape compatible with [64, 33, 128]

    model = ETSFormer(
                time_features = 9,
                model_dim = 128,
                embed_kernel_size = 3,
                layers = 2,
                heads = 4,
                K = 4,
                dropout = 0.2
            )
    

    input = torch.rand(64, 64, 9) x = model(input, num_steps_forecast = 16)

    opened by inspirit 3
  • Does ETS-Former allow adding features

    Does ETS-Former allow adding features

    @lucidrains Thanks for making the code of the model available!

    In your paper, you state that the model infers seasonal patterns itself, so that there is no need to add time features like week, month, etc.

    Still, to increase the applicability of your approach, does the current implementation allow to add any (time-invariant and time-varying) features, e.g., categorical or numeric?

    opened by StatMixedML 2
  • wrong order of arguments

    wrong order of arguments

    https://github.com/lucidrains/ETSformer-pytorch/blob/2e0d465576c15fc8d84c4673f93fdd71d45b799c/etsformer_pytorch/etsformer_pytorch.py#L327

    you pass latents on wrong order to Level module: according to forward method first should be growth and then seasonal

    opened by inspirit 1
  • Clarification regarding data pre-processing

    Clarification regarding data pre-processing

    Hello,

    I was trying to run the ETSformer for ETT dataset. The paper mentions that the dataset is split as 60/20/20 for train, validation and test. Could you give some insight as to how the dataset split is happening in the code.

    Thank you.

    opened by vageeshmaiya 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This application explain how we can easily integrate Deepface framework with Python Django application

deepface_suite This application explain how we can easily integrate Deepface framework with Python Django application install redis cache install requ

Mohamed Naji Aboo 3 Apr 18, 2022
Joint-task Self-supervised Learning for Temporal Correspondence (NeurIPS 2019)

Joint-task Self-supervised Learning for Temporal Correspondence Project | Paper Overview Joint-task Self-supervised Learning for Temporal Corresponden

Sifei Liu 167 Dec 14, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Lightwood is Legos for Machine Learning.

Lightwood is like Legos for Machine Learning. A Pytorch based framework that breaks down machine learning problems into smaller blocks that can be glu

MindsDB Inc 312 Jan 08, 2023
Code for the paper "Reinforced Active Learning for Image Segmentation"

Reinforced Active Learning for Image Segmentation (RALIS) Code for the paper Reinforced Active Learning for Image Segmentation Dependencies python 3.6

Arantxa Casanova 79 Dec 19, 2022
This is the pytorch implementation of the paper - Axiomatic Attribution for Deep Networks.

Integrated Gradients This is the pytorch implementation of "Axiomatic Attribution for Deep Networks". The original tensorflow version could be found h

Tianhong Dai 150 Dec 23, 2022
TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured Scenarios

TPH-YOLOv5 This repo is the implementation of "TPH-YOLOv5: Improved YOLOv5 Based on Transformer Prediction Head for Object Detection on Drone-Captured

cv516Buaa 439 Dec 22, 2022
Chinese Advertisement Board Identification(Pytorch)

Chinese-Advertisement-Board-Identification. We use YoloV5 to extract the ROI of the location of the chinese word. Next, we sort the bounding box and recognize every chinese words which we extracted.

Li-Wei Hsiao 12 Jul 21, 2022
An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.

Bottom-Up and Top-Down Attention for Visual Question Answering An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge. The

Hengyuan Hu 731 Jan 03, 2023
The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

Yuki M. Asano 249 Dec 22, 2022
Optical machine for senses sensing using speckle and deep learning

# Senses-speckle [Remote Photonic Detection of Human Senses Using Secondary Speckle Patterns](https://doi.org/10.21203/rs.3.rs-724587/v1) paper Python

Zeev Kalyuzhner 0 Sep 26, 2021
RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation

RL-GAN: Transfer Learning for Related Reinforcement Learning Tasks via Image-to-Image Translation RL-GAN is an official implementation of the paper: T

42 Nov 10, 2022
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML)

pytorch-maml This is a PyTorch implementation of the supervised learning experiments from the paper Model-Agnostic Meta-Learning (MAML): https://arxiv

Kate Rakelly 516 Jan 05, 2023
A Java implementation of the experiments for the paper "k-Center Clustering with Outliers in Sliding Windows"

OutliersSlidingWindows A Java implementation of the experiments for the paper "k-Center Clustering with Outliers in Sliding Windows" Dataset generatio

PaoloPellizzoni 0 Jan 05, 2022
The original weights of some Caffe models, ported to PyTorch.

pytorch-caffe-models This repo contains the original weights of some Caffe models, ported to PyTorch. Currently there are: GoogLeNet (Going Deeper wit

Katherine Crowson 9 Nov 04, 2022
Implementation of RegretNet with Pytorch

Dependencies are Python 3, a recent PyTorch, numpy/scipy, tqdm, future and tensorboard. Plotting with Matplotlib. Implementation of the neural network

Horris zhGu 1 Nov 05, 2021
Graph Self-Supervised Learning for Optoelectronic Properties of Organic Semiconductors

SSL_OSC Graph Self-Supervised Learning for Optoelectronic Properties of Organic Semiconductors

zaixizhang 2 May 14, 2022
VM3000 Microphones

VM3000-Microphones This project was completed by Ricky Leman under the supervision of Dr Ben Travaglione and Professor Melinda Hodkiewicz as part of t

UWA System Health Lab 0 Jun 04, 2021