决策树分类与回归模型的实现和可视化

Overview

DecisionTree

决策树分类与回归模型,以及可视化

ID3

ID3决策树是最朴素的决策树分类器:

  • 无剪枝
  • 只支持离散属性
  • 采用信息增益准则

data.py中,我们记录了一个小的西瓜数据集,用于离散属性的二分类任务。我们可以像下面这样训练一个ID3决策树分类器:

from ID3 import ID3Classifier
from data import load_watermelon2
import numpy as np

X, y = load_watermelon2(return_X_y=True) # 函数参数仿照sklearn.datasets
model = ID3Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树是正确的。

C4.5

C4.5决策树分类器对ID3进行了改进:

  • 用信息增益率的启发式方法来选择划分特征;
  • 能够处理离散型和连续型的属性类型,即将连续型的属性进行离散化处理;
  • 剪枝;
  • 能够处理具有缺失属性值的训练数据;

我们实现了前两点,以及第三点中的预剪枝功能(超参数)

data.py中还有一个连续离散特征混合的西瓜数据集,我们用它来测试C4.5决策树的效果:

from C4_5 import C4_5Classifier
from data import load_watermelon3
import numpy as np

X, y = load_watermelon3(return_X_y=True) # 函数参数仿照sklearn.datasets
model = C4_5Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树正确.

CART

分类

CART(Classification and Regression Tree)是C4.5决策树的扩展,支持分类和回归。CART分类树算法使用基尼系数选择特征,此外对于离散特征,CART决策树在每个节点二分划分,缓解了过拟合。

这里我们用sklearn中的鸢尾花数据集测试:

from CART import CARTClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTClassifier()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(accuracy_score(test_y, pred))

准确率95.55%。

回归

CARTRegressor类实现了决策树回归,以sklearn的波士顿数据集为例:

from CART import CARTRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X, y = load_boston(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTRegressor()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(mean_squared_error(test_y, pred))

输出26.352171052631576,sklearn决策树回归的Baseline是22.46,性能近似,说明我们的实现正确。

决策树绘制

分类树

利用python3的graphviz第三方库和Graphviz(需要安装),我们可以将决策树可视化:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
model = CARTClassifier()
model.fit(X, y)
tree_plot(model)

运行,文件夹中生成tree.png

iris_tree

如果提供了特征的名词和标签的名称,决策树会更明显:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

iris = load_iris()
model = CARTClassifier()
model.fit(iris.data, iris.target)
tree_plot(model,
          filename="tree2",
          feature_names=iris.feature_names,
          target_names=iris.target_names)

iris_tree2

绘制西瓜数据集2对应的ID3决策树:

from plot import tree_plot
from ID3 import ID3Classifier
from data import load_watermelon2

watermelon = load_watermelon2()
model = ID3Classifier()
model.fit(watermelon.data, watermelon.target)
tree_plot(
    model,
    filename="tree",
    font="SimHei",
    feature_names=watermelon.feature_names,
    target_names=watermelon.target_names,
)

这里要自定义字体,否则无法显示中文:

watermelon

回归树

用同样的方法,我们可以进行回归树的绘制:

from plot import tree_plot
from ID3 import ID3Classifier
from sklearn.datasets import load_boston

boston = load_boston()
model = ID3Classifier(max_depth=5)
model.fit(boston.data, boston.target)
tree_plot(
    model,
    feature_names=boston.feature_names,
)

由于生成的回归树很大,我们限制最大深度再绘制:

regression

调参

CART和C4.5都是有超参数的,我们让它们作为sklearn.base.BaseEstimator的派生类,借助sklearn的GridSearchCV,就可以实现调参:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, GridSearchCV

wine = load_wine()
train_X, test_X, train_y, test_y = train_test_split(
    wine.data,
    wine.target,
    train_size=0.7,
)
model = CARTClassifier()
grid_param = {
    'max_depth': [2, 4, 6, 8, 10],
    'min_samples_leaf': [1, 3, 5, 7],
}

search = GridSearchCV(model, grid_param, n_jobs=4, verbose=5)
search.fit(train_X, train_y)
best_model = search.best_estimator_
print(search.best_params_, search.best_estimator_.score(test_X, test_y))
tree_plot(
    best_model,
    feature_names=wine.feature_names,
    target_names=wine.target_names,
)

输出最优参数和最优模型在测试集上的表现:

{'max_depth': 4, 'min_samples_leaf': 3} 0.8518518518518519

绘制对应的决策树:

wine

剪枝

在ID3和CART回归中加入了REP剪枝,C4.5则支持了PEP剪枝。

对IRIS数据集训练后的决策树进行PEP剪枝:

iris = load_iris()
model = C4_5Classifier()
X, y = iris.data, iris.target
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model.fit(train_X, train_y)
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/pre_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names)
model.pep_pruning()
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/post_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names,
)

剪枝前后的准确率分别为97.78%,100%,即泛化性能的提升:

prepre

Owner
Welt Xing
Undergraduate in AI school, Nanjing University. Main interest(for now): Machine learning and deep learning.
Welt Xing
A high performance and generic framework for distributed DNN training

BytePS BytePS is a high performance and general distributed training framework. It supports TensorFlow, Keras, PyTorch, and MXNet, and can run on eith

Bytedance Inc. 3.3k Dec 28, 2022
Time series forecasting with PyTorch

Our article on Towards Data Science introduces the package and provides background information. Pytorch Forecasting aims to ease state-of-the-art time

Jan Beitner 2.5k Jan 02, 2023
CD) in machine learning projectsImplementing continuous integration & delivery (CI/CD) in machine learning projects

CML with cloud compute This repository contains a sample project using CML with Terraform (via the cml-runner function) to launch an AWS EC2 instance

Iterative 19 Oct 03, 2022
Relevance Vector Machine implementation using the scikit-learn API.

scikit-rvm scikit-rvm is a Python module implementing the Relevance Vector Machine (RVM) machine learning technique using the scikit-learn API. Quicks

James Ritchie 204 Nov 18, 2022
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective. 10x Larger Models 10x Faster Trainin

Microsoft 8.4k Dec 30, 2022
A game theoretic approach to explain the output of any machine learning model.

SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allo

Scott Lundberg 18.2k Jan 02, 2023
neurodsp is a collection of approaches for applying digital signal processing to neural time series

neurodsp is a collection of approaches for applying digital signal processing to neural time series, including algorithms that have been proposed for the analysis of neural time series. It also inclu

NeuroDSP 224 Dec 02, 2022
Python factor analysis library (PCA, CA, MCA, MFA, FAMD)

Prince is a library for doing factor analysis. This includes a variety of methods including principal component analysis (PCA) and correspondence anal

Max Halford 915 Dec 31, 2022
Provide an input CSV and a target field to predict, generate a model + code to run it.

automl-gs Give an input CSV file and a target field you want to predict to automl-gs, and get a trained high-performing machine learning or deep learn

Max Woolf 1.8k Jan 04, 2023
AtsPy: Automated Time Series Models in Python (by @firmai)

Automated Time Series Models in Python (AtsPy) SSRN Report Easily develop state of the art time series models to forecast univariate data series. Simp

Derek Snow 465 Jan 02, 2023
A high-performance topological machine learning toolbox in Python

giotto-tda is a high-performance topological machine learning toolbox in Python built on top of scikit-learn and is distributed under the G

giotto.ai 632 Dec 29, 2022
An implementation of Relaxed Linear Adversarial Concept Erasure (RLACE)

Background This repository contains an implementation of Relaxed Linear Adversarial Concept Erasure (RLACE). Given a dataset X of dense representation

Shauli Ravfogel 4 Apr 13, 2022
Retrieve annotated intron sequences and classify them as minor (U12-type) or major (U2-type)

(intron I nterrogator and C lassifier) intronIC is a program that can be used to classify intron sequences as minor (U12-type) or major (U2-type), usi

Graham Larue 4 Jul 26, 2022
Required for a machine learning pipeline data preprocessing and variable engineering script needs to be prepared

Feature-Engineering Required for a machine learning pipeline data preprocessing and variable engineering script needs to be prepared. When the dataset

kemalgunay 5 Apr 21, 2022
Probabilistic programming framework that facilitates objective model selection for time-varying parameter models.

Time series analysis today is an important cornerstone of quantitative science in many disciplines, including natural and life sciences as well as eco

Christoph Mark 129 Dec 24, 2022
Repository for DCA0305, an undergraduate course about Machine Learning Workflows and Pipelines

Federal University of Rio Grande do Norte Technology Center Department of Computer Engineering and Automation Machine Learning Based Systems Design Re

Ivanovitch Silva 81 Oct 18, 2022
Databricks Certified Associate Spark Developer preparation toolkit to setup single node Standalone Spark Cluster along with material in the form of Jupyter Notebooks.

Databricks Certification Spark Databricks Certified Associate Spark Developer preparation toolkit to setup single node Standalone Spark Cluster along

19 Dec 13, 2022
Code for the TCAV ML interpretability project

Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV) Been Kim, Martin Wattenberg, Justin Gilmer, C

552 Dec 27, 2022
nn-Meter is a novel and efficient system to accurately predict the inference latency of DNN models on diverse edge devices

A DNN inference latency prediction toolkit for accurately modeling and predicting the latency on diverse edge devices.

Microsoft 241 Dec 26, 2022
MCML is a toolkit for semi-supervised dimensionality reduction and quantitative analysis of Multi-Class, Multi-Label data

MCML is a toolkit for semi-supervised dimensionality reduction and quantitative analysis of Multi-Class, Multi-Label data. We demonstrate its use

Pachter Lab 26 Nov 29, 2022