Efficient Householder transformation in PyTorch

Overview

Efficient Householder Transformation in PyTorch

PyPiVersion PythonVersion PyPiDownloads License License: CC BY 4.0

This repository implements the Householder transformation algorithm for calculating orthogonal matrices and Stiefel frames. The algorithm is implemented as a Python package with differentiable bindings to PyTorch. In particular, the package provides an enhanced drop-in replacement for the torch.orgqr function.

Overview

APIs for orthogonal transformations have been around since LAPACK; however, their support in the deep learning frameworks is lacking. Recently, orthogonal constraints have become popular in deep learning as a way to regularize models and improve training dynamics [1, 2], and hence the need to backpropagate through orthogonal transformations arised.

PyTorch 1.7 implements matrix exponential function torch.matrix_exp, which can be repurposed to performing the orthogonal transformation when the input matrix is skew-symmetric. This is the baseline we use in Speed and Precision evaluation.

Compared to torch.matrix_exp, the Householder transformation implemented in this package has the following advantages:

  • Orders of magnitude lower memory footprint
  • Ability to transform non-square matrices (Stiefel frames)
  • A significant speed-up for non-square matrices
  • Better numerical precision for all matrix and batch sizes

Usage

Installation

pip install torch-householder

API

The Householder transformation takes a matrix of Householder reflectors parameters of shape d x r with d >= r > 0 (denoted as 'thin' matrix from now on) and produces an orthogonal matrix of the same shape.

torch_householder_orgqr(param) is the recommended API in the Deep Learning context. Other arguments of this function are provided for compatibility with the torch.orgqr (Reference) interface.

Unlike torch.orgqr, torch_householder_orgqr supports:

  • Both CPU and GPU devices
  • Backpropagation through arguments
  • Batched inputs

The parameter param is a matrix of size d x r or a batch of matrices b x d x r of Householder reflectors parameters. The LAPACK convention suggests to structure the matrix of parameters as shown in the figure on the right.

Given a matrix param of size d x r, here is a simple way to construct a valid matrix of Householder reflectors parameters from it:

hh = param.tril(diagonal=-1) + torch.eye(d, r)

This result can be used as an argument to torch_householder_orgqr.

Example

import torch
from torch_householder import torch_householder_orgqr

D, R = 4, 2
param = torch.randn(D, R)
hh = param.tril(diagonal=-1) + torch.eye(D, R)

mat = torch_householder_orgqr(hh)

print(mat)              # a 4x2 Stiefel frame
print(mat.T.mm(mat))    # a 2x2 identity matrix

Output:

tensor([[ 0.4141, -0.3049],
        [ 0.2262,  0.5306],
        [-0.5587,  0.6074],
        [ 0.6821,  0.5066]])
tensor([[ 1.0000e+00, -2.9802e-08],
        [-2.9802e-08,  1.0000e+00]])

Speed

Given a tuple of b (batch size), d (matrix height), and r (matrix width), we generate a random batch of orthogonal parameters of the given shape and perform a fixed number of orthogonal transformations with both torch.matrix_exp and torch_householder_orgqr functions. We then associate each such tuple with a ratio of run times taken by functions.

We perform a sweep of matrix dimensions d and r, starting with 1 and doubling each time until reaching 32768. The batch dimension is swept similarly until reaching the maximum size of 2048. The sweeps were performed on a single NVIDIA P-100 GPU with 16 GB of RAM using the code from the benchmark:

Speed chart

Since the ORGQR function's convention assumes only thin matrices with d >= r > 0, we skip the evaluation of fat matrices altogether.

The torch.matrix_exp only deals with square matrices; hence to parameterize a thin matrix with d > r, we perform transformation of a square skew-symmetric matrix d x d and then take a d x r minor from the result. This aspect makes torch.matrix_exp especially inefficient for parameterizing Stiefel frames and provides major speed gains to the Householder transformation.

Numerical Precision

The numerical precision of an orthogonal transformation is evaluated using a synthetic test. Given a matrix size d x r, we generate random inputs and perform orthogonal transformation with the tested function. Since the output M of size d x r is expected to be a Stiefel frame, we calculate transformation error using the formula below. This calculation is repeated for each matrix size at least 5 times, and the results are averaged.

We re-use the sweep used for benchmarking speed, compute errors of both functions using the formula above, and report their ratio in the following chart:

Error chart

Conclusions

The speed chart suggests that the Householder transformation is especially efficient when either matrix width r is smaller than matrix height d or when batch size is comparable with matrix dimensions. In these cases, the Householder transformation provides up to a few orders of magnitude speed-up. However, for the rest of the square matrix cases, torch.matrix_exp appears to be faster.

Another benefit of torch_householder_orgqr is its memory usage, which is much lower than that of torch.matrix_exp. This property allows us to transform either much larger matrices, or larger batches with torch_householder_orgqr. To give an example, in sub-figure corresponding to "batch size: 1", the largest matrix transformed with torch_householder_orgqr has the size 16384 x 16384, whereas the largest matrix transformed with torch.matrix_exp is only 4096 x 4096.

As can be seen from the precision chart, tests with the Householder transformation lead to orders of magnitude more accurate orthogonality constraints in all tested configurations.

Citation

To cite this repository, use the following BibTeX:

@misc{obukhov2021torchhouseholder,
    author = {Anton Obukhov},
    year = 2021,
    title = {Efficient Householder transformation in PyTorch},
    url = {www.github.com/toshas/torch-householder}
}
Comments
  • How are the gradients implemented for non-full rank matrices?

    How are the gradients implemented for non-full rank matrices?

    It is not clear how to implement the "gradient" (adjoint) of the QR decomposition for a matrix that it is not full rank (e.g. the zero matrix). How does this package handle this?

    Also, if this is a drop-in replacement for the QR decomposition implemented in PyTorch and it works better, why not making a PR to core PyTorch with this? Where does the speed-up come from vs LAPACK / MAGMA?

    opened by lezcano 9
  • error when installing torch-householder via pip

    error when installing torch-householder via pip

    Hello, I am having trouble installing the package. I don't know if there is a dependency problem. Could there be an installation problem due to the version of CUDA? My specs are using Ubuntu 18.04 and CUDA 11.3. There is no difference from installing in CUDA 10.2 ?

    Collecting torch-householder
      Using cached torch_householder-1.0.1.tar.gz (457 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bee491246ef740cbb618d42e92070ff3
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-40k3d87l/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/f3/7d/a87d4ea6c11f23d237fc81c094a6c18909486fdb9914599479cbeb5d089f/torch_householder-1.0.1.tar.gz#sha256=9a4b240c68947491c4e96a78771497562650f9a555001e062a0969fce206f786 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmp_p_1kd61 Check the logs for full command output.
      Using cached torch_householder-1.0.0.tar.gz (177 kB)
      Installing build dependencies ... done
      Getting requirements to build wheel ... error
      ERROR: Command errored out with exit status 1:
       command: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp
           cwd: /tmp/pip-install-ot451iqp/torch-householder_bb93639781d44e6799b67c9f4f83fae9
      Complete output (21 lines):
      Traceback (most recent call last):
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 363, in <module>
          main()
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 345, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py", line 130, in get_requires_for_build_wheel
          return hook(config_settings)
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 338, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=['wheel'])
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 320, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/setuptools/build_meta.py", line 335, in run_setup
          exec(code, locals())
        File "<string>", line 4, in <module>
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 191, in <module>
          _load_global_deps()
        File "/tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/__init__.py", line 153, in _load_global_deps
          ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
        File "/home/ubuntu/anaconda3/envs/nsd/lib/python3.9/ctypes/__init__.py", line 374, in __init__
          self._handle = _dlopen(self._name, mode)
      OSError: /tmp/pip-build-env-x9jwr6cn/overlay/lib/python3.9/site-packages/torch/lib/../../nvidia/cublas/lib/libcublas.so.11: symbol cublasLtGetStatusString version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
      ----------------------------------------
    WARNING: Discarding https://files.pythonhosted.org/packages/c2/9c/7af0f1414e24c09ddc67439364c70c7e894c65ed83f94aa82a0ff3308673/torch_householder-1.0.0.tar.gz#sha256=e9f06c29685a6bcbc360af5c8cbae9227d9990e3e9acc744b0ea15654ab782c0 (from https://pypi.org/simple/torch-householder/) (requires-python:>=3.6). Command errored out with exit status 1: /home/ubuntu/anaconda3/envs/nsd/bin/python3.9 /home/ubuntu/anaconda3/envs/nsd/lib/python3.9/site-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpiqcotimp Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement torch-householder (from versions: 1.0.0, 1.0.1)
    ERROR: No matching distribution found for torch-householder
    
    opened by jeongwhanchoi 2
  • [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    [Noob] Can I use this somehow as a drop-in replacement for matrix_exp?

    I have some code that uses torch.matrix_exp and would be very happy to speed it up. Is this possible using your library? I sort of lost my way trying to figure it out from the benchmarks code. Many thanks in advance?

    opened by generalizedeigenvector 1
  • The parametrisation does not seem to be surjective.

    The parametrisation does not seem to be surjective.

    When having a look at the implementation and looking at the differences with torch.householder_product, I found a weird thing.

    At the moment, when called with a tensor of the form hh = param.tril(diagonal=-1) + torch.eye(d, r) (as per the documentation) we are passing nk - k(k+1)/2 parameters. This is correct, as it is the number of parameters necessary to parametrise the orthogonal matrices (i.e. it is the dimension of the Stiefel manifold).

    Now, when this matrix is passed to torch_householder_orgqr, it columns are normalised: https://github.com/toshas/torch-householder/blob/afe72ebf9a01c257a5070a6d01b3f81be9e8efd6/torch_householder/householder.py#L67 Now, this is another constraint. Each column is normalised, which removes another k degrees of freedom (k dimensions to be precise). This means that the current implementation cannot represent all the possible orthogonal matrices.

    Do you know what is going on here?

    opened by lezcano 1
  • > \prod_{i=1}^k H(v_i)

    > \prod_{i=1}^k H(v_i)

    \prod_{i=1}^k H(v_i), is that right? Right. Hi, toshas, your work is great, thanks! You confirm that if we denote H(v) = I_n - v v^T / \norm{v}^2 a Householder reflection, the torch_householder_orgqr() computes \prod_{i=1}^k H(v_i), and I have two questions, one is I think H(v) = I_n – 2vv^T/\norm{v}^2, and the other is \prod_{i=1}^k H(v_i) means H_1* H_2*…H_(n-1) or H_(n-1)* H_(n-2)…H_1 ? If it is H_1 H_2*…H_(n-1), it refers to Q in QR decomposition, right? But you said earlier that this is not a QR decomposition, so I feel a little confused. Hope to get your prompt reply, thank you.

    opened by xjzhao001 1
  • Slow performance compared to `torch.linalg.householder_product` in forward pass

    Slow performance compared to `torch.linalg.householder_product` in forward pass

    Problem

    I'm using an orthonormal constrained matrix in a learnable filterbank setting. Now I want to optimize the training and run some profiling with torch, but getting strange results. Just want to double-check here whether I'm doing something wrong.

    Code

    I'm constructing the matrix during forward pass like this:

    def __init__(self, ..):
           [..]
    
            # householder decomposition
            decomp, tau = torch.geqrf(filters)
    
            # assume that DCT is orthogonal
            filters = decomp.tril(diagonal=-1) + torch.eye(decomp.shape[0], decomp.shape[1])
    
            # register everything as parameter and set gradient flags
            self.filters = torch.nn.Parameter(filters, requires_grad=True)
            self.register_parameter('filter_q', self.filters)
    
    def filters(self):
            valid_coeffs = self.filters.tril(diagonal=-1)
            tau = 2. / (1 + torch.norm(valid_coeffs, dim=0) ** 2)
            return torch.linalg.householder_product(valid_coeffs, tau)
            #return torch_householder_orgqr(valid_coeffs, tau)
    

    Profiles

    All profiles are created with the pytorch profiler with warmup of one and two trial runs:

    Profile torch householder_product (matrix 512x512 f32)

    • forward pass: ~823us
    • backward pass: ~790ms

    Marked forward pass and backward pass visible in light green:

    image

    Profile torch-householder (matrix 512x512 f32)

    • forward pass: ~240ms
    • backward pass: ~513ms

    image

    Questions

    I'm not an expert in torch and do not follow the development closely. There is an issue https://github.com/pytorch/pytorch/issues/50104 for integrating CUDA support to orgqr, may this cause the difference in time?

    • why is the torch-householder library much slower in the forward pass
    • is this performance expected from AD of a matrix w.r.t to its householder or am I doing something wrong here?
    • why does the number actually add up again to ~800ms, this makes me suspect that my profiling is doing something wrong but couldn't find a cause

    I'm also happy to share the traces with you, please just ping then :)

    opened by bytesnake 6
  • [POLL] Should the package switch install-time to run-time native code compilation?

    [POLL] Should the package switch install-time to run-time native code compilation?

    Currently, the package compiles native code (C++) upon package installation. This saves a few seconds during code run time, as the compilation does not happen when the user code starts. However, one scenario when it hurts is when the package is installed from a different environment or a machine than the actual code will be run on. This is a use case with most cluster environments, where packages may be installed from a login node, rather than the actual machine with the GPU.

    Should compilation be rather performed at run time?

    👍 - Move compilation to run time 👎 - Keep as is

    opened by toshas 0
Releases(v1.0.1)
Owner
Anton Obukhov
CV+ML PhD student with industrial past
Anton Obukhov
MoCoGAN: Decomposing Motion and Content for Video Generation

MoCoGAN: Decomposing Motion and Content for Video Generation This repository contains an implementation and further details of MoCoGAN: Decomposing Mo

Sergey Tulyakov 514 Dec 18, 2022
Run object detection model on the Raspberry Pi

Using TensorFlow Lite with Python is great for embedded devices based on Linux, such as Raspberry Pi.

Dimitri Yanovsky 6 Oct 08, 2022
Self Governing Neural Networks (SGNN): the Projection Layer

Self Governing Neural Networks (SGNN): the Projection Layer A SGNN's word projections preprocessing pipeline in scikit-learn In this notebook, we'll u

Guillaume Chevalier 22 Nov 06, 2022
Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control

Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control Official implementation of: Cooperative multi-agent reinfor

0 Nov 16, 2021
A Vision Transformer approach that uses concatenated query and reference images to learn the relationship between query and reference images directly.

A Vision Transformer approach that uses concatenated query and reference images to learn the relationship between query and reference images directly.

24 Dec 13, 2022
StellarGraph - Machine Learning on Graphs

StellarGraph Machine Learning Library StellarGraph is a Python library for machine learning on graphs and networks. Table of Contents Introduction Get

S T E L L A R 2.6k Jan 05, 2023
Detecting Blurred Ground-based Sky/Cloud Images

Detecting Blurred Ground-based Sky/Cloud Images With the spirit of reproducible research, this repository contains all the codes required to produce t

1 Oct 20, 2021
C3d-pytorch - Pytorch porting of C3D network, with Sports1M weights

C3D for pytorch This is a pytorch porting of the network presented in the paper Learning Spatiotemporal Features with 3D Convolutional Networks How to

Davide Abati 311 Jan 06, 2023
Segmentation for medical image.

EfficientSegmentation Introduction EfficientSegmentation is an open source, PyTorch-based segmentation framework for 3D medical image. Features A whol

68 Nov 28, 2022
Hough Transform and Hough Line Transform Using OpenCV

Hough transform is a feature extraction method for detecting simple shapes such as circles, lines, etc in an image. Hough Transform and Hough Line Transform is implemented in OpenCV with two methods;

Happy N. Monday 3 Feb 15, 2022
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022
A PyTorch implementation of Implicit Q-Learning

IQL-PyTorch This repository houses a minimal PyTorch implementation of Implicit Q-Learning (IQL), an offline reinforcement learning algorithm, along w

Garrett Thomas 30 Dec 12, 2022
Official PyTorch implementation of Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations

Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations Zhenyu Jiang, Yifeng Zhu, Maxwell Svetlik, Kuan Fang, Yu

UT-Austin Robot Perception and Learning Lab 63 Jan 03, 2023
MediaPipe is a an open-source framework from Google for building multimodal

MediaPipe is a an open-source framework from Google for building multimodal (eg. video, audio, any time series data), cross platform (i.e Android, iOS, web, edge devices) applied ML pipelines. It is

Bhavishya Pandit 3 Sep 30, 2022
PaddleBoBo是基于PaddlePaddle和PaddleSpeech、PaddleGAN等开发套件的虚拟主播快速生成项目

PaddleBoBo - 元宇宙时代,你也可以动手做一个虚拟主播。 PaddleBoBo是基于飞桨PaddlePaddle深度学习框架和PaddleSpeech、PaddleGAN等开发套件的虚拟主播快速生成项目。PaddleBoBo致力于简单高效、可复用性强,只需要一张带人像的图片和一段文字,就能

502 Jan 08, 2023
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022
Robust and Accurate Object Detection via Self-Knowledge Distillation

Robust and Accurate Object Detection via Self-Knowledge Distillation paper:https://arxiv.org/abs/2111.07239 Environments Python 3.7 Cuda 10.1 Prepare

Weipeng Xu 6 Jul 01, 2022
Scale-aware Automatic Augmentation for Object Detection (CVPR 2021)

SA-AutoAug Scale-aware Automatic Augmentation for Object Detection Yukang Chen, Yanwei Li, Tao Kong, Lu Qi, Ruihang Chu, Lei Li, Jiaya Jia [Paper] [Bi

DV Lab 182 Dec 29, 2022
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax

ProGen - (wip) Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily

Phil Wang 71 Dec 01, 2022
Code Repository for The Kaggle Book, Published by Packt Publishing

The Kaggle Book Data analysis and machine learning for competitive data science Code Repository for The Kaggle Book, Published by Packt Publishing "Lu

Packt 1.6k Jan 07, 2023