Global Filter Networks for Image Classification

Overview

Global Filter Networks for Image Classification

Created by Yongming Rao, Wenliang Zhao, Zheng Zhu, Jiwen Lu, Jie Zhou

This repository contains PyTorch implementation for GFNet.

Global Filter Networks is a transformer-style architecture that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform.

intro

Our code is based on pytorch-image-models and DeiT.

[Project Page] [arXiv]

Global Filter Layer

GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:

import torch
import torch.nn as nn
import torch.fft

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h

    def forward(self, x):
        B, H, W, C = x.shape
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
        return x

Compared to self-attention and spatial MLP, our Global Filter Layer is much more efficient to process high-resolution feature maps:

efficiency

Model Zoo

We provide our GFNet models pretrained on ImageNet:

name arch Params FLOPs [email protected] [email protected] url
GFNet-Ti gfnet-ti 7M 1.3G 74.6 92.2 Tsinghua Cloud / Google Drive
GFNet-XS gfnet-xs 16M 2.8G 78.6 94.2 Tsinghua Cloud / Google Drive
GFNet-S gfnet-s 25M 4.5G 80.0 94.9 Tsinghua Cloud / Google Drive
GFNet-B gfnet-b 43M 7.9G 80.7 95.1 Tsinghua Cloud / Google Drive
GFNet-H-Ti gfnet-h-ti 15M 2.0G 80.1 95.1 Tsinghua Cloud / Google Drive
GFNet-H-S gfnet-h-s 32M 4.5G 81.5 95.6 Tsinghua Cloud / Google Drive
GFNet-H-B gfnet-h-b 54M 8.4G 82.9 96.2 Tsinghua Cloud / Google Drive

Usage

Requirements

  • torch>=1.8.1
  • torchvision
  • timm

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Evaluation

To evaluate a pre-trained GFNet model on the ImageNet validation set with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --path /path/to/model

Training

ImageNet

To train GFNet models on ImageNet from scratch, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs --arch gfnet-xs --batch-size 128 --data-path /path/to/ILSVRC2012/

To finetune a pre-trained model at higher resolution, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet.py  --output_dir logs/gfnet-xs-img384 --arch gfnet-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model

Transfer Learning Datasets

To finetune a pre-trained model on a transfer learning dataset, run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_gfnet_transfer.py  --output_dir logs/gfnet-xs-cars --arch gfnet-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --dist-eval --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021global,
  title={Global Filter Networks for Image Classification},
  author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
  journal={arXiv preprint arXiv:2107.00645},
  year={2021}
}
Probabilistic Tracklet Scoring and Inpainting for Multiple Object Tracking

Probabilistic Tracklet Scoring and Inpainting for Multiple Object Tracking (CVPR 2021) Pytorch implementation of the ArTIST motion model. In this repo

Fatemeh 38 Dec 12, 2022
Multi-Objective Loss Balancing for Physics-Informed Deep Learning

Multi-Objective Loss Balancing for Physics-Informed Deep Learning Code for ReLoBRaLo. Abstract Physics Informed Neural Networks (PINN) are algorithms

Rafael Bischof 16 Dec 12, 2022
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 03, 2023
Gym Threat Defense

Gym Threat Defense The Threat Defense environment is an OpenAI Gym implementation of the environment defined as the toy example in Optimal Defense Pol

Hampus Ramström 5 Dec 08, 2022
Session-aware Item-combination Recommendation with Transformer Network

Session-aware Item-combination Recommendation with Transformer Network 2nd place (0.39224) code and report for IEEE BigData Cup 2021 Track1 Report EDA

Tzu-Heng Lin 6 Mar 10, 2022
Statistical and Algorithmic Investing Strategies for Everyone

Eiten - Algorithmic Investing Strategies for Everyone Eiten is an open source toolkit by Tradytics that implements various statistical and algorithmic

Tradytics 2.5k Jan 02, 2023
Deep Halftoning with Reversible Binary Pattern

Deep Halftoning with Reversible Binary Pattern ICCV Paper | Project Website | BibTex Overview Existing halftoning algorithms usually drop colors and f

Menghan Xia 17 Nov 22, 2022
TensorFlow (Python API) implementation of Neural Style

neural-style-tf This is a TensorFlow implementation of several techniques described in the papers: Image Style Transfer Using Convolutional Neural Net

Cameron 3.1k Jan 02, 2023
Code for Multiple Instance Active Learning for Object Detection, CVPR 2021

MI-AOD Language: 简体中文 | English Introduction This is the code for Multiple Instance Active Learning for Object Detection (The PDF is not available tem

Tianning Yuan 269 Dec 21, 2022
Deeprl - Standard DQN and dueling network for simple games

DeepRL This code implements the standard deep Q-learning and dueling network with experience replay (memory buffer) for playing simple games. DQN algo

Yao Zhou 6 Apr 12, 2020
Visualizer using audio and semantic analysis to explore BigGAN (Brock et al., 2018) latent space.

BigGAN Audio Visualizer Description This visualizer explores BigGAN (Brock et al., 2018) latent space by using pitch/tempo of an audio file to generat

Rush Kapoor 2 Nov 21, 2022
ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation.

ENet This work has been published in arXiv: ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation. Packages: train contains too

e-Lab 344 Nov 21, 2022
Easy-to-use,Modular and Extendible package of deep-learning based CTR models .

DeepCTR DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models along with lots of core components layers which can

浅梦 6.6k Jan 08, 2023
Pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model'

RTK-PAD This is an official pytorch implementation of 'Fingerprint Presentation Attack Detector Using Global-Local Model', which is accepted by IEEE T

6 Aug 01, 2022
Simple and understandable swin-transformer OCR project

swin-transformer-ocr ocr with swin-transformer Overview Simple and understandable swin-transformer OCR project. The model in this repository heavily r

Ha YongWook 67 Dec 31, 2022
PrimitiveNet: Primitive Instance Segmentation with Local Primitive Embedding under Adversarial Metric (ICCV 2021)

PrimitiveNet Source code for the paper: Jingwei Huang, Yanfeng Zhang, Mingwei Sun. [PrimitiveNet: Primitive Instance Segmentation with Local Primitive

Jingwei Huang 47 Dec 06, 2022
A really easy-to-use and powerful sudoku solver.

SodukuSolver This is a really useful sudoku solver with a Qt gui. USAGE Enter the numbers in and click "RUN"! If you don't want to wait, simply press

Ujhhgtg Teams 11 Jun 02, 2022
[ICCV2021] Official Pytorch implementation for SDGZSL (Semantics Disentangling for Generalized Zero-Shot Learning)

Semantics Disentangling for Generalized Zero-shot Learning This is the official implementation for paper Zhi Chen, Yadan Luo, Ruihong Qiu, Zi Huang, J

25 Dec 06, 2022
nn_builder lets you build neural networks with less boilerplate code

nn_builder lets you build neural networks with less boilerplate code. You specify the type of network you want and it builds it. Install pip install n

Petros Christodoulou 157 Nov 20, 2022
ML-Ensemble – high performance ensemble learning

A Python library for high performance ensemble learning ML-Ensemble combines a Scikit-learn high-level API with a low-level computational graph framew

Sebastian Flennerhag 764 Dec 31, 2022