Implementation of Uformer, Attention-based Unet, in Pytorch

Overview

Uformer - Pytorch

Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection.

This repository will be geared towards use in a project for learning protein structures. Specifically, it will include the ability to condition on time steps (needed for DDPM), as well as 2d relative positional encoding using rotary embeddings (instead of the bias on the attention matrix in the paper).

Install

$ pip install uformer-pytorch

Usage

import torch
from uformer_pytorch import Uformer

model = Uformer(
    dim = 64,           # initial dimensions after input projection, which increases by 2x each stage
    stages = 4,         # number of stages
    num_blocks = 2,     # number of transformer blocks per stage
    window_size = 16,   # set window size (along one side) for which to do the attention within
    dim_head = 64,
    heads = 8,
    ff_mult = 4
)

x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 3, 256, 256)

To condition on time for DDPM training

import torch
from uformer_pytorch import Uformer

model = Uformer(
    dim = 64,
    stages = 4,
    num_blocks = 2,
    window_size = 16,
    dim_head = 64,
    heads = 8,
    ff_mult = 4,
    time_emb = True    # set this to true
)

x = torch.randn(1, 3, 256, 256)
time = torch.arange(1)
pred = model(x, time = time) # (1, 3, 256, 256)

Citations

@misc{wang2021uformer,
    title   = {Uformer: A General U-Shaped Transformer for Image Restoration}, 
    author  = {Zhendong Wang and Xiaodong Cun and Jianmin Bao and Jianzhuang Liu},
    year    = {2021},
    eprint  = {2106.03106},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation detail for paper
Implementation detail for paper "Multi-level colonoscopy malignant tissue detection with adversarial CAC-UNet"

Multi-level-colonoscopy-malignant-tissue-detection-with-adversarial-CAC-UNet Implementation detail for our paper "Multi-level colonoscopy malignant ti

Implementation of UNet on the Joey ML framework

Independent Research Project - Code Joey can be cloned from here https://github.com/devitocodes/joey/. Devito and other dependencies such as PyTorch a

Implementation of UNET architecture for Image Segmentation.
Implementation of UNET architecture for Image Segmentation.

Semantic Segmentation using UNET This is the implementation of UNET on Carvana Image Masking Kaggle Challenge About the Dataset This dataset contains

Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

A unet implementation for Image semantic segmentation

Unet-pytorch a unet implementation for Image semantic segmentation 参考网上的Unet做分割的代码,做了一个针对kaggle地盐识别的,请去以下地址获取数据集: https://www.kaggle.com/c/tgs-salt-id

RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

Hippocampal segmentation using the  UNet network for each axis
Hippocampal segmentation using the UNet network for each axis

Hipposeg Hippocampal segmentation using the UNet network for each axis, inspired by https://github.com/MICLab-Unicamp/e2dhipseg Red: False Positive Gr

Owner
Phil Wang
Working with Attention
Phil Wang
This repository contains the reference implementation for our proposed Convolutional CRFs.

ConvCRF This repository contains the reference implementation for our proposed Convolutional CRFs in PyTorch (Tensorflow planned). The two main entry-

Marvin Teichmann 553 Dec 07, 2022
CMSC320 - Introduction to Data Science - Fall 2021

CMSC320 - Introduction to Data Science - Fall 2021 Instructors: Elias Jonatan Gonzalez and José Manuel Calderón Trilla Lectures: MW 3:30-4:45 & 5:00-6

Introduction to Data Science 6 Sep 12, 2022
The Curious Layperson: Fine-Grained Image Recognition without Expert Labels (BMVC 2021)

The Curious Layperson: Fine-Grained Image Recognition without Expert Labels Subhabrata Choudhury, Iro Laina, Christian Rupprecht, Andrea Vedaldi Code

Subhabrata Choudhury 18 Dec 27, 2022
The devkit of the nuPlan dataset.

The devkit of the nuPlan dataset.

Motional 264 Jan 03, 2023
R interface to fast.ai

R interface to fastai The fastai package provides R wrappers to fastai. The fastai library simplifies training fast and accurate neural nets using mod

113 Dec 20, 2022
Sparse R-CNN: End-to-End Object Detection with Learnable Proposals, CVPR2021

End-to-End Object Detection with Learnable Proposal, CVPR2021

Peize Sun 1.2k Dec 27, 2022
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop.

VoiceLoop PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop. VoiceLoop is a n

Meta Archive 873 Dec 15, 2022
Tensorflow 2 implementation of the paper: Learning and Evaluating Representations for Deep One-class Classification published at ICLR 2021

Deep Representation One-class Classification (DROC). This is not an officially supported Google product. Tensorflow 2 implementation of the paper: Lea

Google Research 137 Dec 23, 2022
Deep Learning Pipelines for Apache Spark

Deep Learning Pipelines for Apache Spark The repo only contains HorovodRunner code for local CI and API docs. To use HorovodRunner for distributed tra

Databricks 2k Jan 08, 2023
Semi-supervised semantic segmentation needs strong, varied perturbations

Semi-supervised semantic segmentation using CutMix and Colour Augmentation Implementations of our papers: Semi-supervised semantic segmentation needs

146 Dec 20, 2022
HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow

Class HiddenMarkovModel HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow 2.0 Installatio

Susara Thenuwara 2 Nov 03, 2021
EgoNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale

EgonNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale Paper: EgoNN: Egocentric Neural Network for Point Cloud

19 Sep 20, 2022
Most popular metrics used to evaluate object detection algorithms.

Most popular metrics used to evaluate object detection algorithms.

Rafael Padilla 4.4k Dec 25, 2022
A Deep Reinforcement Learning Framework for Stock Market Trading

DQN-Trading This is a framework based on deep reinforcement learning for stock market trading. This project is the implementation code for the two pap

61 Jan 01, 2023
Advances in Neural Information Processing Systems (NeurIPS), 2020.

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

Google Research 36 Aug 26, 2022
This is the pytorch code for the paper Curious Representation Learning for Embodied Intelligence.

Curious Representation Learning for Embodied Intelligence This is the pytorch code for the paper Curious Representation Learning for Embodied Intellig

19 Oct 19, 2022
Continual World is a benchmark for continual reinforcement learning

Continual World Continual World is a benchmark for continual reinforcement learning. It contains realistic robotic tasks which come from MetaWorld. Th

41 Dec 24, 2022
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation

CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation) CoCosNet v2: Full-Resolution Correspondence

Microsoft 308 Dec 07, 2022