首页 > 编程知识 正文

iris数据集决策树可视化代码,使用sklearn的决策树算法对iris数据集进行分类

时间:2023-05-05 16:08:43 阅读:213159 作者:2208

此处主要学习决策树的分类问题——DecisionTreeClassifier

1、决策树算法的环境搭建

GraphViz是将决策树模型可视化的一个模块。Anaconda不自带该模块,因此想要可视化决策树则需要安装Graphviz,执行以下步骤:

(1)可通过网址https://graphviz.io/_pages/Download/Download_windows.html下载安装Graphviz。如果计算机系统是Linux,可以用apt-get或者yum方法安装。若是Windows系统,在官网下载GraphViz-2.38.msi文件并安装。无论是Linux还是Windows,装完后都要设置环境变量,将GraphViz的大方的舞蹈目录加入PATH。如果是Windows系统,将C:/Program Files(x86)/Graphviz2.38/大方的舞蹈/加入PATH。

(2)安装Python插件GraphViz,在Anaconda Prompt弹出的窗口中运行下面的命令:

pip install graphviz

(3)安装Python插件pydotplus.

conda install -c conda-forge pydotpluspip install pydotplus

这样环境就搭好了,有时候仍然找不到graphviz,这时可以在代码里面加入这一行:

import os os.environ["PATH"]+=os.pathsep+'C:/Program Files(x86)/Graphviz2.38/大方的舞蹈/'

2、使用决策树对鸢尾花数据集iris分析

主要是数据集构造决策树,先生成DecisionTreeClassifier类的一个实例(如clf_tree);然后使用该实例调用fit()方法进行训练;对于训练好的决策树模型,可以使用predict()方法对新的样本进行预测,其predict()将返回新样本值的预测类别;sklearn.tree模块提供了训练决策树模型的文本描述输出方法export_graphviz(),该方法可查看训练的决策树模型参数。

(1)先导入iris数据集,观察数据集的基本信息

from sklearn import datasetsiris=datasets.load_iris()print('iris.data的形状为',iris.data.shape)print('iris.data的特征名称为:',iris.feature_names)print('iris.target的内容为:n',iris.target)print('iris.target的形状为',iris.target.shape)print('iris.target的鸢尾花名称为',iris.target_names)

(2)导入tree模块,生成DecisionTreeClassifier()类的实例,训练模型,并输出模型数据文件

x=iris.data # 数据特征y=iris.target # 数据特征from sklearn import tree # 导入scikit-learn的tree模块# 评价标准为 criterion='entropy',决策树最大深度为 max_depth=2clf_tree=tree.DecisionTreeClassifier(criterion='entropy',max_depth=2) clf_tree.fit(x,y)dot_data=tree.export_graphviz(clf_tree,out_file=None, feature_names=iris.feature_names, class_names=True,filled=True,rounded=True)print('dot_data决策结果数据文件为:n',dot_data)

(3)为了能直观地观察训练好的决策树,则利用pydotplus+GraphViz将决策树可视化,有2种方法,下面我们一一例举

# 决策树可视化方法一:(可直接把图产生在python的notebook)from IPython.display import Image import pydotplus graph = pydotplus.graph_from_dot_data(dot_data) Image(graph.create_png())

# 决策树可视化方法二(用pydotplus生成iris.pdf)import pydotplus graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("iris.pdf")

可数化结果保存在iris.pdf文件中

 决策树可视化的方法中,个人比较推荐第一种的做法,因为这样可以直接把图产生在ipython的notebook,直接观察其结果。

(4)使用训练好的决策树模型clf_tree对数据集进行预测,将预测结果与真实类标签进行可视化对比,观察其预测结果。

# 预测结果部分y_predict=clf_tree.predict(x)# 可视化部分import matplotlib.pyplot as plt plt.rcParams['font.sans-serif']='SimHei' # 设置字体为SimHei以显示中文plt.rcParams['axes.unicode_minus']=False # 坐标轴刻度显示负号plt.rc('font',size=(14))plt.scatter(range(len(y)),y,marker='o')plt.scatter(range(len(y)),y_predict+0.1,marker='*')plt.legend(['真实类别','预测类别'])plt.title('使用决策树对iris数据集的预测结果与真实类别进行对比')plt.show()

从上图中,我们可以看出有6个样本的预测结果是错误的。

(5)改变评价标准然后观察其预测结果

# 调用决策树,将评价标准改为:giniclf_tree2=tree.DecisionTreeClassifier(criterion='gini',max_depth=2)clf_tree2.fit(x,y)dot_data=tree.export_graphviz(clf_tree2,out_file=None, feature_names=iris.feature_names, class_names=True, filled=True,rounded=True)graph = pydotplus.graph_from_dot_data(dot_data) Image(graph.create_png())

# 预测值部分y_predict2=clf_tree2.predict(x)# 可视化部分plt.figure(figsize=(10,4))plt.scatter(range(len(y)),y,marker='o')plt.scatter(range(len(y)),y_predict2+0.1,marker='*')plt.legend(['真实类别','预测类别'])plt.title('使用决策树对iris数据集的预测结果与真实类别进行对比')plt.show()

我们从图中可以看出,还是有6个样本的预测结果是错误的,这说明不能通过改变评价标准的方式提高预测结果的准确率。

(6)改变深度然后观察其预测结果

# 调用决策树,改变最大深度clf_tree3=tree.DecisionTreeClassifier(criterion='entropy',max_depth=3)clf_tree3.fit(x,y)dot_data=tree.export_graphviz(clf_tree3,out_file=None, feature_names=iris.feature_names, class_names=True, filled=True,rounded=True)graph = pydotplus.graph_from_dot_data(dot_data) Image(graph.create_png())

# 预测结果部分y_predict3=clf_tree3.predict(x)# 可视化部分plt.figure(figsize=(10,4))plt.scatter(range(len(y)),y,marker='o')plt.scatter(range(len(y)),y_predict3+0.1,marker='*')plt.legend(['真实类别','预测类别'])plt.title('使用决策树对iris数据集的预测结果与真实类别进行对比')plt.show()

 从上图我们可以看到,预测结果的4个样本是错误的,这说明我们可以通过改变深度来提高分类的准确率。

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。