赞
踩
闹心,希望四个月后一切顺利!!!!!!!!!!!!!!!
0.开始是想把自己遇到的损失函数都总结一下,结果第一个unet++的损失函数就卡住了。发现自己理解的还远远不够啊。那就先总结它吧。
1.这个算法出自论文《The Lovasz-Softmax loss: A tractable surrogate for the optimization of the ´ intersection-over-union measure in neural networks》。粗看就是IOU方法的一个优化方法。
先是提出了loss的最基础形式:公式3 和4
然后说这个有啥啥啥问题。将它变个形:就是公式5 6
变完了形还是不满意,再转成具有凸解形式: 就是公式8 9
8就是最终变形的loss函数形式。
然后专门讲了公式9的实现过程:
里面的理论一大堆,没咋看懂。后来就直接看pytorch代码了。有些地方实现的过程还是懵懵懂懂啊。在网上找相关资料发现翻译的我能看明白的也基本木有啊,那贴上自己的理解吧。
2.贴代码加注释:
- """
- Lovasz-Softmax and Jaccard hinge loss in PyTorch
- Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
- """
-
- from __future__ import print_function, division
-
- import torch
- from torch.autograd import Variable
- import torch.nn.functional as F
- import numpy as np
- try:
- from itertools import ifilterfalse
- except ImportError: # py3k
- from itertools import filterfalse as ifilterfalse
-
-
- #函数输入是一个排过序的 标签组 越靠近前面的标签 表示这个像素点与真值的误差越大
- def lovasz_grad(gt_sorted):
- """
- Computes gradient of the Lovasz extension w.r.t sorted errors
- See Alg. 1 in paper
- """
- p = len(gt_sorted)
- print("p = ", p)
- print("gt_sorted = ", gt_sorted)
- gts = gt_sorted.sum()#求个和
- #gt_sorted.float().cumsum(0) 1维的 计算的是累加和 例如 【1 2 3 4 5】 做完后就是【1 3 6 10 15】
- #这个intersection是用累加和的值按维度减 累加数组的值,目的是做啥呢 看字面是取交集
- intersection = gts - gt_sorted.float().cumsum(0) #对应论文Algorithm 1的第3行
- union = gts + (1 - gt_sorted).float().cumsum(0) #对应论文Algorithm 1的第4行
- jaccard = 1. - intersection / union #对应论文Algorithm 1的第5行
- if p > 1: # cover 1-pixel case
- jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]#对应论文Algorithm 1的第7行
- return jaccard
-
-
- def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
- """
- IoU for foreground class
- binary: 1 foreground, 0 background
- """
- if not per_image:
- preds, labels = (preds,), (labels,)
- ious = []
- for pred, label in zip(preds, labels):
- intersection = ((label == 1) & (pred == 1)).sum()
- union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
- if not union:
- iou = EMPTY
- else:
- iou = float(intersection) / float(union)
- ious.append(iou)
- iou = mean(ious) # mean accross images if per_image
- return 100 * iou
-
-
- def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
- """
- Array of IoU for each (non ignored) class
- """
- if not per_image:
- preds, labels = (preds,), (labels,)
- ious = []
- for pred, label in zip(preds, labels):
- iou = []
- for i in range(C):
- if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
- intersection = ((label == i) & (pred == i)).sum()
- union = ((label == i) | ((pred == i) & (label != ignore))).sum()
- if not union:
- iou.append(EMPTY)
- else:
- iou.append(float(intersection) / float(union))
- ious.append(iou)
- ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
- return 100 * np.array(ious)
-
-
- # --------------------------- BINARY LOSSES ---------------------------
-
-
- def lovasz_hinge(logits, labels, per_image=True, ignore=None):
- """
- Binary Lovasz hinge loss
- logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
- labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
- per_image: compute the loss per image instead of per batch
- ignore: void class id
- """
- if per_image:
- loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
- for log, lab in zip(logits, labels))
- else:
- loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
- return loss
-
-
- def lovasz_hinge_flat(logits, labels):
- """
- Binary Lovasz hinge loss
- logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
- labels: [P] Tensor, binary ground truth labels (0 or 1)
- ignore: label to ignore
- """
- if len(labels) == 0:
- # only void pixels, the gradients should be 0
- return logits.sum() * 0.
- signs = 2. * labels.float() - 1.
- errors = (1. - logits * Variable(signs))
- errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
- perm = perm.data
- gt_sorted = labels[perm]
- grad = lovasz_grad(gt_sorted)
- loss = torch.dot(F.relu(errors_sorted), Variable(grad))
- return loss
-
-
- def flatten_binary_scores(scores, labels, ignore=None):
- """
- Flattens predictions in the batch (binary case)
- Remove labels equal to 'ignore'
- """
- scores = scores.view(-1)
- labels = labels.view(-1)
- if ignore is None:
- return scores, labels
- valid = (labels != ignore)
- vscores = scores[valid]
- vlabels = labels[valid]
- return vscores, vlabels
-
-
- class StableBCELoss(torch.nn.modules.Module):
- def __init__(self):
- super(StableBCELoss, self).__init__()
- def forward(self, input, target):
- neg_abs = - input.abs()
- loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
- return loss.mean()
-
-
- def binary_xloss(logits, labels, ignore=None):
- """
- Binary Cross entropy loss
- logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
- labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
- ignore: void class id
- """
- logits, labels = flatten_binary_scores(logits, labels, ignore)
- loss = StableBCELoss()(logits, Variable(labels.float()))
- return loss
-
-
- # --------------------------- MULTICLASS LOSSES ---------------------------
-
-
- def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
- """
- Multi-class Lovasz-Softmax loss
- probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
- Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
- labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
- classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
- per_image: compute the loss per image instead of per batch
- ignore: void class labels
- """
- print("probas.shape = ", probas.shape)
- if per_image:
- loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
- for prob, lab in zip(probas, labels))
- else:
- #lovasz_softmax_flat的输入就是probas 【262144 2】 labels【262144】
- loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
- return loss
-
-
- #这个函数是计算损失函数的部位
- def lovasz_softmax_flat(probas, labels, classes='present'):
- """
- Multi-class Lovasz-Softmax loss
- probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
- labels: [P] Tensor, ground truth labels (between 0 and C - 1)
- classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
- """
- #预测像素点个数,一张512*512的图
- if probas.numel() == 0:#返回数组中元素的个数
- # only void pixels, the gradients should be 0
- return probas * 0.
- C = probas.size(1)#获得通道数呗 就是预测几类
- losses = []
- #class_to_sum = [0 1] 类的种类总数 用list存储
- class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
- for c in class_to_sum:
- fg = (labels == c).float() # foreground for class c 如果语义标注数据与符合第c类,fg中存储1.0样数据
- if (classes is 'present' and fg.sum() == 0):
- continue
- if C == 1:
- if len(classes) > 1:
- raise ValueError('Sigmoid output possible only with 1 class')
- class_pred = probas[:, 0]
- else:
- class_pred = probas[:, c]#取出第c类预测值 是介于 0~1之间的float数
- #errors 是预测结果与标签结果差的绝对值
- errors = (Variable(fg) - class_pred).abs()
- #对误差排序 从大到小排 perm是下标值 errors_sorted 是排序后的预测值
- errors_sorted, perm = torch.sort(errors, 0, descending=True)
-
- perm = perm.data
- #排序后的标签值
- fg_sorted = fg[perm]
-
- losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
- return mean(losses)
-
-
- def flatten_probas(probas, labels, ignore=None):
- """
- Flattens predictions in the batch
- """
- #在这维度为probas 【1 2 512 512】 labels维度为【1 1 512 512】
- if probas.dim() == 3:#dim()数组维度
- # assumes output of a sigmoid layer
- B, H, W = probas.size()
- probas = probas.view(B, 1, H, W)
- B, C, H, W = probas.size()#数组维度
- #维度交换并变形 将probas.permute(0, 2, 3, 1)变换后的前3维合并成1维,通道不变
- probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
- #
- labels = labels.view(-1)
- #我的代码是用默认值 直接返回了 probas labels 两个压缩完事的东西
- #在这维度为probas 【262144 2】 labels维度为【262144】
-
- if ignore is None:
- return probas, labels
- valid = (labels != ignore)
- vprobas = probas[valid.nonzero().squeeze()]
- vlabels = labels[valid]
- return vprobas, vlabels
-
- def xloss(logits, labels, ignore=None):
- """
- Cross entropy loss
- """
- return F.cross_entropy(logits, Variable(labels), ignore_index=255)
-
-
- # --------------------------- HELPER FUNCTIONS ---------------------------
- def isnan(x):
- return x != x
-
-
- def mean(l, ignore_nan=False, empty=0):
- """
- nanmean compatible with generators.
- """
- l = iter(l)
- if ignore_nan:
- l = ifilterfalse(isnan, l)
- try:
- n = 1
- acc = next(l)
- except StopIteration:
- if empty == 'raise':
- raise ValueError('Empty mean')
- return empty
- for n, v in enumerate(l, 2):
- acc += v
- if n == 1:
- return acc
- return acc / n
3.论文中的某些地方与代码的对应。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。