PyTorch implementation of popular datasets and models in remote sensing

Related tags

Deep Learningtorchrs
Overview

PyTorch Remote Sensing (torchrs)

(WIP) PyTorch implementation of popular datasets and models in remote sensing tasks (Change Detection, Image Super Resolution, Land Cover Classification/Segmentation, Image-to-Image Translation, etc.) for various Optical (Sentinel-2, Landsat, etc.) and Synthetic Aperture Radar (SAR) (Sentinel-1) sensors.

Installation

# pypi
pip install torch-rs

# latest
pip install git+https://github.com/isaaccorley/torchrs

Table of Contents

Datasets

PROBA-V Super Resolution

The PROBA-V Super Resolution Challenge dataset is a Multi-image Super Resolution (MISR) dataset of images taken by the ESA PROBA-Vegetation satellite. The dataset contains sets of unregistered 300m low resolution (LR) images which can be used to generate single 100m high resolution (HR) images for both Near Infrared (NIR) and Red bands. In addition, Quality Masks (QM) for each LR image and Status Masks (SM) for each HR image are available. The PROBA-V contains sensors which take imagery at 100m and 300m spatial resolutions with 5 and 1 day revisit rates, respectively. Generating high resolution imagery estimates would effectively increase the frequency at which HR imagery is available for vegetation monitoring.

The dataset can be downloaded (0.83GB) using scripts/download_probav.sh and instantiated below:

from torchrs.transforms import Compose, ToTensor
from torchrs.datasets import PROBAV

transform = Compose([ToTensor()])

dataset = PROBAV(
    root="path/to/dataset/",
    split="train",  # or 'test'
    band="RED",     # or 'NIR'
    lr_transform=transform,
    hr_transform=transform
)

x = dataset[0]
"""
x: dict(
    lr: low res images  (t, 1, 128, 128)
    qm: quality masks   (t, 1, 128, 128)
    hr: high res image  (1, 384, 384)
    sm: status mask     (1, 384, 384)
)
t varies by set of images (minimum of 9)
"""

ETCI 2021 Flood Detection

The ETCI 2021 Dataset is a Flood Detection segmentation dataset of SAR images taken by the ESA Sentinel-1 satellite. The dataset contains pairs of VV and VH polarization images processed by the Hybrid Pluggable Processing Pipeline (hyp3) along with corresponding binary flood and water body ground truth masks.

The dataset can be downloaded (5.6GB) using scripts/download_etci2021.sh and instantiated below:

from torchrs.transforms import Compose, ToTensor
from torchrs.datasets import ETCI2021

transform = Compose([ToTensor()])

dataset = ETCI2021(
    root="path/to/dataset/",
    split="train",  # or 'val', 'test'
    transform=transform
)

x = dataset[0]
"""
x: dict(
    vv:         (3, 256, 256)
    vh:         (3, 256, 256)
    flood_mask: (1, 256, 256)
    water_mask: (1, 256, 256)
)
"""

Onera Satellite Change Detection (OSCD)

The Onera Satellite Change Detection (OSCD) dataset, proposed in "Urban Change Detection for Multispectral Earth Observation Using Convolutional Neural Networks", Daudt et al. is a Change Detection dataset of 13 band multispectral (MS) images taken by the ESA Sentinel-2 satellite. The dataset contains 24 registered image pairs from multiple continents between 2015-2018 along with binary change masks.

The dataset can be downloaded (0.73GB) using scripts/download_oscd.sh and instantiated below:

from torchrs.transforms import Compose, ToTensor
from torchrs.datasets import OSCD

transform = Compose([ToTensor(permute_dims=False)])

dataset = OSCD(
    root="path/to/dataset/",
    split="train",  # or 'test'
    transform=transform,
)

x = dataset[0]
"""
x: dict(
    x: (2, 13, h, w)
    mask: (1, h, w)
)
"""

Remote Sensing Visual Question Answering (RSVQA) Low Resolution (LR)

The RSVQA LR dataset, proposed in "RSVQA: Visual Question Answering for Remote Sensing Data", Lobry et al. is a visual question answering (VQA) dataset of RGB images taken by the ESA Sentinel-2 satellite. Each image is annotated with a set of questions and their corresponding answers. Among other applications, this dataset can be used to train VQA models to perform scene understanding of medium resolution remote sensing imagery.

The dataset can be downloaded (0.2GB) using scripts/download_rsvqa_lr.sh and instantiated below:

import torchvision.transforms as T
from torchrs.datasets import RSVQALR

transform = T.Compose([T.ToTensor()])

dataset = RSVQALR(
    root="path/to/dataset/",
    split="train",  # or 'val', 'test'
    transform=transform
)

x = dataset[0]
"""
x: dict(
    x:         (3, 256, 256)
    questions:  List[str]
    answers:    List[str]
    types:      List[str]
)
"""

Remote Sensing Image Captioning Dataset (RSICD)

The RSICD dataset, proposed in "Exploring Models and Data for Remote Sensing Image Caption Generation", Lu et al. is an image captioning dataset with 5 captions per image for 10,921 RGB images extracted using Google Earth, Baidu Map, MapABC and Tianditu. While one of the larger remote sensing image captioning datasets, this dataset contains very repetitive language with little detail and many captions are duplicated.

The dataset can be downloaded (0.57GB) using scripts/download_rsicd.sh and instantiated below:

import torchvision.transforms as T
from torchrs.datasets import RSICD

transform = T.Compose([T.ToTensor()])

dataset = RSICD(
    root="path/to/dataset/",
    split="train",  # or 'val', 'test'
    transform=transform
)

x = dataset[0]
"""
x: dict(
    x:        (3, 224, 224)
    captions: List[str]
)
"""

Remote Sensing Image Scene Classification (RESISC45)

The RESISC45 dataset, proposed in "Remote Sensing Image Scene Classification: Benchmark and State of the Art", Cheng et al. is an image classification dataset of 31,500 RGB images extracted using Google Earth Engine. The dataset contains 45 scenes with 700 images per class from over 100 countries and was selected to optimize for high variability in image conditions (spatial resolution, occlusion, weather, illumination, etc.).

The dataset can be downloaded (0.47GB) using scripts/download_resisc45.sh and instantiated below:

import torchvision.transforms as T
from torchrs.datasets import RESISC45

transform = T.Compose([T.ToTensor()])

dataset = RESISC45(
    root="path/to/dataset/",
    transform=transform
)

x, y = dataset[0]
"""
x: (3, 256, 256)
y: int
"""

dataset.classes
"""
['airplane', 'airport', 'baseball_diamond', 'basketball_court', 'beach', 'bridge', 'chaparral',
'church', 'circular_farmland', 'cloud', 'commercial_area', 'dense_residential', 'desert', 'forest',
'freeway', 'golf_course', 'ground_track_field', 'harbor', 'industrial_area', 'intersection', 'island',
'lake', 'meadow', 'medium_residential', 'mobile_home_park', 'mountain', 'overpass', 'palace', 'parking_lot',
'railway', 'railway_station', 'rectangular_farmland', 'river', 'roundabout', 'runway', 'sea_ice', 'ship',
'snowberg', 'sparse_residential', 'stadium', 'storage_tank', 'tennis_court', 'terrace', 'thermal_power_station', 'wetland']
"""

EuroSAT

The EuroSAT dataset, proposed in "EuroSAT: A Novel Dataset and Deep Learning Benchmark for Land Use and Land Cover Classification", Helber et al. is a land cover classification dataset of 27,000 images taken by the ESA Sentinel-2 satellite. The dataset contains 10 land cover classes with 2-3k images per class from over 34 European countries. The dataset is available in the form of RGB only or all Multispectral (MS) Sentinel-2 bands. This dataset is fairly easy with ~98.6% accuracy achieved with a ResNet-50.

The dataset can be downloaded (.13GB and 2.8GB) using scripts/download_eurosat_rgb.sh or scripts/download_eurosat_ms.sh and instantiated below:

import torchvision.transforms as T
from torchrs.transforms import ToTensor
from torchrs.datasets import EuroSATRGB, EuroSATMS

transform = T.Compose([T.ToTensor()])

dataset = EuroSATRGB(
    root="path/to/dataset/",
    transform=transform
)

x, y = dataset[0]
"""
x: (3, 64, 64)
y: int
"""

transform = T.Compose([ToTensor()])

dataset = EuroSATMS(
    root="path/to/dataset/",
    transform=transform
)

x, y = dataset[0]
"""
x: (13, 64, 64)
y: int
"""

dataset.classes
"""
['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial',
'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake']
"""

Models

RAMS

Residual Attention Multi-image Super-resolution Network (RAMS) from "Multi-Image Super Resolution of Remotely Sensed Images Using Residual Attention Deep Neural Networks", Salvetti et al. (2021)

RAMS is currently one of the top performers on the PROBA-V Super Resolution Challenge. This Multi-image Super Resolution (MISR) architecture utilizes attention based methods to extract spatial and spatiotemporal features from a set of low resolution images to form a single high resolution image. Note that the attention methods are effectively Squeeze-and-Excitation blocks from "Squeeze-and-Excitation Networks", Hu et al..

import torch
from torchrs.models import RAMS

# increase resolution by factor of 3 (e.g. 128x128 -> 384x384)
model = RAMS(
    scale_factor=3,
    t=9,
    c=1,
    num_feature_attn_blocks=12
)

# Input should be of shape (bs, t, c, h, w), where t is the number
# of low resolution input images and c is the number of channels/bands
lr = torch.randn(1, 9, 1, 128, 128)
sr = model(lr) # (1, 1, 384, 384)

Tests

$ pytest -ra
Comments
  • Error in training example

    Error in training example

    Following the example from the README:

    ValueError                                Traceback (most recent call last)
    <ipython-input-11-f854c515c2ab> in <module>()
    ----> 1 trainer.fit(model, datamodule=dm)
    
    23 frames
    /usr/local/lib/python3.7/dist-packages/torchmetrics/functional/classification/stat_scores.py in _stat_scores_update(preds, target, reduce, mdmc_reduce, num_classes, top_k, threshold, multiclass, ignore_index)
        123         if not mdmc_reduce:
        124             raise ValueError(
    --> 125                 "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter"
        126             )
        127         if mdmc_reduce == "global":
    
    ValueError: When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter
    
    opened by robmarkcole 3
  • probaV key issue

    probaV key issue

    Listed as

        def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
            """ Returns a dict containing lrs, qms, hr, sm
            lrs: (t, 1, h, w) low resolution images
    

    but the key is actually lr. Similarly its not qms but qm

    opened by robmarkcole 3
  • Hello, I use gid data and use the FCEFModule model to train, but the following error is reported. Is it because of a version problem? thanks

    Hello, I use gid data and use the FCEFModule model to train, but the following error is reported. Is it because of a version problem? thanks

    there is my code

    import torch
    import torch.nn as nn
    import pytorch_lightning as pl
    from torchrs.datasets import GID15
    from torchrs.train.modules import FCEFModule
    from torchrs.train.datamodules import GID15DataModule
    from torchrs.transforms import Compose, ToTensor
    
    
    def collate_fn(batch):
        x = torch.stack([x["x"] for x in batch])
        y = torch.cat([x["mask"] for x in batch])
        x = x.to(torch.float32)
        y = y.to(torch.long)
        return x, y
    
    transform = Compose([ToTensor()])
    
    dm = GID15DataModule(
        root="./datasets/gid-15",
        transform=transform,
        batch_size=128,
        num_workers=1,
        prefetch_factor=1,
        collate_fn=collate_fn,
        test_collate_fn=collate_fn,
    )
    
    
    model = FCEFModule(channels=3, t=2, num_classes=15, lr=1E-3)
    
    
    callbacks = [
        pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min", verbose=True, save_top_k=1),
        pl.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=10)
    ]
    
    trainer = pl.Trainer(
        gpus=1,
        precision=16,
        accumulate_grad_batches=1,
        max_epochs=25,
        callbacks=callbacks,
        weights_summary="top"
    )
    #
    trainer.fit(model, datamodule=dm)
    trainer.test(datamodule=dm)
    

    Error

    image

    opened by olongfen 2
  • Torchmetrics update

    Torchmetrics update

    • Updated segmentation metrics input args (namely added the mdmc_average="global" args) e.g.
    torchmetrics.Accuracy(threshold=0.5, num_classes=num_classes, average="micro", mdmc_average="global"),
    
    opened by isaaccorley 0
  • Change Detection models

    Change Detection models

    • Added EarlyFusion (EF) and Siamese (Siam) from the OSCD dataset paper
    • Added Fully convolutional EarlyFusion (FC-EF), Siamese Concatenation (FC-Siam-conc), and Siamese Difference (FC-Siam-diff)
    opened by isaaccorley 0
  • converted to datamodules and modules, updated readme, added install e…

    converted to datamodules and modules, updated readme, added install e…

    • Created LightningDataModule for each Dataset
    • Created LightningModule for each Model
    • Moved some config in setup.py to setup.cfg
    • Some minor fixes to a few datasets
    opened by isaaccorley 0
  • fair1m Only small part 1 dataset

    fair1m Only small part 1 dataset

    I noticed the official dataset is split into parts 1 & 2, with the bulk of the images being in part 2

    image

    The data downloaded using the script in this repo only downloads a subset of 1733 images, which I believe are the part 1 images?

    opened by robmarkcole 2
  • probav training memory error

    probav training memory error

    Using colab pro with nominally 25 Gb I am still running out of memory at 17 epochs using your probav example notebook. Is there any way to free memory on the fly? I was able to train the tensorflow RAMS implementation to 50 epochs on colab pro

    CUDA out of memory. Tried to allocate 46.00 MiB (GPU 0; 15.90 GiB total capacity; 14.01 GiB already allocated; 25.75 MiB free; 14.96 GiB reserved in total by PyTorch)
    
    opened by robmarkcole 5
Releases(0.0.4)
  • 0.0.4(Sep 3, 2021)

    • Added RSVQAHR, Sydney Captions, UC Merced Captions, S2MTCP, ADVANCE, SAT-4, SAT-6, HRSCD, Inria AIL, TiSeLaC, GID-15, ZeuriCrop, AID, Dubai Segmentation, HKH Glacier Mapping, UC Merced, PatternNet, WHU-RS19, RSSCN7, Brazilian Coffee Scenes datasets & datamodules
    • Added methods for creating train/val/test splits in datamodules
    • Added ExtractChips transform
    • Added dataset and datamodules local tests
    • Added h5py, imagecodecs, torchaudio reqs
    • Some minor fixes+additions to current datasets
    • Added FCEF, FCSiamConc, FCSiamDiff PL modules
    Source code(tar.gz)
    Source code(zip)
  • 0.0.3(Aug 2, 2021)

  • 0.0.2(Jul 26, 2021)

Owner
isaac
Senior Computer Vision Engineer @ BlackSky, Ph.D. Student at the University of Texas at San Antonio. Former @housecanary, @boozallen, @swri, @ornl
isaac
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection.

Introduction This repo contains the official PyTorch implementation of our ICCV paper DeFRCN: Decoupled Faster R-CNN for Few-Shot Object Detection. Up

133 Dec 29, 2022
Tensorflow implementation for Self-supervised Graph Learning for Recommendation

If the compilation is successful, the evaluator of cpp implementation will be called automatically. Otherwise, the evaluator of python implementation will be called.

152 Jan 07, 2023
Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution

Syllabic Quantity Patterns as Rhythmic Features for Latin Authorship Attribution Abstract Within the Latin (and ancient Greek) production, it is well

4 Dec 03, 2022
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
Code for the CVPR 2021 paper "Triple-cooperative Video Shadow Detection"

Triple-cooperative Video Shadow Detection Code and dataset for the CVPR 2021 paper "Triple-cooperative Video Shadow Detection"[arXiv link] [official l

Zhihao Chen 24 Oct 04, 2022
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 03, 2023
[제 13회 투빅스 컨퍼런스] OK Mugle! - 장르부터 멜로디까지, Content-based Music Recommendation

Ok Mugle! 🎵 장르부터 멜로디까지, Content-based Music Recommendation 'Ok Mugle!'은 제13회 투빅스 컨퍼런스(2022.01.15)에서 진행한 음악 추천 프로젝트입니다. Description 📖 본 프로젝트에서는 Kakao

SeongBeomLEE 5 Oct 09, 2022
The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography using a CNN-based orientation classifier')

The PyTorch re-implement of a 3D CNN Tracker to extract coronary artery centerlines with state-of-the-art (SOTA) performance. (paper: 'Coronary artery centerline extraction in cardiac CT angiography

James 135 Dec 23, 2022
Code for "On the Effects of Batch and Weight Normalization in Generative Adversarial Networks"

Note: this repo has been discontinued, please check code for newer version of the paper here Weight Normalized GAN Code for the paper "On the Effects

Sitao Xiang 182 Sep 06, 2021
An Approach to Explore Logistic Regression Models

User-centered Regression An Approach to Explore Logistic Regression Models This tool applies the potential of Attribute-RadViz in identifying correlat

0 Nov 12, 2021
Generative Adversarial Networks for High Energy Physics extended to a multi-layer calorimeter simulation

CaloGAN Simulating 3D High Energy Particle Showers in Multi-Layer Electromagnetic Calorimeters with Generative Adversarial Networks. This repository c

Deep Learning for HEP 101 Nov 13, 2022
[3DV 2021] A Dataset-Dispersion Perspective on Reconstruction Versus Recognition in Single-View 3D Reconstruction Networks

dispersion-score Official implementation of 3DV 2021 Paper A Dataset-dispersion Perspective on Reconstruction versus Recognition in Single-view 3D Rec

Yefan 7 May 28, 2022
This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis

This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis, accepted at ACMMM 2021.

Ziqi Yuan 10 Sep 30, 2022
A PyTorch implementation of EventProp [https://arxiv.org/abs/2009.08378], a method to train Spiking Neural Networks

Spiking Neural Network training with EventProp This is an unofficial PyTorch implemenation of EventProp, a method to compute exact gradients for Spiki

Pedro Savarese 35 Jul 29, 2022
Source code for From Stars to Subgraphs

GNNAsKernel Official code for From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness Visualizations GNN-AK(+) GNN-AK(+) with Subgra

44 Dec 19, 2022
meProp: Sparsified Back Propagation for Accelerated Deep Learning

meProp The codes were used for the paper meProp: Sparsified Back Propagation for Accelerated Deep Learning with Reduced Overfitting (ICML 2017) [pdf]

LancoPKU 107 Nov 18, 2022
Atomistic Line Graph Neural Network

Table of Contents Introduction Installation Examples Pre-trained models Quick start using colab JARVIS-ALIGNN webapp Peformances on a few datasets Use

National Institute of Standards and Technology 91 Dec 30, 2022
A unified 3D Transformer Pipeline for visual synthesis

Overview This is the official repo for the paper: NÜWA: Visual Synthesis Pre-training for Neural visUal World creAtion. NÜWA is a unified multimodal p

Microsoft 2.6k Jan 06, 2023
A copy of Ares that costs 30 fucking dollars.

Finalement, j'ai décidé d'abandonner cette idée, je me suis comporté comme un enfant qui été en colère. Comme m'ont dit certaines personnes j'ai des c

Bleu 24 Apr 14, 2022