【动手学深度学习】之非极大值抑制(NMS)代码实现

编程入门 行业动态 更新时间:2024-10-16 02:28:50

【动手学深度学习】之非<a href=https://www.elefans.com/category/jswz/34/1753253.html style=极大值抑制(NMS)代码实现"/>

【动手学深度学习】之非极大值抑制(NMS)代码实现

import torch
from d2l import torch as d2l# 更改打印设置
torch.set_printoptions(2)def show_bboxes(axes, bboxes, labels=None, colors=None):"""显⽰所有边界框"""def _make_list(obj, default_values=None):if obj is None:obj = default_valueselif not isinstance(obj, (list, tuple)):obj = [obj]return objlabels = _make_list(labels)colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])for i, bbox in enumerate(bboxes):color = colors[i % len(colors)]rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)axes.add_patch(rect)if labels and len(labels) > i:text_color = 'k' if color == 'w' else 'w'axes.text(rect.xy[0], rect.xy[1], labels[i],va='center', ha='center', fontsize=9, color=text_color,bbox=dict(facecolor=color, lw=0))def box_iou(boxes1, boxes2):"""计算两个锚框或边界框列表中成对的交并比"""box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *(boxes[:, 3] - boxes[:, 1]))# boxes1,boxes2,areas1,areas2的形状:# boxes1:(boxes1的数量,4),# boxes2:(boxes2的数量,4),# areas1:(boxes1的数量,),# areas2:(boxes2的数量,)# 两个锚框的面积areas1 = box_area(boxes1)areas2 = box_area(boxes2)# inter_upperlefts,inter_lowerrights,inters的形状:# (boxes1的数量,boxes2的数量,2)inter_upperlefts  = torch.max(boxes1[:, None, :2], boxes2[:, :2])inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)# 交集inter_areas = inters[:,:,0] * inters[:,:,1]# 并集union_areas = areas1[:, None] + areas2 - inter_areasreturn inter_areas / union_areasdef offset_inverse(anchors, offset_preds):"""根据带有预测偏移量的锚框来预测边界框"""anc = d2l.box_corner_to_center(anchors)pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)predicted_bbox = d2l.box_center_to_corner(pred_bbox)return predicted_bboxdef nms(boxes, scores, iou_threshold):""" 对预测边界框的置信度降序排列"""# argsort()函数默认将元素从小到大排列,提取其对应的索引输出。B= torch.argsort(scores, dim=-1, descending=True)# 保留预测边界框的指标keep = []while B.numel()>0:# 取B的第一个元素,也就是最大的预测概率i = B[0]keep.append(i)# 如果取完所有的类if B.numel() == 1: break# 计算最大预测概率相应的锚框与其他所有锚框的iouiou = box_iou(boxes[i,:].reshape(-1,4),boxes[B[1:],:].reshape(-1,4)).reshape(-1)# 找出iou中小于0.5的索引,其对应元素可能是另一类的物体# 大于0.5的iou一般是重复的锚框inds = torch.nonzero(iou <= iou_threshold).reshape(-1)# 将B截止到第二高预测边界框,(+1是因为boxes[B[1:],:]从第二个框开始)B = B[inds + 1]return torch.tensor(keep, device=boxes.device)def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
pos_threshold=0.009999999):"""NMS预测边界框"""# 找到数据位置(device)与概率的batch_sizedevice, batch_size = cls_probs.device, cls_probs.shape[0]anchors = anchors.squeeze(0)# 得到类别数与锚框数num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]out = []for i in range(batch_size):# 得到预测概率与偏移值cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)# 找到每一个锚框中最大的预测概率,并返回值与索引conf, class_id = torch.max(cls_prob[1:], 0)# 根据预测偏移得到预测边界框predicted_bb = offset_inverse(anchors, offset_pred)# 运用NMS算法keep = nms(predicted_bb, conf, nms_threshold)# 找到所有的non_keep索引,并将类设置为背景all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)# 将keep与新生成的all_idx组合起来combined = torch.cat((keep, all_idx))# 返回参数数组中所有不同的值,并从小到大排序# return_counts=True: 统计新列表元素中出现过的次数uniques, counts = combined.unique(return_counts=True)# 找出只出现过一次的元素索引non_keep = uniques[counts == 1]# 得到所有的锚框ID,前面是最大预测概率,后面是被抑制的锚框all_id_sorted = torch.cat((keep, non_keep))# 将被抑制的锚框标注在置信度索引中class_id[non_keep] = -1# 将class_id按照all_id_sorted顺序排序class_id = class_id[all_id_sorted]# 将conf, predicted_bb按照all_id_sorted顺序排序conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]# pos_threshold是⼀个⽤于⾮背景预测的阈值below_min_idx = (conf < pos_threshold)# 将背景的对应锚框ID改为-1class_id[below_min_idx] = -1# 将背景对应锚框取预测概率相反值conf[below_min_idx] = 1 - conf[below_min_idx]# 将锚框所有属性拼接起来# 第一个索引为预测的类索引# 第二个索引是预测边界框的置信度# 第三到第六个索引是锚框坐标pred_info = torch.cat((class_id.unsqueeze(1),conf.unsqueeze(1),predicted_bb), dim=1)# print(pred_info)out.append(pred_info)return torch.stack(out)anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],[0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])
offset_preds = torch.tensor([0] * anchors.numel())
cls_probs = torch.tensor([[0] * 4, # 背景的预测概率[0.9, 0.8, 0.7, 0.1], # 狗的预测概率[0.1, 0.2, 0.3, 0.9]]) # 猫的预测概率output = multibox_detection(cls_probs.unsqueeze(dim=0),offset_preds.unsqueeze(dim=0),anchors.unsqueeze(dim=0),nms_threshold=0.5)# 导入图片
img = d2l.plt.imread('catdog.jpg')
h, w = img.shape[:2]
d2l.set_figsize()
bbox_scale = torch.tensor((w, h, w, h))fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, anchors * bbox_scale,['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])fig = d2l.plt.imshow(img)
for i in output[0].detach().numpy():if i[0] == -1:continuelabel = ('dog=', 'cat=')[int(i[0])] + str(i[1])show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)

【展示】

 (自学李沐老师《动手学深度学习》使用,仅供参考,侵权删除)

更多推荐

【动手学深度学习】之非极大值抑制(NMS)代码实现

本文发布于:2024-03-10 17:22:24,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1728570.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:极大值   抑制   深度   代码   NMS

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!