语义分割任务中如何处理label为255的标签"/>
语义分割任务中如何处理label为255的标签
语义分割常用数据集Cityscapes中会将不需要用到的像素标签设置为255,但初学者可能会遇到困惑,我们在训练或者评估的时候遇到255的标签该怎么办呢?我们需要做的是忽略。
训练计算loss时的处理
import torch
from torch import nnclass CrossEntropy2d(nn.Module):def __init__(self, ignore_label=255):super().__init__()self.ignore_label = ignore_labeldef forward(self, predict, target):""":param predict: [batch, num_class, height, width]:param target: [batch, height, width]:return: entropy loss"""target_mask = target != self.ignore_label # [batch, height, width]筛选出所有需要训练的像素点标签target = target[target_mask] # [num_pixels]batch, num_class, height, width = predict.size()predict = predict.permute(0, 2, 3, 1) # [batch, height, width, num_class]predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)loss = F.cross_entropy(predict, target)return loss
上面代码的核心就是通过索引将需要训练的像素点拿出来进行交叉熵损失的计算
评估计算Pixel accuracy 和Mean IoU
def eval_metrics(predict, target, ignore_label=255):# 预处理 将ignore label对应的像素点筛除target_mask = (target != ignore_label) # [batch, height, width]筛选出所有需要训练的像素点标签target = target[target_mask] # [num_pixels]batch, num_class, height, width = predict.size()predict = predict.permute(0, 2, 3, 1) # [batch, height, width, num_class]# 计算pixel accuracypredict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class)predict = predict.argmax(dim=1)num_pixels = target.numel()correct = (predict == target).sum()pixel_acc = correct / num_pixels# 计算所有类别的mIoUpredict = predict + 1target = target + 1intersection = predict * (predict == target).long()area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)area_label = torch.histc(target.float(), bins=num_class, max=num_class, min=1)mIoU = area_inter.mean() / (area_pred + area_label - area_inter).mean()return pixel_acc, mIoU
更多推荐
语义分割任务中如何处理label为255的标签
发布评论