赞
踩
-
- from torch import nn
- import torch
- from torch.nn import functional as F
-
- class focal_loss(nn.Module):
- def __init__(self, alpha=0.25, gamma=2, num_classes=5, size_average=True):
-
- super(focal_loss, self).__init__()
- self.size_average = size_average
- if isinstance(alpha, (float, int)): #仅仅设置第一类别的权重
- print("33afsd33")
- self.alpha = torch.zeros(num_classes)
- print(self.alpha, alpha)
- self.alpha[0] += alpha
- print(self.alpha[0])
- self.alpha[1:] += (1 - alpha)
- print(self.alpha[1:])
- if isinstance(alpha, list): #全部权重自己设置
- print("333d3")
- self.alpha = torch.Tensor(alpha)
- self.gamma = gamma
-
-
- def forward(self, inputs, targets):
- alpha = self.alpha
- N = inputs.size(0)
- C = inputs.size(1)
- # 下面这些只是为了获取四个样本的概率probs
- P = F.softmax(inputs,dim=1)
- # ---------one hot start--------------#
- class_mask = inputs.data.new(N, C).fill_(0) # 生成和input一样shape的tensor
- class_mask = class_mask.requires_grad_() # 需要更新, 所以加入梯度计算
- ids = targets.view(-1, 1) # 取得目标的索引
- class_mask.data.scatter_(1, ids.data, 1.) # 利用scatter将索引丢给mask
- # ---------one hot end-------------------#
- probs = (P * class_mask).sum(1).view(-1, 1)
- print('留下targets的概率(1的部分),0的部分消除\n', probs)
- # 将softmax * one_hot 格式,0的部分被消除 留下1的概率, shape = (5, 1), 5就是每个target的概率
- #
-
-
-
- # 上面那些不需要管,重点看下面的focal loss公式;其实魔改自己多分类的,就是这里加上
- log_p = probs.log()
- # 取得对数
- print("1 - probs",1 - probs)
- loss = torch.pow((1 - probs), self.gamma) * log_p
- print("loss", loss)
- batch_loss = -alpha *loss.t() # 對應下面公式
- print('每一个batch的loss\n', batch_loss)
- # batch_loss就是取每一个batch的loss值
-
- # 最终将每一个batch的loss加总后平均
- if self.size_average:
- loss = batch_loss.mean()
- else:
- loss = batch_loss.sum()
- print('loss值为\n', loss)
- return loss
-
- #多分类 五类数据,第一类少样本数据,a= 0.25,其他都是0.75;
-
- torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
- input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
- # print('input值为\n', input)
- targets = torch.randint(5, (5, ))
- print('targets值为\n', targets)
-
- criterion = focal_loss()
- loss = criterion(input, targets)
- loss.backward()
- # 针对多分类任务的 CELoss 和 Focal Loss
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- class CELoss(nn.Module):
- def __init__(self, class_num, alpha=None, use_alpha=False, size_average=True):
- super(CELoss, self).__init__()
- self.class_num = class_num
- self.alpha = alpha
- if use_alpha:
- self.alpha = torch.tensor(alpha).cuda()
-
- self.softmax = nn.Softmax(dim=1)
- self.use_alpha = use_alpha
- self.size_average = size_average
-
- def forward(self, pred, target):
- prob = self.softmax(pred.view(-1,self.class_num))
- prob = prob.clamp(min=0.0001,max=1.0)
-
- target_ = torch.zeros(target.size(0),self.class_num).cuda()
- target_.scatter_(1, target.view(-1, 1).long(), 1.)
-
- if self.use_alpha:
- batch_loss = - self.alpha.double() * prob.log().double() * target_.double()
- else:
- batch_loss = - prob.log().double() * target_.double()
-
- batch_loss = batch_loss.sum(dim=1)
-
- # print(prob[0],target[0],target_[0],batch_loss[0])
- # print('--')
-
- if self.size_average:
- loss = batch_loss.mean()
- else:
- loss = batch_loss.sum()
-
- return loss
-
- class FocalLoss(nn.Module):
- def __init__(self, class_num, alpha=None, gamma=2, use_alpha=False, size_average=True):
- super(FocalLoss, self).__init__()
- self.class_num = class_num
- self.alpha = alpha
- self.gamma = gamma
- if use_alpha:
- self.alpha = torch.tensor(alpha).cuda()
-
- self.softmax = nn.Softmax(dim=1)
- self.use_alpha = use_alpha
- self.size_average = size_average
-
- def forward(self, pred, target):
- prob = self.softmax(pred.view(-1,self.class_num))
- prob = prob.clamp(min=0.0001,max=1.0)
-
- target_ = torch.zeros(target.size(0),self.class_num).cuda()
- target_.scatter_(1, target.view(-1, 1).long(), 1.)
-
- if self.use_alpha:
- batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
- else:
- batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
-
- batch_loss = batch_loss.sum(dim=1)
-
- if self.size_average:
- loss = batch_loss.mean()
- else:
- loss = batch_loss.sum()
-
- return loss
-
-
- torch.manual_seed(50) #随机种子确保每次input tensor值是一样的
- input = torch.randn(5, 5, dtype=torch.float32, requires_grad=True)
- # print('input值为\n', input)
- targets = torch.randint(5, (5, ))
- print('targets值为\n', targets)
-
- criterion = FocalLoss()
- loss = criterion(input, targets)
- loss.backward()
-
-
- FL中伽马等于0,就是CE
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。