当前位置:   article > 正文

BP(全连接网络)正向反向传播理论推导_bp神经网络正向传递的公式

bp神经网络正向传递的公式

  接触过神经网络的朋友,都会知道正向反向传播理论。本篇博客旨在对其中的理论进行详细的解释。
推荐这本书超级好懂
https://www.ituring.com.cn/book/1921

画图理解

矩形代表的是数据节点,圆形代表操作节点
我们先看下乘法的正向反向传播过程:

w
*
x
z

上述代表公式: w ∗ x = z w*x=z wx=z,这就是一个简单的传播过程。
当我们对 w w w进行 ∂ z ∂ w = x \frac{\partial z}{\partial w}=x wz=x,当我们对 x x x进行 ∂ z ∂ x = w \frac{\partial z}{\partial x}=w xz=w

1
1*x
1*w
z
*
w
x

我们在看下加法的正向反向传播过程:

w
+
x
z
1
1*1
1*1
z
+
w
x

这个比较简单。
再来个练习,假设:
z = w 3 w 2 w 1 x z=w_3w_2w_1x z=w3w2w1x,这里为了方便,我没有加b
它的正向反向传播过程(以下m表示乘法):

*
*
*
x
w1
w2
w3
z
1
1*w2*w1*x
1*w3
1*w3*w1*x
1*w3*w2
1*w3*w2*x
1*w3*w2*w1
*
*
*
z
w3
w2
w1
x

进入乘法节点的时候,注意把非叶子结点当成是一个整体。
我们可以再对图做一个抽象

dz
dz*w
x
*w
z

  上图依然表示的是 w ∗ x = z w*x=z wx=z d z dz dz表示的是上游传过来的梯度(误差)。 d z ∗ w dz*w dzw中的 w w w是该环节求导得到的值。

上面是乘法,如果是加法(+)呢?比如 z = x + b z=x+b z=x+b

dz
dz
x
+b
z

  上图圆形中是个参数,如果是一个函数,我又该怎么画呢?

  看这个等式: z = f ( x ) z=f(x) z=f(x),正向反向传播传播图如下:

dz
dz*f'
x
f
z

   和前面差不多,要注意 x x x接收到的导数(误差)是 d z ∗ f ′ dz*f' dzf,其中 f ′ f' f表示 f ′ ( x ) f'(x) f(x)

理论推导

  通过上图大家应该更直观的知道正向反向传播过程。还有d后面加个字母,表示的是在该数据节点返回的梯度值(以下称误差值)。比如在上面的图中 x x x数据节点可以怎么写 d x dx dx,也就是说 d x = d z ∗ f ′ dx=dz*f' dx=dzf

下面进行理论上的推导:
理论推导来源https://zhuanlan.zhihu.com/p/33876102有兴趣的可以看一下。

定义输入: a [ l − 1 ] a^{[l-1]} a[l1] ; 输出: a [ l ] a^{[l]} a[l]; 参数: w [ l ] w^{[l]} w[l], b [ l ] b^{[l]} b[l] ; 缓存变量(没有经过激活函数之前): z [ l ] z^{[l]} z[l]

正向传播过程:
z [ l ] = w [ l ] a [ l − 1 ] + b [ l ] z^{[l]}=w^{[l]}a^{[l-1]}+b^{[l]} z[l]=w[l]a[l1]+b[l]
a [ l ] = g [ l ] ( z [ l ] ) a^{[l]}=g^{[l]}(z^{[l]}) a[l]=g[l](z[l]) g g g表示激活函数

根据上图,反向传播过程为:
d z [ l ] = d a [ l ] ∗ g [ l ] ′ ( z [ l ] ) dz^{[l]}=da^{[l]}*g^{[l]'}(z^{[l]}) dz[l]=da[l]g[l](z[l])
d w [ l ] = d z [ l ] ∗ a [ l − 1 ] dw^{[l]}=dz^{[l]}*a^{[l-1]} dw[l]=dz[l]a[l1]
d b [ l ] = d z [ l ] db^{[l]}=dz^{[l]} db[l]=dz[l]
d a [ l − 1 ] = w [ l ] ∗ d z [ l ] da^{[l-1]}=w^{[l]}*dz^{[l]} da[l1]=w[l]dz[l]
上式第一行和第四行相结合:
d z [ l ] = w [ l + 1 ] ∗ d z [ l + 1 ] ∗ g [ l ] ′ ( z [ l ] ) dz^{[l]}=w^{[l+1]}*dz^{[l+1]}*g^{[l]'}(z^{[l]}) dz[l]=w[l+1]dz[l+1]g[l](z[l])

当然我们知道全连接神经网络是一个矩阵 W W W代表中间处理过程,而不是一个 w w w。在这里,我们需要了解对矩阵进行求导的公式:
在这里插入图片描述
我们关注一下第一个公式:
x = ( x 1 , x 2 , . . . x N ) T x=(x_1,x_2,...x_N)^T x=(x1,x2,...xN)T

以下我们使用了第一个公式:
d z [ l ] = d a [ l ] ∗ g [ l ] ′ ( z [ l ] ) dz^{[l]}=da^{[l]}*g^{[l]'}(z^{[l]}) dz[l]=da[l]g[l](z[l])
d w [ l ] = d z [ l ] ∗ a [ l − 1 ] dw^{[l]}=dz^{[l]}*a^{[l-1]} dw[l]=dz[l]a[l1]
d b [ l ] = d z [ l ] db^{[l]}=dz^{[l]} db[l]=dz[l]
d a [ l − 1 ] = W [ l ] T ∗ d z [ l ] da^{[l-1]}=W^{[l]T}*dz^{[l]} da[l1]=W[l]Tdz[l]
第一行和第四行推导可得:
d z [ l ] = W [ l + 1 ] T ∗ d z [ l + 1 ] ∗ g [ l ] ′ ( z [ l ] ) dz^{[l]}=W^{[l+1]T}*dz^{[l+1]}*g^{[l]'}(z^{[l]}) dz[l]=W[l+1]Tdz[l+1]g[l](z[l])
上式反映了 d z [ l ] dz^{[l]} dz[l] d z [ l + 1 ] dz^{[l+1]} dz[l+1]之间的关系。

如果数据有m条,公式如下:
求1/m是为了得到平均值。
在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述
在这里插入图片描述
参考文献:
https://zhuanlan.zhihu.com/p/33876102

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/1019860
推荐阅读
相关标签
  

闽ICP备14008679号