赞
踩
以下是一个基于Python语言的线性分类代码示例,主要用到pytorch,numpy,matplotlib,
- import numpy as np
- import torch
- import matplotlib.pyplot as plt
-
- # 生成随机数据
- x_ = np.linspace(1, 0.1, 100).reshape(100, 1)
- y_ = 3 * x_ + np.random.uniform(0, 1, (100, 1))
-
- # 转成tensor
- x = torch.from_numpy(x_).float()
- y = torch.from_numpy(y_).float()
-
- linear = torch.nn.Linear(1, 1) # 线性分类
- optimizer = torch.optim.SGD(linear.parameters(), lr=0.1) # 优化器
- loss_func = torch.nn.MSELoss() # 损失函数
-
- EPOCH_NUM = 200
- for epoch in range(EPOCH_NUM):
- plt.scatter(x, y, s=5)
- y_pred = linear(x) # 线性分类
- loss = loss_func(y_pred, y) # 计算损失
- print(loss)
- optimizer.zero_grad() # 梯度清零
- loss.backward() # 损失回传
- optimizer.step() # 进行一次优化,更新参数
- # 线性分类表达式,f(x,W)=Wx+b
- f = linear.state_dict()['weight'].numpy()[0][0] * x_ + linear.state_dict()['bias'].numpy()[0]
- plt.plot(x_, f, 'r-', linewidth=1, label='f(x)')
- plt.pause(0.001)
- plt.clf()
- print(linear.state_dict())
- f = linear.state_dict()['weight'].numpy()[0][0] * x_ + linear.state_dict()['bias'].numpy()[0]
- plt.scatter(x_, y_, s=5)
- plt.plot(x_, f, 'r-', linewidth=1, label='f(x)')
- plt.show()
- # torch.save(linear.state_dict(), 'linear.pt') # 保存模型

对线性拟合过程做了可视化如下:
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。