Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Overview

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256
)

x = torch.randint(0, 20000, (1, 256))
emb = model(x) # (1, 256, 512)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Custom image sizes?

    Custom image sizes?

    Hi, Thanks for your great (and very fast) contribution! I was wondering if you could help me figure out how to apply this to a different image size? It's not really an image, but rather a 2D dimensional tensor of 4096X100.

    I saw that I can change the number of channels, so I could just set channels to be 1. But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

    Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

    Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

    Thanks.

    opened by danarte 3
  • Parameter count doesnt line up with paper

    Parameter count doesnt line up with paper

    Just a note (and correct me if I misunderstood the paper) -

    The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult. Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

    Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

    Then param count is roughly 5.5 M params.

    opened by titu1994 2
  • Add Support for Stochastic Depth

    Add Support for Stochastic Depth

    This PR adds support for stochastic depth, which is used in the paper for the vision experiments. I went ahead an added it to gMLP as well for completeness.

    I tried my best to match your style. Let me know if there are any problems, or if you want me to refactor anything.

    opened by mlw214 2
  • Don't you think this is more legible?

    Don't you think this is more legible?

    ` class SpatialGatingUnit(nn.Module): def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3): super().init() dim_out = dim // 2 self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
    
        self.dim_seq = dim_seq
        self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
        self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####
    
        self.act = act
    
        init_eps /= dim_seq
        #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        #nn.init.constant_(self.proj.bias, 1.)
    
    def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
        device, n = x.device, x.shape[1]
    
        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)
    
        weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len
    
        if self.causal:
            weight.unsqueeze(-1) # TODO
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            weight.squeeze(-1)# TODO
    
        gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b
    
        #gate = F.conv1d(gate, weight, bias)   # WZ + b
    
        if exists(gate_res):
            gate = gate + gate_res
    
        return self.act(gate) * res
    

    `

    opened by ZIZUN 0
  • Potentially missing the high way pass

    Potentially missing the high way pass

    Hello,

    Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

    Thank you

    opened by Vincent-Li-9701 1
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

18 Sep 02, 2022
Iran Open Source Hackathon

Iran Open Source Hackathon is an open-source hackathon (duh) with the aim of encouraging participation in open-source contribution amongst Iranian dev

OSS Hackathon 121 Dec 25, 2022
LabelImg is a graphical image annotation tool.

LabelImgPlus LabelImg is a graphical image annotation tool. This project is not updated with new functions now. More functions are supported with Labe

lzx1413 200 Dec 20, 2022
Python implementation of Bayesian optimization over permutation spaces.

Bayesian Optimization over Permutation Spaces This repository contains the source code and the resources related to the paper "Bayesian Optimization o

Aryan Deshwal 9 Dec 23, 2022
A very short and easy implementation of Quantile Regression DQN

Quantile Regression DQN Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression (https://arx

Arsenii Senya Ashukha 80 Sep 17, 2022
PyTorch implementation of the Pose Residual Network (PRN)

Pose Residual Network This repository contains a PyTorch implementation of the Pose Residual Network (PRN) presented in our ECCV 2018 paper: Muhammed

Salih Karagoz 289 Nov 28, 2022
This folder contains the python code of UR5E's advanced forward kinematics model.

This folder contains the python code of UR5E's advanced forward kinematics model. By entering the angle of the joint of UR5e, the detailed coordinates of up to 48 points around the robot arm can be c

Qiang Wang 4 Sep 17, 2022
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
This is a official repository of SimViT.

SimViT This is a official repository of SimViT. We will open our models and codes about object detection and semantic segmentation soon. Our code refe

ligang 57 Dec 15, 2022
Codes for CVPR2021 paper "PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization"

PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization (CVPR 2021) This is the official implementation of PW

Intelligent Robotics and Machine Vision Lab 42 Dec 18, 2022
Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System

Multi-Task Pre-Training for Plug-and-Play Task-Oriented Dialogue System Authors: Yixuan Su, Lei Shu, Elman Mansimov, Arshit Gupta, Deng Cai, Yi-An Lai

Amazon Web Services - Labs 123 Dec 23, 2022
Official git for "CTAB-GAN: Effective Table Data Synthesizing"

CTAB-GAN This is the official git paper CTAB-GAN: Effective Table Data Synthesizing. The paper is published on Asian Conference on Machine Learning (A

30 Dec 26, 2022
NAVER BoostCamp Final Project

CV 14조 final project Super Resolution and Deblur module Inference code & Pretrained weight Repo SwinIR Deblur 실행 방법 streamlit run WebServer/Server_SRD

JiSeong Kim 5 Sep 06, 2022
Taichi Course Homework Template

太极图形课S1-标题部分 这个作业未来或将是你的开源项目,标题的内容可以来自作业中的核心关键词,让读者一眼看出你所完成的工作/做出的好玩demo 如果暂时未想好,起名时可以参考“太极图形课S1-xxx作业” 如下是作业(项目)展开说明的方法,可以帮大家理清思路,并且也对读者非常友好,请小伙伴们多多参

TaichiCourse 30 Nov 19, 2022
This code is part of the reproducibility package for the SANER 2022 paper "Generating Clarifying Questions for Query Refinement in Source Code Search".

Clarifying Questions for Query Refinement in Source Code Search This code is part of the reproducibility package for the SANER 2022 paper "Generating

Zachary Eberhart 0 Dec 04, 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Jiaqi Gu 2 Jan 04, 2022
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
Group Activity Recognition with Clustered Spatial Temporal Transformer

GroupFormer Group Activity Recognition with Clustered Spatial-TemporalTransformer Backbone Style Action Acc Activity Acc Config Download Inv3+flow+pos

28 Dec 12, 2022
Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement (NeurIPS 2020)

MTTS-CAN: Multi-Task Temporal Shift Attention Networks for On-Device Contactless Vitals Measurement Paper Xin Liu, Josh Fromm, Shwetak Patel, Daniel M

Xin Liu 106 Dec 30, 2022
AOT (Associating Objects with Transformers) in PyTorch

An efficient modular implementation of Associating Objects with Transformers for Video Object Segmentation in PyTorch

162 Dec 14, 2022