当前位置:   article > 正文

批量训练、参数保存与提取-pytorch_python中将dataloader类型保存

python中将dataloader类型保存

简单回归问题

以简单回归问题为例,实现神经网络的小批量训练、网络参数保存以及参数提取。
简单回归问题的神经网络实现可见:【简单回归问题的神经网络实现-pytorch】

dataLoader定义

dataLoader是torch提供用于封装数据的工具,可以有效实现网络训练过程中的批量训练问题。

#生成DataLoader数据结构
def dataLoader(x,y):
    #将torch转换为Dataset
    torch_dataset = Data.TensorDataset(x, y)
    #将dataset放入DataLoader
    loader = Data.DataLoader(
        dataset=torch_dataset,
        batch_size=20,      #最小训练批量
        shuffle=True,       #是否对数据进行随机打乱
        num_workers=2,      #多线程来读数据
    )
    return loader
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

使用方法如下:

	#模拟数据
    x,y=dataSet()
    loader=dataLoader(x,y)
	#迭代训练
    for epoch in range(40):
        lossAll=0
        for step, (batch_x, batch_y) in enumerate(loader):
            #预测
            prediction=net(batch_x)
            #计算误差
            loss=loss_fun(prediction,batch_y)
            lossAll+=loss.data.numpy()
            #梯度降为0
            optimizer.zero_grad()
            #反向传递
            loss.backward()
            #优化梯度
            optimizer.step()
            #打印误差
            print('Epoch: ', epoch, '| Step: ', step, '| loss: ',loss.data.numpy())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

迭代过程如下:
在这里插入图片描述

参数保存

#保存网络
def saveNet(net,params=False):
    if params:
        #保存网络参数
        torch.save(net.state_dict(),'net_params.pkl')
    else:
        #保存整个网络
        torch.save(net,'net.pkl')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

参数提取

#提取网络
def restoreNet(params=False):
    if params:
        #提取网络参数->注意需要新建一个相同类型的网络
        net1=Net(1,[10,20],1)
        net1.load_state_dict(torch.load('net_params.pkl'))
    else:
        #保存整个网络
        net1 = torch.load('net.pkl')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

====================================
今天到此为止,后续记录其他神经网络技术的学习过程。
以上学习笔记,如有侵犯,请立即联系并删除!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/439336
推荐阅读
相关标签
  

闽ICP备14008679号