赞
踩
学习一门编程语言最快速的途径便是学习案例,然后自己再独立去实现案例,本文将介绍PyTorch的第一个实战案例——线性回归算法。
案例为:利用PyTorch设计神经网络拟合直线y=Wx+b,其中W=[2,-3.4]T , b=4.2。
该案例有两个特征,分别是W的两个维度,有一个标签,为输出y。
定义构造数据集函数:
#根据带有噪声的线性模型构造一个人造数据集
def synthetic_data(w,b,num_examples):
'''生成y=wx+b+e(噪声)'''
x=torch.normal(0,1,(num_examples,len(w))) # 形状:num_examples*len(w)
y=torch.matmul(x,w)+b # 形状:num_examples
y+=torch.normal(0,0.01,y.shape)
return x,y.reshape(-1,1) #将y转成列向量
首先,初始化服从正态分布的随机数据x,形状为num_examples*len(w),然后计算相应的y,最后再加入随机噪声,以增强生成的数据的通用性,此时的y是一个行向量,最后返回x和转成列向量的y
接着,将w和b的真实值传入构造数据集,构造400个数据集样本:
# y=Wx+b W=[2,-3.4]T b=4.2
true_w=torch.tensor([2,-3.4])
true_b=4.2
n=400 #样本数量
features,labels=synthetic_data(true_w,true_b,n) #生成数据features和labels
我们可以打印输出x和y的尺寸:
print(features.size(),labels.size())
尺寸如下:
torch.Size([400, 2]) torch.Size([400, 1])
我们也可以利用matplotlib库对features和labels进行可视化操作:
# 将features和labels进行可视化
plt.figure(figsize = (12,5))
ax1 = plt.subplot(121)
ax1.scatter(features[:,0],labels, c = "b",label = "features[:,0]")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y",rotation = 0)
ax2 = plt.subplot(122)
ax2.scatter(features[:,1],labels, c = "g",label = "features[:,1]")
ax2.legend()
plt.xlabel("x2")
plt.ylabel("y",rotation = 0)
运行结果:
定义PyTorch数据迭代器函数:
def load_array(data_arrays,batch_size,is_train=True):
'''构造一个PyTorch数据迭代器'''
dataset=data.TensorDataset(*data_arrays) #将样本数据和样本标签包装成datasets
return data.DataLoader(dataset,batch_size,shuffle=is_train) #加载包装好的数据集
TensorDataset 函数可以用来对 tensor 进行打包,就好像 python 中的 zip 功能,其用法如下:
torch.utils.data.TensorDataset(data_tensor, target_tensor)
该函数的用法是将样本数据data_tensor和样本标签target_tensor进行包装。
【注意】:TensorDataset函数经常与DataLoader函数结合使用
包装完毕后,利用DataLoader加载数据,并对数据进行采样,生成batch迭代器:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
定义完成数据迭代器后,将先前生成的features和labels传入该函数:
batch_size=10
data_iter=load_array((features,labels),batch_size) #从datasets中抽取batch_size个数据
我们可以抽取抽取数据集中的第一个data_iter,包含10个features和10个labels,查看一下他们的内容:
next(iter(data_iter)) #抽取数据集中的第一个data_iter,包含10个features和10个labels
运行结果:
[tensor([[-0.0977, 0.2305],
[ 1.5306, -0.5823],
[-1.2912, 0.1854],
[-2.5628, -0.7940],
[-0.5865, 0.6999],
[-0.3598, 2.8408],
[ 0.1976, -1.5368],
[-0.1743, 0.0667],
[-1.0448, -1.5595],
[ 0.4903, 0.1679]]),
tensor([[ 3.2290],
[ 9.2385],
[ 0.9945],
[ 1.7771],
[ 0.6469],
[-6.1814],
[ 9.8219],
[ 3.6375],
[ 7.4208],
[ 4.6111]])]
定义一个单层神经网络:
#使用预定义好的层
from torch import nn
net=nn.Sequential(nn.Linear(2,1)) # 全连接层
可以初始化模型参数(也可以不进行初始化):
#初始化模型参数(可选可不选)
net[0].weight.data.normal_(0,0.01) #正态分布
net[0].bias.data.fill_(0) #全零
然后定义损失函数MSEloss:
#损失函数
loss=nn.MSELoss()
定义SGD优化器:
optimizer=torch.optim.SGD(net.parameters(),lr=0.03)#实例化SGD优化器
开始训练代码如下:
#开始训练
#开始训练
num_epochs=5
for epoch in range(num_epochs):
for x,y in data_iter:
l=loss(net(x),y)
optimizer.zero_grad() #梯度清零
l.backward() #进行反向传播
optimizer.step() #模型更新
l=loss(net(features),labels)
print(f'epoch{epoch+1},loss{l:f}')
训练的基本流程如上所示,首先计算以下网络和真实标签的损失,然后将优化器的梯度清零,对损失进行反向传播,优化器进行模型更新,如此贩毒迭代epoch次,每次迭代完成后计算最终损失l,然后打印输出迭代次数和损失。
运行结果如下:
epoch1,loss0.000102
epoch2,loss0.000101
epoch3,loss0.000102
epoch4,loss0.000101
epoch5,loss0.000101
运行完成后,可以查看最终的w和b:
net[0].weight.data
运行结果:tensor([[ 1.9995, -3.4004]])
net[0].bias.data
运行结果:
tensor([4.2002])
查看一下二者的误差:
print(f'w的估计误差:{true_w-net[0].weight.data}')
print(f'b的估计误差:{true_b-net[0].bias.data}')
运行结果:
w的估计误差:tensor([[0.0005, 0.0004]])
b的估计误差:tensor([-0.0002])
利用最终训练出来的w和b可以进行最终输出结果可视化:
#结果可视化 plt.figure(figsize = (12,5)) ax1 = plt.subplot(121) ax1.scatter(features[:,0],labels, c = "b",label = "features[:,0]") ax1.plot(features[:,0],net[0].weight.data[:,0]*features[:,0]+net[0].bias.data,"-r",linewidth = 5.0,label = "model") ax1.legend() plt.xlabel("x1") plt.ylabel("y",rotation = 0) ax2 = plt.subplot(122) ax2.scatter(features[:,1],labels, c = "g",label = "features[:,1]") ax2.plot(features[:,1],net[0].weight.data[:,1]*features[:,1]+net[0].bias.data,"-r",linewidth = 5.0,label = "model") ax2.legend() plt.xlabel("x2") plt.ylabel("y",rotation = 0)
运行结果:
蓝色和绿色分别表示输入训练的散点数据,红色直线表示训练后的w和b组成的直线。可以发现,线性回归对数据具有较强的拟合能力。
完整代码可以参考:https://download.csdn.net/download/didi_ya/37378255
ok,以上便是本文的全部内容了,看完了之后记得一定要亲自独立动手实践一下呀~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。