Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Overview

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}
Comments
  • Contiguity problem:

    Contiguity problem: "RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input."

    It seems that LambaLayer breaks contiguity when I try it.

    layer(x).is_contiguous()
    >> False
    

    I have to use .contiguous() where I train with it, is it normal?

    opened by Whiax 5
  • Warning: Mixed memory format inputs detected while calling the operator.

    Warning: Mixed memory format inputs detected while calling the operator.

    I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

    Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

    opened by Pluto1314 4
  • Implementation of Lambda convolution

    Implementation of Lambda convolution

    Thanks for the great work in the implementations!

    I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

    Point out if I misunderstood.

    Thx a lot.

    opened by romulus0914 3
  • How lambda layer handle the downsample in LambdaResNet?

    How lambda layer handle the downsample in LambdaResNet?

    Hi, Thanks for your clear code, i try to implement the LambdaResNet. Does lambda layer replace all conv2d layer? If so, how does lambda layer handle the downsample in conv2d, like stride=2? Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

    opened by qiaoran-dawnlight 2
  • question about hybrid lambdaResnet

    question about hybrid lambdaResnet

    Hi,

    In the paper, there is this paragraph:

    When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50, 3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and 8 lambda layers for LambdaResNet-420.

    I have several questions about constructing the hybrid lambdaResnet:

    1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
    2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?
    opened by CoinCheung 2
  • Question: Is there an easy way to visualise lambdas?

    Question: Is there an easy way to visualise lambdas?

    I want to train classifier and tell what regions it pays the most.. well.. attention to :) And make this simultaneously with an inference without using gradcam etc Can I do this?

    opened by lebionick 1
  • Fix relative positional attention (position lambda)

    Fix relative positional attention (position lambda)

    Hi lucidrains, Thanks to your nice work.

    I found an error in the position lambda (relative positional attention) implementation. Relative positional attention, λp, should be translation equivariance, as written in the paper Sec. 3.2. It means that the positional embedding has a constraint, E[n, m] = E[t(n), t(m)], but it is missed in current implementation. This PR fixes it by adding the translation equivariance constraint. I checked that this PR improves the result in my experiment.

    NOTE that this PR modify the function parameter n, from total area (n=w*h) to length of each side (n=w=h).

    opened by khanrc 1
  • Use of Keras Lambda

    Use of Keras Lambda

    Hey! Thank for the awesome implementations :D

    I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

    https://github.com/lucidrains/lambda-networks/blob/06a48f2a5b41f3cd278aee67838c32051a0a9bed/lambda_networks/tfkeras.py#L73

    You can also call the functional version of the softmax instead.

    opened by cgarciae 1
  • Lambda for a sequence of images

    Lambda for a sequence of images

    Thanks for the quick implementation!

    I have a problem where I have a sequence of images rather than 1. (a video) So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

    Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

    In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

    opened by AmitMY 1
  • why flops so high

    why flops so high

    I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

    opened by zisuina 0
  • How to load_model correctly

    How to load_model correctly

    Hi everyone, I am struggling when loading this model saved in .h5 file.

    How is the correct way to load this network? If I use custom_objects I get init() got an unexpected keyword argument 'name'

    opened by JNaranjo-Alcazar 0
  • Please add clarity to code

    Please add clarity to code

    so Phil - I love your work - I wish you could go extra few steps to help out users. I found this class by François-Guillaume @frgfm - which adds in clear math coments. I want to merge it but there's a bit of code drift don't want to introduce any bugs. I beseech you to go extra step to help users bridge from papers to code.

    https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py

    eg. please articulate return types def forward(self, x: torch.Tensor) -> torch.Tensor:

    Please give any clarity in arguments. # Project input and context to get queries, keys & values

    Throw in some maths as a comment / this is great as it bridges the paper to the code.

    B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

    import torch
    from torch import nn, einsum
    import torch.nn.functional as F
    from typing import Optional
    
    __all__ = ['LambdaLayer']
    
    
    class LambdaLayer(nn.Module):
        """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
        <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
        <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
        Args:
            in_channels (int): input channels
            out_channels (int, optional): output channels
            dim_k (int): key dimension
            n (int, optional): number of input pixels
            r (int, optional): receptive field for relative positional encoding
            num_heads (int, optional): number of attention heads
            dim_u (int, optional): intra-depth dimension
        """
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dim_k: int,
            n: Optional[int] = None,
            r: Optional[int] = None,
            num_heads: int = 4,
            dim_u: int = 1
        ) -> None:
            super().__init__()
            self.u = dim_u
            self.num_heads = num_heads
    
            if out_channels % num_heads != 0:
                raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
            dim_v = out_channels // num_heads
    
            # Project input and context to get queries, keys & values
            self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
            self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
            self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)
    
            self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
    
            self.local_contexts = r is not None
            if r is not None:
                if r % 2 != 1:
                    raise AssertionError('Receptive kernel size should be odd')
                self.padding = r // 2
                self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
            else:
                if n is None:
                    raise AssertionError('You must specify the total sequence length (h x w)')
                self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            b, c, h, w = x.shape
    
            # Project inputs & context to retrieve queries, keys and values
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
    
            # Normalize queries & values
            q = self.norm_q(q)
            v = self.norm_v(v)
    
            # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
            q = q.reshape(b, self.num_heads, -1, h * w)
            # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
            k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
            v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
    
            # Normalized keys
            k = k.softmax(dim=-1)
    
            # Content function
            λc = einsum('b u k m, b u v m -> b k v', k, v)
            Yc = einsum('b h k n, b k v -> b n h v', q, λc)
    
            # Position function
            if self.local_contexts:
                # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
                v = v.reshape(b, self.u, v.shape[2], h, w)
                λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
                Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
            else:
                λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
                Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
    
            Y = Yc + Yp
            # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
            out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
            return out
    
    opened by johndpope 1
  • Image Size

    Image Size

    Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

    λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

    opened by anklebreaker 0
  • LambdaResNet Implementation?

    LambdaResNet Implementation?

    I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

    Do you plan on putting out a lambdaresnet model in this repository?

    opened by nollied 4
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
Bunch of optimizer implementations in PyTorch

Bunch of optimizer implementations in PyTorch

Hyeongchan Kim 76 Jan 03, 2023
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation.

PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation. It aims to accelerate research by providing a modular design that all

Preferred Networks, Inc. 96 Nov 28, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 08, 2023
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

878 Dec 30, 2022