当前位置:   article > 正文

pytorch(二):分类_torch.normal(2*data,1)是什么意思

torch.normal(2*data,1)是什么意思
  1. import torch
  2. import torch.nn.functional as f
  3. from torch.autograd import Variable
  4. import matplotlib.pyplot as plt
  5. # 建造数据集
  6. data = torch.ones((100, 2))
  7. x0 = torch.normal(2*data, 1)
  8. y0 = torch.zeros(100) # y0是标签 shape(100,),是一维
  9. x1 = torch.normal(-2*data, 1)
  10. y1 = torch.ones(100) # y1也是标签 shape(100,),是一维
  11. x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # 参数0表示维度,在纵向方向将x0,x1合并,合并后shape(200, 2))
  12. y = torch.cat((y0, y1), 0).type(torch.LongTensor) # 标签是01,类型为整数,LongTensor = 64-bit integer,
  13. x, y = Variable(x), Variable(y) # 训练神经网络只能接受变量输入,故要把x, y转化为变量
  14. plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], # 这两个参数分别代表x,y轴坐标
  15. c=y.data.numpy(), s=100, cmap='RdYlGn') # c为color,y有两种标签,代表两种颜色的点,'RdYlGn'红色和绿色
  16. plt.show()
  17. # 建造神经网络模型
  18. class Net(torch.nn.Module):
  19. def __init__(self, n_feature, n_hidden, n_output):
  20. super(Net, self).__init__()
  21. self.hidden = torch.nn.Linear(n_feature, n_hidden)
  22. self.out = torch.nn.Linear(n_hidden, n_output)
  23. def forward(self, x):
  24. x = f.relu(self.hidden(x))
  25. y = self.out(x)
  26. return y
  27. # 定义神经网络
  28. net = Net(n_feature=2, n_hidden=10, n_output=2)
  29. # n_output=2,因为它返回一个元素为2的列表。[0, 1]表示学习到的内容为标签1,[1, 0]表示学习到的内容为标签0
  30. print(net)
  31. # 训练神经网络模型并将训练过程可视化
  32. optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
  33. loss_func = torch.nn.CrossEntropyLoss()
  34. plt.ion()
  35. for i in range(100):
  36. out = net(x)
  37. loss = loss_func(out, y)
  38. optimizer.zero_grad()
  39. loss.backward()
  40. optimizer.step()
  41. # 绘图
  42. if i % 2 == 0:
  43. plt.cla()
  44. # torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引
  45. # f.softmax(out)是将out的内容以概率表示。
  46. # torch.max()返回的是两个Variable,第一个Variable存的是最大值,第二个存的是其对应的位置索引index。这里我们想要得到的是索引,所以后面用[1]。
  47. prediction = torch.max(f.softmax(out), 1)[1]
  48. pred_y = prediction.data.numpy().squeeze()
  49. target_y = y.data.numpy()
  50. plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, cmap='RdYlGn')
  51. accuracy = sum(pred_y == target_y)/200
  52. plt.text(1.5, -4, 'accuracy=%.2f'%accuracy, fontdict={'size':10, 'color':'red'})
  53. plt.pause(0.1)
  54. plt.ioff()
  55. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/358159
推荐阅读
相关标签
  

闽ICP备14008679号