决策树分类与回归模型的实现和可视化

Overview

DecisionTree

决策树分类与回归模型,以及可视化

ID3

ID3决策树是最朴素的决策树分类器:

  • 无剪枝
  • 只支持离散属性
  • 采用信息增益准则

data.py中,我们记录了一个小的西瓜数据集,用于离散属性的二分类任务。我们可以像下面这样训练一个ID3决策树分类器:

from ID3 import ID3Classifier
from data import load_watermelon2
import numpy as np

X, y = load_watermelon2(return_X_y=True) # 函数参数仿照sklearn.datasets
model = ID3Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树是正确的。

C4.5

C4.5决策树分类器对ID3进行了改进:

  • 用信息增益率的启发式方法来选择划分特征;
  • 能够处理离散型和连续型的属性类型,即将连续型的属性进行离散化处理;
  • 剪枝;
  • 能够处理具有缺失属性值的训练数据;

我们实现了前两点,以及第三点中的预剪枝功能(超参数)

data.py中还有一个连续离散特征混合的西瓜数据集,我们用它来测试C4.5决策树的效果:

from C4_5 import C4_5Classifier
from data import load_watermelon3
import numpy as np

X, y = load_watermelon3(return_X_y=True) # 函数参数仿照sklearn.datasets
model = C4_5Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树正确.

CART

分类

CART(Classification and Regression Tree)是C4.5决策树的扩展,支持分类和回归。CART分类树算法使用基尼系数选择特征,此外对于离散特征,CART决策树在每个节点二分划分,缓解了过拟合。

这里我们用sklearn中的鸢尾花数据集测试:

from CART import CARTClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTClassifier()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(accuracy_score(test_y, pred))

准确率95.55%。

回归

CARTRegressor类实现了决策树回归,以sklearn的波士顿数据集为例:

from CART import CARTRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X, y = load_boston(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTRegressor()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(mean_squared_error(test_y, pred))

输出26.352171052631576,sklearn决策树回归的Baseline是22.46,性能近似,说明我们的实现正确。

决策树绘制

分类树

利用python3的graphviz第三方库和Graphviz(需要安装),我们可以将决策树可视化:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
model = CARTClassifier()
model.fit(X, y)
tree_plot(model)

运行,文件夹中生成tree.png

iris_tree

如果提供了特征的名词和标签的名称,决策树会更明显:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

iris = load_iris()
model = CARTClassifier()
model.fit(iris.data, iris.target)
tree_plot(model,
          filename="tree2",
          feature_names=iris.feature_names,
          target_names=iris.target_names)

iris_tree2

绘制西瓜数据集2对应的ID3决策树:

from plot import tree_plot
from ID3 import ID3Classifier
from data import load_watermelon2

watermelon = load_watermelon2()
model = ID3Classifier()
model.fit(watermelon.data, watermelon.target)
tree_plot(
    model,
    filename="tree",
    font="SimHei",
    feature_names=watermelon.feature_names,
    target_names=watermelon.target_names,
)

这里要自定义字体,否则无法显示中文:

watermelon

回归树

用同样的方法,我们可以进行回归树的绘制:

from plot import tree_plot
from ID3 import ID3Classifier
from sklearn.datasets import load_boston

boston = load_boston()
model = ID3Classifier(max_depth=5)
model.fit(boston.data, boston.target)
tree_plot(
    model,
    feature_names=boston.feature_names,
)

由于生成的回归树很大,我们限制最大深度再绘制:

regression

调参

CART和C4.5都是有超参数的,我们让它们作为sklearn.base.BaseEstimator的派生类,借助sklearn的GridSearchCV,就可以实现调参:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, GridSearchCV

wine = load_wine()
train_X, test_X, train_y, test_y = train_test_split(
    wine.data,
    wine.target,
    train_size=0.7,
)
model = CARTClassifier()
grid_param = {
    'max_depth': [2, 4, 6, 8, 10],
    'min_samples_leaf': [1, 3, 5, 7],
}

search = GridSearchCV(model, grid_param, n_jobs=4, verbose=5)
search.fit(train_X, train_y)
best_model = search.best_estimator_
print(search.best_params_, search.best_estimator_.score(test_X, test_y))
tree_plot(
    best_model,
    feature_names=wine.feature_names,
    target_names=wine.target_names,
)

输出最优参数和最优模型在测试集上的表现:

{'max_depth': 4, 'min_samples_leaf': 3} 0.8518518518518519

绘制对应的决策树:

wine

剪枝

在ID3和CART回归中加入了REP剪枝,C4.5则支持了PEP剪枝。

对IRIS数据集训练后的决策树进行PEP剪枝:

iris = load_iris()
model = C4_5Classifier()
X, y = iris.data, iris.target
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model.fit(train_X, train_y)
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/pre_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names)
model.pep_pruning()
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/post_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names,
)

剪枝前后的准确率分别为97.78%,100%,即泛化性能的提升:

prepre

Owner
Welt Xing
Undergraduate in AI school, Nanjing University. Main interest(for now): Machine learning and deep learning.
Welt Xing
Skoot is a lightweight python library of machine learning transformer classes that interact with scikit-learn and pandas.

Skoot is a lightweight python library of machine learning transformer classes that interact with scikit-learn and pandas. Its objective is to ex

Taylor G Smith 54 Aug 20, 2022
Little Ball of Fur - A graph sampling extension library for NetworKit and NetworkX (CIKM 2020)

Little Ball of Fur is a graph sampling extension library for Python. Please look at the Documentation, relevant Paper, Promo video and External Resour

Benedek Rozemberczki 619 Dec 14, 2022
Machine-Learning with python (jupyter)

Machine-Learning with python (jupyter) 머신러닝 야학 작심 10일과 쥬피터 노트북 기반 데이터 사이언스 시작 들어가기전 https://nbviewer.org/ 페이지를 통해서 쥬피터 노트북 내용을 볼 수 있다. 위 페이지에서 현재 레포 기

HyeonWoo Jeong 1 Jan 23, 2022
Kalman filter library

The kalman filter framework described here is an incredibly powerful tool for any optimization problem, but particularly for visual odometry, sensor fusion localization or SLAM.

comma.ai 276 Jan 01, 2023
Traingenerator 🧙 A web app to generate template code for machine learning ✨

Traingenerator 🧙 A web app to generate template code for machine learning ✨ 🎉 Traingenerator is now live! 🎉

Johannes Rieke 1.2k Jan 07, 2023
List of Data Science Cheatsheets to rule the world

Data Science Cheatsheets List of Data Science Cheatsheets to rule the world. Table of Contents Business Science Business Science Problem Framework Dat

Favio André Vázquez 11.7k Dec 30, 2022
(3D): LeGO-LOAM, LIO-SAM, and LVI-SAM installation and application

SLAM-application: installation and test (3D): LeGO-LOAM, LIO-SAM, and LVI-SAM Tested on Quadruped robot in Gazebo ● Results: video, video2 Requirement

EungChang-Mason-Lee 203 Dec 26, 2022
Python package for stacking (machine learning technique)

vecstack Python package for stacking (stacked generalization) featuring lightweight functional API and fully compatible scikit-learn API Convenient wa

Igor Ivanov 671 Dec 25, 2022
LILLIE: Information Extraction and Database Integration Using Linguistics and Learning-Based Algorithms

LILLIE: Information Extraction and Database Integration Using Linguistics and Learning-Based Algorithms Based on the work by Smith et al. (2021) Query

5 Aug 06, 2022
Time series forecasting with PyTorch

Our article on Towards Data Science introduces the package and provides background information. Pytorch Forecasting aims to ease state-of-the-art time

Jan Beitner 2.5k Jan 02, 2023
🌲 Implementation of the Robust Random Cut Forest algorithm for anomaly detection on streams

🌲 Implementation of the Robust Random Cut Forest algorithm for anomaly detection on streams

Real-time water systems lab 416 Jan 06, 2023
PLUR is a collection of source code datasets suitable for graph-based machine learning.

PLUR (Programming-Language Understanding and Repair) is a collection of source code datasets suitable for graph-based machine learning. We provide scripts for downloading, processing, and loading the

Google Research 76 Nov 25, 2022
MLOps pipeline project using Amazon SageMaker Pipelines

This project shows steps to build an end to end MLOps architecture that covers data prep, model training, realtime and batch inference, build model registry, track lineage of artifacts and model drif

AWS Samples 3 Sep 16, 2022
MLBox is a powerful Automated Machine Learning python library.

MLBox is a powerful Automated Machine Learning python library. It provides the following features: Fast reading and distributed data preprocessing/cle

Axel 1.4k Jan 06, 2023
A simple guide to MLOps through ZenML and its various integrations.

ZenBytes Join our Slack Community and become part of the ZenML family Give the main ZenML repo a GitHub star to show your love ZenBytes is a series of

ZenML 127 Dec 27, 2022
ParaMonte is a serial/parallel library of Monte Carlo routines for sampling mathematical objective functions of arbitrary-dimensions

ParaMonte is a serial/parallel library of Monte Carlo routines for sampling mathematical objective functions of arbitrary-dimensions, in particular, the posterior distributions of Bayesian models in

Computational Data Science Lab 182 Dec 31, 2022
Machine Learning Course with Python:

A Machine Learning Course with Python Table of Contents Download Free Deep Learning Resource Guide Slack Group Introduction Motivation Machine Learnin

Instill AI 6.9k Jan 03, 2023
Upgini : data search library for your machine learning pipelines

Automated data search library for your machine learning pipelines → find & deliver relevant external data & features to boost ML accuracy :chart_with_upwards_trend:

Upgini 175 Jan 08, 2023
Neural Machine Translation (NMT) tutorial with OpenNMT-py

Neural Machine Translation (NMT) tutorial with OpenNMT-py. Data preprocessing, model training, evaluation, and deployment.

Yasmin Moslem 29 Jan 09, 2023
PySpark ML Bank Churn Prediction

PySpark-Bank-Churn Surname: corresponds to the record (row) number and has no effect on the output. CreditScore: contains random values and has no eff

kemalgunay 2 Nov 11, 2021