Global Filter Networks for Image Classification

Overview

Global Filter Networks for Image Classification

Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou

This repository contains PyTorch implementation for GFNet.

Global Filter Networks is a transformer-style architecture that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform.

intro

Our code is based on pytorch-image-models and DeiT.

[Project Page] [arXiv]

Global Filter Layer

GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

Compared to self-attention and spatial MLP, our Global Filter Layer is much more efficient to process high-resolution feature maps:

efficiency

Model Zoo

We provide our GFNet models pretrained on ImageNet:

name arch Params FLOPs [email protected] [email protected] url
GFNet-Ti gfnet-ti 7M 1.3G 74.6 92.2 Tsinghua Cloud / Google Drive
GFNet-XS gfnet-xs 16M 2.8G 78.6 94.2 Tsinghua Cloud / Google Drive
GFNet-S gfnet-s 25M 4.5G 80.0 94.9 Tsinghua Cloud / Google Drive
GFNet-B gfnet-b 43M 7.9G 80.7 95.1 Tsinghua Cloud / Google Drive
GFNet-H-Ti gfnet-h-ti 15M 2.0G 80.1 95.1 Tsinghua Cloud / Google Drive
GFNet-H-S gfnet-h-s 32M 4.5G 81.5 95.6 Tsinghua Cloud / Google Drive
GFNet-H-B gfnet-h-b 54M 8.4G 82.9 96.2 Tsinghua Cloud / Google Drive

Usage

Requirements

  • torch>=1.8.1
  • torchvision
  • timm

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Evaluation

To evaluate a pre-trained GFNet model on the ImageNet validation set with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --path /path/to/model

Training

ImageNet

To train GFNet models on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs --arch gfnet-xs --batch-size 128 --data-path /path/to/ILSVRC2012/

To finetune a pre-trained model at higher resolution, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs-img384 --arch gfnet-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model

Transfer Learning Datasets

To finetune a pre-trained model on a transfer learning dataset, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet_transfer.py  --output_dir logs/gfnet-xs-cars --arch gfnet-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --dist-eval --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021global,
  title={Global Filter Networks for Image Classification},
  author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
  journal={arXiv preprint arXiv:2107.00645},
  year={2021}
}
code for "Self-supervised edge features for improved Graph Neural Network training",

Self-supervised edge features for improved Graph Neural Network training Data availability: Here is a link to the raw data for the organoids dataset.

Neal Ravindra 23 Dec 02, 2022
Losslandscapetaxonomy - Taxonomizing local versus global structure in neural network loss landscapes

Taxonomizing local versus global structure in neural network loss landscapes Int

Yaoqing Yang 8 Dec 30, 2022
MoCoGAN: Decomposing Motion and Content for Video Generation

MoCoGAN: Decomposing Motion and Content for Video Generation This repository contains an implementation and further details of MoCoGAN: Decomposing Mo

Sergey Tulyakov 514 Dec 18, 2022
In-Place Activated BatchNorm for Memory-Optimized Training of DNNs

In-Place Activated BatchNorm In-Place Activated BatchNorm for Memory-Optimized Training of DNNs In-Place Activated BatchNorm (InPlace-ABN) is a novel

1.3k Dec 29, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
StyleGAN2-ADA - Official PyTorch implementation

Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmenta

NVIDIA Research Projects 3.2k Dec 30, 2022
[WACV 2020] Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints

Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints Official implementation for Reducing Footskate in Human Motion Recon

Virginia Tech Vision and Learning Lab 38 Nov 01, 2022
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Nicolas Girard 186 Jan 04, 2023
Unsupervised 3D Human Mesh Recovery from Noisy Point Clouds

Unsupervised 3D Human Mesh Recovery from Noisy Point Clouds Xinxin Zuo, Sen Wang, Minglun Gong, Li Cheng Prerequisites We have tested the code on Ubun

41 Dec 12, 2022
Deep Anomaly Detection with Outlier Exposure (ICLR 2019)

Outlier Exposure This repository contains the essential code for the paper Deep Anomaly Detection with Outlier Exposure (ICLR 2019). Requires Python 3

Dan Hendrycks 464 Dec 27, 2022
Towards Implicit Text-Guided 3D Shape Generation (CVPR2022)

Towards Implicit Text-Guided 3D Shape Generation Towards Implicit Text-Guided 3D Shape Generation (CVPR2022) Code for the paper [Towards Implicit Text

55 Dec 16, 2022
[2021][ICCV][FSNet] Full-Duplex Strategy for Video Object Segmentation

Full-Duplex Strategy for Video Object Segmentation (ICCV, 2021) Authors: Ge-Peng Ji, Keren Fu, Zhe Wu, Deng-Ping Fan*, Jianbing Shen, & Ling Shao This

Daniel-Ji 55 Dec 22, 2022
Code for "My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack" paper

Myo Keylogging This is the source code for our paper My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack by Matthias Ga

Secure Mobile Networking Lab 7 Jan 03, 2023
Code and Datasets from the paper "Self-supervised contrastive learning for volcanic unrest detection from InSAR data"

Code and Datasets from the paper "Self-supervised contrastive learning for volcanic unrest detection from InSAR data" You can download the pretrained

Bountos Nikos 3 May 07, 2022
A Python Package for Portfolio Optimization using the Critical Line Algorithm

PyCLA A Python Package for Portfolio Optimization using the Critical Line Algorithm Getting started To use PyCLA, clone the repo and install the requi

19 Oct 11, 2022
"Structure-Augmented Text Representation Learning for Efficient Knowledge Graph Completion"(WWW 2021)

STAR_KGC This repo contains the source code of the paper accepted by WWW'2021. "Structure-Augmented Text Representation Learning for Efficient Knowled

Bo Wang 60 Dec 26, 2022
Crowd-Kit is a powerful Python library that implements commonly-used aggregation methods for crowdsourced annotation and offers the relevant metrics and datasets

Crowd-Kit: Computational Quality Control for Crowdsourcing Documentation Crowd-Kit is a powerful Python library that implements commonly-used aggregat

Toloka 125 Dec 30, 2022
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
Locally Differentially Private Distributed Deep Learning via Knowledge Distillation (LDP-DL)

Locally Differentially Private Distributed Deep Learning via Knowledge Distillation (LDP-DL) A preprint version of our paper: Link here This is a samp

Di Zhuang 3 Jan 08, 2023
This is an example of object detection on Micro bacterium tuberculosis using Mask-RCNN

Mask-RCNN on Mycobacterium tuberculosis This is an example of object detection on Mycobacterium Tuberculosis using Mask RCNN. Implement of Mask R-CNN

Jun-En Ding 1 Sep 16, 2021