PiRank: Learning to Rank via Differentiable Sorting

Related tags

Deep Learningpirank
Overview

PiRank: Learning to Rank via Differentiable Sorting

This repository provides a reference implementation for learning PiRank-based models as described in the paper:

PiRank: Learning to Rank via Differentiable Sorting
Robin Swezey, Aditya Grover, Bruno Charron and Stefano Ermon.
Paper: https://arxiv.org/abs/2012.06731

Requirements

The codebase is implemented in Python 3.7. To install the necessary base requirements, run the following commands:

pip install -r requirements.txt

If you intend to use a GPU, modify requirements.txt to install tensorflow-gpu instead of tensorflow.

You will also need the NeuralSort implementation available here. Make sure it is added to your PYTHONPATH.

Datasets

PiRank was tested on the two following datasets:

Additionally, the code is expected to work with any dataset stored in the standard LibSVM format used for LTR experiments.

Scripts

There are two scripts for the code:

  • pirank_simple.py implements a simple depth-1 PiRank loss (d=1). It is used in the experiments of sections 4.1 (benchmark evaluation on MSLR-WEB30K and Yahoo! C14 datasets), 4.2.1 (effect of temperature parameter), and 4.2.2 (effect of training list size).

  • pirank_deep.py implements the deeper PiRank losses (d>=1). It is used for the experiments of section 4.2.3 and comes with a convenient synthetic data generator as well as more tuning options.

Options

Options are handled by Sacred (see Examples section below).

pirank_simple.py and pirank_deep.py

PiRank-related:

Parameter Default Value Description
loss_fn pirank_simple_loss The loss function to use (either a TFR RankingLossKey, or loss function from the script)
ste False Whether to use the Straight-Through Estimator
ndcg_k 15 [email protected] cutoff when using NS-NDCG loss

NeuralSort-related:

Parameter Default Value Description
tau 5 Temperature
taustar 1e-10 Temperature for trues and straight-through estimation.

TensorFlow-Ranking and architecture-related:

Parameter Default Value Description
hidden_layers "256,tanh,128,tanh,64,tanh" Hidden layers for an example-wise feedforward network in the format size,activation,...,size,activation
num_features 136 Number of features per document. The default value is for MSLR and depends on the dataset (e.g. for Yahoo!, please change to 700).
list_size 100 List size used for training
group_size 1 Group size used in score function

Training-related:

Parameter Default Value Description
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing
model_dir None Output directory for models
num_epochs 200 Number of epochs to train, set 0 to just test
lr 1e-4 initial learning rate
batch_size 32 The batch size for training
num_train_steps None Number of steps for training
num_vali_steps None Number of steps for validation
num_test_steps None Number of steps for testing
learning_rate 0.01 Learning rate for optimizer
dropout_rate 0.5 The dropout rate before output layer
optimizer Adagrad The optimizer for gradient descent

Sacred:

In addition, you can use regular parameters from Sacred (such as -m for logging the experiment to MongoDB).

pirank_deep.py only

Parameter Default Value Description
merge_block_size None Block size used if merging, None if not merging
top_k None Use a different Top-k for merging than final [email protected] for loss
straight_backprop False Backpropagate on scores only through NS operator
full_loss False Use the complete loss at the end of merge
tau_scheme None Which scheme to use for temperature going deeper (default: constant)
data_generator None Data generator (default: TFR\s libsvm); use this for synthetic generation
num_queries 30000 Number of queries for synthetic data generator
num_query_features 10 Number of columns used as factors for each query by synthetic data generator
actual_list_size None Size of actual list per query in synthetic data generation
train_path "/data/MSLR-WEB30K/Fold*/train.txt" Input file path used for training; alternatively value of seed if using data generator
vali_path "/data/MSLR-WEB30K/Fold*/vali.txt" Input file path used for validation; alternatively value of seed if using data generator
test_path "/data/MSLR-WEB30K/Fold*/test.txt" Input file path used for testing; alternatively value of seed if using data generator
with_opa True Include pairwise metric OPA

Examples

Run the benchmark experiment of section 4.1 with PiRank simple loss on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with PiRank simple loss on Yahoo! C14

cd pirank
python3 pirank_simple.py with loss_fn=pirank_simple_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/YAHOO/set1.train.txt \
    vali_path=/data/YAHOO/set1.valid.txt \
    test_path=/data/YAHOO/set1.test.txt \
    num_features=700 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the benchmark experiment of section 4.1 with classic LambdaRank on MSLR-WEB30K

cd pirank
python3 pirank_simple.py with loss_fn=lambda_rank_loss \
    ndcg_k=10 \
    tau=5 \
    list_size=80 \
    hidden_layers=256,relu,256,relu,128,relu,64,relu \
    train_path=/data/MSLR-WEB30K/Fold1/train.txt \
    vali_path=/data/MSLR-WEB30K/Fold1/vali.txt \
    test_path=/data/MSLR-WEB30K/Fold1/test.txt \
    num_features=136 \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16 \
    model_dir=/tmp/model

Run the scaling ablation experiment of section 4.2.3 using synthetic data generation (d=2)

cd pirank
python3 pirank_deep.py with loss_fn=pirank_deep_loss \
    ndcg_k=10 \
    ste=True \
    merge_block_size=100 \
    tau=5 \
    taustar=1e-10 \
    tau_scheme=square \
    data_generator=synthetic_data_generator \
    actual_list_size=1000 \
    list_size=1000 \
    vali_list_size=1000 \
    test_list_size=1000 \
    full_loss=False \
    train_path=0 \
    vali_path=1 \
    test_path=2 \
    num_queries=1000 \
    num_features=25 \
    num_query_features=5 \
    hidden_layers=256,relu,256,relu,128,relu,128,relu,64,relu,64,relu \
    optimizer=Adam \
    learning_rate=0.00001 \
    num_epochs=100 \
    batch_size=16

Help

If you need help, reach out to Robin Swezey or raise an issue.

Citing

If you find PiRank useful in your research, please consider citing the following paper:

@inproceedings{
swezey2020pirank,
title={PiRank: Learning to Rank via Differentiable Sorting},
author={Robin Swezey and Aditya Grover and Bruno Charron and Stefano Ermon},
year={2020},
url={},
}

We are More than Our JOints: Predicting How 3D Bodies Move

We are More than Our JOints: Predicting How 3D Bodies Move Citation This repo contains the official implementation of our paper MOJO: @inproceedings{Z

72 Oct 20, 2022
Zeyuan Chen, Yangchao Wang, Yang Yang and Dong Liu.

Principled S2R Dehazing This repository contains the official implementation for PSD Framework introduced in the following paper: PSD: Principled Synt

zychen 78 Dec 30, 2022
Miscellaneous and lightweight network tools

Network Tools Collection of miscellaneous and lightweight network tools to simplify daily operations, administration, and troubleshooting of networks.

Nicholas Russo 22 Mar 22, 2022
Pytorch implementation of the paper "Enhancing Content Preservation in Text Style Transfer Using Reverse Attention and Conditional Layer Normalization"

Pytorch implementation of the paper "Enhancing Content Preservation in Text Style Transfer Using Reverse Attention and Conditional Layer Normalization"

Dongkyu Lee 4 Sep 18, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 07, 2022
PyTorchMemTracer - Depict GPU memory footprint during DNN training of PyTorch

A Memory Tracer For PyTorch OOM is a nightmare for PyTorch users. However, most

Jiarui Fang 9 Nov 14, 2022
PyTorch implementation for Graph Contrastive Learning with Augmentations

Graph Contrastive Learning with Augmentations PyTorch implementation for Graph Contrastive Learning with Augmentations [poster] [appendix] Yuning You*

Shen Lab at Texas A&M University 382 Dec 15, 2022
This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to generate a dynamic forecast from your own data.

📈 Automated Time Series Forecasting Background: This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to gene

Zach Renwick 42 Jan 04, 2023
Top #1 Submission code for the first https://alphamev.ai MEV competition with best AUC (0.9893) and MSE (0.0982).

alphamev-winning-submission Top #1 Submission code for the first alphamev MEV competition with best AUC (0.9893) and MSE (0.0982). The code won't run

70 Oct 29, 2022
Data and code for the paper "Importance of Kernel Bandwidth in Quantum Machine Learning"

Reproducibility materials for "Importance of Kernel Bandwidth in Quantum Machine Learning" Repo structure: code contains Python scripts used to genera

Ruslan Shaydulin 3 Oct 23, 2022
SynNet - synthetic tree generation using neural networks

SynNet This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can s

Wenhao Gao 60 Dec 29, 2022
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

75 Dec 02, 2022
A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022)

A Text Attention Network for Spatial Deformation Robust Scene Text Image Super-resolution (CVPR2022) https://arxiv.org/abs/2203.09388 Jianqi Ma, Zheto

MA Jianqi, shiki 104 Jan 05, 2023
Make your master artistic punk avatar through machine learning world famous paintings.

Master-art-punk Make your master artistic punk avatar through machine learning world famous paintings. 通过机器学习世界名画制作属于你的大师级艺术朋克头像 Nowadays, NFT is beco

Philipjhc 53 Dec 27, 2022
Python KNN model: Predicting a probability of getting a work visa. Tableau: Non-immigrant visas over the years.

The value of international students to the United States. Probability of getting a non-immigrant visa. Project timeline: Jan 2021 - April 2021 Project

Zinaida Dvoskina 2 Nov 21, 2021
​ This is the Pytorch implementation of Progressive Attentional Manifold Alignment.

PAMA This is the Pytorch implementation of Progressive Attentional Manifold Alignment. Requirements python 3.6 pytorch 1.2.0+ PIL, numpy, matplotlib C

98 Nov 15, 2022
Implementation of the paper titled "Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees"

Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees Implementation of the paper titled "Using Sampling to

MIDAS, IIIT Delhi 2 Aug 29, 2022
Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Christopher T. Chubb 35 Dec 21, 2022
Code for "Searching for Efficient Multi-Stage Vision Transformers"

Searching for Efficient Multi-Stage Vision Transformers This repository contains the official Pytorch implementation of "Searching for Efficient Multi

Yi-Lun Liao 62 Oct 25, 2022