赞
踩
def item_use():
# 理解:
# 1.item()取出张量具体位置的元素元素值
# 2.并且返回的是该位置元素值的高精度值
# 3.保持原元素类型不变;必须指定位置
# 一般用在求loss或者accuracy时,使用.item()
import torch
loss = torch.randn(2, 2)
print(loss)
print(loss[1, 1])
print(loss[1, 1].item())
# tensor([[-2.0274, -1.5974],
# [-1.4775, 1.9320]])
# tensor(1.9320)
# 1.9319512844085693
def unsqueeze_use():
import torch
input=torch.arange(0,6)
print(input)
print(input.shape)
print(input.unsqueeze(0))
print(input.unsqueeze(0).shape)
print(input.unsqueeze(1))
print(input.unsqueeze(1).shape)
def MSELOSS_use():
import torch
import numpy as np
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
a = np.array([[1, 2], [3, 4]])
b = np.array([[2, 3], [4, 5]])
input = torch.autograd.Variable(torch.from_numpy(a))
target = torch.autograd.Variable(torch.from_numpy(b))
loss = loss_fn(input.float(), target.float())
print(input.float())
print(target.float())
print(loss)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。