当前位置:   article > 正文

夯实基础系列:CRNN核心代码_crnn源码

crnn源码
引言
  • CRNN是经典的文本识别算法,这里主要用来夯实基础,掌握CRNN基本原理以及PyTorch实现。
基本原理
  • CRNN采取的架构是CNN+RNN+CTC,
    • CNN:使用深度CNN,对输入图像提取特征,得到特征图
    • RNN:使用双向RNN对特征序列进行预测,对序列中每个特征向量进行学习,并输出预测标签
    • CTC:使用CTC损失,把从循环层获取的一系列标签分布转换为最终的标签序列
核心代码实现(可直接复制运行)
  • torch.nn.CTCLoss()的输入必须要经过logsoftmax函数的
  • 以下代码包括:训练和损失,推理和解码4个部分
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import string

import torch
from torch import nn
import torch.nn.functional as F


class CRNN(nn.Module):
    def __init__(self, img_height, input_channel, n_class, hidden_size):
        super().__init__()

        if img_height % 16 != 0:
            raise ValueError('img_height has to be a multiple of 16')

        kernel_size = [3, 3, 3, 3, 3, 3, 2]
        padding_size = [1, 1, 1, 1, 1, 1, 0]
        stride = [1, 1, 1, 1, 1, 1, 1]
        channel = [64, 128, 256, 256, 512, 512, 512]

        def conv_relu(i, batchNormalization=False):
            in_channels = input_channel if i == 0 else channel[i - 1]
            out_channels = channel[i]
            cnn.add_module(f'conv{i}',
                           nn.Conv2d(in_channels, out_channels,
                                     kernel_size[i],
                                     stride[i],
                                     padding_size[i]))

            if batchNormalization:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(out_channels))
            cnn.add_module(f'relu{i}', nn.ReLU(True))

        # x: 1 x 32 x 320
        cnn = nn.Sequential()
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(2, 2))  # 64x16x160

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(2, 2))  # 128x8x80

        conv_relu(2, True)
        conv_relu(3)
        cnn.add_module('pooling2',
                       nn.MaxPool2d(kernel_size=(2, 2),
                                    stride=(2, 1),
                                    padding=(0, 1)))  # 256x4x81

        conv_relu(4, True)
        conv_relu(5)
        cnn.add_module('pooling3',
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x82
        conv_relu(6, True)  # 512x1x81

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, n_class)
        )

    def forward(self, x):
        cnn_feature = self.cnn(x)

        # 1 x 512 x 1 x 81
        h = cnn_feature.size()[2]
        if h != 1:
            raise ValueError("the height of cnn_feature must be 1")

        cnn_feature = cnn_feature.squeeze(2)

        # 81: 序列长度 1: batch size, 512: 每个特征的维度
        cnn_feature = cnn_feature.permute(2, 0, 1)

        output = self.rnn(cnn_feature)
        # [81, 1, num_classes]
        x = F.log_softmax(x, dim=2)
        return output


class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, out_feature):
        super().__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
        self.embedding = nn.Linear(hidden_size * 2, out_feature)

    def forward(self, x):
        # x: [81, 1, 512] → [sequence_length, batch_size, input_size]
        recurrent, _ = self.rnn(x)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


def decode(preds, preds_length):
    length = preds_length[0]
    char_list = []
    for i in range(length):
        # 第一个索引是blank
        if preds[i] != 0 and (not (i > 0 and preds[i - 1] == preds[i])):
            char_list.append(alphabet[preds[i] - 1])
    return ''.join(char_list)


if __name__ == '__main__':
    alphabet = ['blank'] + list(string.ascii_lowercase)
    num_classes = len(alphabet)  # 27
    
    img = torch.randn((1, 1, 32, 320))

    ctc_loss = nn.CTCLoss()

    crnn = CRNN(32, 1, num_classes, 256)

    # 推理
    preds = crnn(img)

    # 推理:解码得到文字内容
    # 获得每一个seq对应的num_classes类中最大的那一类的索引
    _, infer_preds = preds.max(2)  # preds out: [81, 1]
    infer_preds = infer_preds.transpose(1, 0).contiguous().view(-1)  # out: [81]
    preds_len = torch.IntTensor([infer_preds.shape[0]])
    text = decode(infer_preds, preds_len)
    print(text)

    # 训练:计算loss
    min_seq_length = 10
    max_seq_length = 30
    batch_size = img.shape[0]
    time_step = preds.shape[0]
    input_length = torch.IntTensor([time_step] * batch_size)
    target = torch.randint(low=1, high=num_classes,
                           size=(batch_size, max_seq_length),
                           dtype=torch.long)
    target_length = torch.randint(low=min_seq_length,
                                  high=max_seq_length,
                                  size=(batch_size,), dtype=torch.long)

    # preds shape: [81, 1, num_classes]
    # target shape: [1, 30]
    # input_length: [1]
    # target_length: [1]
    loss = ctc_loss(preds, target, input_length, target_length)

    print(preds.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
参考博客
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/219280
推荐阅读
相关标签
  

闽ICP备14008679号