NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch

Overview

PyTorch implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Python Package Docs

Paper: https://arxiv.org/abs/2102.06171.pdf

Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Do star this repository if it helps your work!

Note: See this comment for a generic implementation for any optimizer as a temporary reference for anyone who needs it.

Installation

Install from PyPi:

pip3 install nfnets-pytorch

or install the latest code using:

pip3 install git+https://github.com/vballoli/nfnets-pytorch

Usage

WSConv2d

Use WSConv2d and WSConvTranspose2d like any other torch.nn.Conv2d or torch.nn.ConvTranspose2d modules.

import torch
from torch import nn
from nfnets import WSConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

conv_t = nn.ConvTranspose2d(3,6,3)
w_conv_t = WSConvTranspose2d(3,6,3)

SGD - Adaptive Gradient Clipping

Similarly, use SGD_AGC like torch.optim.SGD

import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)

Generic AGC

import torch
from torch import nn, optim
from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing

Using it within any PyTorch model

import torch
from torch import nn
from torchvision.models import resnet18

from nfnets import replace_conv

model = resnet18()
replace_conv(model)

Docs

Find the docs at readthedocs

TODO

  • WSConv2d
  • SGD - Adaptive Gradient Clipping
  • Function to automatically replace Convolutions in any module with WSConv2d
  • Documentation
  • Generic AGC wrapper.(See this comment for a reference implementation) (Needs testing for now)
  • WSConvTranspose2d
  • NFNets
  • NF-ResNets

Cite Original Work

To cite the original paper, use:

@article{brock2021high,
  author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  title={High-Performance Large-Scale Image Recognition Without Normalization},
  journal={arXiv preprint arXiv:},
  year={2021}
}
Comments
  • optim_agc.step() this instruction is raising  AttributeError: 'NoneType' object has no attribute 'ndim'

    optim_agc.step() this instruction is raising AttributeError: 'NoneType' object has no attribute 'ndim'

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import numpy as np
    import torchvision
    from torchvision import *
    from torch.utils.data import Dataset, DataLoader
    from nfnets import replace_conv,SGD_AGC
    
    net = models.resnet50(pretrained=True)
    replace_conv(net)
    net = net.cuda() if device else net
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 13)
    net.fc = net.fc.cuda() if use_cuda else net.fc
    ct = 0
    for name, child in net.named_children():
        ct += 1
        if ct < 8:
            for name2, params in child.named_parameters():
                params.requires_grad = False
    
    criterion = nn.CrossEntropyLoss()
    optim_agc = SGD_AGC(net.parameters(), 1e-3,momentum=0.9)
    
    n_epochs = 5
    print_every = 10
    valid_loss_min = np.Inf
    val_loss = []
    val_acc = []
    train_loss = []
    train_acc = []
    total_step = len(train_dataloader)
    
    for epoch in range(1, n_epochs+1):
        running_loss = 0.0
        correct = 0
        total=0
        print(f'Epoch {epoch}\n')
        for batch_idx, (data_, target_) in enumerate(train_dataloader):
            data_, target_ = data_.to(device), target_.to(device)
            optim_agc.zero_grad()
            outputs = net(data_)
            loss = criterion(outputs, target_)
            loss.backward()
            optim_agc.step()
            running_loss += loss.item()
            _,pred = torch.max(outputs, dim=1)
            correct += torch.sum(pred==target_).item()
            total += target_.size(0)
            if (batch_idx) % 20 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
        train_acc.append(100 * correct / total)
        train_loss.append(running_loss/total_step)
    		
    

    Traceback (most recent call last):

    File "image_product_classifier.py", line 74, in optim_agc.step() File "/home/sachin_mohan/venv/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context return func(*args, **kwargs) File "/home/sachin_mohan/venv/lib/python3.6/site-packages/nfnets/sgd_agc.py", line 105, in step grad_norm = unitwise_norm(p.grad) File "/home/sachin_mohan/venv/lib/python3.6/site-packages/nfnets/utils.py", line 27, in unitwise_norm if x.ndim <= 1: AttributeError: 'NoneType' object has no attribute 'ndim'``

    bug Fixed 
    opened by sachin22225 13
  • Is it pytorch compatibility??

    Is it pytorch compatibility??

    I use the code, however, I got the following error.

    At first, I got the following;

    base.py", line 262, in __init__
        dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
    TypeError: __init__() got an unexpected keyword argument 'padding_mode'
    

    When I remove padding_mode=padding_mode, I got another error, as presented below

    utils.py", line 23, in replace_conv
        setattr(module, name, torch.nn.Identity())
    AttributeError: module 'torch.nn' has no attribute 'Identity'
    

    How to solve this?

    question 
    opened by Choneke 11
  • Example in readme does not work

    Example in readme does not work

    Describe the bug Running either replace_conv this code form the readme on the front page: model = vgg16() replace_conv(model, WSConv2d) # This repo's original implementation replace_conv(model, ScaledStdConv2d) # From timm

    Results in this error:

      File "/opt/conda/lib/python3.8/site-packages/nfnets/utils.py", line 25, in replace_conv
        replace_conv(mod, conv_class)
      File "/opt/conda/lib/python3.8/site-packages/nfnets/utils.py", line 18, in replace_conv
        setattr(module, name, conv_class(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
      File "/opt/conda/lib/python3.8/site-packages/nfnets/base.py", line 262, in __init__
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
      File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 386, in __init__
        super(Conv2d, self).__init__(
      File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 108, in __init__
        if bias:
    RuntimeError: Boolean value of Tensor with more than one value is ambiguous
    
    enhancement WIP will be delayed 
    opened by RandomString123 4
  • Implement `param_groups` for AGC

    Implement `param_groups` for AGC

    Describe the bug Using LambdaLR, it will call len(optimizer.param_groups), but this is not implemented for AGC.

    To Reproduce

    model = torch.nn.Conv1d(10,20,4)
    optimizer = optim.AdamW(model.parameters())
    optimizer_agc = AGC(model.parameters(),optimizer)
    
    lambda1 = lambda iteration: 0.05*iteration
    scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)
    scheduler_warmup_agc = torch.optim.lr_scheduler.LambdaLR(optimizer_agc,lr_lambda=lambda1)
    
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    <ipython-input-58-f68dd026a6de> in <module>
    ----> 1 scheduler_warmup2 = torch.optim.lr_scheduler.LambdaLR(optimizer2,lr_lambda=lambda1)
    
    /gpfs/alpine/proj-shared/fus131/conda-envs/torch1.5.0v2/lib/python3.6/site-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, lr_lambda, last_epoch)
        180
        181         if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
    --> 182             self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
        183         else:
        184             if len(lr_lambda) != len(optimizer.param_groups):
    
    AttributeError: 'AGC' object has no attribute 'param_groups'
    

    Pytorch v1.5.0

    bug 
    opened by rmchurch 4
  • torch.max doesn't check for tensors being on different devices.

    torch.max doesn't check for tensors being on different devices.

    Describe the bug A clear and concise description of what the bug is.

    To Reproduce Steps to reproduce the behavior:

    1. Go to example and instantiate a resnet18
    2. Send model to torch.device('cuda)
    3. Define a tensor on the gpu
    4. Call model.forward()
    5. RuntimeError: iter.device(arg).is_cuda() INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Loops.cuh":94, please report a bug to PyTorch.

    Expected behavior Regular output

    Screenshots If applicable, add screenshots to help explain your problem.

    See here

    My Solution:

    Its hacky obviously but it works. Simply replace https://github.com/vballoli/nfnets-pytorch/blob/867860eebffcc70fb87a389d770cfd4a73c6b30c/nfnets/base.py#L22 with scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)

    Fixed 
    opened by bfialkoff 4
  • AGC without modifying the optimizer

    AGC without modifying the optimizer

    Hello,

    Is there a way to apply AGC externally without modifying the optimizer code?

    I am using optimizers from torch_optimizer package and that would be good.

    Fixed 
    opened by kayuksel 4
  • Add WSConv1d

    Add WSConv1d

    I just did a minor modification on WSConv2d to implement a WS version Conv1d. It works fine with my own model which is a time series model using Transformer + Conv1d, and it removes the overfitting issue in my model. Thanks

    opened by shi27feng 3
  • Model is None

    Model is None

    Hi @vballoli, It seems to have a bug in the code of AGC:

    if model is not None:
        assert ignore_agc not in [None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
        names = [name for name, module in model.named_modules()]
    
        for module_name in ignore_agc:
            if module_name not in names:
                raise ModuleNotFoundError("Module name {} not found in the model".format(module_name))
            params = [{"params": list(module.parameters())} for name,
                              module in model.named_modules() if name not in ignore_agc]
    else:
        params = [{"params": list(module.parameters())} for name,
                           module in model.named_modules()]
    

    When model is None then the else part of the code cannot get name and module from model.named_modules().

    Thanks

    bug 
    opened by shi27feng 2
  • AGC condition

    AGC condition

    Hello! What's the reason you changed the AGC condition from trigger = grad_norm > max_norm as in original paper and code?

    https://github.com/vballoli/nfnets-pytorch/blob/1513cea2c39e09189f4883ad90c2337a8fb9f9ed/nfnets/agc.py#L71

    commit: 4b03658f16d5746d45ea27711f5b33551a472e00

    opened by haigh1510 2
  • Conv layer from timm work better

    Conv layer from timm work better

    resnet18 with replace this layer learn better, but still need two times more epochs for same result than original on CIFAR10

    class ScaledStdConv2d(nn.Conv2d):
        """Conv2d layer with Scaled Weight Standardization.
        Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
            https://arxiv.org/abs/2101.08692
        """
    
        def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
                     bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
            if padding is None:
                padding = get_padding(kernel_size, stride, dilation)
            super().__init__(
                in_channels, out_channels, kernel_size, stride=stride,
                padding=padding, dilation=dilation, groups=groups, bias=bias)
            self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
            self.scale = gamma * self.weight[0].numel() ** -0.5  # gamma * 1 / sqrt(fan-in)
            self.eps = eps ** 2 if use_layernorm else eps
            self.use_layernorm = use_layernorm  # experimental, slightly faster/less GPU memory use
    
        def get_weight(self):
            if self.use_layernorm:
                weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
            else:
                std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
                weight = self.scale * (self.weight - mean) / (std + self.eps)
            if self.gain is not None:
                weight = weight * self.gain
            return weight
    
        def forward(self, x):
            return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
    
    enhancement help wanted Discussion 
    opened by attashe 2
  • Adaptive gradient clipping shouldn't be used in the final classifier layer

    Adaptive gradient clipping shouldn't be used in the final classifier layer

    Describe the bug In the paper it is mentioned that AGC is NOT used in the final layer. image

    Expected behavior It will be great AGC can be disabled for final layer. Adding a check here may be sufficient to get the desired behavior. Please let me know if i am misunderstanding something. Thanks.

    opened by sidml 2
Releases(0.1.3)
  • 0.0.6(Feb 17, 2021)

  • 0.0.4(Feb 15, 2021)

  • 0.0.3(Feb 15, 2021)

  • 0.0.2(Feb 15, 2021)

  • 0.0.1(Feb 14, 2021)

    PyTorch implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

    Paper: https://arxiv.org/abs/2102.06171.pdf Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

    Installation

    pip3 install git+https://github.com/vballoli/nfnets-pytorch

    Usage

    WSConv2d

    Use WSConv2d like any other torch.nn.Conv2d.

    import torch
    from torch import nn
    from nfnets import WSConv2d
    
    conv = nn.Conv2d(3,6,3)
    w_conv = WSConv2d(3,6,3)
    

    SGD - Adaptive Gradient Clipping

    Similarly, use SGD_AGC like torch.optim.SGD

    import torch
    from torch import nn, optim
    from nfnets import WSConv2d, SGD_AGC
    
    conv = nn.Conv2d(3,6,3)
    w_conv = WSConv2d(3,6,3)
    
    optim = optim.SGD(conv.parameters(), 1e-3)
    optim_agc = SGD_AGC(conv.parameters(), 1e-3)
    

    Using it within any PyTorch model

    import torch
    from torch import nn
    from torchvision.models import resnet18
    
    from nfnets import replace_conv
    
    model = resnet18()
    replace_conv(model)
    

    Docs

    Find the docs at readthedocs

    TODO

    • [x] WSConv2d
    • [x] SGD - Adaptive Gradient Clipping
    • [x] Function to automatically replace Convolutions in any module with WSConv2d
    • [x] Documentation
    • [ ] NFNets
    • [ ] NF-ResNets

    Cite Original Work

    To cite the original paper, use:

    @article{brock2021high,
      author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
      title={High-Performance Large-Scale Image Recognition Without Normalization},
      journal={arXiv preprint arXiv:},
      year={2021}
    }
    
    Source code(tar.gz)
    Source code(zip)
Owner
Vaibhav Balloli
ML, RL, Hardware Systems.
Vaibhav Balloli
This repo implements a 3D segmentation task for an airport baggage dataset.

3D CT Scan Segmentation With Occupancy Network This repo implements a 3D superresolution segmentation task for an airport baggage dataset. Our final p

Christoph Reich 2 Mar 28, 2022
Repository for the paper "Online Domain Adaptation for Occupancy Mapping", RSS 2020

RSS 2020 - Online Domain Adaptation for Occupancy Mapping Repository for the paper "Online Domain Adaptation for Occupancy Mapping", Robotics: Science

Anthony 26 Sep 22, 2022
Checkout some cool self-projects you can try your hands on to curb your boredom this December!

SoC-Winter Checkout some cool self-projects you can try your hands on to curb your boredom this December! These are short projects that you can do you

Web and Coding Club, IIT Bombay 29 Nov 08, 2022
Official PyTorch code of DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context Graph and Relation-based Optimization (ICCV 2021 Oral).

DeepPanoContext (DPC) [Project Page (with interactive results)][Paper] DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context G

Cheng Zhang 66 Nov 16, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022
Autoregressive Predictive Coding: An unsupervised autoregressive model for speech representation learning

Autoregressive Predictive Coding This repository contains the official implementation (in PyTorch) of Autoregressive Predictive Coding (APC) proposed

iamyuanchung 173 Dec 18, 2022
Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol.

Updated Updated for TTS(CE) = Also Known as TTN V3. The code requires the first server to be 'ttn' protocol. Introduction This balenaCloud (previously

Remko 1 Oct 17, 2021
PyTorch implementation of our ICCV 2019 paper: Liquid Warping GAN: A Unified Framework for Human Motion Imitation, Appearance Transfer and Novel View Synthesis

Impersonator PyTorch implementation of our ICCV 2019 paper: Liquid Warping GAN: A Unified Framework for Human Motion Imitation, Appearance Transfer an

SVIP Lab 1.7k Jan 06, 2023
Lightweight Face Image Quality Assessment

LightQNet This is a demo code of training and testing [LightQNet] using Tensorflow. Uncertainty Losses: IDQ loss PCNet loss Uncertainty Networks: Mobi

Kaen 5 Nov 18, 2022
RobustVideoMatting and background composing in one model by using onnxruntime.

RVM_onnx_compose RobustVideoMatting and background composing in one model by using onnxruntime. Usage pip install -r requirements.txt python infer_cam

Quantum Liu 4 Apr 07, 2022
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

8 Oct 10, 2022
SMD-Nets: Stereo Mixture Density Networks

SMD-Nets: Stereo Mixture Density Networks This repository contains a Pytorch implementation of "SMD-Nets: Stereo Mixture Density Networks" (CVPR 2021)

Fabio Tosi 115 Dec 26, 2022
Deep Unsupervised 3D SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment.

(ACMMM 2021 Oral) SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment This repository shows two tasks: Face landmark detection and Fac

BoomStar 51 Dec 13, 2022
[NeurIPS 2021] The PyTorch implementation of paper "Self-Supervised Learning Disentangled Group Representation as Feature"

IP-IRM [NeurIPS 2021] The PyTorch implementation of paper "Self-Supervised Learning Disentangled Group Representation as Feature". Codes will be relea

Wang Tan 67 Dec 24, 2022
Gated-Shape CNN for Semantic Segmentation (ICCV 2019)

GSCNN This is the official code for: Gated-SCNN: Gated Shape CNNs for Semantic Segmentation Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler

859 Dec 26, 2022
Python-kafka-reset-consumergroup-offset-example - Python Kafka reset consumergroup offset example

Python Kafka reset consumergroup offset example This is a simple example of how

Willi Carlsen 1 Feb 16, 2022
BrainGNN - A deep learning model for data-driven discovery of functional connectivity

A deep learning model for data-driven discovery of functional connectivity https://doi.org/10.3390/a14030075 Usman Mahmood, Zengin Fu, Vince D. Calhou

Usman Mahmood 3 Aug 28, 2022
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Free Book about Deep-Learning approaches for Chess (like AlphaZero, Leela Chess Zero and Stockfish NNUE)

Dominik Klein 189 Dec 21, 2022
A PyTorch Implementation of Neural IMage Assessment

NIMA: Neural IMage Assessment This is a PyTorch implementation of the paper NIMA: Neural IMage Assessment (accepted at IEEE Transactions on Image Proc

yunxiaos 418 Dec 29, 2022