当前位置:   article > 正文

【PyTorch】深度学习实战之PyTorch实现线性回归_基于pytorch的回归分析

基于pytorch的回归分析

PyTorch实现线性回归

可调用对象:

如果要使用一个可调用对象,那么在类的声明的时候要定义一个call函数

class Foobar: 
    def __init__(self):
        pass
    def __call__(self,*args,**kwargs):
        pass
  • 1
  • 2
  • 3
  • 4
  • 5

其中参数*args代表把前n个参数变成n元组,**kwargsd会把参数变成一个词典,这些都是python的基础语法

def func(*args,**kwargs):
    print(args)
    print(kwargs)
    
func(1,2,3,4,x=3,y=5)
"""
(1, 2, 3, 4)
{'x': 3, 'y': 5}
"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

PyTorch线性回归的四个过程:

  • 准备训练集
  • 使用类设计模型(目的是为了前向传播forward,计算y hat)
  • 构造损失函数和优化器(其中,loss是为了进行反向传播,optimizer是为了更新梯度
  • 循环训练(前向算损失,反向算梯度,然后不断更新)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HF6SQqKx-1647356737901)(../AppData/Roaming/Typora/typora-user-images/image-20220315105324246.png)]

每一次训练的过程就是:

  • 前向传播,求y_hat(预测值)
  • 根据y_hat和y_label(y_data)计算loss
  • 反向传播backward(计算梯度)
  • 根据梯度,更新参数

实现代码:

import torch
import matplotlib.pyplot as plt
import numpy as np

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])


class LinearModel(torch.nn.Module):
    def __init__(self):  # 构造函数
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # 构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b

    def forward(self, x):
        y_pred = self.linear(x)  # 可调用对象,计算y=wx+b
        return y_pred


model = LinearModel()  # 实例化模型

criterion = torch.nn.MSELoss(reduction='sum')
# model.parameters()会扫描module中的所有成员,如果成员中有相应的权重,那么都会将结果加到要训练的集合参数上
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # lr为学习率

epoch_list = []
loss_list = []
# for epoch in np.arange(0, 100, 2):
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_list.append(epoch)
    loss_list.append(loss.item())
print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
plt.plot(epoch_list, loss_list)
plt.xlabel('times')
plt.ylabel('loss')
plt.title('SGD')
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

运行结果:

0 111.91926574707031
1 49.82788848876953
2 22.186500549316406
3 9.881272315979004
4 4.403273582458496
5 1.9645596742630005
6 0.8788504600524902
7 0.3954624831676483
8 0.18021120131015778
9 0.08432696759700775
10 0.04158348590135574
11 0.022497136145830154
12 0.013943195343017578
13 0.01007873099297285
14 0.008302716538310051
15 0.007457221858203411
16 0.007026821840554476
17 0.006781961768865585
18 0.006620422005653381
19 0.006496733520179987
20 0.006390667520463467
21 0.006293224636465311
22 0.0062002213671803474
23 0.006110009737312794
24 0.006021701730787754
25 0.005934945307672024
26 0.005849512759596109
27 0.0057654669508337975
28 0.005682558752596378
29 0.005600868724286556
30 0.005520401056855917
31 0.005441035609692335
32 0.0053628794848918915
33 0.005285775288939476
34 0.005209808703511953
35 0.005134933162480593
36 0.005061125382781029
37 0.004988380707800388
38 0.00491672195494175
39 0.0048460508696734905
40 0.004776409827172756
41 0.00470777926966548
42 0.004640108905732632
43 0.004573439247906208
44 0.00450771301984787
45 0.004442923702299595
46 0.004379057325422764
47 0.004316150210797787
48 0.004254107363522053
49 0.004192924126982689
50 0.0041326736100018024
51 0.004073282703757286
52 0.004014759790152311
53 0.003957051318138838
54 0.0039002075791358948
55 0.0038441140204668045
56 0.003788899164646864
57 0.0037344531156122684
58 0.003680775174871087
59 0.0036278674378991127
60 0.003575714770704508
61 0.0035243607126176357
62 0.003473697230219841
63 0.003423791378736496
64 0.003374570980668068
65 0.0033260590862482786
66 0.003278267802670598
67 0.003231176408007741
68 0.0031847076024860144
69 0.003138953121379018
70 0.003093830542638898
71 0.0030493782833218575
72 0.0030055600218474865
73 0.0029623594600707293
74 0.002919779857620597
75 0.0028778419364243746
76 0.002836476778611541
77 0.002795706270262599
78 0.0027555148117244244
79 0.0027159445453435183
80 0.002676892327144742
81 0.0026384363882243633
82 0.002600492676720023
83 0.0025631182361394167
84 0.002526274649426341
85 0.0024899819400161505
86 0.002454179571941495
87 0.002418922260403633
88 0.0023841557558625937
89 0.002349911257624626
90 0.002316119149327278
91 0.002282818779349327
92 0.0022500380873680115
93 0.002217694651335478
94 0.002185826888307929
95 0.00215441663749516
96 0.0021234452724456787
97 0.0020929216407239437
98 0.002062862040475011
99 0.0020332084968686104
100 0.002003985922783613
101 0.0019751866348087788
102 0.0019467804813757539
103
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/577718
推荐阅读
相关标签
  

闽ICP备14008679号