首页 > 编程知识 正文

Python绘制混淆矩阵图表

时间:2024-04-28 10:06:42 阅读:336091 作者:FUVH

一、介绍

混淆矩阵(Confusion Matrix)是机器学习领域中的一个重要概念,是用于评估分类器性能的一种矩阵。混淆矩阵能够直观的展示分类结果的正确性和错误性,有助于分析分类器的优缺点,优化分类器性能。在实际应用中,我们经常需要使用混淆矩阵来评估模型性能,这时候就需要使用Python进行混淆矩阵的绘制和分析。

Python作为一门强大的编程语言,有着丰富的机器学习库和可视化工具,可以方便地进行混淆矩阵的绘制和分析。本文将介绍如何使用Python库对混淆矩阵进行可视化展示,包括常用的sklearn库和matplotlib库,帮助读者掌握Python绘制混淆矩阵图表的技能。

二、如何绘制混淆矩阵

1. 混淆矩阵是什么

混淆矩阵是分类模型评价指标中的一种,主要提供了分类结果的信息,矩阵的行代表实际的情况,列代表预测的情况。相互交叉的单元格表示分类性能的质量,每一个单元格代表了模型的一种分类情况。

下面是一个示例混淆矩阵:

[[46  2  4  0  2  0  0  0  1  0]
 [ 4 48  0  1  1  0  0  0  1  2]
 [ 1  0 47  2  0  0  0  0  0  0]
 [ 1  3  3 34  0  4  0  0  0  0]
 [ 0  2  1  1 44  0  0  0  0  2]
 [ 0  0  0  1  5 34  2  1  1  1]
 [ 0  0  0  0  0  0 50  0  0  0]
 [ 0  0  0  0  1  0  0 48  1  0]
 [ 0  1  0  0  2  0  0  1 46  0]
 [ 0  0  0  0  0  2  0  0  0 48]] 

在这个混淆矩阵中,每一列代表预测的情况(如预测数值为1的样本),每一行代表实际的情况(如真实数值为1的样本)。例如,第一行的第一列代表实际数值为1,预测为1的数量。在这个例子中,混淆矩阵的对角线表示正确预测的样本数量,非对角线表示被错误预测的样本数量。通过混淆矩阵,我们可以方便地获得分类器的准确率、召回率、F1分数等分类性能的指标。

2. 使用sklearn进行混淆矩阵的计算

在Python中,使用sklearn库可以方便地计算混淆矩阵,并获得分类性能的相关指标。假设我们有如下的真实值和预测值列表:

true_labels = [0, 1, 2, 0, 1, 2, 0, 1, 2]
pred_labels = [0, 1, 2, 0, 2, 1, 0, 1, 2]

我们可以使用sklearn库中的confusion_matrix()函数来计算混淆矩阵:

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(true_labels, pred_labels)
print(cm)

运行结果为:

[[3 0 0]
 [0 2 1]
 [1 0 2]]

sklearn中使用的混淆矩阵默认是按照行展示的,即每一行表示对于一个真实类别,预测结果是在各类别中的数量。例如,第一行第一列的值3表示样本真实为第一类,而预测结果也为第一类的样本数量为3个。同样地,第二行第三列的值1表示真实为第二类,而预测结果为第三类的样本数量为1个。

对于这个混淆矩阵,我们也可以通过sklearn计算出分类器的准确率、召回率、F1分数等性能指标。例如,计算准确率的代码为:

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(true_labels, pred_labels)
print(accuracy)

运行结果为:

0.7777777777777778

3. 使用matplotlib进行混淆矩阵的可视化

在获得混淆矩阵后,我们可以使用Python的matplotlib库进行可视化展示。使用matplotlib可以将混淆矩阵转换成彩色图表的形式,便于直观地评估分类器性能。

下面是一个使用matplotlib绘制混淆矩阵的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import plot_confusion_matrix

fig, ax = plt.subplots(figsize=(10, 10))
normalize = 'true'
disp = plot_confusion_matrix(clf, X_test, y_test, display_labels=class_names,cmap=plt.cm.Blues,normalize=normalize, ax=ax)
disp.ax_.set_title('Normalized confusion matrix')
plt.show()

这个代码段使用已经训练好的分类器clf在测试数据集X_test上进行预测,并计算生成混淆矩阵。使用plot_confusion_matrix()函数绘制混淆矩阵图表。函数的参数包括分类器、测试数据、测试标签、类别名称、颜色表、是否进行归一化、绘图的坐标轴等。

运行结果如下:

从图表中,我们可以看出分类器在不同类别上的表现情况。矩阵的对角线上的数字表示分类预测的正确,非对角线表示分类错误。我们可以直观地看出哪些类别分类结果不佳,有助于分析分类器的优缺点并进行改进。

三、总结

本文介绍了如何使用Python绘制混淆矩阵图表,包括对混淆矩阵进行计算和使用matplotlib进行可视化展示。混淆矩阵是评估分类器性能的重要指标之一,绘制混淆矩阵图表能够帮助我们直观地分析分类器的优劣,并优化分类器的性能。在实际应用中,我们可以使用Python进行混淆矩阵的绘制和分析,提高分类器的性能和应用效果。

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。