当前位置:   article > 正文

多分类focal loss代码实现,基于正常loss函数获取多分类概率后,用文中几行代码公式,计算得到新的loss_focal loss的多分类代码

focal loss的多分类代码
  1. from torch import nn
  2. import torch
  3. from torch.nn import functional as F
  4. class focal_loss(nn.Module):
  5. def __init__(self, alpha=0.25, gamma=2, num_classes=5, size_average=True):
  6. super(focal_loss, self).__init__()
  7. self.size_average = size_average
  8. if isinstance(alpha, (float, int)): #仅仅设置第一类别的权重
  9. print("33afsd33")
  10. self.alpha = torch.zeros(num_classes)
  11. print(self.alpha, alpha)
  12. self.alpha[0] += alpha
  13. print(self.alpha[0])
  14. self.alpha[1:] += (1 - alpha)
  15. print(self.alpha[1:])
  16. if isinstance(alpha, list): #全部权重自己设置
  17. print("333d3")
  18. self.alpha = torch.Tensor(alpha)
  19. self.gamma = gamma
  20. def forward(self, inputs, targets):
  21. alpha = self.alpha
  22. N = inputs.size(0)
  23. C = inputs.size(1)
  24. # 下面这些只是为了获取四个样本的概率probs
  25. P = F.softmax(inputs,dim=1)
  26. # ---------one hot start--------------#
  27. class_mask = inputs.data.new(N, C).fill_(0) # 生成和input一样shape的tensor
  28. class_mask = class_mask.requires_grad_() # 需要更新, 所以加入梯度计算
  29. ids = targets.view(-1, 1) # 取得目标的索引
  30. class_mask.data.scatter_(1, ids.data, 1.) # 利用scatter将索引丢给mask
  31. # ---------one hot end-------------------#
  32. probs = (P * class_mask).sum(1).view(-1, 1)
  33. print('留下targets的概率(1的部分),0的部分消除\n', probs)
  34. # 将softmax * one_hot 格式,0的部分被消除 留下1的概率, shape = (5, 1), 5就是每个target的概率
  35. #
  36. # 上面那些不需要管,重点看下面的focal loss公式;其实魔改自己多分类的,就是这里加上
  37. log_p = probs.log()
  38. # 取得对数
  39. print("1 - probs",1 - probs)
  40. loss = torch.pow((1 - probs), self.gamma) * log_p
  41. print("loss", loss)
  42. batch_loss = -alpha *loss.t() # 對應下面公式
  43. print('每一个batch的loss\n', batch_loss)
  44. # batch_loss就是取每一个batch的loss值
  45. # 最终将每一个batch的loss加总后平均
  46. if self.size_average:
  47. loss = batch_loss.mean()
  48. else:
  49. loss = batch_loss.sum()
  50. print('loss值为\n', loss)
  51. return loss
  52. #多分类 五类数据,第一类少样本数据,a= 0.25,其他都是0.75
  53. torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
  54. input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
  55. # print('input值为\n', input)
  56. targets = torch.randint(5, (5, ))
  57. print('targets值为\n', targets)
  58. criterion = focal_loss()
  59. loss = criterion(input, targets)
  60. loss.backward()
  1. # 针对多分类任务的 CELoss 和 Focal Loss
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class CELoss(nn.Module):
  6. def __init__(self, class_num, alpha=None, use_alpha=False, size_average=True):
  7. super(CELoss, self).__init__()
  8. self.class_num = class_num
  9. self.alpha = alpha
  10. if use_alpha:
  11. self.alpha = torch.tensor(alpha).cuda()
  12. self.softmax = nn.Softmax(dim=1)
  13. self.use_alpha = use_alpha
  14. self.size_average = size_average
  15. def forward(self, pred, target):
  16. prob = self.softmax(pred.view(-1,self.class_num))
  17. prob = prob.clamp(min=0.0001,max=1.0)
  18. target_ = torch.zeros(target.size(0),self.class_num).cuda()
  19. target_.scatter_(1, target.view(-1, 1).long(), 1.)
  20. if self.use_alpha:
  21. batch_loss = - self.alpha.double() * prob.log().double() * target_.double()
  22. else:
  23. batch_loss = - prob.log().double() * target_.double()
  24. batch_loss = batch_loss.sum(dim=1)
  25. # print(prob[0],target[0],target_[0],batch_loss[0])
  26. # print('--')
  27. if self.size_average:
  28. loss = batch_loss.mean()
  29. else:
  30. loss = batch_loss.sum()
  31. return loss
  32. class FocalLoss(nn.Module):
  33. def __init__(self, class_num, alpha=None, gamma=2, use_alpha=False, size_average=True):
  34. super(FocalLoss, self).__init__()
  35. self.class_num = class_num
  36. self.alpha = alpha
  37. self.gamma = gamma
  38. if use_alpha:
  39. self.alpha = torch.tensor(alpha).cuda()
  40. self.softmax = nn.Softmax(dim=1)
  41. self.use_alpha = use_alpha
  42. self.size_average = size_average
  43. def forward(self, pred, target):
  44. prob = self.softmax(pred.view(-1,self.class_num))
  45. prob = prob.clamp(min=0.0001,max=1.0)
  46. target_ = torch.zeros(target.size(0),self.class_num).cuda()
  47. target_.scatter_(1, target.view(-1, 1).long(), 1.)
  48. if self.use_alpha:
  49. batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
  50. else:
  51. batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
  52. batch_loss = batch_loss.sum(dim=1)
  53. if self.size_average:
  54. loss = batch_loss.mean()
  55. else:
  56. loss = batch_loss.sum()
  57. return loss
  58. torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
  59. input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
  60. # print('input值为\n', input)
  61. targets = torch.randint(5, (5, ))
  62. print('targets值为\n', targets)
  63. criterion = FocalLoss()
  64. loss = criterion(input, targets)
  65. loss.backward()
  66. FL中伽马等于0,就是CE

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/114164
推荐阅读
相关标签
  

闽ICP备14008679号