赞
踩
PyTorch实现线性回归
可调用对象:
如果要使用一个可调用对象,那么在类的声明的时候要定义一个call函数
class Foobar:
def __init__(self):
pass
def __call__(self,*args,**kwargs):
pass
其中参数*args代表把前n个参数变成n元组,**kwargsd会把参数变成一个词典,这些都是python的基础语法
def func(*args,**kwargs):
print(args)
print(kwargs)
func(1,2,3,4,x=3,y=5)
"""
(1, 2, 3, 4)
{'x': 3, 'y': 5}
"""
PyTorch线性回归的四个过程:
每一次训练的过程就是:
实现代码:
import torch import matplotlib.pyplot as plt import numpy as np x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[2.0], [4.0], [6.0]]) class LinearModel(torch.nn.Module): def __init__(self): # 构造函数 super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1) # 构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b def forward(self, x): y_pred = self.linear(x) # 可调用对象,计算y=wx+b return y_pred model = LinearModel() # 实例化模型 criterion = torch.nn.MSELoss(reduction='sum') # model.parameters()会扫描module中的所有成员,如果成员中有相应的权重,那么都会将结果加到要训练的集合参数上 optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # lr为学习率 epoch_list = [] loss_list = [] # for epoch in np.arange(0, 100, 2): for epoch in range(100): y_pred = model(x_data) loss = criterion(y_pred, y_data) print(epoch, loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_list.append(epoch) loss_list.append(loss.item()) print('w=', model.linear.weight.item()) print('b=', model.linear.bias.item()) x_test = torch.Tensor([[4.0]]) y_test = model(x_test) print('y_pred = ', y_test.data) plt.plot(epoch_list, loss_list) plt.xlabel('times') plt.ylabel('loss') plt.title('SGD') plt.show()
运行结果:
0 111.91926574707031 1 49.82788848876953 2 22.186500549316406 3 9.881272315979004 4 4.403273582458496 5 1.9645596742630005 6 0.8788504600524902 7 0.3954624831676483 8 0.18021120131015778 9 0.08432696759700775 10 0.04158348590135574 11 0.022497136145830154 12 0.013943195343017578 13 0.01007873099297285 14 0.008302716538310051 15 0.007457221858203411 16 0.007026821840554476 17 0.006781961768865585 18 0.006620422005653381 19 0.006496733520179987 20 0.006390667520463467 21 0.006293224636465311 22 0.0062002213671803474 23 0.006110009737312794 24 0.006021701730787754 25 0.005934945307672024 26 0.005849512759596109 27 0.0057654669508337975 28 0.005682558752596378 29 0.005600868724286556 30 0.005520401056855917 31 0.005441035609692335 32 0.0053628794848918915 33 0.005285775288939476 34 0.005209808703511953 35 0.005134933162480593 36 0.005061125382781029 37 0.004988380707800388 38 0.00491672195494175 39 0.0048460508696734905 40 0.004776409827172756 41 0.00470777926966548 42 0.004640108905732632 43 0.004573439247906208 44 0.00450771301984787 45 0.004442923702299595 46 0.004379057325422764 47 0.004316150210797787 48 0.004254107363522053 49 0.004192924126982689 50 0.0041326736100018024 51 0.004073282703757286 52 0.004014759790152311 53 0.003957051318138838 54 0.0039002075791358948 55 0.0038441140204668045 56 0.003788899164646864 57 0.0037344531156122684 58 0.003680775174871087 59 0.0036278674378991127 60 0.003575714770704508 61 0.0035243607126176357 62 0.003473697230219841 63 0.003423791378736496 64 0.003374570980668068 65 0.0033260590862482786 66 0.003278267802670598 67 0.003231176408007741 68 0.0031847076024860144 69 0.003138953121379018 70 0.003093830542638898 71 0.0030493782833218575 72 0.0030055600218474865 73 0.0029623594600707293 74 0.002919779857620597 75 0.0028778419364243746 76 0.002836476778611541 77 0.002795706270262599 78 0.0027555148117244244 79 0.0027159445453435183 80 0.002676892327144742 81 0.0026384363882243633 82 0.002600492676720023 83 0.0025631182361394167 84 0.002526274649426341 85 0.0024899819400161505 86 0.002454179571941495 87 0.002418922260403633 88 0.0023841557558625937 89 0.002349911257624626 90 0.002316119149327278 91 0.002282818779349327 92 0.0022500380873680115 93 0.002217694651335478 94 0.002185826888307929 95 0.00215441663749516 96 0.0021234452724456787 97 0.0020929216407239437 98 0.002062862040475011 99 0.0020332084968686104 100 0.002003985922783613 101 0.0019751866348087788 102 0.0019467804813757539 103
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。