赞
踩
接触过神经网络的朋友,都会知道正向反向传播理论。本篇博客旨在对其中的理论进行详细的解释。
推荐这本书超级好懂
https://www.ituring.com.cn/book/1921
矩形代表的是数据节点,圆形代表操作节点
我们先看下乘法的正向反向传播过程:
上述代表公式:
w
∗
x
=
z
w*x=z
w∗x=z,这就是一个简单的传播过程。
当我们对
w
w
w进行
∂
z
∂
w
=
x
\frac{\partial z}{\partial w}=x
∂w∂z=x,当我们对
x
x
x进行
∂
z
∂
x
=
w
\frac{\partial z}{\partial x}=w
∂x∂z=w
我们在看下加法的正向反向传播过程:
这个比较简单。
再来个练习,假设:
z
=
w
3
w
2
w
1
x
z=w_3w_2w_1x
z=w3w2w1x,这里为了方便,我没有加b
它的正向反向传播过程(以下m表示乘法):
进入乘法节点的时候,注意把非叶子结点当成是一个整体。
我们可以再对图做一个抽象
上图依然表示的是 w ∗ x = z w*x=z w∗x=z, d z dz dz表示的是上游传过来的梯度(误差)。 d z ∗ w dz*w dz∗w中的 w w w是该环节求导得到的值。
上面是乘法,如果是加法(+)呢?比如 z = x + b z=x+b z=x+b
上图圆形中是个参数,如果是一个函数,我又该怎么画呢?
看这个等式: z = f ( x ) z=f(x) z=f(x),正向反向传播传播图如下:
和前面差不多,要注意 x x x接收到的导数(误差)是 d z ∗ f ′ dz*f' dz∗f′,其中 f ′ f' f′表示 f ′ ( x ) f'(x) f′(x)。
通过上图大家应该更直观的知道正向反向传播过程。还有d后面加个字母,表示的是在该数据节点返回的梯度值(以下称误差值)。比如在上面的图中 x x x数据节点可以怎么写 d x dx dx,也就是说 d x = d z ∗ f ′ dx=dz*f' dx=dz∗f′。
下面进行理论上的推导:
理论推导来源https://zhuanlan.zhihu.com/p/33876102有兴趣的可以看一下。
定义输入: a [ l − 1 ] a^{[l-1]} a[l−1] ; 输出: a [ l ] a^{[l]} a[l]; 参数: w [ l ] w^{[l]} w[l], b [ l ] b^{[l]} b[l] ; 缓存变量(没有经过激活函数之前): z [ l ] z^{[l]} z[l]。
正向传播过程:
z
[
l
]
=
w
[
l
]
a
[
l
−
1
]
+
b
[
l
]
z^{[l]}=w^{[l]}a^{[l-1]}+b^{[l]}
z[l]=w[l]a[l−1]+b[l]
a
[
l
]
=
g
[
l
]
(
z
[
l
]
)
a^{[l]}=g^{[l]}(z^{[l]})
a[l]=g[l](z[l]),
g
g
g表示激活函数
根据上图,反向传播过程为:
d
z
[
l
]
=
d
a
[
l
]
∗
g
[
l
]
′
(
z
[
l
]
)
dz^{[l]}=da^{[l]}*g^{[l]'}(z^{[l]})
dz[l]=da[l]∗g[l]′(z[l])
d
w
[
l
]
=
d
z
[
l
]
∗
a
[
l
−
1
]
dw^{[l]}=dz^{[l]}*a^{[l-1]}
dw[l]=dz[l]∗a[l−1]
d
b
[
l
]
=
d
z
[
l
]
db^{[l]}=dz^{[l]}
db[l]=dz[l]
d
a
[
l
−
1
]
=
w
[
l
]
∗
d
z
[
l
]
da^{[l-1]}=w^{[l]}*dz^{[l]}
da[l−1]=w[l]∗dz[l]
上式第一行和第四行相结合:
d
z
[
l
]
=
w
[
l
+
1
]
∗
d
z
[
l
+
1
]
∗
g
[
l
]
′
(
z
[
l
]
)
dz^{[l]}=w^{[l+1]}*dz^{[l+1]}*g^{[l]'}(z^{[l]})
dz[l]=w[l+1]∗dz[l+1]∗g[l]′(z[l])
当然我们知道全连接神经网络是一个矩阵
W
W
W代表中间处理过程,而不是一个
w
w
w。在这里,我们需要了解对矩阵进行求导的公式:
我们关注一下第一个公式:
x
=
(
x
1
,
x
2
,
.
.
.
x
N
)
T
x=(x_1,x_2,...x_N)^T
x=(x1,x2,...xN)T
以下我们使用了第一个公式:
d
z
[
l
]
=
d
a
[
l
]
∗
g
[
l
]
′
(
z
[
l
]
)
dz^{[l]}=da^{[l]}*g^{[l]'}(z^{[l]})
dz[l]=da[l]∗g[l]′(z[l])
d
w
[
l
]
=
d
z
[
l
]
∗
a
[
l
−
1
]
dw^{[l]}=dz^{[l]}*a^{[l-1]}
dw[l]=dz[l]∗a[l−1]
d
b
[
l
]
=
d
z
[
l
]
db^{[l]}=dz^{[l]}
db[l]=dz[l]
d
a
[
l
−
1
]
=
W
[
l
]
T
∗
d
z
[
l
]
da^{[l-1]}=W^{[l]T}*dz^{[l]}
da[l−1]=W[l]T∗dz[l]
第一行和第四行推导可得:
d
z
[
l
]
=
W
[
l
+
1
]
T
∗
d
z
[
l
+
1
]
∗
g
[
l
]
′
(
z
[
l
]
)
dz^{[l]}=W^{[l+1]T}*dz^{[l+1]}*g^{[l]'}(z^{[l]})
dz[l]=W[l+1]T∗dz[l+1]∗g[l]′(z[l])
上式反映了
d
z
[
l
]
dz^{[l]}
dz[l]和
d
z
[
l
+
1
]
dz^{[l+1]}
dz[l+1]之间的关系。
如果数据有m条,公式如下:
求1/m是为了得到平均值。
参考文献:
https://zhuanlan.zhihu.com/p/33876102
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。