首页 > 编程知识 正文

多分类混淆矩阵,混淆矩阵只有二分类有吗

时间:2023-05-03 16:44:53 阅读:247910 作者:3784

绘制混淆矩阵代码 # 绘制混淆矩阵def plotCM(matrix, classes): def plot_confusion_matrix(cm, labels,title='Confusion Matrix', cmap = plt.cm.Blues): # plt.figure(figsize=(950,950)) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() xlocations = np.array(range(len(labels))) plt.xticks(xlocations, labels, rotation=90) plt.yticks(xlocations, labels) plt.ylabel('True label') plt.xlabel('Predicted label') plt.savefig('./HAR_cm.png') plt.show() """classes: a list of class names""" # Normalize by row cm_normalized = matrix.astype('float')/matrix.sum(axis=1)[:, np.newaxis] print(cm_normalized.shape) # plot fig = plt.figure() ind_array = np.arange(len(classes)) x, y = np.meshgrid(ind_array, ind_array) print(x.shape) tick_marks = np.array(range(len(classes))) + 0.5 for x_val, y_val in zip(x.flatten(), y.flatten()): c = cm_normalized[y_val][x_val] if (c > 0.01): plt.text(x_val, y_val, "%0.2f" %(c,), color='red', fontsize=7, va='center', ha='center') #offset the tick plt.gca().set_xticks(tick_marks, minor=True) plt.gca().set_yticks(tick_marks, minor=True) plt.gca().xaxis.set_ticks_position('none') plt.gca().yaxis.set_ticks_position('none') plt.grid(True, which='minor', linestyle='-') plt.gcf().subplots_adjust(bottom=0.15) plot_confusion_matrix(cm_normalized, classes,title='Normalized confusion matrix') 调用代码 classes代表类别标签列表,注意要按照y_true中的顺序排列。label_id_name_dict 是一个字典,key是索引,value是标签名字 # 绘制混淆矩阵 # classes = list(infer.label_id_name_dict.values()) # classes = list(range(infer.num_classes)) classes = list(set(y_true)) classes.sort() classes_name = [infer.label_id_name_dict[c] for c in classes] cf_matrix = confusion_matrix(y_true,y_pred) plotCM(cf_matrix,classes_name) 效果图

PCA降维绘图 降低到3维 colors是颜色列表 x_train是特征向量(降维后)、y_train是特征标签(降维后)、class_names是标签列表(按照y_train的顺序排序)降低到2维就把Axes3D换成Axes2D即可 def plot_pca_scatter(x_train,y_train,class_names): print(x_train.shape) print(y_train.shape) colors = ['r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r', 'g', 'b', 'c', 'k', 'm', 'y', 'r'] ax = Axes3D(plt.figure()) # for c, i, target_name in zip(colors,list(range(len(class_names))), class_names): # plt.scatter(x_train[y_train==i, 0], x_train[y_train==i, 1], c=c, label=target_name) for c, i, target_name in zip(colors,list(range(len(class_names))), class_names): ax.scatter(x_train[y_train==i, 0], x_train[y_train==i, 1],x_train[y_train==i, 2], c=c, label=target_name) #设置每个坐标的取值范围 # plt.axis([-20,20,-20,20]) # plt.xlabel('Dimension1') # plt.ylabel('Dimension2') plt.title('data distribution') plt.legend() plt.show() 效果图

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。