A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

Overview

The GatedTabTransformer.

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. Check out our paper on arXiv.

Architecture

Usage

import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformer

model = GatedTabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    transformer_dim = 32,               # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    transformer_depth = 6,              # depth, paper recommended 6
    transformer_heads = 8,              # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_act = nn.LeakyReLU(0),          # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
    mlp_depth=4,                        # mlp hidden layers depth
    mlp_dimension=32,                   # dimension of mlp layers
    gmlp_enabled=True                   # gmlp or standard mlp
)

x_categ = torch.randint(0, 5, (1, 5))   # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)             # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
print(pred)

Citation

@misc{cholakov2022gatedtabtransformer,
      title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling}, 
      author={Radostin Cholakov and Todor Kolev},
      year={2022},
      eprint={2201.00199},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@software{GatedTabTransformer,
  author = {{Radostin Cholakov, Todor Kolev}},
  title = {The GatedTabTransformer.},
  url = {https://github.com/radi-cho/GatedTabTransformer},
  version = {0.0.1},
  date = {2021-12-15},
}
Owner
Radi Cho
10th grader, coding for more than 6 years now ;) Looking forward the next big thing!
Radi Cho
This is the official implement of paper "ActionCLIP: A New Paradigm for Action Recognition"

This is an official pytorch implementation of ActionCLIP: A New Paradigm for Video Action Recognition [arXiv] Overview Content Prerequisites Data Prep

268 Jan 09, 2023
paper list in the area of reinforcenment learning for recommendation systems

paper list in the area of reinforcenment learning for recommendation systems

HenryZhao 23 Jun 09, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
ManipNet: Neural Manipulation Synthesis with a Hand-Object Spatial Representation - SIGGRAPH 2021

ManipNet: Neural Manipulation Synthesis with a Hand-Object Spatial Representation - SIGGRAPH 2021 Dataset Code Demos Authors: He Zhang, Yuting Ye, Tak

HE ZHANG 194 Dec 06, 2022
PyTorch implementation of Convolutional Neural Fabrics http://arxiv.org/abs/1606.02492

PyTorch implementation of Convolutional Neural Fabrics arxiv:1606.02492 There are some minor differences: The raw image is first convolved, to obtain

Anuvabh Dutt 25 Dec 22, 2021
Continual learning with sketched Jacobian approximations

Continual learning with sketched Jacobian approximations This repository contains the code for reproducing figures and results in the paper ``Provable

Machine Learning and Information Processing Laboratory 1 Jun 30, 2022
Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

Code for EmBERT, a transformer model for embodied, language-guided visual task completion.

41 Jan 03, 2023
Asterisk is a framework to generate high-quality training datasets at scale

Asterisk is a framework to generate high-quality training datasets at scale

Mona Nashaat 44 Apr 25, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

217 Jan 03, 2023
Official PyTorch implementation of UACANet: Uncertainty Aware Context Attention for Polyp Segmentation

UACANet: Uncertainty Aware Context Attention for Polyp Segmentation Official pytorch implementation of UACANet: Uncertainty Aware Context Attention fo

Taehun Kim 85 Dec 14, 2022
Neural Scene Flow Fields using pytorch-lightning, with potential improvements

nsff_pl Neural Scene Flow Fields using pytorch-lightning. This repo reimplements the NSFF idea, but modifies several operations based on observation o

AI葵 178 Dec 21, 2022
Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL)

Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL) This repository contains all source code used to generate the results in the article "

Charlotte Loh 3 Jul 23, 2022
Optimized Gillespie algorithm for simulating Stochastic sPAtial models of Cancer Evolution (OG-SPACE)

OG-SPACE Introduction Optimized Gillespie algorithm for simulating Stochastic sPAtial models of Cancer Evolution (OG-SPACE) is a computational framewo

Data and Computational Biology Group UNIMIB (was BI*oinformatics MI*lan B*icocca) 0 Nov 17, 2021
Official repository for Natural Image Matting via Guided Contextual Attention

GCA-Matting: Natural Image Matting via Guided Contextual Attention The source codes and models of Natural Image Matting via Guided Contextual Attentio

Li Yaoyi 349 Dec 26, 2022
Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.

CasRel-pytorch-reimplement Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The o

longlongman 170 Dec 01, 2022
CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network)

CasualHealthcare's Pneumonia detection with Artificial Intelligence (Convolutional Neural Network) This is PneumoniaDiagnose, an artificially intellig

Azhaan 2 Jan 03, 2022
PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data.

A ready-to-use framework of latest models for structured (tabular) data learning with PyTorch. Applications include recommendation, CRT prediction, healthcare analytics, and etc.

48 Nov 30, 2022
Liver segmentation using MONAI and pytorch

Machine Learning use case in the field of Healthcare. In this project MONAI and pytorch frameworks are used for 3D Liver segmentation.

Abhishek Gajbhiye 2 May 30, 2022
An open source library for face detection in images. The face detection speed can reach 1000FPS.

libfacedetection This is an open source library for CNN-based face detection in images. The CNN model has been converted to static variables in C sour

Shiqi Yu 11.4k Dec 27, 2022
CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images

Code and result about CCAFNet(IEEE TMM) 'CCAFNet: Crossflow and Cross-scale Adaptive Fusion Network for Detecting Salient Objects in RGB-D Images' IEE

zyrant丶 14 Dec 29, 2021