赞
踩
本文解读的是一篇来自2015年的一篇文字识别论文 [ 1 ] ^{[1]} [1]。里面的CTC Loss相关内容的理解有一定的挑战性,本文是对自己当前理解的一份记录。
首先,先看一下CRNN的前向推理过程,来了解其文字识别的整体流程,如下图所示。
action1 : 一张
10
∗
40
∗
3
10*40*3
10∗40∗3的文字图片块,经过CNN层特征提取,下采样为
1
∗
10
∗
512
1*10*512
1∗10∗512的特征图。高度压缩为1,宽度下采样4倍,每一个特征是维度为512。
action2 : 通过深度双向LSTM网络将
10
∗
512
10*512
10∗512的Feature sequence做了一个特征的进一步转换和提取变为一个
10
∗
(
26
+
1
)
10*(26+1)
10∗(26+1)的预测分布概率矩阵。这里使用双向LSTM是期待特征序列做更加充分的贯通,例如在预测“state”
中“a”的时候既采纳了“st”的信息又采纳了"te"的信息。
action3 : 通过转录层操作,根据分布概率矩阵可以获得最终的预测结果。例如
a
r
g
m
a
x
(
y
,
d
i
m
=
1
)
argmax(y, dim=1)
argmax(y,dim=1),可以得到预测值的初始形态:
- | s | - | t | - | a | a | t | t | e |
---|
然后合并成为最终的预测结果: state。合并的基本规则是:
前向推理过程比较明晰,然而,训练过程会遇到如下疑惑,如果按照上述例子,我们会把这一个序列作为预测概率矩阵 y y y的GT。然后就相当于并行做10个(26+1)类的分类任务学习。
- | s | - | t | - | a | a | t | t | e |
---|
这样的问题在于:
抛出问题1 : 对于同一张图片可以有不同的GT方案。
例如,下列序列作为“state”对应的的分布概率矩阵GT,也是不违背任何逻辑的。事实上,这种不违背逻辑的方案还有很多。
- | - | s | t | - | a | a | t | t | e |
---|
尝试解决问题1 :尝试列举出所有可能的方案, 在训练的过程中随机给出一个gt。
这样做理论上是可行的。但是会有一个时间复杂度问题。采用暴力求解的方法罗列出所有可能是
(
26
+
1
)
10
(26+1)^{10}
(26+1)10。即使模型的最大预测字符串长度为10,仅为26个字母这种简易场景,这种级别的时间复杂度是不可以接受的。
但不管怎样,至此,上述整个过程是一种理论上完备的训练、推理流程,只不过训练速度会很慢(或者说慢到不可接受)。
CTC Loss 或者说CTC 算法是来源于HMM(隐马尔可夫),用一句话总结:就是通过“动态规划”算法来替代“暴力求解”来解决所有方案的概率和。并将问题的loss定义为一个最大似然问题:使得学到尽可能的网络参数使得 p ( l x ) p(\frac{l}{x}) p(xl)最大,论文中将loss定义为 − l o g ( p ( l x ) ) -log(p(\frac{l}{x})) −log(p(xl))。CTC的过程可以总结为以下四步骤:
以下例子来自于torch官网。为了便于描述,将参数的规模进行了缩小。
>>> import torch >>> import torch.nn as nn >>> # Target are to be padded >>> T = 5 # torch 官网为50 # Input sequence length >>> C = 7 # torch 官网为20 # Number of classes (including blank, 0 class) >>> N = 1 # torch 官网为16 # Batch size >>> S = 3 # 30 # Target sequence length of longest target in batch (padding length) >>> S_min = 2 # 10 # Minimum target length, for demonstration purposes >>> >>> # Initialize random batch of input vectors, for *size = (T,N,C) >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() >>> >>> # Initialize random batch of targets (0 = blank, 1:C = classes) >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long) >>> >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward()
其中,input表示的是预测概率的
l
o
g
log
log矩阵:
预测概率矩阵
y
=
e
i
n
p
u
t
y=e^{input}
y=einput,如下所示:
举一个简单的例子target为
f
e
fe
fe: 根据1.2.1节中步骤3可以根据
y
y
y矩阵动态递归算得
α
s
(
t
)
α_s(t)
αs(t)矩阵:
根据1.2.2节,步骤4可以根据
y
y
y矩阵动态递归算得
β
s
(
t
)
β_s(t)
βs(t)矩阵:
根据1.2.1节中步骤1可以根据α矩阵和β矩阵计算得到两者的联合概率:
l
o
s
s
=
−
l
o
g
(
0.001247115
/
3
)
=
3.38
loss = -log(0.001247115/3)=3.38
loss=−log(0.001247115/3)=3.38, 与pytorch的输出一致。
本文主要以CTC是如何做的角度来写,并通过pytorch和自己手算结果的对比来验证自己理解的正确性。后续如果有新的理解,应该会补充上一些更多的细节。
[1] 原始论文:An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition
[2]Pytorch ctc demo example
[3]公式
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。