yolov5增加iou loss(SIoU,EIoU,WIoU),无痛涨点trick

import math
import torchdef bbox_iou(box1,box2,xywh=True,GIoU=False,DIoU=False,CIoU=False,SIoU=False,EIoU=False,WIoU=False,Focal=False,alpha=1,gamma=0.5,scale=False,monotonous=False,eps=1e-7):"""计算bboxes iouArgs:box1: predict bboxesbox2: target bboxesxywh: 将bboxes转换为xyxy的形式GIoU: 为True时计算GIoU LOSS (yolov5自带)DIoU: 为True时计算DIoU LOSS (yolov5自带)CIoU: 为True时计算CIoU LOSS (yolov5自带,默认使用)SIoU: 为True时计算SIoU LOSS (新增)EIoU: 为True时计算EIoU LOSS (新增)WIoU: 为True时计算WIoU LOSS (新增)Focal: 为True时,可结合其他的XIoU生成对应的IoU变体,如CIoU=True,Focal=True时为Focal-CIoUalpha: AlphaIoU中的alpha参数,默认为1,为1时则为普通的IoU,如果想采用AlphaIoU,论文alpha默认值为3,此时设置CIoU=True则为AlphaCIoUgamma: Focal_XIoU中的gamma参数,默认为0.5scale: scale为True时,WIoU会乘以一个系数monotonous: 3个输入分别代表WIoU的3个版本,None: origin v1, True: monotonic FM v2, False: non-monotonic FM v3eps: 防止除0Returns:iou"""# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)# Get the coordinates of bounding boxesif xywh:  # transform from xywh to xyxy(x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_else:  # x1, y1, x2, y2 = box1b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)# Intersection areainter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \(b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)# Union Areaunion = w1 * h1 + w2 * h2 - inter + epsif scale:wise_scale = WIoU_Scale(1 - (inter / union), monotonous=monotonous)# IoU# iou = inter / union # ori iouiou = torch.pow(inter / (union + eps), alpha)  # alpha iouif CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) widthch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex heightif CIoU or DIoU or EIoU or SIoU or WIoU:  # Distance or Complete IoU .08287v1c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squaredrho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2if CIoU:  # .py#L47v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)with torch.no_grad():alpha_ciou = v / (v - iou + (1 + eps))if Focal:return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),gamma)  # Focal_CIoUreturn iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoUelif EIoU:rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2cw2 = torch.pow(cw ** 2 + eps, alpha)ch2 = torch.pow(ch ** 2 + eps, alpha)if Focal:return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps), gamma)  # Focal_EIoureturn iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)  # EIouelif SIoU:# SIoU Loss .12740.pdfs_cw, s_ch = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps, (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + epssigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)sin_alpha_1, sin_alpha_2 = torch.abs(s_cw) / sigma, torch.abs(s_ch) / sigmathreshold = pow(2, 0.5) / 2sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)rho_x, rho_y = (s_cw / cw) ** 2, (s_ch / ch) ** 2gamma = angle_cost - 2distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)omiga_w, omiga_h = torch.abs(w1 - w2) / torch.max(w1, w2), torch.abs(h1 - h2) / torch.max(h1, h2)shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)if Focal:return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter / (union + eps), gamma)  # Focal_SIoureturn iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha)  # SIouelif WIoU:if scale:return getattr(WIoU_Scale, '_scaled_loss')(wise_scale), (1 - iou) * torch.exp((rho2 / c2)), iou  # WIoU v3 .10051return iou, torch.exp((rho2 / c2))  # WIoU v1if Focal:return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma)  # Focal_DIoUreturn iou - rho2 / c2  # DIoUc_area = cw * ch + eps  # convex areaif Focal:return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps), gamma)  # Focal_GIoU .09630.pdfreturn iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU .09630.pdfif Focal:return iou, torch.pow(inter / (union + eps), gamma)  # Focal_IoUreturn iou  # IoUclass WIoU_Scale:"""monotonous: {None: origin v1True: monotonic FM v2False: non-monotonic FM v3}momentum: The momentum of running mean"""iou_mean = 1._momentum = 1 - pow(0.5, exp=1 / 7000)_is_train = Truedef __init__(self, iou, monotonous=False):self.iou = iouself.monotonous = monotonousself._update(self)@classmethoddef _update(cls, self):if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \cls._momentum * self.iou.detach().mean().item()@classmethoddef _scaled_loss(cls, self, gamma=1.9, delta=3):if isinstance(self.monotonous, bool):if self.monotonous:return (self.iou.detach() / self.iou_mean).sqrt()else:beta = self.iou.detach() / self.iou_meanalpha = delta * torch.pow(gamma, beta - delta)return beta / alphareturn 1


class ComputeLoss:sort_obj_iou = False# Compute lossesdef __init__(self, model, autobalance=False):device = next(model.parameters()).device  # get model deviceh = model.hyp  # hyperparameters# Define criteriaBCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))# Class label smoothing .04103.pdf eqn 3self.cp, self = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets# Focal lossg = h['fl_gamma']  # focal loss gammaif g > 0:BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)m = de_parallel(model).model[-1]  # Detect() moduleself.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7self.ssi = list(m.stride).index(16) if autobalance else 0  # stride 16 indexself.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalanceself.na = m.na  # number of anchorsself.nc = m.nc  # number of classesself.nl = m.nl  # number of layersself.anchors = m.anchorsself.device = devicedef __call__(self, p, targets):  # predictions, targetslcls = torch.zeros(1, device=self.device)  # class losslbox = torch.zeros(1, device=self.device)  # box losslobj = torch.zeros(1, device=self.device)  # object losstcls, tbox, indices, anchors = self.build_targets(p, targets)  # targets# Lossesfor i, pi in enumerate(p):  # layer index, layer predictionsb, a, gj, gi = indices[i]  # image, anchor, gridy, gridxtobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)  # target objn = b.shape[0]  # number of targetsif n:# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1)  # faster, requires torch 1.8.0pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1)  # target-subset of predictions# Regressionpxy = pxy.sigmoid() * 2 - 0.5pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]pbox = torch.cat((pxy, pwh), 1)  # predicted box# iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target)# lbox += (1.0 - iou).mean()  # iou loss# //iou = bbox_iou(pbox, tbox[i], WIoU=True, scale=True)if isinstance(iou, tuple):if len(iou) == 2:lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()iou = iou[0].squeeze()else:lbox += (iou[0] * iou[1]).mean()iou = iou[2].squeeze()else:lbox += (1.0 - iou.squeeze()).mean()  # iou lossiou = iou.squeeze()# /# Objectnessiou = iou.detach().clamp(0).type(tobj.dtype)if self.sort_obj_iou:j = iou.argsort()b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]if self.gr < 1:iou = (1.0 - self.gr) + self.gr * ioutobj[b, a, gj, gi] = iou  # iou ratio# Classificationif self.nc > 1:  # cls loss (only if multiple classes)t = torch.full_like(pcls, self, device=self.device)  # targetst[range(n), tcls[i]] = self.cplcls += self.BCEcls(pcls, t)  # BCE# Append targets to text file# with open('targets.txt', 'a') as file:#     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]obji = self.BCEobj(pi[..., 4], tobj)lobj += obji * self.balance[i]  # obj lossif self.autobalance:self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()if self.autobalance:self.balance = [x / self.balance[self.ssi] for x in self.balance]lbox *= self.hyp['box']lobj *= self.hyp['obj']lcls *= self.hyp['cls']bs = tobj.shape[0]  # batch sizereturn (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()def build_targets(self, p, targets):# Build targets for compute_loss(), input targets(image,class,x,y,w,h)na, nt = self.na, targets.shape[0]  # number of anchors, targetstcls, tbox, indices, anch = [], [], [], []gain = torch.ones(7, device=self.device)  # normalized to gridspace gainai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)  # append anchor indicesg = 0.5  # biasoff = torch.tensor([[0, 0],[1, 0],[0, 1],[-1, 0],[0, -1],  # j,k,l,m# [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm],device=self.device).float() * g  # offsetsfor i in range(self.nl):anchors, shape = self.anchors[i], p[i].shapegain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain# Match targets to anchorst = targets * gain  # shape(3,n,7)if nt:# Matchesr = t[..., 4:6] / anchors[:, None]  # wh ratioj = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))t = t[j]  # filter# Offsetsgxy = t[:, 2:4]  # grid xygxi = gain[[2, 3]] - gxy  # inversej, k = ((gxy % 1 < g) & (gxy > 1)).Tl, m = ((gxi % 1 < g) & (gxi > 1)).Tj = torch.stack((torch.ones_like(j), j, k, l, m))t = t.repeat((5, 1, 1))[j]offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]else:t = targets[0]offsets = 0# Definebc, gxy, gwh, a = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchorsa, (b, c) = a.long().view(-1), bc.long().T  # anchors, image, classgij = (gxy - offsets).long()gi, gj = gij.T  # grid indices# Appendindices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, gridtbox.append(torch.cat((gxy - gij, gwh), 1))  # boxanch.append(anchors[a])  # anchorstcls.append(c)  # classreturn tcls, tbox, indices, anch


# from utils.metrics import bbox_iou
from utils.iou import bbox_iou

         如果需要应用对应的IoU loss的变体,即可将Focal设置为True,并将对应的IoU也设置为True,如CIoU=True,Focal=True时为Focal-CIoU,此时可以调整gamma,默认设置为0.5。


        更新WIoU,monotonous有3个输入分别代表WIoU的3个版本,None: origin v1, True: monotonic FM v2, False: non-monotonic FM v3,同时需要设置scale,scale为True时,WIoU会乘以一个系数,结合monotonous即会对应WIoU的3个版本。



