Efficient Householder transformation in PyTorch

Overview

Efficient Householder Transformation in PyTorch

PyPiVersion PythonVersion PyPiDownloads License License: CC BY 4.0

This repository implements the Householder transformation algorithm for calculating orthogonal matrices and Stiefel frames. The algorithm is implemented as a Python package with differentiable bindings to PyTorch. In particular, the package provides an enhanced drop-in replacement for the torch.orgqr function.

Overview

APIs for orthogonal transformations have been around since LAPACK; however, their support in the deep learning frameworks is lacking. Recently, orthogonal constraints have become popular in deep learning as a way to regularize models and improve training dynamics [1, 2], and hence the need to backpropagate through orthogonal transformations arised.

PyTorch 1.7 implements matrix exponential function torch.matrix_exp, which can be repurposed to performing the orthogonal transformation when the input matrix is skew-symmetric. This is the baseline we use in Speed and Precision evaluation.

Compared to torch.matrix_exp, the Householder transformation implemented in this package has the following advantages:

  • Orders of magnitude lower memory footprint
  • Ability to transform non-square matrices (Stiefel frames)
  • A significant speed-up for non-square matrices
  • Better numerical precision for all matrix and batch sizes

Usage

Installation

pip install torch-householder

API

The Householder transformation takes a matrix of Householder reflectors parameters of shape d x r with d >= r > 0 (denoted as 'thin' matrix from now on) and produces an orthogonal matrix of the same shape.

torch_householder_orgqr(param) is the recommended API in the Deep Learning context. Other arguments of this function are provided for compatibility with the torch.orgqr (Reference) interface.

Unlike torch.orgqr, torch_householder_orgqr supports:

  • Both CPU and GPU devices
  • Backpropagation through arguments
  • Batched inputs

The parameter param is a matrix of size d x r or a batch of matrices b x d x r of Householder reflectors parameters. The LAPACK convention suggests to structure the matrix of parameters as shown in the figure on the right.

Given a matrix param of size d x r, here is a simple way to construct a valid matrix of Householder reflectors parameters from it:

hh = param.tril(diagonal=-1) + torch.eye(d, r)

This result can be used as an argument to torch_householder_orgqr.

Example

import torch
from torch_householder import torch_householder_orgqr

D, R = 4, 2
param = torch.randn(D, R)
hh = param.tril(diagonal=-1) + torch.eye(D, R)

mat = torch_householder_orgqr(hh)

print(mat)              # a 4x2 Stiefel frame
print(mat.T.mm(mat))    # a 2x2 identity matrix

Output:

tensor([[ 0.4141, -0.3049],
        [ 0.2262,  0.5306],
        [-0.5587,  0.6074],
        [ 0.6821,  0.5066]])
tensor([[ 1.0000e+00, -2.9802e-08],
        [-2.9802e-08,  1.0000e+00]])

Speed

Given a tuple of b (batch size), d (matrix height), and r (matrix width), we generate a random batch of orthogonal parameters of the given shape and perform a fixed number of orthogonal transformations with both torch.matrix_exp and torch_householder_orgqr functions. We then associate each such tuple with a ratio of run times taken by functions.

We perform a sweep of matrix dimensions d and r, starting with 1 and doubling each time until reaching 32768. The batch dimension is swept similarly until reaching the maximum size of 2048. The sweeps were performed on a single NVIDIA P-100 GPU with 16 GB of RAM using the code from the benchmark:

Speed chart

Since the ORGQR function's convention assumes only thin matrices with d >= r > 0, we skip the evaluation of fat matrices altogether.

The torch.matrix_exp only deals with square matrices; hence to parameterize a thin matrix with d > r, we perform transformation of a square skew-symmetric matrix d x d and then take a d x r minor from the result. This aspect makes torch.matrix_exp especially inefficient for parameterizing Stiefel frames and provides major speed gains to the Householder transformation.

Numerical Precision

The numerical precision of an orthogonal transformation is evaluated using a synthetic test. Given a matrix size d x r, we generate random inputs and perform orthogonal transformation with the tested function. Since the output M of size d x r is expected to be a Stiefel frame, we calculate transformation error using the formula below. This calculation is repeated for each matrix size at least 5 times, and the results are averaged.

We re-use the sweep used for benchmarking speed, compute errors of both functions using the formula above, and report their ratio in the following chart:

Error chart

Conclusions

The speed chart suggests that the Householder transformation is especially efficient when either matrix width r is smaller than matrix height d or when batch size is comparable with matrix dimensions. In these cases, the Householder transformation provides up to a few orders of magnitude speed-up. However, for the rest of the square matrix cases, torch.matrix_exp appears to be faster.

Another benefit of torch_householder_orgqr is its memory usage, which is much lower than that of torch.matrix_exp. This property allows us to transform either much larger matrices, or larger batches with torch_householder_orgqr. To give an example, in sub-figure corresponding to "batch size: 1", the largest matrix transformed with torch_householder_orgqr has the size 16384 x 16384, whereas the largest matrix transformed with torch.matrix_exp is only 4096 x 4096.

As can be seen from the precision chart, tests with the Householder transformation lead to orders of magnitude more accurate orthogonality constraints in all tested configurations.

Citation

To cite this repository, use the following BibTeX:

@misc{obukhov2021torchhouseholder,
    author = {Anton Obukhov},
    year = 2021,
    title = {Efficient Householder transformation in PyTorch},
    url = {www.github.com/toshas/torch-householder}
}
Comments
  • How are the gradients implemented for non-full rank matrices?

    How are the gradients implemented for non-full rank matrices?

    It is not clear how to implement the "gradient" (adjoint) of the QR decomposition for a matrix that it is not full rank (e.g. the zero matrix). How does this package handle this?

    Also, if this is a drop-in replacement for the QR decomposition implemented in PyTorch and it works better, why not making a PR to core PyTorch with this? Where does the speed-up come from vs LAPACK / MAGMA?

    opened by lezcano 9
  • error when installing torch-householder via pip

    error when installing torch-householder via pip

    Hello, I am having trouble installing the package. I don't know if there is a dependency problem. Could there be an installation problem due to the version of CUDA? My specs are using Ubuntu 18.04 and CUDA 11.3. There is no difference from installing in CUDA 10.2 ?

    Collecting torch-householder
      Using cached torch_householder-1.0.1.tar.gz (457 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bee491246ef740cbb618d42e92070ff3
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/f3/7d/a87d4ea6c11f23d237fc81c094a6c18909486fdb9914599479cbeb5d089f/torch_householder-1.0.1.tar.gz#sha256=9a4b240c68947491c4e96a78771497562650f9a555001e062a0969fce206f786 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61 Check the logs for full command output.
      Using cached torch_householder-1.0.0.tar.gz (177 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bb93639781d44e6799b67c9f4f83fae9
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/c2/9c/7af0f1414e24c09ddc67439364c70c7e894c65ed83f94aa82a0ff3308673/torch_householder-1.0.0.tar.gz#sha256=e9f06c29685a6bcbc360af5c8cbae9227d9990e3e9acc744b0ea15654ab782c0 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement torch-householder (from versions: 1.0.0, 1.0.1)
    ERROR: No matching distribution found for torch-householder
    
    opened by jeongwhanchoi 2
  • [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    I have some code that uses torch.matrix_exp and would be very happy to speed it up. Is this possible using your library? I sort of lost my way trying to figure it out from the benchmarks code. Many thanks in advance?

    opened by generalizedeigenvector 1
  • The parametrisation does not seem to be surjective.

    The parametrisation does not seem to be surjective.

    When having a look at the implementation and looking at the differences with torch.householder_product, I found a weird thing.

    At the moment, when called with a tensor of the form hh = param.tril(diagonal=-1) + torch.eye(d, r) (as per the documentation) we are passing nk - k(k+1)/2 parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).

    Now, when this matrix is passed to torch_householder_orgqr, it columns are normalised: https://github.com/toshas/torch-householder/blob/afe72ebf9a01c257a5070a6d01b3f81be9e8efd6/torch_householder/householder.py#L67 Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.

    Do you know what is going on here?

    opened by lezcano 1
  • > \prod_{i=1}^k H(v_i)

    > \prod_{i=1}^k H(v_i)

    \prod_{i=1}^k H(v_i), is that right? Right. Hi, toshas, your work is great, thanks! You confirm that if we denote H(v) = I_n - v v^T / \norm{v}^2 a Householder reflection, the torch_householder_orgqr() computes \prod_{i=1}^k H(v_i), and I have two questions, one is I think H(v) = I_n – 2vv^T/\norm{v}^2, and the other is \prod_{i=1}^k H(v_i) means H_1* H_2*…H_(n-1) or H_(n-1)* H_(n-2)…H_1 ? If it is H_1 H_2*…H_(n-1), it refers to Q in QR decomposition, right? But you said earlier that this is not a QR decomposition, so I feel a little confused. Hope to get your prompt reply, thank you.

    opened by xjzhao001 1
  • Slow performance compared to `torch.linalg.householder_product` in forward pass

    Slow performance compared to `torch.linalg.householder_product` in forward pass

    Problem

    I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

    Code

    I'm constructing the matrix during forward pass like this:

    def __init__(self, ..):
           [..]
    
            # householder decomposition
            decomp, tau = torch.geqrf(filters)
    
            # assume that DCT is orthogonal
            filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])
    
            # register everything as parameter and set gradient flags
            self.filters = torch.nn.Parameter(filters, requires_grad=True)
            self.register_parameter('filter_q', self.filters)
    
    def filters(self):
            valid_coeffs = self.filters.tril(diagonal=-1)
            tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
            return torch.linalg.householder_product(valid_coeffs, tau)
            #return torch_householder_orgqr(valid_coeffs, tau)
    

    Profiles

    All profiles are created with the pytorch profiler with warmup of one and two trial runs:

    Profile torch householder_product (matrix 512x512 f32)

    • forward pass: ~823us
    • backward pass: ~790ms

    Marked forward pass and backward pass visible in light green:

    image

    Profile torch-householder (matrix 512x512 f32)

    • forward pass: ~240ms
    • backward pass: ~513ms

    image

    Questions

    I'm not an expert in torch and do not follow the development closely. There is an issue https://github.com/pytorch/pytorch/issues/50104 for integrating CUDA support to orgqr, may this cause the difference in time?

    • why is the torch-householder library much slower in the forward pass
    • is this performance expected from AD of a matrix w.r.t to its householder or am I doing something wrong here?
    • why does the number actually add up again to ~800ms, this makes me suspect that my profiling is doing something wrong but couldn't find a cause

    I'm also happy to share the traces with you, please just ping then :)

    opened by bytesnake 6
  • [POLL] Should the package switch install-time to run-time native code compilation?

    [POLL] Should the package switch install-time to run-time native code compilation?

    Currently, the package compiles native code (C++) upon package installation. This saves a few seconds during code run time, as the compilation does not happen when the user code starts. However, one scenario when it hurts is when the package is installed from a different environment or a machine than the actual code will be run on. This is a use case with most cluster environments, where packages may be installed from a login node, rather than the actual machine with the GPU.

    Should compilation be rather performed at run time?

    👍 - Move compilation to run time 👎 - Keep as is

    opened by toshas 0
Releases(v1.0.1)
Owner
Anton Obukhov
CV+ML PhD student with industrial past
Anton Obukhov
Learning Energy-Based Models by Diffusion Recovery Likelihood

Learning Energy-Based Models by Diffusion Recovery Likelihood Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma Paper: https://arxiv.o

Ruiqi Gao 41 Nov 22, 2022
PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM

Quasi-Recurrent Neural Network (QRNN) for PyTorch Updated to support multi-GPU environments via DataParallel - see the the multigpu_dataparallel.py ex

Salesforce 1.3k Dec 28, 2022
FANet - Real-time Semantic Segmentation with Fast Attention

FANet Real-time Semantic Segmentation with Fast Attention Ping Hu, Federico Perazzi, Fabian Caba Heilbron, Oliver Wang, Zhe Lin, Kate Saenko , Stan Sc

Ping Hu 42 Nov 30, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

102 Dec 25, 2022
Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance

Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance Project Page | Paper | Data This repository contains an implementatio

Lior Yariv 521 Dec 30, 2022
Fast, differentiable sorting and ranking in PyTorch

Torchsort Fast, differentiable sorting and ranking in PyTorch. Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.)

Teddy Koker 655 Jan 04, 2023
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles

Workspace Permissions Manage the availability of workspaces within Frappe/ ERPNext (sidebar) based on user-roles. Features Configure foreach workspace

Patrick.St. 18 Sep 26, 2022
CurriculumNet: Weakly Supervised Learning from Large-Scale Web Images

CurriculumNet Introduction This repo contains related code and models from the ECCV 2018 CurriculumNet paper. CurriculumNet is a new training strategy

156 Jul 04, 2022
2D Time independent Schrodinger equation solver for arbitrary shape of well

Schrodinger Well Python Python solver for timeless Schrodinger equation for well with arbitrary shape https://imgur.com/a/jlhK7OZ Pictures of circular

WeightAn 24 Nov 18, 2022
BridgeGAN - Tensorflow implementation of Bridging the Gap between Label- and Reference-based Synthesis in Multi-attribute Image-to-Image Translation.

Bridging the Gap between Label- and Reference based Synthesis(ICCV 2021) Tensorflow implementation of Bridging the Gap between Label- and Reference-ba

huangqiusheng 8 Jul 13, 2022
Repo for the Tutorials of Day1-Day3 of the Nordic Probabilistic AI School 2021 (https://probabilistic.ai/)

ProbAI 2021 - Probabilistic Programming and Variational Inference Tutorial with Pryo Day 1 (June 14) Slides Notebook: students_PPLs_Intro Notebook: so

PGM-Lab 46 Nov 01, 2022
Genetic Programming in Python, with a scikit-learn inspired API

Welcome to gplearn! gplearn implements Genetic Programming in Python, with a scikit-learn inspired and compatible API. While Genetic Programming (GP)

Trevor Stephens 1.3k Jan 03, 2023
Graph Analysis From Scratch

Graph Analysis From Scratch Goal In this notebook we wanted to implement some functionalities to analyze a weighted graph only by using algorithms imp

Arturo Ghinassi 0 Sep 17, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Composing methods for ML training efficiency

MosaicML Composer contains a library of methods, and ways to compose them together for more efficient ML training.

MosaicML 2.8k Jan 08, 2023
Facial detection, landmark tracking and expression transfer library for Windows, Linux and Mac

Welcome to the CSIRO Face Analysis SDK. Documentation for the SDK can be found in doc/documentation.html. All code in this SDK is provided according t

Luiz Carlos Vieira 7 Jul 16, 2020
Keras Image Embeddings using Contrastive Loss

Image to Embedding projection in vector space. Implementation in keras and tensorflow of batch all triplet loss for one-shot/few-shot learning.

Shravan Anand K 5 Mar 21, 2022
A Dying Light 2 (DL2) PAKFile Utility for Modders and Mod Makers.

Dying Light 2 PAKFile Utility A Dying Light 2 (DL2) PAKFile Utility for Modders and Mod Makers. This tool aims to make PAKFile (.pak files) modding a

RHQ Online 12 Aug 26, 2022
JupyterLite demo deployed to GitHub Pages 🚀

JupyterLite Demo JupyterLite deployed as a static site to GitHub Pages, for demo purposes. ✨ Try it in your browser ✨ ➡️ https://jupyterlite.github.io

JupyterLite 223 Jan 04, 2023