Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL)

Related tags

Deep LearningSIB-CL
Overview

Surrogate- and Invariance-Boosted Contrastive Learning (SIB-CL)

This repository contains all source code used to generate the results in the article "Surrogate- and invariance-boosted contrastive learning for data-scarce applications in science". (url: to-be-updated)

  • The folder generate_datasets contains all numerical programs used to generate the datasets, for both Photonic Crystals (PhC) and the Time-independent Schrodinger Equation (TISE)
  • main.py is the main code used to train the neural networks (explained in detail below)

Dependencies

Please install the required Python packages: pip install -r requirements.txt

A python3 environment can be created prior to this: conda create -n sibcl python=3.8; conda activate sibcl

Assess to MATLAB is required to calculate the density-of-states (DOS) of PhCs.

Dataset Generation

Photonic Crystals (PhCs)

Relevant code stored in generate_datasets/PhC/. Periodic unit cells are defined using a level set of a Fourier sum; different unit cells can be generated using the get_random() method of the FourierPhC class defined in fourier_phc.py.

To generate the labeled PhC datasets, we first compute their band structures using MPB. This can be executed via:

For the target dataset of random fourier unit cells, python phc_gendata.py --h5filename="mf1-s1" --pol="tm" --nsam=5000 --maxF=1 --seed=1;

and for the source dataset of simple cylinders, python phc_gencylin.py --h5filename="cylin" --pol="tm" --nsam=10000;

each program will create a dataset with the eigen-frequencies, group velocities, etc, stored in a .h5 file (which can be accessed using the h5py package). We then calculate the DOS using the GRR method provided by the MATLAB code https://github.com/boyuanliuoptics/DOS-calculation/blob/master/DOS_GGR.m. To do so, we first parse the data to create the .txt files required as inputs to the program, compute the DOS using MATLAB and then add the DOS labels back to the original .h5 files. These steps will be executed automatically by simply running the shell script get_DOS.sh after modifying the h5 filename identifier defined at the top. Note that for this to run smoothly, python and MATLAB will first need to be added to PATH.

Time-independent Schrodinger Equation (TISE)

Relevant code stored in generate_datasets/TISE/. Example usage:

To generate target dataset, e.g. in 3D, python tise_gendata.py --h5filename="tise3d" --ndim 3 --nsam 5000

To generate low resolution dataset, python tise_gendata.py --h5filename='tise3d_lr' --ndim 3 --nsam 10000 --lowres --orires=32 (--orires defines the resolution of the input to the neural network)

To generate qho dataset, python tise_genqho.py --h5filename='tise2d_qho' --ndim 2 --nsam 10000

SIB-CL and baselines training

Training of the neural networks for all problems introduced in the article (i.e. PhC DOS prediction, PhC Band structure prediction, TISE ground state energy prediction using both low resolution or QHO data as surrogate) can all be executed using main.py by indicating the appropriate flags (see below). This code also allows training via the SIB-CL framework or any of the baselines, again with the use of the appropriate flag. This code also contains other prediction problems not presented in the article, such as predicting higher energy states of TISE, TISE wavefunctions and single band structure.

Important flags:

--path_to_h5: indicate directory where h5 datasets are located. The h5 filenames defined in the dataset classes in datasets_PhC_SE.py should also be modified according to the names used during dataset generation.

--predict: defines prediction task. Options: 'DOS', 'bandstructures', 'eigval', 'oneband', 'eigvec'

--train: specify if training via SIB-CL or baselines. Options: 'sibcl', 'tl', 'sl', 'ssl' ('ssl' performs regular contrastive learning without surrogate dataset). For invariance-boosted baselines, e.g. TL-I or SL-I, specify 'tl' or 'sl' here and add the relevant invariances flags (see below).

--iden: required; specify identifier for saving of models, training logs and results

Invariances flags: --translate_pbc (set this flag to include rolling translations), --pg_uniform (set this flag to uniformly sample the point group symmetry transformations), --scale (set this flag to scale unit cell - used for PhC), --rotate (set this flag to do 4-fold rotations), --flip (set this flag to perform horizontal and vertical mirrors). If --pg_uniform is used, there is no need to include --rotate and --flip.

Other optional flags can be displayed via python main.py --help. Examples of shell scripts can be found in the sh_scripts folder.

Training outputs:

By default, running main.py will create 3 subdirectories:

  • ./pretrained_models/: state dictionaries of pretrained models at various epochs indicated in the eplist variable will be saved to this directory. These models are used for further fine-tuning.
  • ./dicts/: stores the evaluation losses on the test set as dictionaries saved as .json files. The results can then be plotted using plot_results.py.
  • ./tlogs/: training curves for pre-training and fine-tuning are stored in dictionaries saved as .json files. The training curves can be plotted using get_training_logs.py. Alternatively, the --log_to_tensorboard flag can be set and training curves can be viewed using tensorboard; in this case, the dictionaries will not be generated.
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Dense Contrastive Learning (DenseCL) for self-supervised representation learning, CVPR 2021.
Dense Contrastive Learning (DenseCL) for self-supervised representation learning, CVPR 2021.

Dense Contrastive Learning for Self-Supervised Visual Pre-Training This project hosts the code for implementing the DenseCL algorithm for se

CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning
CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning

CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning This repository contains the code and relevant instructions

Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

PyTorch implementation of
PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

Code for 'Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning', ICCV 2021
Code for 'Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning', ICCV 2021

CMIC-Retrieval Code for Single Image 3D Shape Retrieval via Cross-Modal Instance and Category Contrastive Learning. ICCV 2021. Introduction In this wo

Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning].

CG3 This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning]. R

Releases(v1.0)
Owner
Charlotte Loh
PhD candidate at MIT EECS
Charlotte Loh
PyContinual (An Easy and Extendible Framework for Continual Learning)

PyContinual (An Easy and Extendible Framework for Continual Learning) Easy to Use You can sumply change the baseline, backbone and task, and then read

176 Jan 05, 2023
GARCH and Multivariate LSTM forecasting models for Bitcoin realized volatility with potential applications in crypto options trading, hedging, portfolio management, and risk management

Bitcoin Realized Volatility Forecasting with GARCH and Multivariate LSTM Author: Chi Bui This Repository Repository Directory ├── README.md

Chi Bui 113 Dec 29, 2022
Code for Environment Inference for Invariant Learning (ICML 2020 UDL Workshop Paper)

Environment Inference for Invariant Learning This code accompanies the paper Environment Inference for Invariant Learning, which appears at ICML 2021.

Elliot Creager 40 Dec 09, 2022
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 06, 2022
PyTorch for Semantic Segmentation

PyTorch for Semantic Segmentation This repository contains some models for semantic segmentation and the pipeline of training and testing models, impl

Zijun Deng 1.7k Jan 06, 2023
Official Code for VideoLT: Large-scale Long-tailed Video Recognition (ICCV 2021)

Pytorch Code for VideoLT [Website][Paper] Updates [10/29/2021] Features uploaded to Google Drive, for access please send us an e-mail: zhangxing18 at

Skye 26 Sep 18, 2022
Generating Images with Recurrent Adversarial Networks

Generating Images with Recurrent Adversarial Networks Python (Theano) implementation of Generating Images with Recurrent Adversarial Networks code pro

Daniel Jiwoong Im 121 Sep 08, 2022
SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data

SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data SurvITE: Learning Heterogeneous Treatment Effects from Time-to-Event Data Au

14 Nov 28, 2022
Catalyst.Detection

Accelerated DL R&D PyTorch framework for Deep Learning research and development. It was developed with a focus on reproducibility, fast experimentatio

Catalyst-Team 12 Oct 25, 2021
YOLOX + ROS(1, 2) object detection package

YOLOX + ROS(1, 2) object detection package

Ar-Ray 158 Dec 21, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 08, 2022
Survival analysis in Python

What is survival analysis and why should I learn it? Survival analysis was originally developed and applied heavily by the actuarial and medical commu

Cameron Davidson-Pilon 2k Jan 08, 2023
Self-supervised learning optimally robust representations for domain generalization.

OptDom: Learning Optimal Representations for Domain Generalization This repository contains the official implementation for Optimal Representations fo

Yangjun Ruan 18 Aug 25, 2022
A general-purpose, flexible, and easy-to-use simulator alongside an OpenAI Gym trading environment for MetaTrader 5 trading platform (Approved by OpenAI Gym)

gym-mtsim: OpenAI Gym - MetaTrader 5 Simulator MtSim is a simulator for the MetaTrader 5 trading platform alongside an OpenAI Gym environment for rein

Mohammad Amin Haghpanah 184 Dec 31, 2022
Deep Learning for Morphological Profiling

Deep Learning for Morphological Profiling An end-to-end implementation of a ML System for morphological profiling using self-supervised learning to di

Danielh Carranza 0 Jan 20, 2022
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022
Automatic labeling, conversion of different data set formats, sample size statistics, model cascade

Simple Gadget Collection for Object Detection Tasks Automatic image annotation Conversion between different annotation formats Obtain statistical info

llt 4 Aug 24, 2022
Code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning".

0. Introduction This repository contains the source code for our SIGCOMM'21 paper "Network Planning with Deep Reinforcement Learning". Notes The netwo

NetX Group 68 Nov 24, 2022
Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing

FGHV Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing Requirements Python 3.6 Pytorch 1.5.0 Cud

5 Jun 02, 2022
Spectrum is an AI that uses machine learning to generate Rap song lyrics

Spectrum Spectrum is an AI that uses deep learning to generate rap song lyrics. View Demo Report Bug Request Feature Open In Colab About The Project S

39 Dec 16, 2022