2.2.3 CART分类决策树的编程实践

针对2.2.2节的天气与是否打网球的数据集(PlayTennis数据集),我们利用Python和PyTorch编码展示CART分类树模型的细节。

2.2.3.1 整体流程

首先介绍整体流程,如代码段2.1所示。程序主要由四部分组成:数据集加载、模型训练、模型预测和决策树可视化。

代码段2.1 CART分类树测试主程序(源码位于Chapter02/test_CartClassifier.py)

1. 数据集加载(第14~23行)

首先,第17行使用“with open as”语法打开指定数据文件,并获取文件句柄f。该语法会在执行完毕“with open”作用域内的代码后自动关闭数据文件。open函数的第一个参数为数据文件的路径;第二个参数为文件打开方式,“r”代表以只读方式打开;encoding参数指定读取文件的编码格式,“gbk”代表使用GBK编码。

然后,第18行使用csv库读取数据文件内容并将其转换成Python内置的list类型。使用csv库需要在代码片段开头添加“import csv”语句。读取数据时使用csv库的reader函数,参数传入“with open”获取的文件句柄f。

最后,第19~23行利用Python切片和列表生成式将原始数据集分割成属性名列表feature_names、目标变量名y_name、属性集X、目标变量集y。由于本例是解决数据集比较小的分类问题,因此令训练集(X_train和y_train)和测试集(X_test和y_test)使用相同的数据集(X和y),而通常的做法是将原始数据集按8:1:1或6:2:2的比例划分成训练集、测试集和验证集。另外,为了便于使用PyTorch进行GPU加速,我们将数据集从list类型进一步转换为numpy类型(PyTorch内部集成了numpy与Tensor的快速转换方法)。同样,使用numpy库也需要导入相应的包,语句为“import numpy as np”。

2. 决策树模型的训练和生成(第25~32行)

CART分类树的创建和训练过程被封装成CartClassifier类,通过使用“from cart import CartClassifier”将其导入当前环境。

在第27行创建决策树的过程中,触发CartClassifier类的构造函数。在这里,设置use_gpu参数为True,代表启用GPU加速。此外,该构造函数还可以传入其他参数,后面的内容将对其进行展开介绍。

在第31行CART分类树的训练过程中,调用CartClassifier类的train成员函数,传入训练集数据X_train、y_train和feature_names。训练完成后,返回模型数据如下:

该model变量实际上是由一组规则表示的。以上输出结果为决策树的字典(树形结构)数据结构形式,在这棵model树中,从根节点到每个叶子节点的每条路径都代表一条规则。为了更清晰地表示规则,我们可以将以上数据结构转换成“if-then”的格式,如下所示:

事实上,CART分类决策树与一组“if-then”规则是等价的。

3. 决策树模型的使用(第34~41行)

在第36~38行的模型预测阶段,调用CartClassifier类的成员函数predict,传入测试集数据X_test,返回numpy.array类型的预测结果y_pred,并且打印输出测试集的真实值y_test和预测值y_pred。

在第39~41行的模型评估阶段,首先使用Python的列表生成式生成测试集中预测值与真实值相等的元素,每种相等的情况用int型变量1表示。然后使用numpy的sum函数对上述列表求和,统计出预测正确的计数。最后打印预测正确的样本计数、总样本计数以及预测准确率。实际执行结果如下:

4. 决策树的可视化(第43~45行)

决策树可视化阶段使用了tree_plotter包的tree_plot函数,传入前面训练好的模型model。tree_plotter包是我们使用Matplotlib自定义的决策树绘图包,在随后的内容中我们将详细介绍,在此先展示一下可视化的效果,如图2.9所示。

2.2.3.2 训练和创建过程

首先介绍用到的构造函数,见代码段2.2。从第8~9行可以看到,CartClassifier类的实现依赖于torch和numpy。

在CartClassifier类的构造函数__init__中,需要提供use_gpu和min_samples_split两个参数。其中,use_gpu是一个布尔值,代表该类是否启用GPU加速,默认为False,代表使用CPU;min_samples_split是一个整型数,代表决策树分裂完成后叶子节点的最少样本数,默认值为1,代表树完全分裂。

图2.9 PlayTennis数据集生成的CART分类树

代码段2.2 CART分类树的构造函数(源码位于Chapter02/CartClassifier.py)

另外,构造函数中还维护一系列类的核心变量。其中,self.tree为存储树模型的核心结构,初始情况下为空dict;self.feature_names为存储数据集属性名的numpy数组;self.str_map、self.num_map、self.x_use_map、self.y_use_map为字符串与数值之间的映射器和开关,用于将numpy数组中的字符串类型映射成数值类型,以兼容PyTorch,与之相关的函数接口为__deal_value_map和__get_value,在后文中将逐一介绍。

接下来介绍CART分类树的训练函数train,如代码段2.3所示。由于Python函数中传递的numpy变量是引用,为了避免后续对数据集进行分割时破坏原始数据集,首先在第36~37行执行numpy.array的copy函数制作X和y的副本。然后在第41行进行数据预处理,将X_copy和y_copy中的字符串通过__deal_value_map函数映射成数值。之后在第44~48行将numpy数组X_copy和y_copy转换成Tensor数组,并根据self.use_gpu的值决定是否启用GPU加速。其中,torch.from_numpy函数是PyTorch提供的内置函数,负责将numpy.array数组转化成torch.Tensor格式,torch.Tensor.cuda函数也是PyTorch提供的内置函数,负责对当前的tensor数组启用GPU加速。最后,第51行进入创建CART分类树的核心函数__create_tree。

代码段2.3 CART分类树训练过程(源码位于Chapter02/CartClassifier.py)

第41行提到了一个关键的字符串映射函数__deal_value_map,在这里我们对它进行详细介绍。之所以将X_copy和y_copy中的字符串通过该函数映射成数值,是因为在执行计算时,为了实现GPU加速,numpy.array需要转换成PyTorch的torch.Tensor数据结构,而torch.Tensor仅支持数值类型。具体实现过程见代码段2.4。

代码段2.4 建立字符串与数值的映射的函数__deal_value_map(源码位于Chapter02/CartClassifier.py)

代码段2.4展示了从字符串到数值建立映射的过程。首先,在第91~103行处理X,改变self.x_use_map的标记,遍历X中的每个元素,以元素值为key,以当前self.str_map的长度为value,在self.str_map中建立映射,同时,在self.num_map中建立反向的映射。然后,在第106~117行处理y,同理,在y不为None的情况下,在当前self.str_map和self.num_map的基础上继续建立字符串映射。最后,在第120~123行返回映射好的X和y。

代码段2.5则实现将数值映射回字符串的功能。其中分为两种情况:一种是使用了字符串到数值的映射的key(通过to_X、self.x_use_map和self.y_use_map的值可以判别,如第132~133行所示),此时使用self.num_map映射回字符串;另一种是没有使用映射的key(如第134~137行所示),这种情况直接返回key或者key.item(key为Tensor中的数值类型时)的值。

代码段2.5 从数值到字符串的映射中还原值(源码位于Chapter02/CartClassifier.py)

接下来,我们回到代码段2.3。在代码段2.3的第51行调用了__create_tree函数,它完成了决策树的创建过程,其具体代码见代码段2.6。

代码段2.6 创建CART分类树的核心代码(源码位于Chapter02/CartClassifier.py)

在上述代码段中,__create_tree函数是一个递归创建决策树的过程。首先,在第147~153行判断三种递归终止条件:X中样本全部属于同一类别、当前节点样本数小于self.min_samples_split、属性集上的取值均相同。若满足终止条件,则调用__get_value函数返回从数值到字符串的映射值,若未满足终止条件,则继续往下计算。然后,在第155~157行根据基尼增益从属性值中选择最优分裂属性的最优切分点,具体过程如__choose_best_point_to_split函数所示。最后,在第159~169行根据最优切分点对子树进行划分,对于其子树再继续执行__create_tree函数完成划分过程。

在代码段2.6的第153行调用了__majority_y_id函数,它用于计算节点中出现次数最多的类别,具体见代码段2.7。它首先在第191行进行合法性检查,确保输入参数y_tensor的元素个数大于0。然后在第193~197行初始化一个空dict,遍历y_tensor并对其元素进行计数。最后在第199~203行从字典y_count中查找出现次数最多的类别ID(或映射值)。

在代码段2.6的第156行,调用了__choose_best_point_to_split函数,它用于选择最优切分点,具体见代码段2.8。

代码段2.7 计算节点中出现次数最多的类别(源码位于Chapter02/CartClassifier.py)

代码段2.8 选择最优切分点(源码位于Chapter02/CartClassifier.py)

代码段2.8是CART分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。首先在第219行调用__cal_gini_impurity函数计算总数据集的基尼不纯度G(root)。然后在第220~237行遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度G(left)和G(right),以及左右子树中数据样本在总样本中占的比例P(left)和P(right),并且将G(root)、G(left)、G(right)、P(left)和P(right)代入__cal_gini_gain函数中计算基尼增益。最后在第239~243行选出具有最大基尼增益的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。

在代码段2.8的第236行,调用__cal_gini_gain函数来计算基尼增益。在计算基尼增益之前,我们需要知道如何计算一个数据集的基尼不纯度。如代码段2.8的第229行和第232行所示,通过调用__cal_gini_impurity函数来计算基尼不纯度,它的具体实现见代码段2.9。

代码段2.9 计算基尼不纯度(源码位于Chapter02/CartClassifier.py)

在代码段2.9的第253~258行,我们分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中。在第260~265行遍历该字典,根据公式用1减去不同的类分布概率的平方和,得到最终的基尼不纯度。接下来在计算基尼不纯度的基础上进一步实现基尼增益的计算,即__cal_gini_gain函数。它的具体代码见代码段2.10。

代码段2.10 计算基尼增益(源码位于Chapter02/CartClassifier.py)

求解基尼指数的过程与求解基尼增益的过程有着相似之处,它们都需要划分数据求出基尼不纯度,以及左右子树中类的比例,只不过基尼指数不需要求总数据集的基尼不纯度,而是将pro_left*gini_impurity_left与pro_right*gini_impurity_right累加求和。因此,在选择最优切分点时,我们选择具有最大基尼增益的属性值,或者具有最小基尼指数的属性值。

2.2.3.3 预测过程

代码段2.11a和代码段2.11b演示了CART分类树进行预测时的整体过程。在预测过程中,依然首先在代码段2.11a的第61~69行拷贝数据集、处理字符串映射和判断是否启用GPU加速,然后在代码段2.11a的第72行和代码段2.11b的第178~182行遍历测试集X_tensor的每个样本,使用__classify函数分别对其进行预测,最终返回拼接好的预测结果。从代码段2.11b的结构中可以看出非常好的并行性,因此在多核CPU机器上处理大数据预测时,使用多线程将其并行化可以大大提升预测效率,对此不做赘述。

代码段2.11a CART分类树预测过程(源码位于Chapter02/CartClassifier.py)

代码段2.11b CART分类树预测过程(源码位于Chapter02/CartClassifier.py)

在代码段2.11b的第180行,通过调用__classify进行预测分类,其具体代码见代码段2.12。在函数__classify的参数中,树模型tree是字典结构,它的每两层代表了实际意义上的一层决策树。因此在第291~304行的递归遍历过程中,每次取出tree的前两层(根节点和根节点的左右孩子节点),其中根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。根据以上特点,从根节点开始,递归遍历CART分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。

代码段2.12 CART分类树预测的核心代码(源码位于Chapter02/CartClassifier.py)

2.2.3.4 可视化过程

对于可视化过程,可以借助Matplotlib库来实现,为此,我们结合树的遍历特点,封装了一套适用于上述决策树的tree_plotter可视化包。

在代码段2.13a和2.13b中,tree_plot函数为该包对外提供的决策树绘制接口,其整体算法的思路可分为两个步骤:首先绘制自身节点,然后判断自身节点类型,若为非叶子节点则继续递归创建子树,若为叶子节点则直接绘制。关于更详细的实现细节,请读者自行阅读源码,由于篇幅原因在此不再展开。

代码段2.13a tree_plotter可视化包(源码位于Chapter02/tree_plotter.py)

代码段2.13b tree_plotter可视化包(源码位于Chapter02/tree_plotter.py)