赞
踩
bce = BCEWithLogitsLoss()
pred为网络输出,y为标签值。
pred1.shape:[1],pred1.dtype:torch.float32
y1.shape:[1],pred1.dtype:torch.float32
- >>> pred1 = torch.Tensor([0.3])
- >>> y1 = torch.Tensor([1])
- >>> bce(pred1,y1)
- tensor(0.5544)
- >>> import math
- >>> e = math.e
- >>> log = math.log
- >>> -log(1/(1+1/math.pow(e,0.3)))
- 0.5543552444685272
pred1.shape:[10],pred1.dtype:torch.float32
y1.shape:[10],pred1.dtype:torch.float32
- >>> y2 = torch.ones([10], dtype=torch.float32)
- >>> pred2 = torch.full([10], 1.5)
- >>> criterion(pred2, y2)
- tensor(0.2014)
pred1.shape:[10,64],pred1.dtype:torch.float32
y1.shape:[10,64],pred1.dtype:torch.float32
- >>> y3 = torch.ones([4, 3], dtype=torch.float32)
- >>> pred3 = torch.full([4, 3], 1.5)
- >>> criterion = torch.nn.BCEWithLogitsLoss()
- >>> y3[0] = 0
- >>> pred3
- tensor([[1.5000, 1.5000, 1.5000],
- [1.5000, 1.5000, 1.5000],
- [1.5000, 1.5000, 1.5000],
- [1.5000, 1.5000, 1.5000]])
- >>> y3
- tensor([[0., 0., 0.],
- [1., 1., 1.],
- [1., 1., 1.],
- [1., 1., 1.]])
- >>> criterion(pred3, y3)
- tensor(0.5764)
- '''
- 若标签为1,且输出为1.5,则结果为:-log(1/(1+1/math.pow(e,1.5)))=0.2014
- 若标签为0,且输出为1.5,则结果为:-log(1-1/(1+1/math.pow(e,1.5)))=1.7014
- y3矩阵中,有3个0,9个1,表示(3*1.7014+9*0.2014)/12=0.5764
- '''
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。