赞
踩
本文目的是针对《Overcoming catastrophic forgetting in neural networks》文中的EWC方法提到的Fisher矩阵进行相关知识调研和记录。
首先针对防止灾难性遗忘的方法EWC中提到的Fisher矩阵,其引用的自然梯度下降方法如下:
参考论文《Revisiting natural gradient for deep networks》
考虑一个密度函数族
F
\mathcal{F}
F将参数
θ
∈
R
P
\theta \in \mathbb{R}^{P}
θ∈RP映射到概率密度函数
p
(
z
)
p(\bold{z})
p(z),
p
:
R
N
→
[
0
,
∞
)
p:\mathbb{R}^N \rightarrow [0,\infin)
p:RN→[0,∞),其中
z
∈
R
N
z \in \mathbb{R}^N
z∈RN。
θ
∈
R
P
\theta \in \mathbb{R}^{P}
θ∈RP的任何选择定义了一个特定的密度函数
p
θ
(
z
)
=
F
(
θ
)
(
z
)
p_{\theta}(\bold{z})=\mathcal{F}(\theta)(\bold{z})
pθ(z)=F(θ)(z),并通过考虑所有可能的
θ
\theta
θ值,探索了集合函数流形
F
\mathcal{F}
F。
在其无穷小的形式中,KL 散度的行为类似于距离度量,因此我们可以定义附近密度函数之间的相似性度量。因此
F
\bold{F}
F是一个黎曼流形,其度量由以下等式中定义的 Fisher 信息矩阵
F
θ
\bold{F}_{\theta}
Fθ给出:
F
θ
=
E
z
[
(
∇
log
p
θ
(
z
)
)
T
(
∇
log
p
θ
(
z
)
)
]
(1)
\bold{F}_{\theta}=\mathbb{E}_z[(\nabla \log p_{\theta}(\bold{z}))^T(\nabla \log p_{\theta}(\bold{z}))] \tag{1}
Fθ=Ez[(∇logpθ(z))T(∇logpθ(z))](1)
也就是说,在某个点
θ
\theta
θ周围局部,该度量定义了向量
u
u
u和
v
v
v之间的内积:
<
u
,
v
>
θ
=
u
F
θ
v
(2)
<u,v>_{\theta}=u\bold{F}_{\theta}v \tag{2}
<u,v>θ=uFθv(2)
因此,它提供了距离的局部度量。假设该矩阵对
θ
\theta
θ的隐式依赖性,接下来为 Fisher 信息矩阵编写
F
\bold{F}
F。
给定由
θ
\theta
θ参数化的损失函数
L
\mathcal{L}
L,自然梯度下降试图通过根据KL发散面的局部曲率校正L的梯度来沿着流形移动,即在方向上移动给定的距离
∇
N
L
(
θ
)
\nabla_N \mathcal{L}(\theta)
∇NL(θ):
∇
N
L
(
θ
)
=
d
e
f
∇
L
(
θ
)
E
z
[
(
∇
log
p
θ
(
z
)
)
T
(
∇
log
p
θ
(
z
)
)
]
−
1
=
d
e
f
∇
L
(
θ
)
F
−
1
对自然梯度使用
∇
N
\nabla_N
∇N,对梯度使用
∇
\nabla
∇,
F
\bold{F}
F是Fisher信息矩阵给出的度量矩阵。在这项工作中,偏导数通常表示为行向量。我们可以通过将自然梯度下降定义为算法来导出这一结果,该算法在每一步都试图选择下降方向,使我们的模型中引起的变化量(在KL意义上)为某个给定值。特别地,当
p
θ
p_θ
pθ和
p
θ
+
∆
θ
p_{θ+∆θ}
pθ+∆θ之间的KL散度的二阶泰勒级数必须是常数时,我们寻找一个小的
∆
θ
∆θ
∆θ,它最小化
L
\mathcal{L}
L的一阶泰勒展开:
arg min
∆
θ
L
(
θ
+
∆
θ
)
s
.
t
.
K
L
(
p
θ
∣
∣
p
θ
+
∆
θ
)
=
c
o
n
s
t
.
使用此约束,确保以恒定速度沿函数流形移动,而不会因其曲率而减慢速度。这也使得学习对模型的重新参数化具有局部鲁棒性,因为
p
p
p的函数行为不取决于它是如何参数化的。
假设
∆
θ
→
0
∆θ→ 0
∆θ→0,我们可以通过其二阶泰勒级数来近似KL散度:
K
L
(
p
θ
∣
∣
p
θ
+
∆
θ
)
≈
(
E
z
[
log
p
θ
]
−
E
z
[
log
p
θ
]
)
−
E
z
[
∇
log
p
θ
(
z
)
]
∆
θ
−
1
2
∆
θ
T
E
z
[
∇
2
log
p
θ
]
∆
θ
=
1
2
∆
θ
T
E
z
[
−
∇
2
log
p
θ
]
∆
θ
=
1
2
∆
θ
T
F
∆
θ
第一项抵消,由于
E
z
[
∇
log
p
θ
(
z
)
]
=
∑
z
(
p
θ
(
z
)
1
p
θ
(
z
)
∂
p
θ
(
z
)
∂
θ
)
=
∂
∂
θ
(
∑
θ
p
θ
(
z
)
)
=
∂
1
∂
θ
=
0
\mathbb{E}_z[\nabla \log p_{\theta}(\bold{z})]= \sum_{\bold{z}}(p_{\theta}(\bold{z}){\frac{1}{p_\theta(\bold{z})} \frac{\partial p_\theta(\bold{z})}{\partial \theta}})=\frac{\partial}{\partial \theta}(\sum_{\theta} p_{\theta}(\bold{z}))=\frac{\partial 1}{\partial \theta}=0
Ez[∇logpθ(z)]=∑z(pθ(z)pθ(z)1∂θ∂pθ(z))=∂θ∂(∑θpθ(z))=∂θ∂1=0,只保留最后一项。另一方面Fisher信息矩阵形式可以通过代数运算从Hessian的期望值获得。
现在将等式 (5) 表示为拉格朗日函数,其中 KL 散度由 (6) 和
L
(
θ
+
∆
θ
)
\mathcal{L}(θ+∆θ)
L(θ+∆θ)通过其一阶泰勒级数
L
(
θ
)
+
∇
L
(
θ
)
∆
θ
\mathcal{L}(θ)+\nabla \mathcal{L}(θ)∆θ
L(θ)+∇L(θ)∆θ近似:
L
(
θ
)
+
∇
L
(
θ
)
∆
θ
+
1
2
λ
∆
θ
T
F
∆
θ
(7)
\mathcal{L}(θ)+\nabla \mathcal{L}(θ)∆θ+\frac{1}{2}\lambda ∆θ^T \bold{F} ∆θ \tag{7}
L(θ)+∇L(θ)∆θ+21λ∆θTF∆θ(7)
求解
∆
θ
∆θ
∆θ的方程(7),得到自然梯度下降公式(4)。得到自然梯度的
2
1
λ
2\frac{1}{\lambda}
2λ1倍的标量因子。我们将此标量折叠到学习率中,现在还控制我们在保持
p
θ
p_θ
pθ 和
p
θ
+
Δ
θ
p_{θ+Δθ}
pθ+Δθ 之间的 KL 距离的权重。我们使用的近似值仅在
θ
θ
θ左右有意义:在 Schul(2012) 研究中,表明采取大步骤可能会损害收敛。通过使用阻尼(即设置
θ
θ
θ周围的信任区域)和正确选择学习率来处理此类问题。
提示:这里的分析结合了“通义千问”的回答
2.1 自然梯度的几何解释:
在优化过程中,我们希望找到使目标函数
L
(
θ
)
\mathcal{L}(θ)
L(θ) 最小化的参数
θ
θ
θ。标准梯度下降是在欧几里得空间中工作,沿梯度
∇
θ
L
(
θ
)
\nabla_\theta \mathcal{L}(θ)
∇θL(θ)的方向是最陡峭下降的方向。然而,在处理概率分布参数时,参数空间往往具有非欧几里得几何结构。Fisher矩阵可以用来定义参数空间上的一个Riemannian度量,即参数间的“自然”距离。在这种几何框架下,自然梯度
G
n
a
t
(
θ
)
G_{nat}(θ)
Gnat(θ) 是在这个参数空间中沿着目标函数最陡峭下降方向的向量。
2.2 Fisher矩阵的物理意义:
Fisher矩阵
I
(
θ
)
I(θ)
I(θ) 可以看作是参数
θ
θ
θ在给定概率模型
p
(
x
∣
θ
)
p(x∣θ)
p(x∣θ)下的信息含量。它量化了当观测数据
x
x
x 固定时,参数
θ
θ
θ的微小变动对对数似然函数
log
p
(
x
∣
θ
)
\log p(x∣θ)
logp(x∣θ)的影响。Fisher矩阵的每个元素
I
i
j
(
θ
)
I_{ij}(θ)
Iij(θ)表示参数
θ
i
θ_i
θi和
θ
j
θ_j
θj之间的二阶偏导数期望,即参数间变化的局部相关性。Fisher矩阵是对称的,且非负定,其逆矩阵
I
(
θ
)
−
1
I(θ)^{-1}
I(θ)−1描述了参数估计的协方差矩阵的逆,反映了参数估计的不确定性。
2.3 Cramér-Rao Bound的关联:
Fisher信息矩阵与参数估计的精度有着直接联系。Cramér-Rao Bound (CRB) 是一个统计学的基本定理,它指出在无偏估计的情况下,任何估计量的协方差矩阵的逆(即精度矩阵)至少要等于Fisher信息矩阵。换句话说,Fisher信息矩阵的逆给出了参数估计误差协方差的下界。因此,参数的Fisher信息量越大,其估计的不确定性理论上就越小,即该参数越重要,我们能更精确地估计它。
2.4 Fisher矩阵与KL散度的关系:
自然梯度与Fisher信息矩阵之间的紧密联系源于它们与Kullback-Leibler (KL) 散度的关系。KL散度是衡量两个概率分布
p
p
p和
q
q
q之间差异的一个常用指标。在自然梯度的背景下,我们关心的是参数
θ
θ
θ的微小变化如何影响模型分布
p
(
x
∣
θ
)
p(x∣θ)
p(x∣θ)。当
θ
θ
θ改变为
θ
+
Δ
θ
θ+Δθ
θ+Δθ时,KL散度
D
K
L
(
p
(
x
∣
θ
)
∣
∣
p
(
x
∣
θ
+
Δ
θ
)
)
D_{KL}(p(x∣θ)∣∣p(x∣θ+Δθ))
DKL(p(x∣θ)∣∣p(x∣θ+Δθ))可以展开为泰勒级数,其中二阶项与Fisher矩阵相关:
D
K
L
(
p
(
x
∣
θ
)
∣
∣
p
(
x
∣
θ
+
Δ
θ
)
)
≈
1
2
∆
θ
T
F
∆
θ
D_{KL}(p(x∣θ)∣∣p(x∣θ+Δθ)) \approx \frac{1}{2}∆θ^T \bold{F} ∆θ
DKL(p(x∣θ)∣∣p(x∣θ+Δθ))≈21∆θTF∆θ
这意味着Fisher矩阵描述了参数变化
Δ
θ
Δθ
Δθ对KL散度的二次贡献。在优化过程中,我们希望最小化KL散度(或者等价地,最大化对数似然),自然梯度的方向正是使得KL散度相对于参数变化
Δ
θ
Δθ
Δθ最陡峭下降的方向。
EWC防止灾难性遗忘方法的训练目标如下:
L
(
θ
)
=
L
B
(
θ
)
+
∑
i
λ
2
F
i
(
θ
i
−
θ
A
,
i
∗
)
2
\mathcal{L}(\theta)=\mathcal{L}_B(\theta)+\sum_i \frac{\lambda}{2}F_i(\theta_i-\theta^*_{A,i})^2
L(θ)=LB(θ)+i∑2λFi(θi−θA,i∗)2
但是面临以下几个问题:
3.1 为什么用fisher矩阵,不用一阶导:
一阶导数仅提供局部梯度信息: Score函数给出的是似然函数在参数空间中的梯度,它指示了似然函数增大的方向,用于找到最大似然估计。然而,梯度仅描述了参数空间中某一点的线性趋势,无法捕捉非线性效应和参数间的相互依赖关系,这些非线性效应和依赖关系对模型的整体行为至关重要。
二阶矩反映曲率和局部稳定性: Fisher信息矩阵包含了对数似然函数的二阶偏导数,它刻画了似然函数在参数空间中的曲率。曲率反映了模型对参数微小变动的响应程度,即参数敏感性。高的曲率意味着参数稍有变化就会导致似然函数显著变化,这对应于参数对模型输出具有高度影响力。此外,曲率还与参数估计的局部稳定性相关,即在估计过程中参数是否容易受到噪声或数据扰动的影响。
统计推断的必要性: 在统计推断中,我们通常关心的是参数的置信区间和估计的精度,这些都与参数估计的方差(或协方差)紧密相关。一阶导数无法直接提供这些信息,而Fisher信息矩阵的逆恰好给出了参数估计误差协方差的下界(Cramér-Rao Bound),这与实际的推断需求直接对应。
3.2 为什么EWC是乘上参数之间的差:
EWC正则化项中的差值
(
θ
i
−
θ
i
∗
)
(\theta_i -\theta_i^*)
(θi−θi∗)表示当前参数与旧任务最优参数之间的差异,平方后乘以Fisher信息量
F
i
F_i
Fi,用来衡量这种差异对于旧任务的影响。大的Fisher信息量意味着参数对于旧任务非常关键,因此当该参数偏离旧任务最优值时,正则化项会施加较大的惩罚,限制其在学习新任务时的变动。
3.3 Fisher矩阵的代码如何实现(只需要计算对角线即可):
def getFisherDiagonal(self, train_loader): fisher = { n: torch.zeros(p.shape).to(self._device) for n, p in self._network.named_parameters() if p.requires_grad } self._network.train() optimizer = optim.SGD(self._network.parameters(), lr=lrate) for i, (_, inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(self._device), targets.to(self._device) logits = self._network(inputs)["logits"] loss = torch.nn.functional.cross_entropy(logits, targets) optimizer.zero_grad() loss.backward() for n, p in self._network.named_parameters(): if p.grad is not None: fisher[n] += p.grad.pow(2).clone() for n, p in fisher.items(): fisher[n] = p / len(train_loader) fisher[n] = torch.min(fisher[n], torch.tensor(fishermax)) return fisher
本部分内容主要总结了Fisher矩阵的理论知识,为下一步的相关工作进行知识铺垫。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。