当前位置:   article > 正文

总结下LovaszSoftmax损失函数(pytorch版)_lovasz-softmax loss

lovasz-softmax loss

闹心,希望四个月后一切顺利!!!!!!!!!!!!!!!

0.开始是想把自己遇到的损失函数都总结一下,结果第一个unet++的损失函数就卡住了。发现自己理解的还远远不够啊。那就先总结它吧。

1.这个算法出自论文《The Lovasz-Softmax loss: A tractable surrogate for the optimization of the ´ intersection-over-union measure in neural networks》粗看就是IOU方法的一个优化方法

先是提出了loss的最基础形式:公式3 和4 

                                    

然后说这个有啥啥啥问题。将它变个形:就是公式5  6 

                                         

变完了形还是不满意,再转成具有凸解形式: 就是公式8  9 

                                      

 8就是最终变形的loss函数形式。

然后专门讲了公式9的实现过程:

                             

 

里面的理论一大堆,没咋看懂。后来就直接看pytorch代码了。有些地方实现的过程还是懵懵懂懂啊。在网上找相关资料发现翻译的我能看明白的也基本木有啊,那贴上自己的理解吧。

2.贴代码加注释:

  1. """
  2. Lovasz-Softmax and Jaccard hinge loss in PyTorch
  3. Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
  4. """
  5. from __future__ import print_function, division
  6. import torch
  7. from torch.autograd import Variable
  8. import torch.nn.functional as F
  9. import numpy as np
  10. try:
  11. from itertools import ifilterfalse
  12. except ImportError: # py3k
  13. from itertools import filterfalse as ifilterfalse
  14. #函数输入是一个排过序的 标签组 越靠近前面的标签 表示这个像素点与真值的误差越大
  15. def lovasz_grad(gt_sorted):
  16. """
  17. Computes gradient of the Lovasz extension w.r.t sorted errors
  18. See Alg. 1 in paper
  19. """
  20. p = len(gt_sorted)
  21. print("p = ", p)
  22. print("gt_sorted = ", gt_sorted)
  23. gts = gt_sorted.sum()#求个和
  24. #gt_sorted.float().cumsum(0) 1维的 计算的是累加和 例如 【1 2 3 4 5】 做完后就是【1 3 6 10 15
  25. #这个intersection是用累加和的值按维度减 累加数组的值,目的是做啥呢 看字面是取交集
  26. intersection = gts - gt_sorted.float().cumsum(0) #对应论文Algorithm 1的第3
  27. union = gts + (1 - gt_sorted).float().cumsum(0) #对应论文Algorithm 1的第4
  28. jaccard = 1. - intersection / union #对应论文Algorithm 1的第5
  29. if p > 1: # cover 1-pixel case
  30. jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]#对应论文Algorithm 1的第7
  31. return jaccard
  32. def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
  33. """
  34. IoU for foreground class
  35. binary: 1 foreground, 0 background
  36. """
  37. if not per_image:
  38. preds, labels = (preds,), (labels,)
  39. ious = []
  40. for pred, label in zip(preds, labels):
  41. intersection = ((label == 1) & (pred == 1)).sum()
  42. union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
  43. if not union:
  44. iou = EMPTY
  45. else:
  46. iou = float(intersection) / float(union)
  47. ious.append(iou)
  48. iou = mean(ious) # mean accross images if per_image
  49. return 100 * iou
  50. def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
  51. """
  52. Array of IoU for each (non ignored) class
  53. """
  54. if not per_image:
  55. preds, labels = (preds,), (labels,)
  56. ious = []
  57. for pred, label in zip(preds, labels):
  58. iou = []
  59. for i in range(C):
  60. if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
  61. intersection = ((label == i) & (pred == i)).sum()
  62. union = ((label == i) | ((pred == i) & (label != ignore))).sum()
  63. if not union:
  64. iou.append(EMPTY)
  65. else:
  66. iou.append(float(intersection) / float(union))
  67. ious.append(iou)
  68. ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
  69. return 100 * np.array(ious)
  70. # --------------------------- BINARY LOSSES ---------------------------
  71. def lovasz_hinge(logits, labels, per_image=True, ignore=None):
  72. """
  73. Binary Lovasz hinge loss
  74. logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
  75. labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
  76. per_image: compute the loss per image instead of per batch
  77. ignore: void class id
  78. """
  79. if per_image:
  80. loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
  81. for log, lab in zip(logits, labels))
  82. else:
  83. loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
  84. return loss
  85. def lovasz_hinge_flat(logits, labels):
  86. """
  87. Binary Lovasz hinge loss
  88. logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
  89. labels: [P] Tensor, binary ground truth labels (0 or 1)
  90. ignore: label to ignore
  91. """
  92. if len(labels) == 0:
  93. # only void pixels, the gradients should be 0
  94. return logits.sum() * 0.
  95. signs = 2. * labels.float() - 1.
  96. errors = (1. - logits * Variable(signs))
  97. errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
  98. perm = perm.data
  99. gt_sorted = labels[perm]
  100. grad = lovasz_grad(gt_sorted)
  101. loss = torch.dot(F.relu(errors_sorted), Variable(grad))
  102. return loss
  103. def flatten_binary_scores(scores, labels, ignore=None):
  104. """
  105. Flattens predictions in the batch (binary case)
  106. Remove labels equal to 'ignore'
  107. """
  108. scores = scores.view(-1)
  109. labels = labels.view(-1)
  110. if ignore is None:
  111. return scores, labels
  112. valid = (labels != ignore)
  113. vscores = scores[valid]
  114. vlabels = labels[valid]
  115. return vscores, vlabels
  116. class StableBCELoss(torch.nn.modules.Module):
  117. def __init__(self):
  118. super(StableBCELoss, self).__init__()
  119. def forward(self, input, target):
  120. neg_abs = - input.abs()
  121. loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
  122. return loss.mean()
  123. def binary_xloss(logits, labels, ignore=None):
  124. """
  125. Binary Cross entropy loss
  126. logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
  127. labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
  128. ignore: void class id
  129. """
  130. logits, labels = flatten_binary_scores(logits, labels, ignore)
  131. loss = StableBCELoss()(logits, Variable(labels.float()))
  132. return loss
  133. # --------------------------- MULTICLASS LOSSES ---------------------------
  134. def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
  135. """
  136. Multi-class Lovasz-Softmax loss
  137. probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
  138. Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
  139. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
  140. classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
  141. per_image: compute the loss per image instead of per batch
  142. ignore: void class labels
  143. """
  144. print("probas.shape = ", probas.shape)
  145. if per_image:
  146. loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
  147. for prob, lab in zip(probas, labels))
  148. else:
  149. #lovasz_softmax_flat的输入就是probas 【262144 2】 labels【262144
  150. loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
  151. return loss
  152. #这个函数是计算损失函数的部位
  153. def lovasz_softmax_flat(probas, labels, classes='present'):
  154. """
  155. Multi-class Lovasz-Softmax loss
  156. probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
  157. labels: [P] Tensor, ground truth labels (between 0 and C - 1)
  158. classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
  159. """
  160. #预测像素点个数,一张512*512的图
  161. if probas.numel() == 0:#返回数组中元素的个数
  162. # only void pixels, the gradients should be 0
  163. return probas * 0.
  164. C = probas.size(1)#获得通道数呗 就是预测几类
  165. losses = []
  166. #class_to_sum = [0 1] 类的种类总数 用list存储
  167. class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
  168. for c in class_to_sum:
  169. fg = (labels == c).float() # foreground for class c 如果语义标注数据与符合第c类,fg中存储1.0样数据
  170. if (classes is 'present' and fg.sum() == 0):
  171. continue
  172. if C == 1:
  173. if len(classes) > 1:
  174. raise ValueError('Sigmoid output possible only with 1 class')
  175. class_pred = probas[:, 0]
  176. else:
  177. class_pred = probas[:, c]#取出第c类预测值 是介于 0~1之间的float数
  178. #errors 是预测结果与标签结果差的绝对值
  179. errors = (Variable(fg) - class_pred).abs()
  180. #对误差排序 从大到小排 perm是下标值 errors_sorted 是排序后的预测值
  181. errors_sorted, perm = torch.sort(errors, 0, descending=True)
  182. perm = perm.data
  183. #排序后的标签值
  184. fg_sorted = fg[perm]
  185. losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
  186. return mean(losses)
  187. def flatten_probas(probas, labels, ignore=None):
  188. """
  189. Flattens predictions in the batch
  190. """
  191. #在这维度为probas 【1 2 512 512】 labels维度为【1 1 512 512
  192. if probas.dim() == 3:#dim()数组维度
  193. # assumes output of a sigmoid layer
  194. B, H, W = probas.size()
  195. probas = probas.view(B, 1, H, W)
  196. B, C, H, W = probas.size()#数组维度
  197. #维度交换并变形 将probas.permute(0, 2, 3, 1)变换后的前3维合并成1维,通道不变
  198. probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
  199. #
  200. labels = labels.view(-1)
  201. #我的代码是用默认值 直接返回了 probas labels 两个压缩完事的东西
  202. #在这维度为probas 【262144 2】 labels维度为【262144
  203. if ignore is None:
  204. return probas, labels
  205. valid = (labels != ignore)
  206. vprobas = probas[valid.nonzero().squeeze()]
  207. vlabels = labels[valid]
  208. return vprobas, vlabels
  209. def xloss(logits, labels, ignore=None):
  210. """
  211. Cross entropy loss
  212. """
  213. return F.cross_entropy(logits, Variable(labels), ignore_index=255)
  214. # --------------------------- HELPER FUNCTIONS ---------------------------
  215. def isnan(x):
  216. return x != x
  217. def mean(l, ignore_nan=False, empty=0):
  218. """
  219. nanmean compatible with generators.
  220. """
  221. l = iter(l)
  222. if ignore_nan:
  223. l = ifilterfalse(isnan, l)
  224. try:
  225. n = 1
  226. acc = next(l)
  227. except StopIteration:
  228. if empty == 'raise':
  229. raise ValueError('Empty mean')
  230. return empty
  231. for n, v in enumerate(l, 2):
  232. acc += v
  233. if n == 1:
  234. return acc
  235. return acc / n

3.论文中的某些地方与代码的对应。 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/461306
推荐阅读
相关标签
  

闽ICP备14008679号