首页 > 编程知识 正文

混淆矩阵画,二维混淆矩阵

时间:2023-05-04 05:39:50 阅读:247931 作者:4986

画图代码如下:

#此脚本用于绘制混淆矩阵图from sklearn import metricsimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matriximport itertoolsdef plot_confusion_matrix(cm, target_names, plot_names, title='Confusion matrix', cmap='Blues',#这个地方设置混淆矩阵的颜色主题,这个主题看着就干净~ normalize=True): accuracy = np.trace(cm) / float(np.sum(cm)) misclass = 1 - accuracy if cmap is None: cmap = plt.get_cmap('Blues') plt.figure(figsize=(9, 7))#plt.figure() plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() if target_names is not None: tick_marks = np.arange(len(target_names)) plt.xticks(tick_marks, target_names, rotation=45) plt.yticks(tick_marks, target_names) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] thresh = cm.max() / 1.5 if normalize else cm.max() / 2 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): if normalize: plt.text(j, i, "{:0.4f}".format(cm[i, j]), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") else: plt.text(j, i, "{:,}".format(cm[i, j]), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label',size=15) plt.xlabel('Predicted labelnaccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass),size=15) plt.savefig(plot_names + '.png', format='png',bbox_inches = 'tight') plt.show()

计算混淆矩阵并批量画图

y_test2 = [np.argmax(one_hot) for one_hot in y_test]threshold = [0.5,0.6,0.7,0.8,0.9]for thr in threshold: predict = [1 if i > thr else 0 for i in one_hots[:,1]] conf_mat = confusion_matrix(y_test2,predict) plot_confusion_matrix(conf_mat, normalize=False,target_names=['0','1'],title='Confusion Matrix threshold:' + str(thr),plot_names = 'Confusion_Matrix_threshold' + str(thr))
Java中cas的实现原理是什么

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。