Rax is a Learning-to-Rank library written in JAX

Related tags

Deep Learningrax
Overview

🦖 Rax: Composable Learning to Rank using JAX

Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking losses and metrics to be used with JAX. It provides the following functionality:

  • Ranking losses (rax.*_loss): rax.softmax_loss, rax.pairwise_logistic_loss, ...
  • Ranking metrics (rax.*_metric): rax.mrr_metric, rax.ndcg_metric, ...
  • Transformations (rax.*_t12n): rax.approx_t12n, rax.gumbel_t12n, ...

Ranking

A ranking problem is different from traditional classification/regression problems in that its objective is to optimize for the correctness of the relative order of a list of examples (e.g., documents) for a given context (e.g., a query). Rax provides support for ranking problems within the JAX ecosystem. It can be used in, but is not limited to, the following applications:

  • Search: ranking a list of documents with respect to a query.
  • Recommendation: ranking a list of items given a user as context.
  • Question Answering: finding the best answer from a list of candidates.
  • Dialogue System: finding the best response from a list of responses.

Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute various ranking losses and metrics:

import jax.numpy as jnp
import rax

scores = jnp.asarray([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.asarray([1., 0., 0.])      # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)         # computes a ranking metric.
rax.pairwise_hinge_loss(scores, labels) # computes a ranking loss.

All of the Rax losses and metrics are purely functional and compose well with standard JAX transformations. Additionally, Rax provides ranking-specific transformations so you can build new ranking losses. An example is rax.approx_t12n, which can be used to transform any (non-differentiable) ranking metric into a differentiable loss. For example:

loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)            # differentiable approx ndcg loss.
jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.

Examples

See the examples/ directory for complete examples on how to use Rax.

You might also like...
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Rank 1st in the public leaderboard of ScanRefer (2021-03-18)
Rank 1st in the public leaderboard of ScanRefer (2021-03-18)

InstanceRefer InstanceRefer: Cooperative Holistic Understanding for Visual Grounding on Point Clouds through Instance Multi-level Contextual Referring

Code for
Code for "LoRA: Low-Rank Adaptation of Large Language Models"

LoRA: Low-Rank Adaptation of Large Language Models This repo contains the implementation of LoRA in GPT-2 and steps to replicate the results in our re

Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

 COD-Rank-Localize-and-Segment (CVPR2021)
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing.

Feedback Prize - Evaluating Student Writing This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing. The

Comments
  • Add Kendall's tau

    Add Kendall's tau

    Kendall's tau (https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient) is a ranking metric which we could add to Rax. We can reference the SciPy version of Kendall's tau (https://github.com/scipy/scipy/blob/v1.9.0/scipy/stats/_stats_py.py#L5015-L5222) as inspiration or comparison for the Rax implementation.

    Expected usage:

    scores = jnp.array([0., 1., 2.])
    labels = jnp.array([1., 2., 0.])
    rax.kendalltau_metric(scores, labels)  # should produce Kendall's tau-b.
    
    enhancement 
    opened by rjagerman 0
Releases(v0.2.0)
Owner
Google
Google ❤️ Open Source
Google
JDet is Object Detection Framework based on Jittor.

JDet is Object Detection Framework based on Jittor.

135 Dec 14, 2022
A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maximum bidding

Business Problem A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maxim

Kübra Bilinmiş 1 Jan 15, 2022
Consistency Regularization for Adversarial Robustness

Consistency Regularization for Adversarial Robustness Official PyTorch implementation of Consistency Regularization for Adversarial Robustness by Jiho

40 Dec 17, 2022
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

benfeng 69 Nov 15, 2022
LSUN Dataset Documentation and Demo Code

LSUN Please check LSUN webpage for more information about the dataset. Data Release All the images in one category are stored in one lmdb database fil

Fisher Yu 426 Jan 02, 2023
Unofficial reimplementation of ECAPA-TDNN for speaker recognition (EER=0.86 for Vox1_O when train only in Vox2)

Introduction This repository contains my unofficial reimplementation of the standard ECAPA-TDNN, which is the speaker recognition in VoxCeleb2 dataset

Tao Ruijie 277 Dec 31, 2022
[CVPRW 21] "BNN - BN = ? Training Binary Neural Networks without Batch Normalization", Tianlong Chen, Zhenyu Zhang, Xu Ouyang, Zechun Liu, Zhiqiang Shen, Zhangyang Wang

BNN - BN = ? Training Binary Neural Networks without Batch Normalization Codes for this paper BNN - BN = ? Training Binary Neural Networks without Bat

VITA 40 Dec 30, 2022
Official Pytorch implementation of "Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral)"

Learning Debiased Representation via Disentangled Feature Augmentation (Neurips 2021, Oral): Official Project Webpage This repository provides the off

Kakao Enterprise Corp. 68 Dec 17, 2022
Distributed Asynchronous Hyperparameter Optimization better than HyperOpt.

UltraOpt : Distributed Asynchronous Hyperparameter Optimization better than HyperOpt. UltraOpt is a simple and efficient library to minimize expensive

98 Aug 16, 2022
Simple tools for logging and visualizing, loading and training

TNT TNT is a library providing powerful dataloading, logging and visualization utilities for Python. It is closely integrated with PyTorch and is desi

1.5k Jan 02, 2023
This is an easy python software which allows to sort images with faces by gender and after by age.

Gender-age Classifier This is an easy python software which allows to sort images with faces by gender and after by age. Usage First install Deepface

Claudio Ciccarone 6 Sep 17, 2022
Code for the submitted paper Surrogate-based cross-correlation for particle image velocimetry

Surrogate-based cross-correlation (SBCC) This repository contains code for the submitted paper Surrogate-based cross-correlation for particle image ve

5 Jun 30, 2022
TakeInfoatNistforICS - Take Information in NIST NVD for ICS

Take Information in NIST NVD for ICS This project developed with Python. When yo

5 Sep 05, 2022
TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline.

TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline

193 Dec 22, 2022
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Junyong Lee 173 Dec 30, 2022
Much faster than SORT(Simple Online and Realtime Tracking), a little worse than SORT

QSORT QSORT(Quick + Simple Online and Realtime Tracking) is a simple online and realtime tracking algorithm for 2D multiple object tracking in video s

Yonghye Kwon 8 Jul 27, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
An implementation for the ICCV 2021 paper Deep Permutation Equivariant Structure from Motion.

Deep Permutation Equivariant Structure from Motion Paper | Poster This repository contains an implementation for the ICCV 2021 paper Deep Permutation

72 Dec 27, 2022
Source code of NeurIPS 2021 Paper ''Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration''

CaGCN This repo is for source code of NeurIPS 2021 paper "Be Confident! Towards Trustworthy Graph Neural Networks via Confidence Calibration". Paper L

6 Dec 19, 2022
PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

Berivan Isik 8 Dec 08, 2022