当前位置:   article > 正文

PyTorch自定义损失函数

pytorch自定义损失函数

PyTorch自定义损失函数

1. 直接使用tensor提供的function接口和python内建的方法

import torch
import torch.nn as nn
import torch.nn.functional as func
class TripletLossFunc(nn.Module):
    def __init__(self, t1, t2, beta):
        super(TripletLossFunc, self).__init__()
        self.t1 = t1
        self.t2 = t2
        self.beta = beta
        return

    def forward(self, anchor, positive, negative):
        matched = torch.pow(func.pairwise_distance(anchor, positive), 2)
        mismatched = torch.pow(func.pairwise_distance(anchor, negative), 2)
        part_1 = torch.clamp(matched - mismatched, min=self.t1)
        part_2 = torch.clamp(matched, min=self.t2)
        dist_hinge = part_1 + self.beta * part_2
        loss = torch.mean(dist_hinge)
        return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

可以在__init__()函数里面定义,计算需要的超参数和损失函数里面可训练的参数(一般应该没有),这种方法适用于比较简单的损失函数,不需要为自己编写autograd的扩展。使用时也很方便,也就是:

a = TripletLossFunc(...)
loss = a(anchor, positive, negative)
  • 1
  • 2

2. 扩展Pytorch

如果计算的过程中需要使用Pytorch之外的算子的话,就需要对Pytorch进行扩展。

2.1 扩展autograd

在autograd上添加操作的话,需要为每一个操作编写一个新的Function子类。autograd通过调用Function函数来计算前向结果和反向梯度,因此每一个Function都需要实现两个方法forward()和backward()。

详情可以参考链接2,这里只记录一些注意事项:

  1. forward()过程中可以调用一些特殊的函数,比如save_for_backward()可以把需要的变量保存下来,以便于计算反向传播。
  2. forward()可以有多个返回值。
  3. backward()的输入会有和forward()输出一样多的Tensor类型的参数,分别代表计算图中关于输出的梯度(猜测是和forward()的return顺序是一致的),返回是和forward()的输入个数一样的Tensor,分别是计算图对输出的参数(注意多输入的时候需要链式求导法则:同一路相乘,不同路相加)

下面是官方给出的一个简单的例子:

# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

为了便于使用可以对apply进行重命名

linear = LinearFunction.apply
  • 1

在function之上,可以编写nn.module,常见的损失函数均为nn.model

2.2 使用numpy和scipy扩展

参考链接: Creating Extensions Using numpy and scipy — PyTorch Tutorials 1.9.0+cu102 documentation

注意此方法本质上还是扩展Function,因此需要自己实现反向传播,而且由于调用了numpy和scipy无法使用cuda加速。

2.3 编写cuda扩展

如果PyTorch并没有提供相应的算子,而且还需要cuda进行加速的话,需要自己编写cuda扩展

参考链接: Custom C++ and CUDA Extensions — PyTorch Tutorials 1.9.0+cu102 documentation

2.4 梯度检查

为了检查自己定义的梯度计算公式是否正确,pytorch提供了梯度检查函数torch.autograd.gradcheck(),原理是 f ′ ( x 0 ) ≈ ( f ( x 0 + e p s ) − f ( x 0 ) ) / e p s f'(x_0)\approx (f(x_0+eps) - f(x_0)) / eps f(x0)(f(x0+eps)f(x0))/eps即使用微小增量的函数差分对梯度进行估计,并与使用梯度计算公式计算所得结果进行比较,若误差在容忍度范围则返回true。

参考链接:torch.autograd.gradcheck — PyTorch 1.9.0 documentation

3 参考链接:

知乎三种方法

Extending PyTorch — PyTorch 1.9.0 documentation

Creating Extensions Using numpy and scipy — PyTorch Tutorials 1.9.0+cu102 documentation

Custom C++ and CUDA Extensions — PyTorch Tutorials 1.9.0+cu102 documentation

torch.autograd.gradcheck — PyTorch 1.9.0 documentation

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/321694
推荐阅读
相关标签
  

闽ICP备14008679号