当前位置:   article > 正文

机器学习——反向传播推导与理解(Backpropagation, BP)_机器学习backpropce

机器学习backpropce

什么是反向传播

在训练神经网络的时候,会经常听到的一个名词:Backpropagation,反向传播。那么究竟什么是反向传播?我们知道训练神经网络的过程其实就是寻找一组较优参数的过程。这一过程经常需要用到梯度下降算法,即通过反复迭代,最终得到一组较优的参数。在线性回归中我们已经知道,参数的更新公式可以表示成:
θ j ≔ θ j − α ∂ ∂ θ j J ( θ ) \theta_j\coloneqq\theta_j-\alpha\frac{\partial}{\partial\theta_j}J\left(\boldsymbol{\theta}\right) θj:=θjαθjJ(θ)
其中, J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 J\left(\boldsymbol{\theta}\right)=\frac{1}{2m}\sum_{i=1}^m\left(h_\boldsymbol{\theta}\left(x^{\left(i\right)}\right)-y^{\left(i\right)}\right)^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2
从公式中可以看到,参数更新时用到了微分,具体而言就是针对每个参数,求其偏微分,然后带入到更新公式中即可。这在线性回归中是行得通的,因为往往这类问题涉及到的参数相对来说比较少,就算用人工的方式一个一个推导,似乎也可以得到最终的结果。但是在神经网络中,参数的数量和线性回归相比,根本不在一个量级,涉及到的函数也变成了多层复合函数。此时刚才的方法就行不通了,好在我们有“反向传播”。反向传播算法可以说是求解多层复合函数的所有变量的偏导数的利器,其具体思想就是我们熟知的链式法则求导。下面将作具体介绍。

反向传播理论推导

在讲解反向传播之前,先来回顾一下梯度下降(图片来源为李宏毅2021春机器学习课程的课件)。

在梯度下降中,我们首先会随机初始化一组参数 θ 0 \theta^0 θ0,然后通过对 θ 0 \theta^0 θ0中的每一个参数求偏导,得到由所有的参数构成的向量 ∇ L ( θ ) \nabla{L(\theta)} L(θ),即图中左边部分内容。将其代入到图中右边公式: θ 1 = θ 0 − η ∇ L ( θ 0 ) \theta^1=\theta^0-\eta\nabla{L(\theta^0)} θ1=θ0ηL(θ0),对参数进行更新,得到 θ 1 \theta^1 θ1。重复上述过程,直到得到一组较优的参数 θ \theta θ
由于神经网络通常涉及大量参数,为了更有效的计算梯度,引入了反向传播。

在训练时,我们通常会有一个损失函数,在图中,损失函数是所有样本损失之和。即假设第 n n n个训练样本输入到神经网络后,得到输出 y n y^n yn,同时,这一样本对应的真实值是 y ^ n \hat{y}^n y^n,则 C n C^n Cn表示两者之间的误差,例如我们可以定义 C n = y n − y ^ n C^n=y^n-\hat{y}^n Cn=yny^n,将所有样本的误差相加就可以得到最终的损失函数 L ( θ ) L(\theta) L(θ)。于是求偏导的公式就可以写成图中形式:
∂ L ( θ ) ∂ ω = ∑ n = 1 N ∂ C n ( θ ) ∂ ω \frac{\partial{L(\theta)}}{\partial{\omega}}=\sum_{n=1}^N\frac{\partial{C^n(\theta)}}{\partial{\omega}} ωL(θ)=n=1NωCn(θ)
于是,我们只需要关注某一个样本的 ∂ C n ( θ ) ∂ ω \frac{\partial{C^n(\theta)}}{\partial{\omega}} ωCn(θ)。以一个神经元的计算为例,即上图中的红色三角形部分。

现在假设神经网络的输入为 x 1 , x 2 x_1,x_2 x1x2,对应参数为 ω 1 , ω 2 \omega_1,\omega_2 ω1ω2,偏量为 b b b,则我们可以得到 z = x 1 ω 1 + x 2 ω 2 + b z=x_1\omega_1+x_2\omega_2+b z=x1ω1+x2ω2+b z z z通过激活函数后,再根据一系列类似的计算最终得到输出 y 1 , y 2 y_1,y_2 y1y2。根据链式求导法则,可以得到:
∂ C ∂ ω = ∂ z ∂ ω ∂ C ∂ z \frac{\partial{C}}{\partial{\omega}}=\frac{\partial{z}}{\partial{\omega}}\frac{\partial{C}}{\partial{z}} ωC=ωzzC
接下来就需要分别计算这两部分。其中计算 ∂ z ∂ ω \frac{\partial{z}}{\partial{\omega}} ωz叫做forward pass,计算 ∂ C ∂ z \frac{\partial{C}}{\partial{z}} zC叫做backward pass

Forward pass很简单,就是一次函数的求导:
∂ z ∂ ω 1 = x 1 ∂ z ∂ ω 2 = x 2 \frac{\partial{z}}{\partial{\omega_1}}=x_1 \\[4pt] \frac{\partial{z}}{\partial{\omega_2}}=x_2 ω1z=x1ω2z=x2
从中我们可以得到规律: ∂ z ∂ ω i \frac{\partial{z}}{\partial{\omega_i}} ωiz的结果就是与参数 ω i \omega_i ωi相连接的输入值。例如,在图中与 ω 1 \omega_1 ω1相连的是 x 1 x_1 x1,那么 ∂ z ∂ ω 1 = x 1 \frac{\partial{z}}{\partial{\omega_1}}=x_1 ω1z=x1

我们再用具体的数值作一下理解。现在我们的输入为-1和1,对应着四个参数: ω 11 = 1 , ω 12 = − 1 , ω 21 = − 2 , ω 22 = 1 \omega_{11}=1,\omega_{12}=-1,\omega_{21}=-2,\omega_{22}=1 ω11=1,ω12=1,ω21=2,ω22=1。按照刚才的规律·,我们可以直接得到 ∂ z 2 ∂ ω 22 = − 1 \frac{\partial{z_2}}{\partial{\omega_{22}}}=-1 ω22z2=1,根据图中的值,我们可以计算 z 2 = 1 × ( − 1 ) + ( − 1 ) × 1 + 0 = − 2 z_2=1\times(-1)+(-1)\times1+0=-2 z2=1×(1)+(1)×1+0=2,经过激活函数函数后(假设为sigmoid),得到 a 2 = 1 1 + e − 2 ≈ 0.12 a_2=\frac{1}{1+e^{-2}}\approx0.12 a2=1+e210.12,于是按照同样的方法,我们可以得到 ∂ z ∂ ω = 0.12 \frac{\partial{z}}{\partial{\omega}}=0.12 ωz=0.12
接下来是Backward pass,即计算 ∂ C ∂ z \frac{\partial{C}}{\partial{z}} zC

我们假设 z z z经过的激活函数为sigmoid函数,则我们可以得到 a = σ ( z ) a=\sigma(z) a=σ(z),则
∂ C ∂ z = ∂ a ∂ z ∂ C ∂ a = σ ′ ( z ) ∂ C ∂ a \frac{\partial{C}}{\partial{z}}=\frac{\partial{a}}{\partial{z}}\frac{\partial{C}}{\partial{a}}=\sigma'(z)\frac{\partial{C}}{\partial{a}} zC=zaaC=σ(z)aC
σ ( z ) \sigma(z) σ(z)及其导数的图像为:

此时,我们将 a a a作为输入,进行上述和求 z z z同样的计算,可以得到 z ′ , z ′ ′ z',z'' z,z,那么 ∂ C ∂ a \frac{\partial{C}}{\partial{a}} aC可以写成:
∂ C ∂ a = ∂ z ′ ∂ a ∂ C ∂ z ′ + ∂ z ′ ′ ∂ a ∂ C ∂ z ′ ′ \frac{\partial{C}}{\partial{a}}=\frac{\partial{z'}}{\partial{a}}\frac{\partial{C}}{\partial{z'}}+\frac{\partial{z''}}{\partial{a}}\frac{\partial{C}}{\partial{z''}} aC=azzC+azzC

∂ a ∂ z \frac{\partial{a}}{\partial{z}} za ∂ C ∂ a \frac{\partial{C}}{\partial{a}} aC两部分的结果进行整理,可以得到:
∂ C ∂ z = σ ′ ( z ) ( ω 3 ∂ C ∂ z ′ + ω 4 ∂ C ∂ z ′ ′ ) \frac{\partial{C}}{\partial{z}}=\sigma'(z)\left( \omega_3 \frac{\partial{C}}{\partial{z'}} +\omega_4 \frac{\partial{C}}{\partial{z''}} \right) zC=σ(z)(ω3zC+ω4zC)
从另一个角度看上述公式:

此时,我们的输入是 ∂ C ∂ z ′ \frac{\partial{C}}{\partial{z'}} zC ∂ C ∂ z ′ ′ \frac{\partial{C}}{\partial{z''}} zC,参数是 ω 3 和 ω 4 \omega_3和\omega_4 ω3ω4,和前向传播类似,通过相乘再相加,最后再乘以一个常数 σ ′ ( z ) \sigma'(z) σ(z),我们就可以得到 ∂ C ∂ z \frac{\partial{C}}{\partial{z}} zC具体表达式,这就是反向传播的内核所在,即从相反的角度作和前向传播类似的计算(前向传播需要经过一个激活函数,而反向传播需要乘以一个常数)。
那么现在的问题就只剩下如何计算 ∂ C ∂ z ′ \frac{\partial{C}}{\partial{z'}} zC ∂ C ∂ z ′ ′ \frac{\partial{C}}{\partial{z''}} zC这两项了。下面分情况进行讨论:
case1: z ′ 和 z ′ ′ z'和z'' zz经过激活函数后,直接输出(例如 y 1 = a ′ = s i g m o i d ( z ′ ) y_1=a'=sigmoid(z') y1=a=sigmoid(z))。

那么,此时具体的计算为:
∂ C ∂ z ′ = ∂ y 1 ∂ z ′ ∂ C ∂ y 1 ∂ C ∂ z ′ ′ = ∂ y 2 ∂ z ′ ′ ∂ C ∂ y 2 \frac{\partial{C}}{\partial{z'}}= \frac{\partial{y_1}}{\partial{z'}} \frac{\partial{C}}{\partial{y_1}} \\[4pt] \frac{\partial{C}}{\partial{z''}}= \frac{\partial{y_2}}{\partial{z''}} \frac{\partial{C}}{\partial{y_2}} zC=zy1y1CzC=zy2y2C
∂ C ∂ y 1 \frac{\partial{C}}{\partial{y_1}} y1C ∂ C ∂ y 2 \frac{\partial{C}}{\partial{y_2}} y2C的具体表达式需要根据损失函数 C C C的具体形式决定。
case2: z ′ 和 z ′ ′ z'和z'' zz经过激活函数后,后面仍有内容,即处于中间某一环节。

此时,
∂ C ∂ z ′ = σ ′ ( z ′ ) ( ω 5 ∂ C ∂ z a + ω 6 ∂ C ∂ z b ) \frac{\partial{C}}{\partial{z'}}=\sigma'(z')\left( \omega_5 \frac{\partial{C}}{\partial{z_a}} + \omega_6 \frac{\partial{C}}{\partial{z_b}} \right) zC=σ(z)(ω5zaC+ω6zbC)
于是,我们的任务又变成了求 ∂ C ∂ z a , ∂ C ∂ z b \frac{\partial{C}}{\partial{z_a}},\frac{\partial{C}}{\partial{z_b}} zaC,zbC。如此反复进行下去,直到我们到达输出层。
以上是我们从正向的角度逐一对参数进行分解求导的。如果我们从输出层开始反向计算,相当于将神经网络计算方向反过来,那么问题就会变得很简单

从输出层,我们可以很容易得计算出 ∂ C ∂ z 5 , ∂ C ∂ z 6 \frac{\partial{C}}{\partial{z_5}},\frac{\partial{C}}{\partial{z_6}} z5C,z6C,于是按照上面公式,就可以得到 ∂ C ∂ z 3 , ∂ C ∂ z 4 \frac{\partial{C}}{\partial{z_3}},\frac{\partial{C}}{\partial{z_4}} z3C,z4C,再进行类似得计算,最后就可以得到 ∂ C ∂ z 1 , ∂ C ∂ z 2 \frac{\partial{C}}{\partial{z_1}},\frac{\partial{C}}{\partial{z_2}} z1C,z2C

以上就是反向传播的全部内容!
(主要参考李宏毅机器学习

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/405432
推荐阅读
相关标签
  

闽ICP备14008679号