Implementation of Bottleneck Transformer in Pytorch

Overview

Bottleneck Transformer - Pytorch

PyPI version

Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms EfficientNet and DeiT in terms of performance-computes trade-off, in Pytorch

Install

$ pip install bottleneck-transformer-pytorch

Usage

import torch
from torch import nn
from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,              # channels in
    fmap_size = 64,         # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = False,    # use relative positional embedding - uses absolute if False
    activation = nn.ReLU()  # activation throughout the network
)

fmap = torch.randn(2, 256, 64, 64) # feature map from previous resnet block(s)

layer(fmap) # (2, 2048, 32, 32)

BotNet

With some simple model surgery off a resnet, you can have the 'BotNet' (what a weird name) for training.

import torch
from torch import nn
from torchvision.models import resnet50

from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,
    fmap_size = 56,        # set specifically for imagenet's 224 x 224
    dim_out = 2048,
    proj_factor = 4,
    downsample = True,
    heads = 4,
    dim_head = 128,
    rel_pos_emb = True,
    activation = nn.ReLU()
)

resnet = resnet50()

# model surgery

backbone = list(resnet.children())

model = nn.Sequential(
    *backbone[:5],
    layer,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(1),
    nn.Linear(2048, 1000)
)

# use the 'BotNet'

img = torch.randn(2, 3, 224, 224)
preds = model(img) # (2, 1000)

Citations

@misc{srinivas2021bottleneck,
    title   = {Bottleneck Transformers for Visual Recognition}, 
    author  = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
    year    = {2021},
    eprint  = {2101.11605},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • How should I modify the code if the input feature map has unequal width and height?

    How should I modify the code if the input feature map has unequal width and height?

    Assume that the width and height of the feature map are 10 and 8, respectively. Could you please help me to check that the modification about the class RelPosEmb is correct?

    class RelPosEmb(nn.Module): def init( self, fmap_size, dim_head ): super().init() scale = dim_head ** -0.5 self.fmap_size = fmap_size self.scale = scale # self.rel_height = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_height = nn.Parameter(torch.randn(8 * 2 - 1, dim_head) * scale) # self.rel_width = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(10* 2 - 1, dim_head) * scale)

    def forward(self, q):
        q = rearrange(q, 'b h (x y) d -> b h x y d', x = 8)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
    
        q = rearrange(q, 'b h x y d -> b h y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
        return rel_logits_w + rel_logits_h
    
    opened by ShuweiShao 4
  • Feature map size

    Feature map size

    hi In my case, the input images size are all different, so the feature map size keeps changing. In this case, how should the fmap_size parameter of BottleStack be set? Is it possible to learn with an unfixed feature map size?

    opened by benlee73 3
  • A little bug.

    A little bug.

    https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/b789de6db39f33854862fbc9bcee27c697cf003c/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L16

    It is necessary to specify the equipment here.

    flat_pad = torch.zeros((b, h, l - 1), device = device, dtype = dtype) 
    
    opened by lartpang 1
  • fix inplace operations

    fix inplace operations

    Latest versions of PyTorch throw runtime errors for inplace operations like *= and += on tensors that require gradients. This pull request fixes the issue by replacing them with binary versions.

    opened by AminRezaei0x443 0
  • could you explain the implements of ralative position embedding?

    could you explain the implements of ralative position embedding?

    reference https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/layers/common_attention.py

    def _generate_relative_positions_matrix(length_q, length_k,
                                            max_relative_position,
                                            cache=False):
      """Generates matrix of relative positions between inputs."""
      if not cache:
        if length_q == length_k:
          range_vec_q = range_vec_k = tf.range(length_q)
        else:
          range_vec_k = tf.range(length_k)
          range_vec_q = range_vec_k[-length_q:]
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
      else:
        distance_mat = tf.expand_dims(tf.range(-length_k+1, 1, 1), 0)
      distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
                                              max_relative_position)
      # Shift values to be >= 0. Each integer still uniquely identifies a relative
      # position difference.
      final_mat = distance_mat_clipped + max_relative_position
      return final_mat
    
    
    def _generate_relative_positions_embeddings(length_q, length_k, depth,
                                                max_relative_position, name,
                                                cache=False):
      """Generates tensor of size [1 if cache else length_q, length_k, depth]."""
      with tf.variable_scope(name):
        relative_positions_matrix = _generate_relative_positions_matrix(
            length_q, length_k, max_relative_position, cache=cache)
        vocab_size = max_relative_position * 2 + 1
        # Generates embedding for each relative position of dimension depth.
        embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
        embeddings = tf.gather(embeddings_table, relative_positions_matrix)
        return embeddings
    
    
    def _relative_attention_inner(x, y, z, transpose):
      """Relative position-aware dot-product attention inner calculation.
      This batches matrix multiply calculations to avoid unnecessary broadcasting.
      Args:
        x: Tensor with shape [batch_size, heads, length or 1, length or depth].
        y: Tensor with shape [batch_size, heads, length or 1, depth].
        z: Tensor with shape [length or 1, length, depth].
        transpose: Whether to transpose inner matrices of y and z. Should be true if
            last dimension of x is depth, not length.
      Returns:
        A Tensor with shape [batch_size, heads, length, length or depth].
      """
      batch_size = tf.shape(x)[0]
      heads = x.get_shape().as_list()[1]
      length = tf.shape(x)[2]
    
      # xy_matmul is [batch_size, heads, length or 1, length or depth]
      xy_matmul = tf.matmul(x, y, transpose_b=transpose)
      # x_t is [length or 1, batch_size, heads, length or depth]
      x_t = tf.transpose(x, [2, 0, 1, 3])
      # x_t_r is [length or 1, batch_size * heads, length or depth]
      x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
      # x_tz_matmul is [length or 1, batch_size * heads, length or depth]
      x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
      # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
      x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
      # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
      x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
      return xy_matmul + x_tz_matmul_r_t
    
    
    def dot_product_attention_relative(q,
                                       k,
                                       v,
                                       bias,
                                       max_relative_position,
                                       dropout_rate=0.0,
                                       image_shapes=None,
                                       save_weights_to=None,
                                       name=None,
                                       make_image_summary=True,
                                       cache=False,
                                       allow_memory=False,
                                       hard_attention_k=0,
                                       gumbel_noise_weight=0.0):
      """Calculate relative position-aware dot-product self-attention.
      The attention calculation is augmented with learned representations for the
      relative position between each element in q and each element in k and v.
      Args:
        q: a Tensor with shape [batch, heads, length, depth].
        k: a Tensor with shape [batch, heads, length, depth].
        v: a Tensor with shape [batch, heads, length, depth].
        bias: bias Tensor.
        max_relative_position: an integer specifying the maximum distance between
            inputs that unique position embeddings should be learned for.
        dropout_rate: a floating point number.
        image_shapes: optional tuple of integer scalars.
        save_weights_to: an optional dictionary to capture attention weights
          for visualization; the weights tensor will be appended there under
          a string key created from the variable scope (including name).
        name: an optional string.
        make_image_summary: Whether to make an attention image summary.
        cache: whether use cache mode
        allow_memory: whether to assume that recurrent memory is in use. If True,
          the length dimension of k/v/bias may be longer than the queries, and it is
          assumed that the extra memory entries precede the non-memory entries.
        hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
        gumbel_noise_weight: if > 0, apply Gumbel noise with weight
          `gumbel_noise_weight` before picking top-k. This is a no op if
          hard_attention_k <= 0.
      Returns:
        A Tensor.
      Raises:
        ValueError: if max_relative_position is not > 0.
      """
      if not max_relative_position:
        raise ValueError("Max relative position (%s) should be > 0 when using "
                         "relative self attention." % (max_relative_position))
      with tf.variable_scope(
          name, default_name="dot_product_attention_relative",
          values=[q, k, v]) as scope:
    
        # This calculation only works for self attention.
        # q, k and v must therefore have the same shape, unless memory is enabled.
        if not cache and not allow_memory:
          q.get_shape().assert_is_compatible_with(k.get_shape())
          q.get_shape().assert_is_compatible_with(v.get_shape())
    
        # Use separate embeddings suitable for keys and values.
        depth = k.get_shape().as_list()[3]
        length_k = common_layers.shape_list(k)[2]
        length_q = common_layers.shape_list(q)[2] if allow_memory else length_k
        relations_keys = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_keys", cache=cache)
        relations_values = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_values", cache=cache)
    
        # Compute self attention considering the relative position embeddings.
        logits = _relative_attention_inner(q, k, relations_keys, True)
        if bias is not None:
          logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if hard_attention_k > 0:
          weights = harden_attention_weights(weights, hard_attention_k,
                                             gumbel_noise_weight)
        if save_weights_to is not None:
          save_weights_to[scope.name] = weights
          save_weights_to[scope.name + "/logits"] = logits
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
            common_layers.should_generate_summaries() and
            make_image_summary):
          attention_image_summary(weights, image_shapes)
        return _relative_attention_inner(weights, v, relations_values, False)
    

    which is coresponding of the formula clip(x; k) = max(-k; min(k; x))

    but in youre code ,there is a randn with grad,i don't understand ,could you make a explanations?

    opened by AncientRemember 0
  • Is it possible to modify these codes to support 3D images as well?

    Is it possible to modify these codes to support 3D images as well?

    Thank you for your great work!

    I was wondering if it is possible to modify these codes to support 3D images as well (i.e. adding z-axis). image

    I can't imagine how to change the dimensions of vectors in the "content-position" part. E.g. Hx1xd and 1xWxd -> Hx1x1xd and 1x1xZxd and 1xWx1xd ?

    Thank you for your answer!

    opened by kyuchoi 0
  • Hello, in the training of the following mistakes, how to solve it

    Hello, in the training of the following mistakes, how to solve it

    einops.EinopsError: Error while processing rearrange-reduction pattern "b h (x y) d -> b h x y d". Input tensor shape: torch.Size([2, 4, 900, 128]). Additional info: {'x': 26, 'y': 26}. Shape mismatch, 900 != 676

    opened by glt999 1
  • the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    Hello, I want to ask a question, the input feature map is 228 * 304, but here is an error, the size of tenor a (9) must match the size of tenor B (10) at a non singleton dimension 3.

    opened by shezhi 1
  • the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    opened by AncientRemember 2
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Official repository for: Continuous Control With Ensemble DeepDeterministic Policy Gradients

Continuous Control With Ensemble Deep Deterministic Policy Gradients This repository is the official implementation of Continuous Control With Ensembl

4 Dec 06, 2021
Rainbow: Combining Improvements in Deep Reinforcement Learning

Rainbow Rainbow: Combining Improvements in Deep Reinforcement Learning [1]. Results and pretrained models can be found in the releases. DQN [2] Double

Kai Arulkumaran 1.4k Dec 29, 2022
Neighborhood Reconstructing Autoencoders

Neighborhood Reconstructing Autoencoders The official repository for Neighborhood Reconstructing Autoencoders (Lee, Kwon, and Park, NeurIPS 2021). T

Yonghyeon Lee 24 Dec 14, 2022
The official github repository for Towards Continual Knowledge Learning of Language Models

Towards Continual Knowledge Learning of Language Models This is the official github repository for Towards Continual Knowledge Learning of Language Mo

Joel Jang | 장요엘 65 Jan 07, 2023
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Elegy Elegy is a framework-agnostic Trainer interface for the Jax ecosystem. Main Features Easy-to-use: Elegy provides a Keras-like high-level API tha

435 Dec 30, 2022
StyleGAN2 with adaptive discriminator augmentation (ADA) - Official TensorFlow implementation

StyleGAN2 with adaptive discriminator augmentation (ADA) — Official TensorFlow implementation Training Generative Adversarial Networks with Limited Da

NVIDIA Research Projects 1.7k Dec 29, 2022
MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

MonoRec: Semi-Supervised Dense Reconstruction in Dynamic Environments from a Single Moving Camera

Felix Wimbauer 494 Jan 06, 2023
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization

sam.pytorch A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization ( Foret+2020) Paper, Official implementa

Ryuichiro Hataya 102 Dec 28, 2022
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
Code for Fold2Seq paper from ICML 2021

[ICML2021] Fold2Seq: A Joint Sequence(1D)-Fold(3D) Embedding-based Generative Model for Protein Design Environment file: environment.yml Data and Feat

International Business Machines 43 Dec 04, 2022
An introduction to bioimage analysis - http://bioimagebook.github.io

Introduction to Bioimage Analysis This book tries explain the main ideas of image analysis in a practical and engaging way. It's written primarily for

Bioimage Book 20 Nov 28, 2022
FreeSOLO for unsupervised instance segmentation, CVPR 2022

FreeSOLO: Learning to Segment Objects without Annotations This project hosts the code for implementing the FreeSOLO algorithm for unsupervised instanc

NVIDIA Research Projects 253 Jan 02, 2023
🏎️ Accelerate training and inference of 🤗 Transformers with easy to use hardware optimization tools

Hugging Face Optimum 🤗 Optimum is an extension of 🤗 Transformers, providing a set of performance optimization tools enabling maximum efficiency to t

Hugging Face 842 Dec 30, 2022
KwaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%)

KuaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%) KuaiRec is a real-world dataset collected from the recommendation log

Chongming GAO (高崇铭) 70 Dec 28, 2022
Hyperbolic Hierarchical Clustering.

Hyperbolic Hierarchical Clustering (HypHC) This code is the official PyTorch implementation of the NeurIPS 2020 paper: From Trees to Continuous Embedd

HazyResearch 154 Dec 15, 2022
Download files from DSpace systems (because for some reason DSpace won't let you)

DSpaceDL A tool for downloading files from DSpace items. For some reason, DSpace systems have a dogshit UI, and Universities absolutely LOOOVE to use

Soumitra Shewale 5 Dec 01, 2022
Detecting and Tracking Small and Dense Moving Objects in Satellite Videos: A Benchmark

This dataset is a large-scale dataset for moving object detection and tracking in satellite videos, which consists of 40 satellite videos captured by Jilin-1 satellite platforms.

Qingyong 87 Dec 22, 2022
Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021) An efficient PyTorch library for Point Cloud Completion.

Microsoft 119 Jan 02, 2023
Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022)

Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022) Please cite "Independent SE(3)-Equivar

Octavian Ganea 154 Jan 02, 2023
Google Landmark Recogntion and Retrieval 2021 Solutions

Google Landmark Recogntion and Retrieval 2021 Solutions In this repository you can find solution and code for Google Landmark Recognition 2021 and Goo

Vadim Timakin 5 Nov 25, 2022