SmallInitEmb - LayerNorm(SmallInit(Embedding)) in a Transformer to improve convergence

Overview

SmallInitEmb

LayerNorm(SmallInit(Embedding)) in a Transformer

I find that when training a transformer, the embedding matrix moves slowly, hence it's difficult for the model to jump out of the initial noisy embedding.

(initial embedding)
[[-0.0073  0.0062 -0.0261 ...  0.0086  0.0107 -0.008 ] ... ]
 (after 1 step, the directions of the embedding vectors are not moved much because the numbers change by ~LR = ~4e-4)
[[-0.0069  0.0066 -0.0265 ...  0.009   0.0111 -0.0084] ... ]

So I propose initializing the embedding matrix to tiny values, and put another LayerNorm after it (before all the SA & FFN layers):

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
if self.config.USE_SMALL_EMB and self.layer_id == 0:
    x = self.lnPre(x) # LN(SmallInit(Emb))
x = x + self.att(self.ln1(x))
x = x + self.ffn(self.ln2(x))

And then you get improved convergence (especially for BPE models) because the model can quickly jump out of the tiny initial embedding (small changes after 1 step -> significant changes of directions -> significant changes after LayerNorm).

Loss curve comparison: https://wandb.ai/blinkdl/SmallEmbTest

(the gap between LayerNorm(SmallEmb)) and baseline persists after more training)

Moreover, you can directly train PostLN models without warmup with SmallInit(Emb)

if isinstance(module, (nn.Embedding)):
    nn.init.uniform_(module.weight, a=-1e-4, b=1e-4) # SmallInit(Emb)
...
x = self.ln1(x) # this plays the same role as the lnPre in the above PreLN code
x = x + self.att(x)
x = self.ln2(x)
x = x + self.ffn(x)
(note you shall have another LN after the final ffn)
Owner
PENG Bo
http://zhihu.com/people/bopengbopeng
PENG Bo
MapReader: A computer vision pipeline for the semantic exploration of maps at scale

MapReader A computer vision pipeline for the semantic exploration of maps at scale MapReader is an end-to-end computer vision (CV) pipeline designed b

Living with Machines 25 Dec 26, 2022
Source code for the paper: Variance-Aware Machine Translation Test Sets (NeurIPS 2021 Datasets and Benchmarks Track)

Variance-Aware-MT-Test-Sets Variance-Aware Machine Translation Test Sets License See LICENSE. We follow the data licensing plan as the same as the WMT

NLP2CT Lab, University of Macau 5 Dec 21, 2021
Fibonacci Method Gradient Descent

An implementation of the Fibonacci method for gradient descent, featuring a TKinter GUI for inputting the function / parameters to be examined and a matplotlib plot of the function and results.

Emma 1 Jan 28, 2022
Mall-Customers-Segmentation - Customer Segmentation Using K-Means Clustering

Overview Customer Segmentation is one the most important applications of unsupervised learning. Using clustering techniques, companies can identify th

NelakurthiSudheer 2 Jan 03, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
A pytorch-version implementation codes of paper: "BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation"

BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation A pytorch-version implementation

11 Oct 08, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 865 Nov 17, 2022
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
This repo is about implementing different approaches of pose estimation and also is a sub-task of the smart hospital bed project :smile:

Pose-Estimation This repo is a sub-task of the smart hospital bed project which is about implementing the task of pose estimation 😄 Many thanks to th

Max 11 Oct 17, 2022
Offical implementation for "Trash or Treasure? An Interactive Dual-Stream Strategy for Single Image Reflection Separation".

Trash or Treasure? An Interactive Dual-Stream Strategy for Single Image Reflection Separation (NeurIPS 2021) by Qiming Hu, Xiaojie Guo. Dependencies P

Qiming Hu 31 Dec 20, 2022
PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Yonglong Tian 2.2k Jan 08, 2023
A high-performance anchor-free YOLO. Exceeding yolov3~v5 with ONNX, TensorRT, NCNN, and Openvino supported.

YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and industrial communities. For more details, please refer to our rep

7.7k Jan 06, 2023
MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021)

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

2 Jan 29, 2022
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
An AI Assistant More Than a Toolkit

tymon An AI Assistant More Than a Toolkit The reason for creating framework tymon is simple. making AI more like an assistant, helping us to complete

TymonXie 46 Oct 24, 2022
Airbus Ship Detection Challenge

Airbus Ship Detection Challenge This is an open solution to the Airbus Ship Detection Challenge. Our goals We are building entirely open solution to t

minerva.ml 55 Nov 29, 2022
FairyTailor: Multimodal Generative Framework for Storytelling

FairyTailor: Multimodal Generative Framework for Storytelling

Eden Bens 172 Dec 30, 2022
Tracking Pipeline helps you to solve the tracking problem more easily

Tracking_Pipeline Tracking_Pipeline helps you to solve the tracking problem more easily I integrate detection algorithms like: Yolov5, Yolov4, YoloX,

VNOpenAI 32 Dec 21, 2022
NHL 94 AI contests

nhl94-ai The end goals of this project is to: Train Models that play NHL 94 Support AI vs AI contests in NHL 94 Provide an improved AI opponent for NH

Mathieu Poliquin 2 Dec 06, 2021
Code for "Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks", CVPR 2021

Neural Parts: Learning Expressive 3D Shape Abstractions with Invertible Neural Networks This repository contains the code that accompanies our CVPR 20

Despoina Paschalidou 161 Dec 20, 2022