当前位置:   article > 正文

线性回归的简单实现_线性模型训练过程中,标签

线性模型训练过程中,标签

代码详解

  1. # 将训练数据的特征和标签组合
  2. dataset = Data.TensorDataset(features, labels)
  3. # 随机读取小批量
  4. data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
  5. for X, y in data_iter:
  6. print(X, y)
  7. break
  8. #利用nn.module实现模型
  9. # class LinearNet(nn.Module):
  10. # def __init__(self, n_feature):
  11. # super(LinearNet, self).__init__()
  12. # self.linear = nn.Linear(n_feature, 1)
  13. # # forward 定义前向传播
  14. # def forward(self, x):
  15. # y = self.linear(x)
  16. # return y
  17. #
  18. # net = LinearNet(num_inputs)
  19. # print(net) # 使用print可以打印出网络的结构
  20. #利用squntial容器搭建网络
  21. net = nn.Sequential(
  22. nn.Linear(num_inputs, 1)
  23. # 此处还可以传入其他层
  24. )
  25. print(net)
  26. print(net[0])
  27. #查看所有可学习参数
  28. for param in net.parameters():
  29. print(param)
  30. #初始化模型参数
  31. init.normal_(net[0].weight, mean=0, std=0.01) #将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布
  32. init.constant_(net[0].bias, val=0) # 也可以直接修改bias的data: net[0].bias.data.fill_(0)
  33. print(net[0].weight)
  34. print(net[0].bias)
  35. #定义损失函数
  36. loss = nn.MSELoss()
  37. #定义优化算法
  38. optimizer = optim.SGD(net.parameters(), lr=0.03)
  39. #调整学习率
  40. # optimizer.param_groups: 是长度为2的list,其中的元素是2个字典;
  41. # optimizer.param_groups[0]: 长度为6的字典,包括[‘amsgrad’, ‘params’, ‘lr’, ‘betas’, ‘weight_decay’, ‘eps’]这6个参数;
  42. # optimizer.param_groups[1]: 好像是表示优化器的状态的一个字典;
  43. optimizer.param_groups[0]['lr'] *= 0.1
  44. #训练模型
  45. num_epochs = 3
  46. for epoch in range(1, num_epochs + 1):
  47. for X, y in data_iter:
  48. output = net(X)
  49. #torch.view(-1, 参数b),则表示在参数a未知,参数b已知的情况下自动补齐行向量长度
  50. l = loss(output, y.view(-1, 1))
  51. optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()
  52. l.backward()
  53. optimizer.step()#更新优化器参数
  54. print('epoch %d, loss: %f' % (epoch, l.item()))
  55. #w b 比较
  56. dense = net[0] #第零层网络
  57. print(true_w, dense.weight)
  58. print(true_b, dense.bias)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/144968
推荐阅读
相关标签
  

闽ICP备14008679号