Skip to content

lucidrains/g-mlp-gpt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GPT - gMLP

This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.

You can also add the "tiny" attention (as described in the paper) with the attn_dim keyword argument, which corresponds to the dimension of the single head (64 is recommended). You can pass in a tuple to customize different dimension per layer.

Install

$ pip install g-mlp-gpt

Usage

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    depth = 4,
    seq_len = 1024,
    window = (128, 256, 512, 1024) # window sizes for each depth
)

x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)

16k context length

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    seq_len = 16384,
    reversible = True,    # reversible networks
    act = nn.Tanh(),      # tanh activation for spatial gating
    depth = 12,
    window = (
        128,
        128,
        256,
        512,
        1024,
        1024,
        (2048, 2),    # window size of 2048, axial of 2
        (2048, 2),
        (4096, 4),
        (4096, 4),
        (8192, 8),    # window size of 8192, axial of 8
        (8192, 8)
    )
).cuda()

x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}