当前位置:   article > 正文

计算图--反向传播_反向传播 计算图

反向传播 计算图

一、计算图

计算图可以用于表示一个复杂的函数,它通过将复杂函数分解为简单的计算来得到最后的计算结果

下面用一个简单的例子来介绍计算图

为了方便表示,在此说明一下:

  • 矩形节点:输入功能或者输出功能

  • 圆形节点:接受一个数据的输入并按照节点运算规则计算,并将计算结果输出到下一个节点

函数: f ( x , y , z ) = ( x + y ) z f(x, y, z)=(x+y)z f(x,y,z)=(x+y)z 的计算图表示如下:

-2
5
3
4
12
x
+
y
*
z
output

顺着计算图的箭头方向进行计算的过程被称为前向传播,在这个例子中,前向传播的最后结果是12

看到这里,你可能会奇怪,为什么一个简单函数需要分解为好几步计算,事实上,上面的节点可以反复使用,当函数稍微复杂时,采用公式直接计算就不如上面的简洁明了

更重要的是在神经网络的反向传播算法中,需要计算梯度,这涉及到了导数,一般的人计算稍微复杂的导数都要花上点时间,更不要说计算机了,通过计算图的形式计算梯度,计算机可以快速地得出结果

二、反向传播

1、基本概念

前向传播是通过神经网络不断计算,向前推进,最后在输出层得到计算结果,但是这个计算结果可能与标准结果存在一定的误差,为了衡量这个误差,引入了损失函数的概念,通过损失函数来定义误差的概念,为了减小这个误差,以表达到优化的目的,采用了反向传播算法(Backpropagation Algorithm)来优化每层网络的参数,利用梯度下降算法,可以得到损失函数在局部的最小值,如果这个损失函数是凸函数,那么这个局部最小值就是全局最小值

梯度下降算法迭代公式:
w = w − α ∂ J ( w , b ) ∂ w w=w-\alpha \frac {\partial J(w,b) }{\partial w} w=wαwJ(w,b)

这个公式的难点在于如何计算梯度,因为梯度计算涉及到了求偏导,对于复杂的神经网络来说,求导不是一件容易的事

接下来看一个例子,使用计算图来计算偏导数

注:例子中没有使用偏导符号,是因为在计算图里面,每个节点都是一些最简运算,都是一对一的,所以使用求导运算,为了方便好看,我就没有写偏导符号,直接写求导符号了,这对最后结果没有任何影响,只是表示符号不一样而已

2、例子

以Logistic回归中的预测函数为例:

输出: y ^ = σ ( z ) \hat y=\sigma (z) y^=σ(z)

其中: σ ( z ) = 1 1 + e − z , z = w T x + b \sigma (z)=\frac {1}{1+e^{-z}},z=w^Tx+b σ(z)=1+ez1z=wTx+b

为了方便观察,这里我们取:$w^T=[2.00,-3.00],\ x^T=[-1.00, -2.00],b=-3.00 $

画出计算图的正向传播过程,如下图所示:

2
-1
-3
-2
-2
6
4
-3
1
-1
0.37
1.37
0.73
w0
*
x0
w1
*
x1
+
+
b
*-1
exp
+1
1/x
output

对于反向传播,我们需要使用求导的链式法则: d y ^ d w 0 = d y ^ d z d z d w 0 \frac {d\hat y}{dw_0}=\frac {d\hat y}{dz} \frac {dz}{dw_0} dw0dy^=dzdy^dw0dz

因此,先从最后的输出开始计算梯度,为了方便表示,这里暂且规定一下,对于每个节点,输入记为 z z z,输出记为 y y y,节点的运算规则可以表示为: f f f,以 f , z , y f,z,y f,z,y的下标来区分不同的节点

整个计算图所表示的是: y ^ = f ( z ) = σ ( w T x + b ) = 1 1 + e − ( w T + b ) = 1 1 + e − ( w 0 x 0 + w 1 x 1 + b ) \hat y=f(z)=\sigma (w^Tx+b)=\frac {1}{1+e^{-(w^T+b)}}=\frac {1}{1+e^{-(w_0x_0+w_1x_1+b)}} y^=f(z)=σ(wTx+b)=1+e(wT+b)1=1+e(w0x0+w1x1+b)1

现在,我们要计算 d y ^ d w 0 \frac {d\hat y}{dw_0} dw0dy^,这里需要注意的是每步计算尽量减少计算过程

由于是反向传播,规定输出节点为第一个

计算步骤如下:

(1)第一个节点

输出节点:因为输出节点没有运算,输出等于输入,因此: y ^ = y 0 = f 0 ( z 0 ) = z 0 \hat y=y_0=f_0(z_0)=z_0 y^=y0=f0(z0)=z0

导数: d y ^ d z 0 = d f 0 d z 0 = 1 \frac {d\hat y}{dz_0}=\frac {df_0}{dz_0}=1 dz0dy^=dz0df0=1

y ^ \hat y y^在第一个节点的导数: d y ^ d z 0 = 1 \frac {d\hat y}{dz_0}=1 dz0dy^=1

(2)第二个节点

第二个节点的运算: y 1 = f 1 ( z 1 ) = 1 z 1 y_1=f_1(z_1)=\frac {1}{z_1} y1=f1(z1)=z11

导数: d y 1 d z 1 = − 1 z 1 2 \frac {dy_1}{dz_1}=\frac {-1}{z_1^2} dz1dy1=z121

因为我们通过前向传播,已经计算出了 y 1 y_1 y1的值,可以用 y 1 y_1 y1化简这个式子, d y 1 d z 1 = − y 1 2 \frac {dy_1}{dz_1}=-y_1^2 dz1dy1=y12

并且: z 0 = y 1 = 0.73 z_0=y_1=0.73 z0=y1=0.73

y ^ \hat y y^在第二个节点的导数: d y ^ d z 1 = d y ^ d z 0 d z 0 d z 1 = d y ^ d z 0 d y 1 d z 1 = 1 ∗ ( − y 1 2 ) = − 0.53 \frac {d\hat y}{dz_1}=\frac {d\hat y}{dz_0}\frac {dz_0}{dz_1}=\frac {d\hat y}{dz_0}\frac {dy_1}{dz_1}=1*(-y_1^2)=-0.53 dz1dy^=dz0dy^dz1dz0=dz0dy^dz1dy1=1(y12)=0.53

(3)第三个节点

第三个节点的运算: y 2 = f 2 ( z 2 ) = z 2 + 1 y_2=f_2(z_2)=z_2+1 y2=f2(z2)=z2+1

导数: d y 2 d z 2 = 1 \frac {dy_2}{dz_2}=1 dz2dy2=1

并且: z 1 = y 2 = 1.37 z_1=y_2=1.37 z1=y2=1.37

d y ^ d z 2 = d y ^ d z 1 d z 1 d z 2 = d y ^ d z 1 d y 2 d z 2 \frac {d\hat y}{dz_2}=\frac {d\hat y}{dz_1} \frac {dz_1}{dz_2}=\frac {d\hat y}{dz_1} \frac {dy_2}{dz_2} dz2dy^=dz1dy^dz2dz1=dz1dy^dz2dy2

在上一步中已经计算出了 d y ^ d z 1 = − 0.53 \frac {d\hat y}{dz_1} =-0.53 dz1dy^=0.53,可以直接使用:

y ^ \hat y y^在第三个节点的导数: d y ^ d z 2 = ( − 0.53 ) ∗ 1 = − 0.53 \frac {d\hat y}{dz_2}=(-0.53)*1=-0.53 dz2dy^=(0.53)1=0.53

(4)第四个节点

第四个节点运算为: y 3 = f 3 ( z 3 ) = e z 3 y_3=f_3(z_3)=e^{z_3} y3=f3(z3)=ez3

导数: d y 3 d z 3 = e z 3 = y 3 \frac {dy_3}{dz_3}=e^{z_3}=y_3 dz3dy3=ez3=y3

y ^ \hat y y^在第四个节点的导数: d y ^ d z 3 = d y ^ d z 2 d y 3 d z 3 = ( − 0.53 ) ∗ y 3 = − 0.20 \frac {d\hat y}{dz_3}=\frac {d\hat y}{dz_2} \frac {dy_3}{dz_3}=(-0.53)*y_3=-0.20 dz3dy^=dz2dy^dz3dy3=(0.53)y3=0.20

(5)小结

通过这几步,我想应该可以很好地解释之后的步骤了:

每次计算 y ^ \hat y y^在当前节点的导数时,都会用到链式法则,使用 y ^ \hat y y^在上一节点的导数值( d y ^ d z i − 1 \frac {d\hat y}{dz_{i-1}} dzi1dy^)和当前节点的导数值( d y i d z i \frac {dy_i}{dz_i} dzidyi),因为节点都是一些最简单的运算,因此求解当前节点的导数值是一件很容易的事,并且得到的形式也非常简单。

注意:可以使用输出值 y i y_i yi去替换导数中某些复杂的部分

以下是反向传播计算图:

1
-0.53
-0.53
-0.20
0.20
0.20
0.20
0.20
0.20
0.40
-0.20
-0.60
-0.40
output
1/x
+1
exp
*-1
+
b
+
*
*
x0
w0
x1
w1
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/405659
推荐阅读
相关标签
  

闽ICP备14008679号