Pre-trained NFNets with 99% of the accuracy of the official paper

Overview

NFNet Pytorch Implementation

This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper High-Performance Large-Scale Image Recognition Without Normalization. The small models are as accurate as an EfficientNet-B7, but train 8.7 times faster. The large models set a new SOTA top-1 accuracy on ImageNet.

NFNet F0 F1 F2 F3 F4 F5 F6+SAM
Top-1 accuracy Brock et al. 83.6 84.7 85.1 85.7 85.9 86.0 86.5
Top-1 accuracy this implementation 82.82 84.63 84.90 85.46 85.66 85.62 TBD

All credits go to the authors of the original paper. This repo is heavily inspired by their nice JAX implementation in the official repository. Visit their repo for citing.

Get started

git clone https://github.com/benjs/nfnets_pytorch.git
pip3 install -r requirements.txt

Download pretrained weights from the official repository and place them in the pretrained folder.

from pretrained import pretrained_nfnet
model_F0 = pretrained_nfnet('pretrained/F0_haiku.npz')
model_F1 = pretrained_nfnet('pretrained/F1_haiku.npz')
# ...

The model variant is automatically derived from the parameter count in the pretrained weights file.

Validate yourself

python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset path/to/imagenet/valset/

You can download the ImageNet validation set from the ILSVRC2012 challenge site after asking for access with, for instance, your .edu mail address.

Scaled weight standardization convolutions in your own model

Simply replace all your nn.Conv2d with WSConv2D and all your nn.ReLU with VPReLU or VPGELU (variance preserving ReLU/GELU).

import torch.nn as nn
from model import WSConv2D, VPReLU, VPGELU

# Simply replace your nn.Conv2d layers
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
 
        self.activation = VPReLU(inplace=True) # or VPGELU
        self.conv0 = WSConv2D(in_channels=128, out_channels=256, kernel_size=1, ...)
        # ...

    def forward(self, x):
      out = self.activation(self.conv0(x))
      # ...

SGD with adaptive gradient clipping in your own model

Simply replace your SGD optimizer with SGD_AGC.

from optim import SGD_AGC

optimizer = SGD_AGC(
        named_params=model.named_parameters(), # Pass named parameters
        lr=1e-3,
        momentum=0.9,
        clipping=0.1, # New clipping parameter
        weight_decay=2e-5, 
        nesterov=True)

It is important to exclude certain layers from clipping or momentum. The authors recommends to exclude the last fully convolutional from clipping and the bias/gain parameters from weight decay:

import re

for group in optimizer.param_groups:
    name = group['name'] 
    
    # Exclude from weight decay
    if len(re.findall('stem.*(bias|gain)|conv.*(bias|gain)|skip_gain', name)) > 0:
        group['weight_decay'] = 0

    # Exclude from clipping
    if name.startswith('linear'):
        group['clipping'] = None

Train your own NFNet

Adjust your desired parameters in default_config.yaml and start training.

python3 train.py --dataset /path/to/imagenet/

There is still some parts missing for complete training from scratch:

  • Multi-GPU training
  • Data augmentations
  • FP16 activations and gradients

Contribute

The implementation is still in an early stage in terms of usability / testing. If you have an idea to improve this repo open an issue, start a discussion or submit a pull request.

Development status

  • Pre-trained NFNet Models
    • F0-F5
    • F6+SAM
    • Scaled weight standardization
    • Squeeze and excite
    • Stochastic depth
    • FP16 activations
  • SGD with unit adaptive gradient clipping (SGD-AGC)
    • Exclude certain layers from weight-decay, clipping
    • FP16 gradients
  • PyPI package
  • PyTorch hub submission
  • Label smoothing loss from Szegedy et al.
  • Training on ImageNet
  • Pre-trained weights
  • Tensorboard support
  • general usability improvements
  • Multi-GPU support
  • Data augmentation
  • Signal propagation plots (from first paper)
Comments
  • ModuleNotFoundError: No module named 'haiku'

    ModuleNotFoundError: No module named 'haiku'

    when i try "python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset ***" i got this error, have you ever met this error? how to fix this?

    opened by Rianusr 2
  • Trained without data augmentation?

    Trained without data augmentation?

    Thanks for the great work on the pytorch implementation of NFNet! The accuracies achieved by this implementation are pretty impressive also and I am wondering if these training results were simply derived from the training script, that is, without data augmentation.

    opened by nandi-zhang 2
  • from_pretrained_haiku

    from_pretrained_haiku

    https://github.com/benjs/nfnets_pytorch/blob/7b4d1cc701c7de4ee273ded01ce21cbdb1e60c48/nfnets/pretrained.py#L90

    model = from_pretrained_haiku(args.pretrained)

    where is 'from_pretrained_haiku' method?

    opened by vkmavani 0
  • About WSconv2d

    About WSconv2d

    I see the authoe's code, I find his WSconv2d pad_mod is 'same'. Pytorch's conv2d dono't have pad_mode, and I think your padding should greater 0, but I find your padding always be 0. I want to know why?

    I see you train.py your learning rate is constant, why? Thank you!

    opened by fancyshun 3
  • AveragePool

    AveragePool

    Hi, noticed that the AveragePool ('pool' layer) is not used in forward function. Instead, forward uses torch.mean. Removing the layer doesn't change pooling behavior. I tried using this model as a feature extractor and was a bit confused for a moment.

    opened by bogdankjastrzebski 1
Releases(v0.0.1)
Owner
Benjamin Schmidt
Engineering Student
Benjamin Schmidt
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

102 Dec 25, 2022
DRLib:A concise deep reinforcement learning library, integrating HER and PER for almost off policy RL algos.

DRLib:A concise deep reinforcement learning library, integrating HER and PER for almost off policy RL algos A concise deep reinforcement learning libr

329 Jan 03, 2023
TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How

Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How TensorFlow implementation for Bayesian Modeling and Unce

Shen Lab at Texas A&M University 8 Sep 02, 2022
PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks

AttentionHTR PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks. Scene Text

Dmitrijs Kass 31 Dec 22, 2022
Code for IntraQ, PyTorch implementation of our paper under review

IntraQ: Learning Synthetic Images with Intra-Class Heterogeneity for Zero-Shot Network Quantization paper Requirements Python = 3.7.10 Pytorch == 1.7

1 Nov 19, 2021
This repository provides data for the VAW dataset as described in the CVPR 2021 paper titled "Learning to Predict Visual Attributes in the Wild"

Visual Attributes in the Wild (VAW) This repository provides data for the VAW dataset as described in the CVPR 2021 Paper: Learning to Predict Visual

Adobe Research 36 Dec 30, 2022
Dynamic Attentive Graph Learning for Image Restoration, ICCV2021 [PyTorch Code]

Dynamic Attentive Graph Learning for Image Restoration This repository is for GATIR introduced in the following paper: Chong Mou, Jian Zhang, Zhuoyuan

Jian Zhang 84 Dec 09, 2022
Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Kaggle G2Net Gravitational Wave Detection : 2nd place solution

Hiroshechka Y 33 Dec 26, 2022
GNPy: Optical Route Planning and DWDM Network Optimization

GNPy is an open-source, community-developed library for building route planning and optimization tools in real-world mesh optical networks

Telecom Infra Project 140 Dec 19, 2022
Finite Element Analysis

FElupe - Finite Element Analysis FElupe is a Python 3.6+ finite element analysis package focussing on the formulation and numerical solution of nonlin

Andreas D. 20 Jan 09, 2023
TDN: Temporal Difference Networks for Efficient Action Recognition

TDN: Temporal Difference Networks for Efficient Action Recognition Overview We release the PyTorch code of the TDN(Temporal Difference Networks).

Multimedia Computing Group, Nanjing University 326 Dec 13, 2022
Improving Object Detection by Label Assignment Distillation

Improving Object Detection by Label Assignment Distillation This is the official implementation of the WACV 2022 paper Improving Object Detection by L

Cybercore Co. Ltd 51 Dec 08, 2022
TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
VLGrammar: Grounded Grammar Induction of Vision and Language

VLGrammar: Grounded Grammar Induction of Vision and Language

Yining Hong 27 Dec 23, 2022
This is the repository for Learning to Generate Piano Music With Sustain Pedals

SusPedal-Gen This is the official repository of Learning to Generate Piano Music With Sustain Pedals Demo Page Dataset The dataset used in this projec

Joann Ching 12 Sep 02, 2022
PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop.

VoiceLoop PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop. VoiceLoop is a n

Meta Archive 873 Dec 15, 2022
CN24 is a complete semantic segmentation framework using fully convolutional networks

Build status: master (production branch): develop (development branch): Welcome to the CN24 GitHub repository! CN24 is a complete semantic segmentatio

Computer Vision Group Jena 123 Jul 14, 2022
🎓Automatically Update CV Papers Daily using Github Actions (Update at 12:00 UTC Every Day)

🎓Automatically Update CV Papers Daily using Github Actions (Update at 12:00 UTC Every Day)

Realcat 270 Jan 07, 2023
ManipNet: Neural Manipulation Synthesis with a Hand-Object Spatial Representation - SIGGRAPH 2021

ManipNet: Neural Manipulation Synthesis with a Hand-Object Spatial Representation - SIGGRAPH 2021 Dataset Code Demos Authors: He Zhang, Yuting Ye, Tak

HE ZHANG 194 Dec 06, 2022
PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders: A PyTorch Implementation This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners: @

Meta Research 4.8k Jan 04, 2023