Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

Overview

CoulombGas

Build Status

This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX, it utilizes (both forward- and backwark-mode) automatic differentiation and the pmap mechanism to achieve a large-scale single-program multiple-data (SPMD) training on multiple GPUs.

Requirements

  • JAX with Nvidia GPU support
  • A handful of GPUs. The more the better :P
  • haiku
  • optax
  • To analytically computing the thermal entropy of a non-interacting Fermi gas in the canonical ensemble based on arbitrary-precision arithmetic, we have used the python library mpmath.

Demo run

To start, try running the following commands to launch a training of 13 spin-polarized electrons in 2D with the dimensionless density parameter 10.0 and (reduced) temperature 0.15 on 8 GPUs:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python main.py --n 13 --dim 2 --rs 10.0 --Theta 0.15 --Emax 25 --sr --batch 4096 --num_devices 8 --acc_steps 2

Note that we effectively sample a batch of totally 8192 samples in each training step. However, such a batch size will result in too large a memory consumption to be accommodated by 8 GPUs. To overcome this problem, we choose to split the batch into two equal pieces, and accumulate the gradient and various observables for each piece in two sequential substeps. In other words, the argument batch in the command above actually stands for the batch per accumulation step.

If you have only, say, 4 GPUs, you can set batch, num_devices, acc_steps to be 2048, 4 and 4 respectively to launch the same training process, at the expense of doubling the running time. The GPU hours are nevertheless the same.

For the detail meaning of other command line arguments, run

python main.py --help

or directly refer to the source code.

Trained model and data

A training process from complete scratch actually contains two stages. In the first stage, a variational autoregressive network is pretrained to approximate the Boltzmann distribution of the corresponding non-interacting electron gas. The resulting model can be saved and then loaded later. In fact, we have provided such a model file for the parameter settings of the last section for your convenience, so you can quickly get a feeling of the second stage of training the truly interacting system of our interest. We encourage you to remove the file to pretrain the model by yourself; it is actually much faster than the training in the second stage.

To facilitate further developments, we also provide the training models and logged data for various calculations in the paper, which are located in the data directory.

To cite

arxiv

Owner
FermiFlow
ab-initio study of fermions at finite temperature
FermiFlow
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 06, 2023
Code release for NeRF (Neural Radiance Fields)

NeRF: Neural Radiance Fields Project Page | Video | Paper | Data Tensorflow implementation of optimizing a neural representation for a single scene an

6.5k Jan 01, 2023
Cave Generation using metaballs in Blender. Originally created by sdfgeoff, Edited by Myself (Archie Jaskowicz).

Blender-Cave-Generation Cave Generation using metaballs in Blender. Originally created by sdfgeoff, Edited by Myself (Archie Jaskowicz). Installation

2 Dec 28, 2022
Code for our CVPR 2021 paper "MetaCam+DSCE"

Joint Noise-Tolerant Learning and Meta Camera Shift Adaptation for Unsupervised Person Re-Identification (CVPR'21) Introduction Code for our CVPR 2021

FlyingRoastDuck 59 Oct 31, 2022
Distributed DataLoader For Pytorch Based On Ray

Dpex——用户无感知分布式数据预处理组件 一、前言 随着GPU与CPU的算力差距越来越大以及模型训练时的预处理Pipeline变得越来越复杂,CPU部分的数据预处理已经逐渐成为了模型训练的瓶颈所在,这导致单机的GPU配置的提升并不能带来期望的线性加速。预处理性能瓶颈的本质在于每个GPU能够使用的C

Dalong 23 Nov 02, 2022
PyTorch code for the paper "Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval".

Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval (M2HSE) PyTorch code fo

Xinlei-Pei 6 Dec 23, 2022
PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation

PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation Created by Charles R. Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas from Sta

Charles R. Qi 4k Dec 30, 2022
Quantum-enhanced transformer neural network

Example of a Quantum-enhanced transformer neural network Get the code: git clone https://github.com/rdisipio/qtransformer.git cd qtransformer Create

Riccardo Di Sipio 61 Nov 08, 2022
On the adaptation of recurrent neural networks for system identification

On the adaptation of recurrent neural networks for system identification This repository contains the Python code to reproduce the results of the pape

Marco Forgione 3 Jan 13, 2022
PyTorch implementation DRO: Deep Recurrent Optimizer for Structure-from-Motion

DRO: Deep Recurrent Optimizer for Structure-from-Motion This is the official PyTorch implementation code for DRO-sfm. For technical details, please re

Alibaba Cloud 56 Dec 12, 2022
3D HourGlass Networks for Human Pose Estimation Through Videos

3D-HourGlass-Network 3D CNN Based Hourglass Network for Human Pose Estimation (3D Human Pose) from videos. This was my summer'18 research project. Dis

Naman Jain 51 Jan 02, 2023
This is an official implementation of our CVPR 2021 paper "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression" (https://arxiv.org/abs/2104.02300)

Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression Introduction In this paper, we are interested in the bottom-up paradigm of estima

HRNet 367 Dec 27, 2022
A more easy-to-use implementation of KPConv based on PyTorch.

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 36 Dec 29, 2022
Malware Analysis Neural Network project.

MalanaNeuralNetwork Description Malware Analysis Neural Network project. Table of Contents Getting Started Requirements Installation Clone Set-Up VENV

2 Nov 13, 2021
Airborne Optical Sectioning (AOS) is a wide synthetic-aperture imaging technique

AOS: Airborne Optical Sectioning Airborne Optical Sectioning (AOS) is a wide synthetic-aperture imaging technique that employs manned or unmanned airc

JKU Linz, Institute of Computer Graphics 39 Dec 09, 2022
Dyalog-apl-docset - Dyalog APL Dash Docset Generator

Dyalog APL Dash Docset Generator o alasa e kili sona kepeken tenpo lili a A Dash

Maciej Goszczycki 1 Jan 10, 2022
Pipeline code for Sequential-GAM(Genome Architecture Mapping).

Sequential-GAM Pipeline code for Sequential-GAM(Genome Architecture Mapping). mapping whole_preprocess.sh include the whole processing of mapping. usa

3 Nov 03, 2022
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.

HAWQ: Hessian AWare Quantization HAWQ is an advanced quantization library written for PyTorch. HAWQ enables low-precision and mixed-precision uniform

Zhen Dong 293 Dec 30, 2022
Codes of the paper Deformable Butterfly: A Highly Structured and Sparse Linear Transform.

Deformable Butterfly: A Highly Structured and Sparse Linear Transform DeBut Advantages DeBut generalizes the square power of two butterfly factor matr

Rui LIN 8 Jun 10, 2022