A PyTorch implementation of a Factorization Machine module in cython.

Overview

fmpytorch

A library for factorization machines in pytorch. A factorization machine is like a linear model, except multiplicative interaction terms between the variables are modeled as well.

The input to a factorization machine layer is a vector, and the output is a scalar. Batching is fully supported.

This is a work in progress. Feedback and bugfixes welcome! Hopefully you find the code useful.

Usage

The factorization machine layers in fmpytorch can be used just like any other built-in module. Here's a simple feed-forward model using a factorization machine that takes in a 50-D input, and models interactions using k=5 factors.

import torch
from fmpytorch.second_order.fm import FactorizationMachine

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(100, 50)
        self.dropout = torch.nn.Dropout(.5)
	# This makes a fm layer mapping from 50-D to 1-D.
	# The number of factors is 5.
        self.fm = FactorizationMachine(50, 5)

    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)
        x = self.fm(x)
        return x

See examples/toy.py or examples/regression.py for fuller examples.

Installation

This package requires pytorch, numpy, and cython.

To install, you can run:

cd fmpytorch
sudo python setup.py install

Factorization Machine brief intro

A linear model, given a vector x models its output y as

where w are the learnable weights of the model.

However, the interactions between the input variables x_i are purely additive. In some cases, it might be useful to model the interactions between your variables, e.g., x_i * x_j. You could add terms into your model like

However, this introduces a large number of w2 variables. Specifically, there are O(n^2) parameters introduced in this formulation, one for each interaction pair. A factorization machine approximates w2 using low dimensional factors, i.e.,

where each v_i is a low-dimensional vector. This is the forward pass of a second order factorization machine. This low-rank re-formulation has reduced the number of additional parameters for the factorization machine to O(k*n). Magically, the forward (and backward) pass can be reformulated so that it can be computed in O(k*n), rather than the naive O(k*n^2) formulation above.

Currently supported features

Currently, only a second order factorization machine is supported. The forward and backward passes are implemented in cython. Compared to the autodiff solution, the cython passes run several orders of magnitude faster. I've only tested it with python 2 at the moment.

TODOs

  1. Support for sparse tensors.
  2. More interesting useage examples
  3. More testing, e.g., with python 3, etc.
  4. Make sure all of the code plays nice with torch-specific stuff, e.g., GPUs
  5. Arbitrary order factorization machine support
  6. Better organization/code cleaning

Thanks to

Vlad Niculae (@vene) for his sage wisdom.

The original factorization machine citation, which this layer is based off of, is

@inproceedings{rendle2010factorization,
	       title={Factorization machines},
    	       author={Rendle, Steffen},
      	       booktitle={ICDM},
               pages={995--1000},
	       year={2010},
	       organization={IEEE}
}
Owner
Jack Hessel
Research Scientist @ AI2: PhD in CS previously from Cornell
Jack Hessel
Unofficial implementation of One-Shot Free-View Neural Talking Head Synthesis

face-vid2vid Usage Dataset Preparation cd datasets wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl chmod a+rx youtube-dl python load_

worstcoder 68 Dec 30, 2022
An example of time series augmentation methods with Keras

Time Series Augmentation This is a collection of time series data augmentation methods and an example use using Keras. News 2020/04/16: Repository Cre

九州大学 ヒューマンインタフェース研究室 229 Jan 02, 2023
A minimal implementation of Gaussian process regression in PyTorch

pytorch-minimal-gaussian-process In search of truth, simplicity is needed. There exist heavy-weighted libraries, but as you know, we need to go bare b

Sangwoong Yoon 38 Nov 25, 2022
A Small and Easy approach to the BraTS2020 dataset (2D Segmentation)

BraTS2020 A Light & Scalable Solution to BraTS2020 | Medical Brain Tumor Segmentation (2D Segmentation) Developed the segmentation models for segregat

Gunjan Haldar 0 Jan 19, 2022
Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Generated Images"

Reverse_Engineering_GMs Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Gener

100 Dec 18, 2022
[cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

PS-MT [cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation by Yuyuan Liu, Yu Tian, Yuanhong Chen, Fengbei Liu, Vasile

Yuyuan Liu 132 Jan 03, 2023
MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python

MNE-Python MNE-Python software is an open-source Python package for exploring, visualizing, and analyzing human neurophysiological data such as MEG, E

MNE tools for MEG and EEG data analysis 2.1k Dec 28, 2022
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 07, 2022
Official implementation of YOGO for Point-Cloud Processing

You Only Group Once: Efficient Point-Cloud Processing with Token Representation and Relation Inference Module By Chenfeng Xu, Bohan Zhai, Bichen Wu, T

Chenfeng Xu 67 Dec 20, 2022
Time should be taken seer-iously

TimeSeers seers - (Noun) plural form of seer - A person who foretells future events by or as if by supernatural means TimeSeers is an hierarchical Bay

279 Dec 26, 2022
Compressed Video Action Recognition

Compressed Video Action Recognition Chao-Yuan Wu, Manzil Zaheer, Hexiang Hu, R. Manmatha, Alexander J. Smola, Philipp Krähenbühl. In CVPR, 2018. [Proj

Chao-Yuan Wu 479 Dec 26, 2022
Anti-UAV base on PaddleDetection

Paddle-Anti-UAV Anti-UAV base on PaddleDetection Background UAVs are very popular and we can see them in many public spaces, such as parks and playgro

Qingzhong Wang 2 Apr 20, 2022
Implementation of Stochastic Image-to-Video Synthesis using cINNs.

Stochastic Image-to-Video Synthesis using cINNs Official PyTorch implementation of Stochastic Image-to-Video Synthesis using cINNs accepted to CVPR202

CompVis Heidelberg 135 Dec 28, 2022
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
📚 A collection of Jupyter notebooks for learning and experimenting with OpenVINO 👓

A collection of ready-to-run Python* notebooks for learning and experimenting with OpenVINO developer tools. The notebooks are meant to provide an introduction to OpenVINO basics and teach developers

OpenVINO Toolkit 840 Jan 03, 2023
Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation"

EgoNet Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation". This repo inclu

Shichao Li 138 Dec 09, 2022
Annealed Flow Transport Monte Carlo

Annealed Flow Transport Monte Carlo Open source implementation accompanying ICML 2021 paper by Michael Arbel*, Alexander G. D. G. Matthews* and Arnaud

DeepMind 30 Nov 21, 2022
This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive Selective Coding)

HCSC: Hierarchical Contrastive Selective Coding This repository provides a PyTorch implementation and model weights for HCSC (Hierarchical Contrastive

YUANFAN GUO 111 Dec 20, 2022
Image Processing, Image Smoothing, Edge Detection and Transforms

opevcvdl-hw1 This project uses openCV and Qt to achieve the requirements. Version Python 3.7 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.1

Kenny Cheng 3 Aug 17, 2022