A port of muP to JAX/Haiku

Overview

MUP for Haiku

This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to suggestions on improving the usability.

Installation

pip install haiku-mup

Learning rate demo

These plots show the evolution of the optimal learning rate for a 3-hidden-layer MLP on MNIST, trained for 10 epochs (5 trials per lr/width combination).

With standard parameterization, the learning rate optimum (w.r.t. training loss) continues changing as the width increases, but μP keeps it approximately fixed:

Here's the same kind of plot for 3 layer transformers on the Penn Treebank, this time showing Validation loss instead of training loss, scaling both the number of heads and the embedding dimension simultaneously:

Note that the optima have the same value for n_embd=80. That's because the other hyperparameters were tuned using an SP model with that width, so this shouldn't be biased in favor of μP.

Usage

from functools import partial

import jax
import jax.numpy as jnp
import haiku as hk
from optax import adam, chain

from haiku_mup import apply_mup, Mup, Readout

class MyModel(hk.Module):
    def __init__(self, width, n_classes=10):
        super().__init__(name='model')
        self.width = width
        self.n_classes = n_classes

    def __call__(self, x):
        x = hk.Linear(self.width)(x)
        x = jax.nn.relu(x)
        return Readout(2)(x) # 1. Replace output layer with Readout layer

def fn(x, width=100):
    with apply_mup(): # 2. Modify parameter creation with apply_mup()
        return MyModel(width)(x)

mup = Mup()

init_input = jnp.zeros(123)
base_model = hk.transform(partial(fn, width=1))

with mup.init_base(): # 3. Use this context manager when initializing the base model
    hk.init(fn, jax.random.PRNGKey(0), init_input) 

model = hk.transform(fn)

with mup.init_target(): # 4. Use this context manager when initializng the target model
    params = model.init(jax.random.PRNGKey(0), init_input)

model = mup.wrap_model(model) # 5. Modify your model with Mup

optimizer = optax.adam(3e-4)
optimizer = mup.wrap_optimizer(optimizer, adam=True) # 6. Use wrap_optimizer to get layer specific learning rates

# Now the model can be trained as normal

Summary

  1. Replace output layers with Readout layers
  2. Modify parameter creation with the apply_mup() context manager
  3. Initialize a base model inside a Mup.init_base() context
  4. Initialize the target model inside a Mup.init_target() context
  5. Wrap the model with Mup.wrap_model
  6. Wrap optimizer with Mup.wrap_optimizer

Shared Input/Output embeddings

If you want to use the input embedding matrix as the output layer's weight matrix make the following two replacements:

# old: embedding_layer = hk.Embed(*args, **kwargs)
# new:
embedding_layer = haiku_mup.SharedEmbed(*args, **kwargs)
input_embeds = embedding_layer(x)

#old: output = hk.Linear(n_classes)(x)
# new:
output = haiku_mup.SharedReadout()(embedding_layer.get_weights(), x) 
Implementation of Nyström Self-attention, from the paper Nyströmformer

Nyström Attention Implementation of Nyström Self-attention, from the paper Nyströmformer. Yannic Kilcher video Install $ pip install nystrom-attention

Phil Wang 95 Jan 02, 2023
nanodet_plus,yolov5_v6.0

OAK_Detection OAK设备上适配nanodet_plus,yolov5_v6.0 Environment pytorch = 1.7.0

炼丹去了 1 Feb 18, 2022
This is a TensorFlow implementation for C2-Rec

This is a TensorFlow implementation for C2-Rec We refer to the repo SASRec. Requirements requirement.txt Datasets This repo includes Amazon Beauty dat

7 Nov 14, 2022
Using PyTorch Perform intent classification using three different models to see which one is better for this task

Using PyTorch Perform intent classification using three different models to see which one is better for this task

Yoel Graumann 1 Feb 14, 2022
A few stylization coreML models that I've trained with CreateML

CoreML-StyleTransfer A few stylization coreML models that I've trained with CreateML You can open and use the .mlmodel files in the "models" folder in

Doron Adler 8 Aug 18, 2022
Code for "Learning the Best Pooling Strategy for Visual Semantic Embedding", CVPR 2021

Learning the Best Pooling Strategy for Visual Semantic Embedding Official PyTorch implementation of the paper Learning the Best Pooling Strategy for V

Jiacheng Chen 106 Jan 06, 2023
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Metrics to evaluate quality and efficacy of synthetic datasets.

An Open Source Project from the Data to AI Lab, at MIT Metrics for Synthetic Data Generation Projects Website: https://sdv.dev Documentation: https://

The Synthetic Data Vault Project 129 Jan 03, 2023
Complete system for facial identity system. Include one-shot model, database operation, features visualization, monitoring

Complete system for facial identity system. Include one-shot model, database operation, features visualization, monitoring

2 Dec 28, 2021
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

25 Dec 08, 2022
Codebase for arXiv preprint "NeRF++: Analyzing and Improving Neural Radiance Fields"

NeRF++ Codebase for arXiv preprint "NeRF++: Analyzing and Improving Neural Radiance Fields" Work with 360 capture of large-scale unbounded scenes. Sup

Kai Zhang 722 Dec 28, 2022
The Dual Memory is build from a simple CNN for the deep memory and Linear Regression fro the fast Memory

Simple-DMA a simple Dual Memory Architecture for classifications. based on the paper Dual-Memory Deep Learning Architectures for Lifelong Learning of

1 Jan 27, 2022
Keyword-BERT: Keyword-Attentive Deep Semantic Matching

project discription An implementation of the Keyword-BERT model mentioned in my paper Keyword-Attentive Deep Semantic Matching (Plz cite this github r

1 Nov 14, 2021
Soft actor-critic is a deep reinforcement learning framework for training maximum entropy policies in continuous domains.

This repository is no longer maintained. Please use our new Softlearning package instead. Soft Actor-Critic Soft actor-critic is a deep reinforcement

Tuomas Haarnoja 752 Jan 07, 2023
Semi-supervised Implicit Scene Completion from Sparse LiDAR

Semi-supervised Implicit Scene Completion from Sparse LiDAR Paper Created by Pengfei Li, Yongliang Shi, Tianyu Liu, Hao Zhao, Guyue Zhou and YA-QIN ZH

114 Nov 30, 2022
190 Jan 03, 2023
Machine Learning automation and tracking

The Open-Source MLOps Orchestration Framework MLRun is an open-source MLOps framework that offers an integrative approach to managing your machine-lea

873 Jan 04, 2023
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
A basic implementation of Layer-wise Relevance Propagation (LRP) in PyTorch.

Layer-wise Relevance Propagation (LRP) in PyTorch Basic unsupervised implementation of Layer-wise Relevance Propagation (Bach et al., Montavon et al.)

Kai Fabi 28 Dec 26, 2022