PiRank: Learning to Rank via Differentiable Sorting

Related tags

Deep Learningpirank
Overview

PiRank: Learning to Rank via Differentiable Sorting

This repository provides a reference implementation for learning PiRank-based models as described in the paper:

PiRank: Learning to Rank via Differentiable Sorting
Robin Swezey, Aditya Grover, Bruno Charron and Stefano Ermon.
Paper: https://arxiv.org/abs/2012.06731

Requirements

The codebase is implemented in Python 3.7. To install the necessary base requirements, run the following commands:

pip install -r requirements.txt

If you intend to use a GPU, modify requirements.txt to install tensorflow-gpu instead of tensorflow.

You will also need the NeuralSort implementation available here. Make sure it is added to your PYTHONPATH.

Datasets

PiRank was tested on the two following datasets:

Additionally, the code is expected to work with any dataset stored in the standard LibSVM format used for LTR experiments.

Scripts

There are two scripts for the code:

  • pirank_simple.py implements a simple depth-1 PiRank loss (d=1). It is used in the experiments of sections 4.1 (benchmark evaluation on MSLR-WEB30K and Yahoo! C14 datasets), 4.2.1 (effect of temperature parameter), and 4.2.2 (effect of training list size).

  • pirank_deep.py implements the deeper PiRank losses (d>=1). It is used for the experiments of section 4.2.3 and comes with a convenient synthetic data generator as well as more tuning options.

Options

Options are handled by Sacred (see Examples section below).

pirank_simple.py and pirank_deep.py

PiRank-related:

Parameter Default Value Description
loss_fn pirank_simple_loss The loss function to use (either a TFR RankingLossKey, or loss function from the script)
ste False Whether to use the Straight-Through Estimator
ndcg_k 15 [email protected] cutoff when using NS-NDCG loss

NeuralSort-related:

Parameter Default Value Description
tau 5 Temperature
taustar 1e-10 Temperature for trues and straight-through estimation.

TensorFlow-Ranking and architecture-related:

Parameter Default Value Description
hidden_layers "256,tanh,128,tanh,64,tanh" Hidden layers for an example-wise feedforward network in the format size,activation,...,size,activation
num_features 136 Number of features per document. The default value is for MSLR and depends on the dataset (e.g. for Yahoo!, please change to 700).
list_size 100 List size used for training
group_size 1 Group size used in score function

Training-related:

Parameter Default Value Description
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing
model_dir None Output directory for models
num_epochs 200 Number of epochs to train, set 0 to just test
lr 1e-4 initial learning rate
batch_size 32 The batch size for training
num_train_steps None Number of steps for training
num_vali_steps None Number of steps for validation
num_test_steps None Number of steps for testing
learning_rate 0.01 Learning rate for optimizer
dropout_rate 0.5 The dropout rate before output layer
optimizer Adagrad The optimizer for gradient descent

Sacred:

In addition, you can use regular parameters from Sacred (such as -m for logging the experiment to MongoDB).

pirank_deep.py only

Parameter Default Value Description
merge_block_size None Block size used if merging, None if not merging
top_k None Use a different Top-k for merging than final [email protected] for loss
straight_backprop False Backpropagate on scores only through NS operator
full_loss False Use the complete loss at the end of merge
tau_scheme None Which scheme to use for temperature going deeper (default: constant)
data_generator None Data generator (default: TFR\s libsvm); use this for synthetic generation
num_queries 30000 Number of queries for synthetic data generator
num_query_features 10 Number of columns used as factors for each query by synthetic data generator
actual_list_size None Size of actual list per query in synthetic data generation
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training; alternatively value of seed if using data generator
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation; alternatively value of seed if using data generator
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing; alternatively value of seed if using data generator
with_opa True Include pairwise metric OPA

Examples

Run the benchmark experiment of section 4.1 with PiRank simple loss on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with PiRank simple loss on Yahoo! C14

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/YAHOO/set1.train.txt \
    vali_path=/data/YAHOO/set1.valid.txt \
    test_path=/data/YAHOO/set1.test.txt \
    num_features=700 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with classic LambdaRank on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=lambda_rank_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the scaling ablation experiment of section 4.2.3 using synthetic data generation (d=2)

cd pirank
python3 pirank_deep.py with loss_fn=pirank_deep_loss \
    ndcg_k=10 \
    ste=True \
    merge_block_size=100 \
    tau=5 \
    taustar=1e-10 \
    tau_scheme=square \
    data_generator=synthetic_data_generator \
    actual_list_size=1000 \
    list_size=1000 \
    vali_list_size=1000 \
    test_list_size=1000 \
    full_loss=False \
    train_path=0 \
    vali_path=1 \
    test_path=2 \
    num_queries=1000 \
    num_features=25 \
    num_query_features=5 \
    hidden_layers=256,relu,256,relu,128,relu,128,relu,64,relu,64,relu \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16

Help

If you need help, reach out to Robin Swezey or raise an issue.

Citing

If you find PiRank useful in your research, please consider citing the following paper:

@inproceedings{
swezey2020pirank,
title={PiRank: Learning to Rank via Differentiable Sorting},
author={Robin Swezey and Aditya Grover and Bruno Charron and Stefano Ermon},
year={2020},
url={},
}

StarGAN v2-Tensorflow - Simple Tensorflow implementation of StarGAN v2

Official Tensorflow implementation Open ! - Clova AI StarGAN v2 — Un-official TensorFlow Implementation [Paper] [Pytorch] : Diverse Image Synthesis f

Junho Kim 110 Jul 02, 2022
MaRS - a recursive filtering framework that allows for truly modular multi-sensor integration

The Modular and Robust State-Estimation Framework, or short, MaRS, is a recursive filtering framework that allows for truly modular multi-sensor integration

Control of Networked Systems - University of Klagenfurt 143 Dec 29, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023
Official implementation of the paper Chunked Autoregressive GAN for Conditional Waveform Synthesis

PyEmits, a python package for easy manipulation in time-series data. Time-series data is very common in real life. Engineering FSI industry (Financial

Descript 150 Dec 06, 2022
FactSeg: Foreground Activation Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing Imagery (TGRS)

FactSeg: Foreground Activation Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing Imagery by Ailong Ma, Junjue Wang*, Yanfei Zhon

Kingdrone 43 Jan 05, 2023
Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image

Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image This repository is an implementation of the method described in the following pap

21 Dec 15, 2022
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
A curated list of awesome game datasets, and tools to artificial intelligence in games

🎮 Awesome Game Datasets In computer science, Artificial Intelligence (AI) is intelligence demonstrated by machines. Its definition, AI research as th

Leonardo Mauro 454 Jan 03, 2023
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ Getting started Prerequ

Cambridge Quantum 315 Jan 01, 2023
BT-Unet: A-Self-supervised-learning-framework-for-biomedical-image-segmentation-using-Barlow-Twins

BT-Unet: A-Self-supervised-learning-framework-for-biomedical-image-segmentation-using-Barlow-Twins Deep learning has brought most profound contributio

Narinder Singh Punn 12 Dec 04, 2022
A NSFW content filter.

Project_Nfilter A NSFW content filter. With a motive of minimizing the spreads and leakage of NSFW contents on internet and access to others devices ,

1 Jan 20, 2022
Generate images from texts. In Russian

ruDALL-E Generate images from texts pip install rudalle==1.1.0rc0 🤗 HF Models: ruDALL-E Malevich (XL) ruDALL-E Emojich (XL) (readme here) ruDALL-E S

AI Forever 1.6k Dec 31, 2022
Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 904 Dec 21, 2022
Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21)

AdvRush Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21) Environmental Set-up Python == 3.6.12, PyTorch =

11 Dec 10, 2022
TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

TorchFlare TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost

Atharva Phatak 85 Dec 26, 2022
Project page for our ICCV 2021 paper "The Way to my Heart is through Contrastive Learning"

The Way to my Heart is through Contrastive Learning: Remote Photoplethysmography from Unlabelled Video This is the official project page of our ICCV 2

36 Jan 06, 2023
Learnable Motion Coherence for Correspondence Pruning

Learnable Motion Coherence for Correspondence Pruning Yuan Liu, Lingjie Liu, Cheng Lin, Zhen Dong, Wenping Wang Project Page Any questions or discussi

liuyuan 41 Nov 30, 2022
Demo notebooks for Qiskit application modules demo sessions (Oct 8 & 15):

qiskit-application-modules-demo-sessions This repo hosts demo notebooks for the Qiskit application modules demo sessions hosted on Qiskit YouTube. Par

Qiskit Community 46 Nov 24, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022