当前位置:   article > 正文

文本识别网络CRNN

crnn

简介

CRNN,全称Convolutional Recurrent Neural Network,卷积循环神经网络
它是一种基于图像的序列识别网络,可以对不定长的文字序列进行端到端的识别。
它集成了卷积神经网络(CNN)和循环神经网络(RNN)的优点。

网络结构

CRNN网络结构由CNN + BLSTM + 转录层(CTC Loss)组成。
在这里插入图片描述

CNN层

CNN层可以采用VGG网络,如下所示特征提取过程中用了7次Convolution,4次MaxPooling层,最后两次的池化核大小由2 x 2变为1 x 2,所以图像经过最后一层MaxPooling层后特征图维度变为:
高度方向减半4次(24),变为原来的1/16
宽度方向减半2次(22),变为原来的1/4
但是,最后一层Convolution,使高宽又减半一次,所以:
高度方向变为原图的1/32
宽度方向变为原图的1/8
在这里插入图片描述

LSTM

RNN层采用双向LSTM层。因为:
LSTM相对于RNN可以防止训练时梯度的消失;
BLSTM相对于LSTM,序列的前向信息和后向信息都可以得到照顾。
在这里插入图片描述

CTC Loss

转录层是根据每帧的预测找到具有最高概率组合的标签序列。

CTC(Connectionist Temporal Classification)损失函数代替Softmax Loss,来对 CNN 和 RNN 进行端到端的联合训练,CTC引入了blank字符,可以解决不定长的文字对齐问题。

对于LSTM,输入为x,则输出为l的概率为:
在这里插入图片描述
CTC的训练过程,其实就是通过求P (l|x) 对于LSTM的参数w的梯度,来更新LSTM的过程。

CTC通过用一种前向-后向算法(The CTC Forward-Backward Algorithm),其和HMM中的forward-backward算法类似,来计算P (l|x) 。

LSTM输入x,经过softmax,输出y的后验概率矩阵,y输入CTC,P (l|x)与Forward和Backward递推公式之间的关系为:
在这里插入图片描述
求导计算梯度,用来更新LSTM的参数:
在这里插入图片描述

代码实现

CNN层

inputs = Input(shape=(picture_height, picture_width, 1), name='pic_inputs') # H×W×1 32*128*1
x = Conv2D(64, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_1')(inputs) # 32*128*64 
x = BatchNormalization(name="BN_1")(x)
x = Activation("relu", name="relu_1")(x)
x = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid', name='maxpl_1')(x) # 16*64*64

x = Conv2D(128, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_2')(x) # 16*64*128
x = BatchNormalization(name="BN_2")(x)
x = Activation("relu", name="relu_2")(x)
x = MaxPooling2D(pool_size=(2,2), strides=2, padding='valid', name='maxpl_2')(x) # 8*32*128

x = Conv2D(256, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_3')(x)  # 8*32*256
x = BatchNormalization(name="BN_3")(x)
x = Activation("relu", name="relu_3")(x)
x = Conv2D(256, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_4')(x) # 8*32*256
x = BatchNormalization(name="BN_4")(x)
x = Activation("relu", name="relu_4")(x)
x = MaxPooling2D(pool_size=(2,1), strides=(2,1), name='maxpl_3')(x) # 4*32*256

x = Conv2D(512, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_5')(x) # 4*32*512
x = BatchNormalization(axis=-1, name='BN_5')(x)
x = Activation("relu", name='relu_5')(x)
x = Conv2D(512, (3,3), strides=(1,1), padding="same", kernel_initializer=initializer, use_bias=True, name='conv2d_6')(x) # 4*32*512
x = BatchNormalization(axis=-1, name='BN_6')(x)
x = Activation("relu", name='relu_6')(x)
x = MaxPooling2D(pool_size=(2,1), strides=(2,1), name='maxpl_4')(x) # 2*32*512

x = Conv2D(512, (2,2), strides=(1,1), padding='same', activation='relu', kernel_initializer=initializer, use_bias=True, name='conv2d_7')(x) # 2*32*512
x = BatchNormalization(name="BN_7")(x)
x = Activation("relu", name="relu_7")(x)
conv_otput = MaxPooling2D(pool_size=(2, 1), name="conv_output")(x) # 1*32*512
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

Map2Sequence层

x = Permute((2, 3, 1), name='permute')(conv_otput) # 32*512*1
rnn_input = TimeDistributed(Flatten(), name='for_flatten_by_time')(x) # 32*512
  • 1
  • 2

RNN层

y = Bidirectional(LSTM(256, kernel_initializer=initializer, return_sequences=True), merge_mode='sum', name='LSTM_1')(rnn_input) # 32*512
y = BatchNormalization(name='BN_8')(y)
  • 1
  • 2

CTC Loss

y_pred = Dense(num_classes, activation='softmax', name='y_pred')(y)
y_true = Input(shape=[max_label_length], name='y_true')
y_pred_length = Input(shape=[1], name='y_pred_length')
y_true_length = Input(shape=[1], name='y_true_length')

ctc_loss_output = Lambda(K.ctc_batch_cost(y_true, y_pred, pred_length, label_length), output_shape=(1,), name='ctc_loss_output')([y_true, y_pred, y_pred_length, y_true_length])
model = Model(inputs=[y_true, inputs, y_pred_length, y_true_length], outputs=ctc_loss_output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/219311
推荐阅读
相关标签
  

闽ICP备14008679号