[ICML 2021] A fast algorithm for fitting robust decision trees.

Overview

GROOT: Growing Robust Trees

Growing Robust Trees (GROOT) is an algorithm that fits binary classification decision trees such that they are robust against user-specified adversarial examples. The algorithm closely resembles algorithms used for fitting normal decision trees (i.e. CART) but changes the splitting criterion and the way samples propagate when creating a split.

This repository contains the module groot that implements GROOT as a Scikit-learn compatible classifier, an adversary for model evaluation and easy functions to import datasets. For documentation see https://groot.cyber-analytics.nl

Simple example

To train and evaluate GROOT on a toy dataset against an attacker that can move samples by 0.5 in each direction one can use the following code:

from groot.adversary import DecisionTreeAdversary
from groot.model import GrootTreeClassifier

from sklearn.datasets import make_moons

X, y = make_moons(noise=0.3, random_state=0)
X_test, y_test = make_moons(noise=0.3, random_state=1)

attack_model = [0.5, 0.5]
is_numerical = [True, True]
tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)

tree.fit(X, y)
accuracy = tree.score(X_test, y_test)
adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)

print("Accuracy:", accuracy)
print("Adversarial Accuracy:", adversarial_accuracy)

Installation

groot can be installed from PyPi: pip install groot-trees

To use Kantchelian's MILP attack it is required that you have GUROBI installed along with their python package: python -m pip install -i https://pypi.gurobi.com gurobipy

Specific dependency versions

To reproduce our experiments with exact package versions you can clone the repository and run: pip install -r requirements.txt

We recommend using virtual environments.

Reproducing 'Efficient Training of Robust Decision Trees Against Adversarial Examples' (article)

To reproduce the results from the paper we provide generate_k_fold_results.py, a script that takes the trained models (from JSON format) and generates tables and figures. The resulting figures generate under /out/.

To not only generate the results but to also retrain all models we include the scripts train_kfold_models.py and fit_chen_xgboost.py. The first script runs the algorithms in parallel for each dataset then outputs to /out/trees/ and /out/forests/. Warning: the script can take a long time to run (about a day given 16 cores). The second script train specifically the Chen et al. boosting ensembles. /out/results.zip contains all results from when we ran the scripts.

To experiment on image datasets we have a script image_experiments.py that fits and output the results. In this script, one can change the dataset variable to 'mnist' or 'fmnist' to switch between the two.

The scripts summarize_datasets.py and visualize_threat_models.py output some figures we used in the text.

Implementation details

The TREANT implementation (groot.treant.py) is copied almost completely from the authors of TREANT at https://github.com/gtolomei/treant with small modifications to better interface with the experiments. The heuristic by Chen et al. runs in the GROOT code, only with a different score function. This score function can be enabled by setting chen_heuristic=True on a GrootTreeClassifier before calling .fit(X, y). The provably robust boosting implementation comes almost completely from their code at https://github.com/max-andr/provably-robust-boosting and we use a small wrapper around their code (groot.provably_robust_boosting.wrapper.py) to use it. When we recorded the runtimes we turned off all parallel options in the @jit annotations from the code. The implementation of Chen et al. boosting can be found in their own repo https://github.com/chenhongge/RobustTrees, from whic we need to compile and copy the binary xgboost to the current directory. The script fit_chen_xgboost.py then calls this binary and uses the command line interface to fit all models.

Important note on TREANT

To encode L-infinity norms correctly we had to modify TREANT to NOT apply rules recursively. This means we added a single break statement in the treant.Attacker.__compute_attack() method. If you are planning on using TREANT with recursive attacker rules then you should remove this statement or use TREANT's unmodified code at https://github.com/gtolomei/treant .

Contact

For any questions or comments please create an issue or contact me directly.

Comments
  • Reproducing results from the article, issue with runtimes.csv

    Reproducing results from the article, issue with runtimes.csv

    Hello! I am trying to reproduce results from the article, and I can't figure out certain problem. First I am trying to run train_kfold_models, but the code always ouputs an error: "ImportError: cannot import name 'GrootTree' from 'groot.model'". Is there something wrong with the .py file I am trying to run, or is this problem something that doesn't occur to you and everyone else (-->something wrong on computer or files or environment)?

    Onni Mansikkamäki

    opened by OnniMansikkamaki 3
  • is_numerical argument GrootTreeClassifier

    is_numerical argument GrootTreeClassifier

    Running the example code on the make moons data in the README I get:

    Traceback (most recent call last):
      File "/home/.../groot_test.py", line 11, in <module>
        tree = GrootTreeClassifier(attack_model=attack_model, is_numerical=is_numerical, random_state=0)
    TypeError: __init__() got an unexpected keyword argument 'is_numerical'
    

    Leaving out the argument and having this line instead: tree = GrootTreeClassifier(attack_model=attack_model, random_state=0) results in this error:

    Traceback (most recent call last):
      File "/home/.../groot_test.py", line 15, in <module>
        adversarial_accuracy = DecisionTreeAdversary(tree, "groot").adversarial_accuracy(X_test, y_test)
      File "/home/.../venv/lib/python3.9/site-packages/groot/adversary.py", line 259, in __init__
        self.is_numeric = self.decision_tree.is_numerical
    AttributeError: 'GrootTreeClassifier' object has no attribute 'is_numerical'
    

    I'm guessing the code got an update, but the readme didn't. Or I made a stupid mistake, also very possible.

    opened by laudv 2
  • Reproducing result from paper

    Reproducing result from paper

    Hello! I am trying to reproduce the results from the paper. I am struggling to find, where these files: generate_k_fold_results.py, train_kfold_models.py, fit_chen_xgboost.py, image_experiments.py, summarize_datasets.py and visualize_threat_models.py are provided?

    Onni Mansikkamäki

    opened by OnniMansikkamaki 0
  • Regression decision trees and random forests

    Regression decision trees and random forests

    This PR adds GROOT decision trees and random forests that use the adversarial sum of absolute errors to make splits. It also adds new tests, speeds them up and updates the documentation.

    opened by daniel-vos 0
  • Add regression, tests and refactor into base class

    Add regression, tests and refactor into base class

    This PR adds a regression GROOT tree based on the adversarial sum of absolute errors, more tests and refactors GROOT trees into a base class (BaseGrootTree) with subclasses GrootTreeClassifier and GrootTreeRegressor extending it.

    opened by daniel-vos 0
Releases(v0.0.1)
Owner
Cyber Analytics Lab
@ Delft University of Technology
Cyber Analytics Lab
Dense matching library based on PyTorch

Dense Matching A general dense matching library based on PyTorch. For any questions, issues or recommendations, please contact Prune at

Prune Truong 399 Dec 28, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations.

HPNet This repository contains the PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations. Installation The

Siming Yan 42 Dec 07, 2022
Making a music video with Wav2CLIP and VQGAN-CLIP

music2video Overview A repo for making a music video with Wav2CLIP and VQGAN-CLIP. The base code was derived from VQGAN-CLIP The CLIP embedding for au

Joel Jang | 장요엘 163 Dec 26, 2022
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 08, 2022
Retina blood vessel segmentation with a convolutional neural network

Retina blood vessel segmentation with a convolution neural network (U-net) This repository contains the implementation of a convolutional neural netwo

Orobix 1.2k Jan 06, 2023
Torch implementation of various types of GAN (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN, LSGAN)

gans-collection.torch Torch implementation of various types of GANs (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN). Note that EBGAN and

Minchul Shin 53 Jan 22, 2022
Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training".

Mixup-Data-Dependency Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training". Running Alternating Line Exp

Muthu Chidambaram 0 Nov 11, 2021
Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search

CLIP-GLaSS Repository for the paper Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search An in-browser demo is

Federico Galatolo 172 Dec 22, 2022
Probabilistic Cross-Modal Embedding (PCME) CVPR 2021

Probabilistic Cross-Modal Embedding (PCME) CVPR 2021 Official Pytorch implementation of PCME | Paper Sanghyuk Chun1 Seong Joon Oh1 Rafael Sampaio de R

NAVER AI 87 Dec 21, 2022
Pytorch reimplementation of PSM-Net: "Pyramid Stereo Matching Network"

This is a Pytorch Lightning version PSMNet which is based on JiaRenChang/PSMNet. use python main.py to start training. PSM-Net Pytorch reimplementatio

XIAOTIAN LIU 1 Nov 25, 2021
Learning where to learn - Gradient sparsity in meta and continual learning

Learning where to learn - Gradient sparsity in meta and continual learning In this paper, we investigate gradient sparsity found by MAML in various co

Johannes Oswald 28 Dec 09, 2022
Target Propagation via Regularized Inversion

Target Propagation via Regularized Inversion The present code implements an ideal formulation of target propagation using regularized inverses compute

Vincent Roulet 0 Dec 02, 2021
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

E(n)-Equivariant Transformer (wip) Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant G

Phil Wang 132 Jan 02, 2023
Automatic Video Captioning Evaluation Metric --- EMScore

Automatic Video Captioning Evaluation Metric --- EMScore Overview For an illustration, EMScore can be computed as: Installation modify the encode_text

Yaya Shi 17 Nov 28, 2022
This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper

DeepShift This is project is the implementation of the DeepShift: Towards Multiplication-Less Neural Networks paper, that aims to replace multiplicati

Mostafa Elhoushi 88 Dec 23, 2022
Semantic Segmentation with SegFormer on Drone Dataset.

SegFormer_Segmentation Semantic Segmentation with SegFormer on Drone Dataset. You can check out the blog on Medium You can also try out the model with

Praneet 8 Oct 20, 2022
Code for "Learning to Regrasp by Learning to Place"

Learning2Regrasp Learning to Regrasp by Learning to Place, CoRL 2021. Introduction We propose a point-cloud-based system for robots to predict a seque

Shuo Cheng (成硕) 18 Aug 27, 2022
CS550 Machine Learning course project on CNN Detection.

CNN Detection (CS550 Machine Learning Project) Team Members (Tensor) : Yadava Kishore Chodipilli (11940310) Thashmitha BS (11941250) This is a work do

yaadava_kishore 2 Jan 30, 2022
Here we present the implementation in TensorFlow of our work about liver lesion segmentation accepted in the Machine Learning 4 Health Workshop

Detection-aided liver lesion segmentation Here we present the implementation in TensorFlow of our work about liver lesion segmentation accepted in the

Image Processing Group - BarcelonaTECH - UPC 96 Oct 26, 2022