Implementation of TabTransformer, attention network for tabular data, in Pytorch

Overview

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Install

$ pip install tab-transformer-pytorch

Usage

import torch
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Citations

@misc{huang2020tabtransformer,
    title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 
    author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year={2020},
    eprint={2012.06678},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Minor Bug: actuation function being applied to output layer in class MLP

    Minor Bug: actuation function being applied to output layer in class MLP

    The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 1)
    

    The last line should be changed to is_last = ind >= (len(dims) - 2):

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 2)
    

    If you like, I can do a pull request.

    opened by rminhas 1
  • Update tab_transformer_pytorch.py

    Update tab_transformer_pytorch.py

    Add activation function out of the loop for the whole model, not after each of the linear layers. 'if is_last' condition was creating linear output all the time no matter what the activation function was.

    opened by EveryoneDirn 0
  • Unindent continuous_mean_std buffer

    Unindent continuous_mean_std buffer

    Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly. Example reproducing AttributeError:

    model = TabTransformer(
        categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
        num_continuous = 10,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    # continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
    x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
    x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
    pred = model(x_categ, x_cont) # gives AttributeError
    
    

    Solution: Simply un-indenting the buffer registration of continuous_mean_std.

    opened by spliew 0
  • low gpu usage,

    low gpu usage,

    Hi.

    I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

    My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

    These are the software versions:

    Windows 10

    Python: 3.7.11 Pytorch: 1.7.0+cu110 Numpy: 1.21.2

    Let me know if you need more info from my side.

    Thanks.

    Xin.

    opened by xinqiao123 0
  • Intended usage of num_special_tokens?

    Intended usage of num_special_tokens?

    From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

    opened by LLYX 2
  • No Category Shared Embedding?

    No Category Shared Embedding?

    I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

    Thanks for this implementation!

    opened by LLYX 3
  • index -1 is out of bounds for dimension 1 with size 17

    index -1 is out of bounds for dimension 1 with size 17

    I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
        return self.tabnet(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
        steps_output, M_loss = self.encoder(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
        M = self.att_transformers[step](prior, att)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
        x = self.selector(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
        return sparsemax(input, self.dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
        tau = input_cumsum.gather(dim, support_size - 1)
    RuntimeError: index -1 is out of bounds for dimension 1 with size 17
    Experiment has terminated.
    
    opened by hengzhe-zhang 2
  • Is there any training example about tabtransformer?

    Is there any training example about tabtransformer?

    Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

    opened by pancodex 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
[ICLR'19] Trellis Networks for Sequence Modeling

TrellisNet for Sequence Modeling This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico

CMU Locus Lab 460 Oct 13, 2022
PyTorch implementation for View-Guided Point Cloud Completion

PyTorch implementation for View-Guided Point Cloud Completion

22 Jan 04, 2023
The implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021

DynamicNeuralGarments Introduction This repository contains the implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021. ./GarmentMoti

42 Dec 27, 2022
Synthetic Humans for Action Recognition, IJCV 2021

SURREACT: Synthetic Humans for Action Recognition from Unseen Viewpoints Gül Varol, Ivan Laptev and Cordelia Schmid, Andrew Zisserman, Synthetic Human

Gul Varol 59 Dec 14, 2022
ML-based medical imaging using Azure

Disclaimer This code is provided for research and development use only. This code is not intended for use in clinical decision-making or for any other

Microsoft Azure 68 Dec 23, 2022
Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation

Info This is the code repository of the work Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation from Elias T

2 Apr 20, 2022
Supplementary code for the experiments described in the 2021 ISMIR submission: Leveraging Hierarchical Structures for Few Shot Musical Instrument Recognition.

Music Trees Supplementary code for the experiments described in the 2021 ISMIR submission: Leveraging Hierarchical Structures for Few Shot Musical Ins

Hugo Flores García 32 Nov 22, 2022
Gluon CV Toolkit

Gluon CV Toolkit | Installation | Documentation | Tutorials | GluonCV provides implementations of the state-of-the-art (SOTA) deep learning models in

Distributed (Deep) Machine Learning Community 5.4k Jan 06, 2023
Baseline and template code for node21 detection track

Nodule Detection Algorithm This codebase implements a baseline model, Faster R-CNN, for the nodule detection track in NODE21. It contains all necessar

node21challenge 11 Jan 15, 2022
Bayesian dessert for Lasagne

Gelato Bayesian dessert for Lasagne Recent results in Bayesian statistics for constructing robust neural networks have proved that it is one of the be

Maxim Kochurov 84 May 11, 2020
R interface to fast.ai

R interface to fastai The fastai package provides R wrappers to fastai. The fastai library simplifies training fast and accurate neural nets using mod

113 Dec 20, 2022
SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

SEOVER-Master This code is the implementation of paper: SEOVER: Sentence-level Emotion Orientation Vector based Conversation Emotion Recognition Model

4 Feb 24, 2022
Grounding Representation Similarity with Statistical Testing

Grounding Representation Similarity with Statistical Testing This repo contains code to replicate the results in our paper, which evaluates representa

26 Dec 02, 2022
Some toy examples of score matching algorithms written in PyTorch

toy_gradlogp This repo implements some toy examples of the following score matching algorithms in PyTorch: ssm-vr: sliced score matching with variance

Ending Hsiao 21 Dec 26, 2022
Machine Learning University: Accelerated Computer Vision Class

Machine Learning University: Accelerated Computer Vision Class This repository contains slides, notebooks, and datasets for the Machine Learning Unive

AWS Samples 1.3k Dec 28, 2022
This project helps to colorize grayscale images using multiple exemplars.

Multiple Exemplar-based Deep Colorization (Pytorch Implementation) Pretrained Model [Jitendra Chautharia](IIT Jodhpur)1,3, Prerequisites Python 3.6+ N

jitendra chautharia 3 Aug 05, 2022
Hierarchical Metadata-Aware Document Categorization under Weak Supervision (WSDM'21)

Hierarchical Metadata-Aware Document Categorization under Weak Supervision This project provides a weakly supervised framework for hierarchical metada

Yu Zhang 53 Sep 17, 2022
Anagram Generator in Python

Anagrams Generator This is a program for computing multiword anagrams. It makes no effort to come up with sentences that make sense; it only finds ana

Day Fundora 5 Nov 17, 2022
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
PyTorch code for training MM-DistillNet for multimodal knowledge distillation

There is More than Meets the Eye: Self-Supervised Multi-Object Detection and Tracking with Sound by Distilling Multimodal Knowledge MM-DistillNet is a

51 Dec 20, 2022