当前位置:   article > 正文

【PyTorch】深度学习实战之梯度下降算法_pytorch梯度下降实验代码

pytorch梯度下降实验代码

梯度下降算法:可能只会找到局部最优点

随机梯度下降算法:随机梯度下降更容易找到全局最优点,类似于快速排序选基点,随机选取能够提高找到全局最优点的概率

梯度下降算法实现代码:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1.0

# 定义线性模型
def forward(x):
    return x * w

# 定义所有样本的平均平方误差
def cost(xs, ys):
    cost = 0
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost / len(xs)

# 定义梯度函数
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)


epoch_list = []
loss_list = []
print('Predict (before training)', 4, forward(4))
for epoch in range(100):
    cost_val = cost(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= 0.01 * grad_val # 0.01是学习率
    print('Epoch:', epoch, 'w=', w, 'loss=', cost_val)
    epoch_list.append(epoch)
    loss_list.append(cost_val)
print('Predict (after training)', 4, forward(4))

plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

实现结果:

Predict (before training) 4 4.0
Epoch: 0 w= 1.0933333333333333 loss= 4.666666666666667
Epoch: 1 w= 1.1779555555555554 loss= 3.8362074074074086
Epoch: 2 w= 1.2546797037037036 loss= 3.1535329869958857
Epoch: 3 w= 1.3242429313580246 loss= 2.592344272332262
Epoch: 4 w= 1.3873135910979424 loss= 2.1310222071581117
Epoch: 5 w= 1.4444976559288012 loss= 1.7517949663820642
Epoch: 6 w= 1.4963445413754464 loss= 1.440053319920117
Epoch: 7 w= 1.5433523841804047 loss= 1.1837878313441108
Epoch: 8 w= 1.5859728283235668 loss= 0.9731262101573632
Epoch: 9 w= 1.6246153643467005 loss= 0.7999529948031382
Epoch: 10 w= 1.659651263674342 loss= 0.6575969151946154
Epoch: 11 w= 1.6914171457314033 loss= 0.5405738908195378
Epoch: 12 w= 1.7202182121298057 loss= 0.44437576375991855
Epoch: 13 w= 1.7463311789976905 loss= 0.365296627844598
Epoch: 14 w= 1.7700069356245727 loss= 0.3002900634939416
Epoch: 15 w= 1.7914729549662791 loss= 0.2468517784170642
Epoch: 16 w= 1.8109354791694263 loss= 0.2029231330489788
Epoch: 17 w= 1.8285815011136133 loss= 0.16681183417217407
Epoch: 18 w= 1.8445805610096762 loss= 0.1371267415488235
Epoch: 19 w= 1.8590863753154396 loss= 0.11272427607497944
Epoch: 20 w= 1.872238313619332 loss= 0.09266436490145864
Epoch: 21 w= 1.8841627376815275 loss= 0.07617422636521683
Epoch: 22 w= 1.8949742154979183 loss= 0.06261859959338009
Epoch: 23 w= 1.904776622051446 loss= 0.051475271914629306
Epoch: 24 w= 1.9136641373266443 loss= 0.04231496130368814
Epoch: 25 w= 1.9217221511761575 loss= 0.03478477885657844
Epoch: 26 w= 1.9290280837330496 loss= 0.02859463421027894
Epoch: 27 w= 1.9356521292512983 loss= 0.023506060193480772
Epoch: 28 w= 1.9416579305211772 loss= 0.01932302619282764
Epoch: 29 w= 1.9471031903392007 loss= 0.015884386331668398
Epoch: 30 w= 1.952040225907542 loss= 0.01305767153735723
Epoch: 31 w= 1.9565164714895047 loss= 0.010733986344664803
Epoch: 32 w= 1.9605749341504843 loss= 0.008823813841374291
Epoch: 33 w= 1.9642546069631057 loss= 0.007253567147113681
Epoch: 34 w= 1.9675908436465492 loss= 0.005962754575689583
Epoch: 35 w= 1.970615698239538 loss= 0.004901649272531298
Epoch: 36 w= 1.9733582330705144 loss= 0.004029373553099482
Epoch: 37 w= 1.975844797983933 loss= 0.0033123241439168096
Epoch: 38 w= 1.9780992835054327 loss= 0.0027228776607060357
Epoch: 39 w= 1.980143350378259 loss= 0.002238326453885249
Epoch: 40 w= 1.9819966376762883 loss= 0.001840003826269386
Epoch: 41 w= 1.983676951493168 loss= 0.0015125649231412608
Epoch: 42 w= 1.9852004360204722 loss= 0.0012433955919298103
Epoch: 43 w= 1.9865817286585614 loss= 0.0010221264385926248
Epoch: 44 w= 1.987834100650429 loss= 0.0008402333603648631
Epoch: 45 w= 1.9889695845897222 loss= 0.0006907091659248264
Epoch: 46 w= 1.9899990900280147 loss= 0.0005677936325753796
Epoch: 47 w= 1.9909325082920666 loss= 0.0004667516012495216
Epoch: 48 w= 1.9917788075181404 loss= 0.000383690560742734
Epoch: 49 w= 1.9925461188164473 loss= 0.00031541069384432885
Epoch: 50 w= 1.9932418143935788 loss= 0.0002592816085930997
Epoch: 51 w= 1.9938725783835114 loss= 0.0002131410058905752
Epoch: 52 w= 1.994444471067717 loss= 0.00017521137977565514
Epoch: 53 w= 1.9949629871013967 loss= 0.0001440315413480261
Epoch: 54 w= 1.9954331083052663 loss= 0.0001184003283899171
Epoch: 55 w= 1.9958593515301082 loss= 9.733033217332803e-05
Epoch: 56 w= 1.9962458120539648 loss= 8.000985883901657e-05
Epoch: 57 w= 1.9965962029289281 loss= 6.57716599593935e-05
Epoch: 58 w= 1.9969138906555615 loss= 5.406722767150764e-05
Epoch: 59 w= 1.997201927527709 loss= 4.444566413387458e-05
Epoch: 60 w= 1.9974630809584561 loss= 3.65363112808981e-05
Epoch: 61 w= 1.9976998600690001 loss= 3.0034471708953996e-05
Epoch: 62 w= 1.9979145397958935 loss= 2.4689670610172655e-05
Epoch: 63 w= 1.9981091827482769 loss= 2.0296006560253656e-05
Epoch: 64 w= 1.9982856590251044 loss= 1.6684219437262796e-05
Epoch: 65 w= 1.9984456641827613 loss= 1.3715169898293847e-05
Epoch: 66 w= 1.9985907355257035 loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/104000
推荐阅读
相关标签
  

闽ICP备14008679号