- 深入浅出Python量化交易实战
- 段小手
- 1682字
- 2022-07-29 16:01:42
3.2.2 KNN算法用于分类
1.载入数据集并查看
scikit-learn内置了一些供大家学习的玩具数据集(toy dataset),其中有些是分类任务的数据,有些是回归任务的数据。首先我们使用一个最简单的数据集来给小瓦演示KNN算法在分类中的应用。输入代码如下:
#首先导入鸢尾花数据载入工具 from sklearn.datasets import load_iris #导入KNN分类模型 from sklearn.neighbors import KNeighborsClassifier #为了方便可视化,我们再导入matplotlib和seaborn import matplotlib.pyplot as plt import seaborn as sns
【结果分析】运行代码,如果程序没有报错,就说明所有的库都已经成功载入。接下来我们就用数据集载入工具加载数据。输入代码如下:
#加载鸢尾花数据集,赋值给iris变量 iris = load_iris() #查看数据集的键名 iris.keys()
运行代码,我们会得到下面的结果:
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
【结果分析】如果读者朋友们也得到了同样的结果,就说明代码运行成功。我们看到,该数据集存储了若干个键(key),这里我们重点关注一下其中的target和feature_names,因为这两个键对应的分别是样本的分类标签和特征名称。
首先我们看下数据集存储了样本的哪些特征,输入代码如下:
#查看数据集的特征名称 iris.feature_names
运行代码,可以得到如下结果:
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
【结果分析】从上述代码结果中可以看出,数据集中的样本共有4个特征,分别是sepal length(萼片长度)、sepal width(萼片宽度)、petal length(花瓣长度)和petal width(花瓣宽度)。
下面再来看一下这些样本被分为几类,输入代码如下:
#查看数据集中的样本分类 iris.target
运行代码,会得到结果如下:
【结果分析】观察代码的运行结果,我们可以发现系统返回了一个数组,数组中的数字有0、1和2。这说明数据集中的样本分为3类,分别用0、1、2这3个数字来表示。
到这里,相信小瓦也已经明白,这个数据集的目的是:根据样本鸢尾花萼片和花瓣的长度及宽度,结合分类标签来训练模型,以便让模型可以预测出某一种鸢尾花属于哪个分类。
2.拆分数据集
下面,我们就来把数据集拆分为训练集和验证集,以便验证模型的准确率。先输入如下代码:
#将样本的特征和标签分别赋值给X和y X, y = iris.data, iris.target #查看是否成功 X.shape
运行代码,会得到以下结果:
(150, 4)
【结果分析】从上面的代码运行结果可以看出,我们将数据集的特征赋值给了X,而将分类标签赋值给了y。通过查看X的形态,可知样本数量共有150个,每个样本有4个特征。
下面来对数据集进行拆分,输入代码如下:
#导入数据集拆分工具 from sklearn.model_selection import train_test_split #将X和y拆分为训练集和验证集 X_train, X_test, y_train, y_test =\ train_test_split(X, y) #查看拆分情况 X_train.shape
运行代码,会得到以下结果:
(112, 4)
【结果分析】从上面的代码运行结果可以看到,通过拆分,训练集中的样本数量为112个,其余的38个样本则进入了验证集。
3.训练模型并评估准确率
下面训练一个最简单的KNN模型,输入代码如下:
#创建KNN分类器,参数保持默认设置 knn_clf = KNeighborsClassifier() #使用训练集拟合模型 knn_clf.fit(X_train, y_train) #查看模型在训练集和验证集中的准确率 print('训练集准确率:%.2f'%knn_clf.score(X_train, y_train)) print('验证集准确率:%.2f'%knn_clf.score(X_test, y_test))
运行代码,会得到以下结果:
训练集准确率:0.98 验证集准确率:0.95
【结果分析】从上面的代码运行结果可以看到,使用KNN算法训练的分类模型,在训练集中的准确率达到了98%,在验证集中的准确率达到了0.95%。这是一个非常不错的成绩。
需要说明的是,在scikit-learn中,KNN可以通过调节n_neighbors参数来改进模型的性能。在不手动指定的情况下,KNN默认的近邻参数n_neighbors为5。那么这个参数是最优的吗?我们可以使用网格搜索法来寻找到模型的最优参数。输入代码如下:
运行代码,会得到以下结果:
{'n_neighbors': 6}
【结果分析】从上面的代码运行结果可以看到,程序将网格搜索找到的最优参数进行了返回——KNN分类器的最优n_neighbors参数是6。也就是说,当n_neighbors参数为6时,模型的准确率是最高的。
下面我们就来看一下当把n_neighbors设置为6时,模型的准确率。输入代码如下:
#创建KNN分类器,n_neighbors设置为6 knn_clf = KNeighborsClassifier(n_neighbors=6) #使用训练集拟合模型 knn_clf.fit(X_train, y_train) #查看模型在训练集和验证集中的准确率 print('训练集准确率:%.2f'%knn_clf.score(X_train, y_train)) print('验证集准确率:%.2f'%knn_clf.score(X_test, y_test))
运行代码,会得到以下结果:
训练集准确率:0.99 验证集准确率:0.95
【结果分析】从上面的代码运行结果可以看到,当把n_neighbors参数设置为6时,模型在训练集中的准确率提高到了99%,可以说这是非常不错的成绩了;而在验证集中的准确率依旧保持在95%左右,没有显著的提升。