首页 > 编程知识 正文

qpython,python实例

时间:2023-05-03 15:45:52 阅读:22434 作者:4859

研究了三天的多分类pr曲线问题终于在昨天晚上凌晨一点绘制成功了!!

试着记录一下在这里学到的东西。 一个是担心自己忘记了可以温情的事情。 另一个是,希望对抱有同样疑问的铁子们给予启发。

下图是为我画的宣传曲线。 精度超过97%,所以曲线饱和了。

首先了解一下二分类中的pr曲线是怎么画的?“p”是precition,是准确率,也是我们常用的准确率。

“r”是recall,是调查全率,也称为召回率。

上图为测试结果混淆矩阵,显示了一个数据集上的所有测试结果。

其中,竖列均为测试结果,即分类器预测概率大于0.5为正类,小于0.5为负类。横列表示groundtruth,即实际类别。

TP给出正确划分正例的数量;FN表示将正例弄错反例的数量;TN显示正确分类反例的数量;FP表示将反例弄错为正例的数量。

3358www.Sina.com/:p=TP/(TPFP ) http://www.Sina.com/:r=TP/(TPfn ) ) ) ) ) ) ) )。

准确率

我们只能根据预测结果求出“p”值和“r”值的对,因为将阈值默认设定为0.5,若大于0.5则为正例,反之亦然。

如果有飞机和大雁的写真集的话,我想从其中找到飞机的照片。

此时飞机为正例,大雁或其他为反例。

然后,计算所有图像被测试为飞机的概率,(召回率),并按照从大到小的顺序进行排序。

蓝色虚线表示我们将阈值设置为0.5时的分类情况,大于0.5的是被测试为飞机的概率,小于0.5的是未被测试为飞机的概率。

上面是对精确率和召回率的简单介绍,下面进入正题!

当然这里有两组概率,为飞机或者为大雁,这时我们不管为大雁的概率,只关注为飞机的概率

从1、0-1之间等间隔设置。 例如0,0.1,0.2,…,0.9,1.0。 这样可以得到10组“p”“r”的值。 当然也可以缩小间隔,得到更多的组“p”“r”的值。

2、按从小到大的顺序对所有样本的概率预测值进行排序,以此数列为阈值,计算“p' “r”值,可以得到更多组“p' “r”值。

当阈值变小时,更多样本会被测试成飞机,虚线下移。假设取极限,阈值为0,那么所有样本都会被预测为飞机,召回率最大,为1;而精确率为 5/10 等于0.5。同理,阈值变大,虚线上移,精确率会变高,但召回率反而变低。首先多分类方法不能绘制标准的pr曲线。

对于多分类问题,可以得到对应于每种类型的精度和召回率,从而多分类问题可以得到多组“p”、“r”值、(P1,R1 )、(P2,R2 )、…、(Pn,Rn )。

与此相对,计算平均值,得到3358www.Sina.com/(macro-p )和3358www.Sina.com/(macro-p )的平均精度和召回率的组。 从这里画的曲线是在设置阈值的时候,有两种方法:

因此,需要检查数据集的测试结果。 将测试集的ground-truth类、预测类和每个测试样本的预测概率保存在. txt文件中。 下图显示了. txt文件的部分数据。

20.0000.00001.00000.97480.02520.000010.00001.00000.000000.76290.237100.99990.00001.00000000 60.000110.00000.99970.000320.0000.0000.00000.0001.00000.00000.0000000000000000000099999990000000000999990000000000000000000000000000002 上图自左向右分别为绘制多分类的pr曲线"宏精确率”“宏召回率”(三

这是用于提取. txt文件的引用代码。

提取#.txt文件的参照代码clses是分类标签列表,preds是预测结果列表,pred_score是预测得分。 打印(sa

ving files to txt....")with open("pr_curve.txt", 'w') as pr: for i in range(len(clses)): pr.write(str(clses[i]) + " " + str(preds[i]) + " " + str(format(pred_score[i][0], '.4f')) + " " + str(format(pred_score[i][1], '.4f')) + " " + str(format(pred_score[i][2], '.4f')) + "n")print("All files have been written!")

下面是计算“宏pr值”以及绘制pr曲线的代码(含注释):

import numpy as npimport matplotlib.pyplot as pltscore_path = "./pr_curve.txt" # 文件路径with open(score_path, 'r') as f: files = f.readlines() # 读取文件lis_all = []for file in files: _, _, s1, s2, s3 = file.strip().split(" ") lis_all.append(s1) lis_all.append(s2) lis_all.append(s3)lis_order = sorted(set(lis_all)) # 记录所有得分情况,并去重从小到大排序,寻找各个阈值点micro_precis = []micro_recall = []for i in lis_order: true_p0 = 0 # 真阳 true_n0 = 0 # 真阴 false_p0 = 0 # 假阳 false_n0 = 0 # 假阴 true_p1 = 0 true_n1 = 0 false_p1 = 0 false_n1 = 0 true_p2 = 0 true_n2 = 0 false_p2 = 0 false_n2 = 0 for file in files: cls, pd, n0, n1, n2 = file.strip().split(" ") # 分别计算比较各个类别的得分,分开计算,各自为二分类, # 最后求平均,得出宏pr if float(n0) >= float(i) and cls == '0': # 遍历所有样本,第0类为正样本,其他类为负样本, true_p0 = true_p0 + 1 # 大于等于阈值,并且真实为正样本,即为真阳, elif float(n0) >= float(i) and cls != '0': # 大于等于阈值,真实为负样本,即为假阳; false_p0 = false_p0 + 1 # 小于阈值,真实为正样本,即为假阴 elif float(n0) < float(i) and cls == '0': false_n0 = false_n0 + 1 if float(n1) >= float(i) and cls == '1': # 遍历所有样本,第1类为正样本,其他类为负样本 true_p1 = true_p1 + 1 elif float(n1) >= float(i) and cls != '1': false_p1 = false_p1 + 1 elif float(n1) < float(i) and cls == '1': false_n1 = false_n1 + 1 if float(n2) >= float(i) and cls == '2': # 遍历所有样本,第2类为正样本,其他类为负样本 true_p2 = true_p2 + 1 elif float(n2) >= float(i) and cls != '2': false_p2 = false_p2 + 1 elif float(n2) < float(i) and cls == '2': false_n2 = false_n2 + 1 prec0 = (true_p0+0.00000000001) / (true_p0 + false_p0 + 0.000000000001) # 计算各类别的精确率,小数防止分母为0 prec1 = (true_p1+0.00000000001) / (true_p1 + false_p1 + 0.000000000001) prec2 = (true_p2+0.00000000001) / (true_p2 + false_p2 + 0.000000000001) recall0 = (true_p0+0.00000000001)/(true_p0+false_n0 + 0.000000000001) # 计算各类别的召回率,小数防止分母为0 recall1 = (true_p1+0.00000000001) / (true_p1 + false_n1+0.000000000001) recall2 = (true_p2+0.00000000001)/(true_p2+false_n2 + 0.00000000001) precision = (prec0 + prec1 + prec2)/3 recall = (recall0 + recall1 + recall2)/3 # 多分类求得平均精确度和平均召回率,即宏micro_pr micro_precis.append(precision) micro_recall.append(recall)micro_precis.append(1)micro_recall.append(0)print(micro_precis)print(micro_recall)x = np.array(micro_recall)y = np.array(micro_precis)plt.figure()plt.xlim([-0.01, 1.01])plt.ylim([-0.01, 1.01])plt.xlabel('recall')plt.ylabel('precision')plt.title('PR curve')plt.plot(x, y)plt.show()

代码是针对三分类写的,当然五分类多分类等,在代码里添加修改就可以了。

首先是了解pr曲线原理;
然后得到包含标签、预测类别和预测得分的.txt文件;
最后绘制pr曲线。

到这里多分类绘制pr曲线就介绍完毕了,早上写到了10点半,主要是想赶紧记下来,不然后面自己肯定又会懒惰了。

编写本文主要参考了:
周志华的《机器学习》西瓜书
https://sanchom.wordpress.com/tag/average-precision/
https://blog.csdn.net/hysteric314/article/details/54093734

日常学习记录,一起交流讨论吧!侵权联系~

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。