A multi-entity Transformer for multi-agent spatiotemporal modeling.

Overview

baller2vec

This is the repository for the paper:

Michael A. Alcorn and Anh Nguyen. baller2vec: A Multi-Entity Transformer For Multi-Agent Spatiotemporal Modeling. arXiv. 2021.

Left: the input for baller2vec at each time step t is an unordered set of feature vectors containing information about the identities and locations of NBA players on the court. Right: baller2vec generalizes the standard Transformer to the multi-entity setting by employing a novel self-attention mask tensor. The mask is then reshaped into a matrix for compatibility with typical Transformer implementations.
By exclusively learning to predict the trajectory of the ball, baller2vec was able to infer idiosyncratic player attributes.
Further, nearest neighbors in baller2vec's embedding space are plausible doppelgängers. Credit for the images: Erik Drost, Keith Allison, Jose Garcia, Keith Allison, Verse Photography, and Joe Glorioso.
Additionally, several attention heads in baller2vec appear to perform different basketball-relevant functions, such as anticipating passes. Code to generate the GIF was adapted from @linouk23's NBA Player Movement's repository.
Here, a baller2vec model trained to simultaneously predict the trajectories of all the players on the court uses both the historical and current context to forecast the target player's trajectory at each time step. The left grid shows the target player's true trajectory at each time step while the right grid shows baller2vec's forecast distribution. The blue-bordered center cell is the "stationary" trajectory.

Citation

If you use this code for your own research, please cite:

@article{alcorn2021baller2vec,
   title={baller2vec: A Multi-Entity Transformer For Multi-Agent Spatiotemporal Modeling},
   author={Alcorn, Michael A. and Nguyen, Anh},
   journal={arXiv preprint arXiv:1609.03675},
   year={2021}
}

Training baller2vec

Setting up .basketball_profile

After you've cloned the repository to your desired location, create a file called .basketball_profile in your home directory:

nano ~/.basketball_profile

and copy and paste in the contents of .basketball_profile, replacing each of the variable values with paths relevant to your environment. Next, add the following line to the end of your ~/.bashrc:

source ~/.basketball_profile

and either log out and log back in again or run:

source ~/.bashrc

You should now be able to copy and paste all of the commands in the various instructions sections. For example:

echo ${PROJECT_DIR}

should print the path you set for PROJECT_DIR in .basketball_profile.

Installing the necessary Python packages

cd ${PROJECT_DIR}
pip3 install --upgrade -r requirements.txt

Organizing the play-by-play and tracking data

  1. Copy events.zip (which I acquired from here [mirror here] using https://downgit.github.io) to the DATA_DIR directory and unzip it:
mkdir -p ${DATA_DIR}
cp ${PROJECT_DIR}/events.zip ${DATA_DIR}
cd ${DATA_DIR}
unzip -q events.zip
rm events.zip

Descriptions for the various EVENTMSGTYPEs can be found here (mirror here).

  1. Clone the tracking data from here (mirror here) to the DATA_DIR directory:
cd ${DATA_DIR}
git clone [email protected]:linouk23/NBA-Player-Movements.git

A description of the tracking data can be found here.

Generating the training data

cd ${PROJECT_DIR}
nohup python3 generate_game_numpy_arrays.py > data.log &

You can monitor its progress with:

top

or:

ls -U ${GAMES_DIR} | wc -l

There should be 1,262 NumPy arrays (corresponding to 631 X/y pairs) when finished.

Animating a sequence

  1. If you don't have a display hooked up to your GPU server, you'll need to first clone the repository to your local machine and retrieve certain files from the remote server:
# From your local machine.
mkdir -p ~/scratch
cd ~/scratch

username=michael
server=gpu3.cse.eng.auburn.edu
data_dir=/home/michael/baller2vec_data
scp ${username}@${server}:${data_dir}/baller2vec_config.pydict .

games_dir=${data_dir}/games
gameid=0021500622

scp ${username}@${server}:${games_dir}/\{${gameid}_X.npy,${gameid}_y.npy\} .
  1. You can then run this code in the Python interpreter from within the repository (make sure you source .basketball_profile first if running locally):
import os

from animator import Game
from settings import DATA_DIR, GAMES_DIR

gameid = "0021500622"
try:
    game = Game(DATA_DIR, GAMES_DIR, gameid)
except FileNotFoundError:
    home_dir = os.path.expanduser("~")
    DATA_DIR = f"{home_dir}/scratch"
    GAMES_DIR = f"{home_dir}/scratch"
    game = Game(DATA_DIR, GAMES_DIR, gameid)

# https://youtu.be/FRrh_WkyXko?t=109
start_period = 3
start_time = "1:55"
stop_period = 3
stop_time = "1:51"
game.show_seq(start_period, start_time, stop_period, stop_time)

to generate the following animation:

Running the training script

Run (or copy and paste) the following script, editing the variables as appropriate.

#!/usr/bin/env bash

# Experiment identifier. Output will be saved to ${EXPERIMENTS_DIR}/${JOB}.
JOB=$(date +%Y%m%d%H%M%S)

# Training options.
echo "train:" >> ${JOB}.yaml
task=ball_traj  # ball_traj, ball_loc, event, player_traj, score, or seq2seq.
echo "  task: ${task}" >> ${JOB}.yaml
echo "  min_playing_time: 0" >> ${JOB}.yaml  # 0/13314/39917/1.0e+6 --> 100%/75%/50%/0%.
echo "  train_valid_prop: 0.95" >> ${JOB}.yaml
echo "  train_prop: 0.95" >> ${JOB}.yaml
echo "  train_samples_per_epoch: 20000" >> ${JOB}.yaml
echo "  valid_samples: 1000" >> ${JOB}.yaml
echo "  workers: 10" >> ${JOB}.yaml
echo "  learning_rate: 1.0e-5" >> ${JOB}.yaml
if [[ ("$task" = "event") || ("$task" = "score") ]]
then
    prev_model=False
    echo "  prev_model: ${prev_model}" >> ${JOB}.yaml
    if [[ "$prev_model" != "False" ]]
    then
        echo "  patience: 5" >> ${JOB}.yaml
    fi
fi

# Dataset options.
echo "dataset:" >> ${JOB}.yaml
echo "  hz: 5" >> ${JOB}.yaml
echo "  secs: 4" >> ${JOB}.yaml
echo "  player_traj_n: 11" >> ${JOB}.yaml
echo "  max_player_move: 4.5" >> ${JOB}.yaml
echo "  ball_traj_n: 19" >> ${JOB}.yaml
echo "  max_ball_move: 8.5" >> ${JOB}.yaml
echo "  n_players: 10" >> ${JOB}.yaml
echo "  next_score_change_time_max: 35" >> ${JOB}.yaml
echo "  n_time_to_next_score_change: 36" >> ${JOB}.yaml
echo "  n_ball_loc_x: 95" >> ${JOB}.yaml
echo "  n_ball_loc_y: 51" >> ${JOB}.yaml
echo "  ball_future_secs: 2" >> ${JOB}.yaml

# Model options.
echo "model:" >> ${JOB}.yaml
echo "  embedding_dim: 20" >> ${JOB}.yaml
echo "  sigmoid: none" >> ${JOB}.yaml
echo "  mlp_layers: [128, 256, 512]" >> ${JOB}.yaml
echo "  nhead: 8" >> ${JOB}.yaml
echo "  dim_feedforward: 2048" >> ${JOB}.yaml
echo "  num_layers: 6" >> ${JOB}.yaml
echo "  dropout: 0.0" >> ${JOB}.yaml
if [[ "$task" != "seq2seq" ]]
then
    echo "  use_cls: False" >> ${JOB}.yaml
    echo "  embed_before_mlp: True" >> ${JOB}.yaml
fi

# Save experiment settings.
mkdir -p ${EXPERIMENTS_DIR}/${JOB}
mv ${JOB}.yaml ${EXPERIMENTS_DIR}/${JOB}/

# Start training the model.
gpu=0
cd ${PROJECT_DIR}
nohup python3 train_baller2vec.py ${JOB} ${gpu} > ${EXPERIMENTS_DIR}/${JOB}/train.log &
Owner
Michael A. Alcorn
Brute-forcing my way through life.
Michael A. Alcorn
Adversarial Learning for Modeling Human Motion

Adversarial Learning for Modeling Human Motion This repository contains the open source code which reproduces the results for the paper: Adversarial l

wangqi 6 Jun 15, 2021
Implementation of UNET architecture for Image Segmentation.

Semantic Segmentation using UNET This is the implementation of UNET on Carvana Image Masking Kaggle Challenge About the Dataset This dataset contains

Anushka agarwal 4 Dec 21, 2021
MODNet: Trimap-Free Portrait Matting in Real Time

MODNet is a model for real-time portrait matting with only RGB image input.

Zhanghan Ke 2.8k Dec 30, 2022
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
Sum-Product Probabilistic Language

Sum-Product Probabilistic Language SPPL is a probabilistic programming language that delivers exact solutions to a broad range of probabilistic infere

MIT Probabilistic Computing Project 57 Nov 17, 2022
Pixray is an image generation system

Pixray is an image generation system

pixray 883 Jan 07, 2023
make ASCII Art by Deep Learning

DeepAA This is convolutional neural networks generating ASCII art. This repository is under construction. This work is accepted by NIPS 2017 Workshop,

OsciiArt 1.4k Dec 28, 2022
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
Hysterese plugin with two temperature offset areas

craftbeerpi4 plugin OffsetHysterese Temperatur-Steuerungs-Plugin mit zwei tempereaturbereich abhängigen Offsets. Installation sudo pip3 install https:

HappyHibo 1 Dec 21, 2021
tinykernel - A minimal Python kernel so you can run Python in your Python

tinykernel - A minimal Python kernel so you can run Python in your Python

fast.ai 37 Dec 02, 2022
A pytorch implementation of faster RCNN detection framework (Use detectron2, it's a masterpiece)

Notice(2019.11.2) This repo was built back two years ago when there were no pytorch detection implementation that can achieve reasonable performance.

Ruotian(RT) Luo 1.8k Jan 01, 2023
Bringing sanity to world of messed-up data

Sanitize sanitize is a Python module for making sure various things (e.g. HTML) are safe to use. It was originally written by Mark Pilgrim and is dist

Alireza Savand 63 Oct 26, 2021
[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation

MonoRUn MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation. CVPR 2021. [paper] Hansheng Chen, Yuyao Huang, Wei Tian*

同济大学智能汽车研究所综合感知研究组 ( Comprehensive Perception Research Group under Institute of Intelligent Vehicles, School of Automotive Studies, Tongji University) 96 Dec 10, 2022
Mining-the-Social-Web-3rd-Edition - The official online compendium for Mining the Social Web, 3rd Edition (O'Reilly, 2018)

Mining the Social Web, 3rd Edition The official code repository for Mining the Social Web, 3rd Edition (O'Reilly, 2019). The book is available from Am

Mikhail Klassen 838 Jan 01, 2023
Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation

Generalizing Gaze Estimation with Outlier-guided Collaborative Adaptation Our paper is accepted by ICCV2021. Picture: Overview of the proposed Plug-an

Yunfei Liu 32 Dec 10, 2022
Bottom-up attention model for image captioning and VQA, based on Faster R-CNN and Visual Genome

bottom-up-attention This code implements a bottom-up attention model, based on multi-gpu training of Faster R-CNN with ResNet-101, using object and at

Peter Anderson 1.3k Jan 09, 2023
Telegram chatbot created with deep learning model (LSTM) and telebot library.

Telegram chatbot Telegram chatbot created with deep learning model (LSTM) and telebot library. Description This program will allow you to create very

1 Jan 04, 2022
Aalto-cs-msc-theses - Listing of M.Sc. Theses of the Department of Computer Science at Aalto University

Aalto-CS-MSc-Theses Listing of M.Sc. Theses of the Department of Computer Scienc

Jorma Laaksonen 3 Jan 27, 2022
Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Varun Nair 37 Dec 30, 2022
1st ranked 'driver careless behavior detection' for AI Online Competition 2021, hosted by MSIT Korea.

2021AICompetition-03 본 repo 는 mAy-I Inc. 팀으로 참가한 2021 인공지능 온라인 경진대회 중 [이미지] 운전 사고 예방을 위한 운전자 부주의 행동 검출 모델] 태스크 수행을 위한 레포지토리입니다. mAy-I 는 과학기술정보통신부가 주최하

Junhyuk Park 9 Dec 01, 2022