Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning

Overview

H-Transformer-1D

Implementation of H-Transformer-1D, Transformer using hierarchical Attention for sequence learning with subquadratic costs.

For now, the H-Transformer will only act as a long-context encoder

Install

$ pip install h-transformer-1d

Usage

import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
    num_tokens = 256,          # number of tokens
    dim = 512,                 # dimension
    depth = 2,                 # depth
    max_seq_len = 8192,        # maximum sequence length
    heads = 8,                 # heads
    dim_head = 64,             # dimension per head
    block_size = 128           # block size
)

x = torch.randint(0, 256, (1, 8000))   # variable sequence length
mask = torch.ones((1, 8000)).bool()    # variable mask length

# network will automatically pad to power of 2, do hierarchical attention, etc

logits = model(x, mask = mask) # (1, 8000, 256)

Citations

@misc{zhu2021htransformer1d,
    title   = {H-Transformer-1D: Fast One-Dimensional Hierarchical Attention for Sequences}, 
    author  = {Zhenhai Zhu and Radu Soricut},
    year    = {2021},
    eprint  = {2107.11906},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Masking not working in training, thanks

    Masking not working in training, thanks

    Hi, I have tried to train the model on GPU with masking enabled. The line 94 t = torch.flip(t, dims = (2,)) reports an error: RuntimeError: "flip_cuda" not implemented for 'Bool', even though I have tried to move mask to CPU.

    Any ideas to solve the problem? Thanks a lot.

    opened by junyongyou 6
  • Application to sequence classification?

    Application to sequence classification?

    Hi,

    Forgive the naive question, I am trying to make sense of this paper but it's tough going. If I understand correctly, this attention mechanism focuses mainly on nearby tokens and only attends to distant tokens via a hierarchical, low-rank approximation. In that case, can the usual sequence classification approach of having a global [CLS] token that can attend to all other tokens (and vice versa) still work? If not, how can this attention mechanism handle the text classification tasks in the long range arena benchmark?

    Cheers for whatever insights you can share, and thanks for the great work!

    opened by trpstra 4
  • eos token does not work in batch mode generation

    eos token does not work in batch mode generation

    When generating the sequence with current code it seems the eos_token will work when generating one sequence at a time https://github.com/lucidrains/h-transformer-1d/blob/main/h_transformer_1d/autoregressive_wrapper.py#L59

    opened by tmphex 4
  • RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    RuntimeError: Tensor type unknown to einops

    lib/python3.7/site-packages/einops/_backends.py", line 52, in get_backend raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

    RuntimeError: Tensor type unknown to einops <class 'torch.return_types.max'>

    I understand why this gets raised. Could it be a pytorch version problem? Mine is 1.6.0

    opened by wajihullahbaig 4
  • Algorithm Mismatch

    Algorithm Mismatch

    Paper Implementation

    In the implementation, we get blocked Q, K, V tensors by level with the code below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L164-L179

    And return the final result of matrix-matrix product with Equation 29 or 69 with the for loop below.

    https://github.com/lucidrains/h-transformer-1d/blob/110cab0038898560d72d460bfef8ca8b7f17f0a5/h_transformer_1d/h_transformer_1d.py#L234-L247

    What is problem?

    However, according to the current code, it is not possible to include information about the level 0 white blocks in the figure below. (Equation 70 of the paper includes the corresponding attention matrix entries.)

    fig2

    I think you should also add an off-diagonal term of near-interaction (level 0) to match Equation 69!

    opened by jinmang2 3
  • I have some questions about implementation details

    I have some questions about implementation details

    Thanks for making your implementation public. I have three questions about your h-transformer 1d implementation.

    1. The number of levels M

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L105-L114

    In papers, eq (32) gives a guide on how M is determined.

    img1

    In your code implementations,

    • n is sequence length which is not padded
    • bsz is block size (is same to N_r which is numerical rank of the off-diagonal blocks)
    • Because code line 111 already contains level 0, M is equal to int(log2(n // bsz)) - 1

    However, in the Section 6.1 Constructing Hierarchical Attention, I found that sequence length(L) must be a multiple of 2. In my opinion, eq (31)'s L is equal to 2^{M+1}. In implementation, n is not padded sequence. So one of M is missing.

    Since x is the sequence padded by processing below, https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L83-L91

    I think the above implementation should be modified as below

    107    num_levels = int(log2(x.size(1) // bsz)) - 1 
    

    2. Super- and Sub-diagonal blocks of the coarsened matrix \tilde{A} as level-l

    https://github.com/lucidrains/h-transformer-1d/blob/63063d5bb036b56a7205aadc5c8198da02d698f6/h_transformer_1d/h_transformer_1d.py#L190-L198

    Ys conatins y and A computed as calculate_Y_and_A. For examples,

    # Ys
    [
        (batch_size*n_heads, N_b(2), N_r), (batch_size*n_heads, N_b(2)),  # level 2, (Y(2), tilde_A(2))
        (batch_size*n_heads, N_b(1), N_r), (batch_size*n_heads, N_b(1)),  # level 1, (Y(1), tilde_A(1))
        (batch_size*n_heads, N_b(0), N_r), (batch_size*n_heads, N_b(0)),  # level 0, (Y(0), tilde_A(0))
    ]
    

    In eq (29), Y is calculated as Y = AV = Y(0) + P(0)( Y(1) + P(1)Y(2) ) However, in code line 190, Y is calculated using only level-0 and level-1 blocks, no matter how many M there are. Y = AV = Y(0) + P(0)Y(1)

    Does increasing the level cause performance degradation issues in implementation? I'm so curious!


    3. Comparison with Luna: Linear Unified Nested Attention

    h-transformer significantly exceeded the scores of BigBird and Luna in LRA. However, what I regretted while reading the paper was that there was no comparison of computation time with other sub-quadratic and Luna. Is this algorithm much faster than other sub-quadratic? And how about compared to Luna?


    Thanks again for the implementation release! The idea of ​​calculating off-diagonal with flip was amazing and I learned a lot. Thank you!! 😄

    opened by jinmang2 3
  • Add Norm Missing

    Add Norm Missing

    I am using code now, and i wonder is there implemented add norm? I only find layer norm, but no add operation. Here is code in h-transformer-1d.py line 489 ... Is this a bug or something ? Thanks @Lucidrains

    for ind in range(depth): attn = attn_class(dim, dim_head = dim_head, heads = heads, block_size = block_size, pos_emb = self.pos_emb, **attn_kwargs) ff = FeedForward(dim, mult = ff_mult)

            if shift_tokens:
                attn, ff = map(lambda t: PreShiftTokens(shift_token_ranges, t), (attn, ff))
    
            attn, ff = map(lambda t: PreNorm(dim, t), (attn, ff))
            layers.append(nn.ModuleList([attn ,ff]))_
    
    opened by wwx13 2
  • Mask not working

    Mask not working

    def forward(self, x, mask = None):
        b, n, device = *x.shape, x.device
        assert n <= self.max_seq_len, 'sequence length must be less than the maximum sequence length'
        x = self.token_emb(x)
        x = self.layers(x)
        return self.to_logits(x)
    

    I think... Masking does not work ???

    opened by wwx13 2
  • One simple question

    One simple question

    Hi, Phil!

    One simple question, (my math is not good) https://github.com/lucidrains/h-transformer-1d/blob/7c11d036d53926495ec0917a34a1aad7261892b5/train.py#L65

    why not be randint(0, self.data.size(0)-self.seq_len+1)? Since the high part should be excluded

    opened by CiaoHe 2
  • Mini-batching (b > 1) does not work with masking

    Mini-batching (b > 1) does not work with masking

    When using x and mask that have batch size larger than 1 following error is arises:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    batch_size = 2
    x = torch.randint(0, 256, (batch_size, 8000))   # variable sequence length
    mask = torch.ones((batch_size, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives following error:

    ~/git/h-transformer-1d/h_transformer_1d/h_transformer_1d.py in masked_aggregate(tensor, mask, dim, average)
         19     diff_len = len(tensor.shape) - len(mask.shape)
         20     mask = mask[(..., *((None,) * diff_len))]
    ---> 21     tensor = tensor.masked_fill(~mask, 0.)
         22 
         23     total_el = mask.sum(dim = dim)
    
    RuntimeError: The size of tensor a (2) must match the size of tensor b (16) at non-singleton dimension 0
    

    It seems the tensor has shape heads * batch in 0 dimension and not batch what mask has.

    opened by jaak-s 2
  • Example in README does not work

    Example in README does not work

    Executing the example:

    import torch
    from h_transformer_1d import HTransformer1D
    
    model = HTransformer1D(
        num_tokens = 256,          # number of tokens
        dim = 512,                 # dimension
        depth = 2,                 # depth
        causal = False,            # autoregressive or not
        max_seq_len = 8192,        # maximum sequence length
        heads = 8,                 # heads
        dim_head = 64,             # dimension per head
        block_size = 128           # block size
    )
    
    x = torch.randint(0, 256, (1, 8000))   # variable sequence length
    mask = torch.ones((1, 8000)).bool()    # variable mask length
    
    # network will automatically pad to power of 2, do hierarchical attention, etc
    
    logits = model(x, mask = mask) # (1, 8000, 256)
    

    Gives the following error:

    ~/miniconda3/lib/python3.7/site-packages/rotary_embedding_torch/rotary_embedding_torch.py in apply_rotary_emb(freqs, t, start_index)
         43     assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
         44     t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    ---> 45     t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
         46     return torch.cat((t_left, t, t_right), dim = -1)
         47 
    
    RuntimeError: The size of tensor a (8192) must match the size of tensor b (8000) at non-singleton dimension 1
    
    opened by jaak-s 2
  • Fix indexing

    Fix indexing

    I am fixing a few apparent bugs in the code. The upshot is that the attention now supports a block size of the (next largest power of two) of the input length, and for this value of the block size it becomes exact. This allows one to look at the systematic error in the output as a function of decreased block size (and memory usage).

    I've found this module to reduce memory consumption by a factor of two, but the approximation quickly becomes too inaccurate with decreasing block size to use it as a drop-in replacement for an existing (full) attention layer.

    This repository shows how to compute the full attention with linear memory complexity: https://github.com/CHARM-Tx/linear_mem_attention_pytorch

    opened by jglaser 0
  • Approximated values are off

    Approximated values are off

    I wrote a simple test to check the output of the hierarchical transformer self attention against the BERT self attention from huggingface transformers.

    import torch
    import torch.nn as nn
    import math
    
    from h_transformer_1d.h_transformer_1d import HAttention1D
    
    def transpose_for_scores(x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def bert_self_attention(query, key, value_layer, attention_mask=None, num_attention_heads=1):
            dim_head = query.size()[-1] // num_attention_heads
            all_head_size = dim_head*num_attention_heads
    
            query_layer = transpose_for_scores(query, num_attention_heads, dim_head)
            key_layer = transpose_for_scores(key, num_attention_heads, dim_head)
            value_layer = transpose_for_scores(value, num_attention_heads, dim_head)
    
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    
            attention_scores = attention_scores / math.sqrt(dim_head)
    
            if attention_mask is not None:
                # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
                attention_scores = attention_scores + attention_mask
    
            # Normalize the attention scores to probabilities.
            attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    
            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            #attention_probs = self.dropout(attention_probs)
    
            context_layer = torch.matmul(attention_probs, value_layer)
    
            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
    
            return context_layer, attention_probs
    
    if __name__ == "__main__":
        query = torch.tensor([[[0.1,0.2],[-0.5,0.7],[-0.5,-0.75],[.123,.456]]])
    #    query = torch.tensor([[[0.1,0.2],[-0.5,0.7]]])
        key = value = query
    
        n_heads = 1
        attn, probs = bert_self_attention(query, key, value, num_attention_heads=n_heads)
        print('bert_self_attention out: ', attn)
    
        block_size = 1
        for _ in range(0,2):
            dim_head = query.size()[-1]//n_heads
            h_attn = HAttention1D(
                dim=query.size()[-1],
                heads=n_heads,
                dim_head=dim_head,
                block_size=block_size
            )
    
            h_attn.to_qkv = torch.nn.Identity()
            h_attn.to_out = torch.nn.Identity()
    
            qkv = torch.stack([query, key, value], dim=2)
            qkv = torch.flatten(qkv, start_dim=2)
    
            attn_scores = h_attn(qkv)
            print('hattention_1d: (block_size = {})'.format(block_size), attn_scores)
    
            block_size *= 2
    

    This is the output I get

    bert_self_attention:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2000,  0.4500],
             [-0.2000,  0.4500],
             [-0.1885, -0.1470],
             [-0.1885, -0.1470]]])
    

    before it errors out with

    assert num_levels >= 0, 'number of levels must be at least greater than 0'
    

    Some of the values are off in absolute magnitude by more than a factor of two.

    Looking at the code, this line seems problematic: https://github.com/lucidrains/h-transformer-1d/blob/8afd75cc6bc41754620bb6ab3737176cb69bdf93/h_transformer_1d/h_transformer_1d.py#L172

    I believe it should read

    num_levels = int(log2(pad_to_len // bsz)) - 1
    

    If I make that change, the approximated attention output is much closer to the exact one:

    bert_self_attention out:  tensor([[[-0.1807,  0.1959],
             [-0.2096,  0.2772],
             [-0.2656, -0.0568],
             [-0.1725,  0.2442]]])
    hattention_1d: (block_size = 1) tensor([[[-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020],
             [-0.2590,  0.2020]]])
    hattention_1d: (block_size = 2) tensor([[[-0.1808,  0.1972],
             [-0.1980,  0.2314],
             [-0.2438,  0.0910],
             [-0.1719,  0.2413]]])
    
    opened by jglaser 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
ColossalAI-Examples - Examples of training models with hybrid parallelism using ColossalAI

ColossalAI-Examples This repository contains examples of training models with Co

HPC-AI Tech 185 Jan 09, 2023
This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming"

Coresets via Bilevel Optimization This is the reference implementation for "Coresets via Bilevel Optimization for Continual Learning and Streaming" ht

Zalán Borsos 51 Dec 30, 2022
Weakly Supervised Segmentation by Tensorflow.

Weakly Supervised Segmentation by Tensorflow. Implements semantic segmentation in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

CHENG-YOU LU 52 Dec 27, 2022
Hub is a dataset format with a simple API for creating, storing, and collaborating on AI datasets of any size.

Hub is a dataset format with a simple API for creating, storing, and collaborating on AI datasets of any size. The hub data layout enables rapid transformations and streaming of data while training m

Activeloop 5.1k Jan 08, 2023
A simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

this is a simple rest api serving a deep learning model that classifies human gender based on their faces. (vgg16 transfare learning)

crispengari 5 Dec 09, 2021
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
Spectral Tensor Train Parameterization of Deep Learning Layers

Spectral Tensor Train Parameterization of Deep Learning Layers This repository is the official implementation of our AISTATS 2021 paper titled "Spectr

Anton Obukhov 12 Oct 23, 2022
Code repo for EMNLP21 paper "Zero-Shot Information Extraction as a Unified Text-to-Triple Translation"

Zero-Shot Information Extraction as a Unified Text-to-Triple Translation Source code repo for paper Zero-Shot Information Extraction as a Unified Text

cgraywang 88 Dec 31, 2022
YOLOPのPythonでのONNX推論サンプル

YOLOP-ONNX-Video-Inference-Sample YOLOPのPythonでのONNX推論サンプルです。 ONNXモデルは、hustvl/YOLOP/weights を使用しています。 Requirement OpenCV 3.4.2 or later onnxruntime 1.

KazuhitoTakahashi 8 Sep 05, 2022
Implementation of DocFormer: End-to-End Transformer for Document Understanding, a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU)

DocFormer - PyTorch Implementation of DocFormer: End-to-End Transformer for Document Understanding, a multi-modal transformer based architecture for t

171 Jan 06, 2023
A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow.

ConvNeXt A Next Generation ConvNet by FaceBookResearch Implementation in PyTorch(Original) and TensorFlow. A FacebookResearch Implementation on A Conv

Raghvender 2 Feb 14, 2022
TabNet for fastai

TabNet for fastai This is an adaptation of TabNet (Attention-based network for tabular data) for fastai (=2.0) library. The original paper https://ar

Mikhail Grankin 116 Oct 21, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
PyTorch implementation of Neural Dual Contouring.

NDC PyTorch implementation of Neural Dual Contouring. Citation We are still writing the paper while adding more improvements and applications. If you

Zhiqin Chen 140 Dec 26, 2022
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads-Tutorial-3 Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads Inc 2 Jan 03, 2022
Rafael Project- Classifying rockets to different types using data science algorithms.

Rocket-Classify Rafael Project- Classifying rockets to different types using data science algorithms. In this project we received data base with data

Hadassah Engel 5 Sep 18, 2021
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration (NeurIPS 2021) PyTorch implementation of the paper: CoFiNet: Reli

76 Jan 03, 2023
Sound Event Detection with FilterAugment

Sound Event Detection with FilterAugment Official implementation of Heavily Augmented Sound Event Detection utilizing Weak Predictions (DCASE2021 Chal

43 Aug 28, 2022