当前位置:   article > 正文

线性分类代码示例_线性分类的源码

线性分类的源码

以下是一个基于Python语言的线性分类代码示例,主要用到pytorch,numpy,matplotlib,

  1. import numpy as np
  2. import torch
  3. import matplotlib.pyplot as plt
  4. # 生成随机数据
  5. x_ = np.linspace(1, 0.1, 100).reshape(100, 1)
  6. y_ = 3 * x_ + np.random.uniform(0, 1, (100, 1))
  7. # 转成tensor
  8. x = torch.from_numpy(x_).float()
  9. y = torch.from_numpy(y_).float()
  10. linear = torch.nn.Linear(1, 1)  # 线性分类
  11. optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)  # 优化器
  12. loss_func = torch.nn.MSELoss()  # 损失函数
  13. EPOCH_NUM = 200
  14. for epoch in range(EPOCH_NUM):
  15.     plt.scatter(x, y, s=5)
  16.     y_pred = linear(x)  # 线性分类
  17.     loss = loss_func(y_pred, y)  # 计算损失
  18.     print(loss)
  19.     optimizer.zero_grad()  # 梯度清零
  20.     loss.backward()  # 损失回传
  21.     optimizer.step()  # 进行一次优化,更新参数
  22.     # 线性分类表达式,f(x,W)=Wx+b
  23.     f = linear.state_dict()['weight'].numpy()[0][0] * x_ + linear.state_dict()['bias'].numpy()[0]
  24.     plt.plot(x_, f, 'r-', linewidth=1, label='f(x)')
  25.     plt.pause(0.001)
  26.     plt.clf()
  27. print(linear.state_dict())
  28. f = linear.state_dict()['weight'].numpy()[0][0] * x_ + linear.state_dict()['bias'].numpy()[0]
  29. plt.scatter(x_, y_, s=5)
  30. plt.plot(x_, f, 'r-', linewidth=1, label='f(x)')
  31. plt.show()
  32. # torch.save(linear.state_dict(), 'linear.pt')  # 保存模型

对线性拟合过程做了可视化如下:

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/354666?site
推荐阅读
相关标签
  

闽ICP备14008679号