A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

Overview

The GatedTabTransformer.

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. Check out our paper on arXiv.

Architecture

Usage

import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformer

model = GatedTabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    transformer_dim = 32,               # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    transformer_depth = 6,              # depth, paper recommended 6
    transformer_heads = 8,              # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_act = nn.LeakyReLU(0),          # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
    mlp_depth=4,                        # mlp hidden layers depth
    mlp_dimension=32,                   # dimension of mlp layers
    gmlp_enabled=True                   # gmlp or standard mlp
)

x_categ = torch.randint(0, 5, (1, 5))   # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)             # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
print(pred)

Citation

@misc{cholakov2022gatedtabtransformer,
      title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling}, 
      author={Radostin Cholakov and Todor Kolev},
      year={2022},
      eprint={2201.00199},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@software{GatedTabTransformer,
  author = {{Radostin Cholakov, Todor Kolev}},
  title = {The GatedTabTransformer.},
  url = {https://github.com/radi-cho/GatedTabTransformer},
  version = {0.0.1},
  date = {2021-12-15},
}
Owner
Radi Cho
10th grader, coding for more than 6 years now ;) Looking forward the next big thing!
Radi Cho
A production-ready, scalable Indexer for the Jina neural search framework, based on HNSW and PSQL

🌟 HNSW + PostgreSQL Indexer HNSWPostgreSQLIndexer Jina is a production-ready, scalable Indexer for the Jina neural search framework. It combines the

Jina AI 25 Oct 14, 2022
ML course - EPFL Machine Learning Course, Fall 2021

EPFL Machine Learning Course CS-433 Machine Learning Course, Fall 2021 Repository for all lecture notes, labs and projects - resources, code templates

EPFL Machine Learning and Optimization Laboratory 1k Jan 04, 2023
GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. (CVPR 2021)

GDR-Net This repo provides the PyTorch implementation of the work: Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided

169 Jan 07, 2023
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
Gesture-Volume-Control - This Python program can adjust the system's volume by using hand gestures

Gesture-Volume-Control This Python program can adjust the system's volume by usi

VatsalAryanBhatanagar 1 Dec 30, 2021
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

770 Jan 02, 2023
Code for CVPR 2018 paper --- Texture Mapping for 3D Reconstruction with RGB-D Sensor

G2LTex This repository contains the implementation of "Texture Mapping for 3D Reconstruction with RGB-D Sensor (CVPR2018)" based on mvs-texturing. Due

Fu Yanping(付燕平) 129 Dec 30, 2022
MaRS - a recursive filtering framework that allows for truly modular multi-sensor integration

The Modular and Robust State-Estimation Framework, or short, MaRS, is a recursive filtering framework that allows for truly modular multi-sensor integration

Control of Networked Systems - University of Klagenfurt 143 Dec 29, 2022
Predict halo masses from simulations via graph neural networks

HaloGraphNet Predict halo masses from simulations via Graph Neural Networks. Given a dark matter halo and its galaxies, creates a graph with informati

Pablo Villanueva Domingo 20 Nov 15, 2022
Stacked Recurrent Hourglass Network for Stereo Matching

SRH-Net: Stacked Recurrent Hourglass Introduction This repository is supplementary material of our RA-L submission, which helps reviewers to understan

28 Jan 03, 2023
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022
PyTorch implementation of DD3D: Is Pseudo-Lidar needed for Monocular 3D Object detection?

PyTorch implementation of DD3D: Is Pseudo-Lidar needed for Monocular 3D Object detection? (ICCV 2021), Dennis Park*, Rares Ambrus*, Vitor Guizilini, Jie Li, and Adrien Gaidon.

Toyota Research Institute - Machine Learning 364 Dec 27, 2022
PURE: End-to-End Relation Extraction

PURE: End-to-End Relation Extraction This repository contains (PyTorch) code and pre-trained models for PURE (the Princeton University Relation Extrac

Princeton Natural Language Processing 657 Jan 09, 2023
MohammadReza Sharifi 27 Dec 13, 2022
Torch-mutable-modules - Use in-place and assignment operations on PyTorch module parameters with support for autograd

Torch Mutable Modules Use in-place and assignment operations on PyTorch module p

Kento Nishi 7 Jun 06, 2022
A Partition Filter Network for Joint Entity and Relation Extraction EMNLP 2021

EMNLP 2021 - A Partition Filter Network for Joint Entity and Relation Extraction

zhy 127 Jan 04, 2023
Header-only library for using Keras models in C++.

frugally-deep Use Keras models in C++ with ease Table of contents Introduction Usage Performance Requirements and Installation FAQ Introduction Would

Tobias Hermann 927 Jan 05, 2023
Hepsiburada - Hepsiburada Urun Bilgisi Cekme

Hepsiburada Urun Bilgisi Cekme from hepsiburada import Marka nike = Marka("nike"

Ilker Manap 8 Oct 26, 2022
This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

This is an unofficial implementation of the paper “Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection”.

haifeng xia 32 Oct 26, 2022
Discriminative Condition-Aware PLDA

DCA-PLDA This repository implements the Discriminative Condition-Aware Backend described in the paper: L. Ferrer, M. McLaren, and N. Brümmer, "A Speak

Luciana Ferrer 31 Aug 05, 2022