IoUs, PA_Recall, Precision计算,来源于憨批的语义分割3——unet模型详解以及训练自己的unet模型(划分斑马线)_Bubbliiiing的博客-CSDN博客_憨批语义分割的github
import csv
import os
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import backend from PIL import Image def Iou_score(smooth = 1e-5, threhold = 0.5): def _Iou_score(y_true, y_pred): # score calculation y_pred = backend.greater(y_pred, threhold) y_pred = backend.cast(y_pred, backend.floatx()) intersection = backend.sum(y_true[...,:-1] * y_pred, axis=[0,1,2]) union = backend.sum(y_true[...,:-1] + y_pred, axis=[0,1,2]) - intersection score = (intersection + smooth) / (union + smooth) return score return _Iou_score def f_score(beta=1, smooth = 1e-5, threhold = 0.5): def _f_score(y_true, y_pred): y_pred = backend.greater(y_pred, threhold) y_pred = backend.cast(y_pred, backend.floatx()) tp = backend.sum(y_true[..., :-1] * y_pred, axis=[0,1,2]) fp = backend.sum(y_pred , axis=[0, 1, 2]) - tp fn = backend.sum(y_true[..., :-1], axis=[0, 1, 2]) - tp score = ((1 + beta ** 2) * tp + smooth) \ / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) return score return _f_score # 设标签宽W,长H def fast_hist(a, b, n): #--------------------------------------------------------------------------------# # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) #--------------------------------------------------------------------------------# k = (a >= 0) & (a < n) #--------------------------------------------------------------------------------# # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) # 返回中,写对角线上的为分类正确的像素点 #--------------------------------------------------------------------------------# return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) def per_class_iu(hist): return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) def per_class_PA_Recall(hist): return np.diag(hist) / np.maximum(hist.sum(1), 1) def per_class_Precision(hist): return np.diag(hist) / np.maximum(hist.sum(0), 1) def per_Accuracy(hist): return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): print('Num classes', num_classes) #-----------------------------------------# # 创建一个全是0的矩阵,是一个混淆矩阵 #-----------------------------------------# hist = np.zeros((num_classes, num_classes)) #------------------------------------------------# # 获得验证集标签路径列表,方便直接读取 # 获得验证集图像分割结果路径列表,方便直接读取 #------------------------------------------------# gt_imgs = [join(gt_dir, x+".png" ) for x in png_name_list] pred_imgs = [join(pred_dir, x+".png") for x in png_name_list] #------------------------------------------------# # 读取每一个(图片-标签)对 #------------------------------------------------# for ind in range(len(gt_imgs)): #------------------------------------------------# # 读取一张图像分割结果,转化成numpy数组 #------------------------------------------------# pred = np.array(Image.open(pred_imgs[ind])) #------------------------------------------------# # 读取一张对应的标签,转化成numpy数组 #------------------------------------------------# print(gt_imgs[ind]) label = np.array(Image.open(gt_imgs[ind])) #---------------------------------------------------------# #根据自己的标签修改 label[label==255]=1 label[label==200]=2 #---------------------------------------------------------# # 如果图像分割结果与标签的大小不一样,这张图片就不计算 if len(label.flatten()) != len(pred.flatten()): print( 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( len(label.flatten()), len(pred.flatten()), gt_imgs[ind], pred_imgs[ind])) continue #------------------------------------------------# # 对一张图片计算21×21的hist矩阵,并累加 #------------------------------------------------# hist += fast_hist(label.flatten(), pred.flatten(), num_classes) # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 if name_classes is not None and ind > 0 and ind % 10 == 0: print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( ind, len(gt_imgs), 100 * np.nanmean(per_class_iu(hist)), 100 * np.nanmean(per_class_PA_Recall(hist)), 100 * per_Accuracy(hist) ) ) #------------------------------------------------# # 计算所有验证集图片的逐类别mIoU值 #------------------------------------------------# IoUs = per_class_iu(hist) PA_Recall = per_class_PA_Recall(hist) Precision = per_class_Precision(hist) #------------------------------------------------# # 逐类别输出一下mIoU值 #------------------------------------------------# if name_classes is not None: for ind_class in range(num_classes): print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) #-----------------------------------------------------------------# # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 #-----------------------------------------------------------------# print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) return np.array(hist, np.int), IoUs, PA_Recall, Precision def adjust_axes(r, t, fig, axes): bb = t.get_window_extent(renderer=r) text_width_inches = bb.width / fig.dpi current_fig_width = fig.get_figwidth() new_fig_width = current_fig_width + text_width_inches propotion = new_fig_width / current_fig_width x_lim = axes.get_xlim() axes.set_xlim([x_lim[0], x_lim[1] * propotion]) def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): fig = plt.gcf() axes = plt.gca() plt.barh(range(len(values)), values, color='royalblue') plt.title(plot_title, fontsize=tick_font_size + 2) plt.xlabel(x_label, fontsize=tick_font_size) plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) r = fig.canvas.get_renderer() for i, val in enumerate(values): str_val = " " + str(val) if val < 1.0: str_val = " {0:.2f}".format(val) t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') if i == (len(values)-1): adjust_axes(r, t, fig, axes) fig.tight_layout() fig.savefig(output_path) if plt_show: plt.show() plt.close() def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \ os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \ os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: writer = csv.writer(f) writer_list = [] writer_list.append([' '] + [str(c) for c in name_classes]) for i in range(len(hist)): writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) writer.writerows(writer_list) print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 

使用方法,可绘制柱状图

miou_out_path为存放图片地址,
name_classes为自己要分割的类,包含背景,大小为需要识别的目标数加一,目标像素外的像素都为属于背景。
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes)  # 执行计算mIoU的函数
print("Get miou done.")
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)

 

原文地址:http://www.cnblogs.com/dark-blue/p/16863253.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性