当前位置:   article > 正文

MSELoss损失函数_mseloss函数

mseloss函数

参考  MSELoss损失函数 - 云+社区 - 腾讯云

MSELoss损失函数中文名字就是:均方损失函数,公式如下所示:

这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标。

很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数。因为一般损失函数都是直接计算 batch 的数据,因此返回的 loss 结果都是维度为 (batch_size, ) 的向量。

一般的使用格式如下所示:

loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)

 这里注意一下两个入参:

  A reduce = False,返回向量形式的 loss 

  B reduce = True, 返回标量形式的loss

       C  size_average = True,返回 loss.mean();

  D  如果 size_average = False,返回 loss.sum()

 默认情况下:两个参数都为True.

下面的是python的例子:

  1. # -*- coding: utf-8 -*-
  2. import torch
  3. import torch.optim as optim
  4. loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
  5. #loss_fn = torch.nn.MSELoss(reduce=True, size_average=True)
  6. #loss_fn = torch.nn.MSELoss()
  7. input = torch.autograd.Variable(torch.randn(3,4))
  8. target = torch.autograd.Variable(torch.randn(3,4))
  9. loss = loss_fn(input, target)
  10. print(input); print(target); print(loss)
  11. print(input.size(), target.size(), loss.size())

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/951148
推荐阅读
相关标签
  

闽ICP备14008679号