赞
踩
以简单回归问题为例,实现神经网络的小批量训练、网络参数保存以及参数提取。
简单回归问题的神经网络实现可见:【简单回归问题的神经网络实现-pytorch】
dataLoader是torch提供用于封装数据的工具,可以有效实现网络训练过程中的批量训练问题。
#生成DataLoader数据结构
def dataLoader(x,y):
#将torch转换为Dataset
torch_dataset = Data.TensorDataset(x, y)
#将dataset放入DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=20, #最小训练批量
shuffle=True, #是否对数据进行随机打乱
num_workers=2, #多线程来读数据
)
return loader
使用方法如下:
#模拟数据 x,y=dataSet() loader=dataLoader(x,y) #迭代训练 for epoch in range(40): lossAll=0 for step, (batch_x, batch_y) in enumerate(loader): #预测 prediction=net(batch_x) #计算误差 loss=loss_fun(prediction,batch_y) lossAll+=loss.data.numpy() #梯度降为0 optimizer.zero_grad() #反向传递 loss.backward() #优化梯度 optimizer.step() #打印误差 print('Epoch: ', epoch, '| Step: ', step, '| loss: ',loss.data.numpy())
迭代过程如下:
#保存网络
def saveNet(net,params=False):
if params:
#保存网络参数
torch.save(net.state_dict(),'net_params.pkl')
else:
#保存整个网络
torch.save(net,'net.pkl')
#提取网络
def restoreNet(params=False):
if params:
#提取网络参数->注意需要新建一个相同类型的网络
net1=Net(1,[10,20],1)
net1.load_state_dict(torch.load('net_params.pkl'))
else:
#保存整个网络
net1 = torch.load('net.pkl')
====================================
今天到此为止,后续记录其他神经网络技术的学习过程。
以上学习笔记,如有侵犯,请立即联系并删除!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。