赞
踩
torch.nn.CTCLoss()
的输入必须要经过logsoftmax
函数的# -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com import string import torch from torch import nn import torch.nn.functional as F class CRNN(nn.Module): def __init__(self, img_height, input_channel, n_class, hidden_size): super().__init__() if img_height % 16 != 0: raise ValueError('img_height has to be a multiple of 16') kernel_size = [3, 3, 3, 3, 3, 3, 2] padding_size = [1, 1, 1, 1, 1, 1, 0] stride = [1, 1, 1, 1, 1, 1, 1] channel = [64, 128, 256, 256, 512, 512, 512] def conv_relu(i, batchNormalization=False): in_channels = input_channel if i == 0 else channel[i - 1] out_channels = channel[i] cnn.add_module(f'conv{i}', nn.Conv2d(in_channels, out_channels, kernel_size[i], stride[i], padding_size[i])) if batchNormalization: cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(out_channels)) cnn.add_module(f'relu{i}', nn.ReLU(True)) # x: 1 x 32 x 320 cnn = nn.Sequential() conv_relu(0) cnn.add_module('pooling0', nn.MaxPool2d(2, 2)) # 64x16x160 conv_relu(1) cnn.add_module('pooling1', nn.MaxPool2d(2, 2)) # 128x8x80 conv_relu(2, True) conv_relu(3) cnn.add_module('pooling2', nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1))) # 256x4x81 conv_relu(4, True) conv_relu(5) cnn.add_module('pooling3', nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x82 conv_relu(6, True) # 512x1x81 self.cnn = cnn self.rnn = nn.Sequential( BidirectionalLSTM(512, hidden_size, hidden_size), BidirectionalLSTM(hidden_size, hidden_size, n_class) ) def forward(self, x): cnn_feature = self.cnn(x) # 1 x 512 x 1 x 81 h = cnn_feature.size()[2] if h != 1: raise ValueError("the height of cnn_feature must be 1") cnn_feature = cnn_feature.squeeze(2) # 81: 序列长度 1: batch size, 512: 每个特征的维度 cnn_feature = cnn_feature.permute(2, 0, 1) output = self.rnn(cnn_feature) # [81, 1, num_classes] x = F.log_softmax(x, dim=2) return output class BidirectionalLSTM(nn.Module): def __init__(self, input_size, hidden_size, out_feature): super().__init__() self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True) self.embedding = nn.Linear(hidden_size * 2, out_feature) def forward(self, x): # x: [81, 1, 512] → [sequence_length, batch_size, input_size] recurrent, _ = self.rnn(x) T, b, h = recurrent.size() t_rec = recurrent.view(T * b, h) output = self.embedding(t_rec) # [T * b, nOut] output = output.view(T, b, -1) return output def decode(preds, preds_length): length = preds_length[0] char_list = [] for i in range(length): # 第一个索引是blank if preds[i] != 0 and (not (i > 0 and preds[i - 1] == preds[i])): char_list.append(alphabet[preds[i] - 1]) return ''.join(char_list) if __name__ == '__main__': alphabet = ['blank'] + list(string.ascii_lowercase) num_classes = len(alphabet) # 27 img = torch.randn((1, 1, 32, 320)) ctc_loss = nn.CTCLoss() crnn = CRNN(32, 1, num_classes, 256) # 推理 preds = crnn(img) # 推理:解码得到文字内容 # 获得每一个seq对应的num_classes类中最大的那一类的索引 _, infer_preds = preds.max(2) # preds out: [81, 1] infer_preds = infer_preds.transpose(1, 0).contiguous().view(-1) # out: [81] preds_len = torch.IntTensor([infer_preds.shape[0]]) text = decode(infer_preds, preds_len) print(text) # 训练:计算loss min_seq_length = 10 max_seq_length = 30 batch_size = img.shape[0] time_step = preds.shape[0] input_length = torch.IntTensor([time_step] * batch_size) target = torch.randint(low=1, high=num_classes, size=(batch_size, max_seq_length), dtype=torch.long) target_length = torch.randint(low=min_seq_length, high=max_seq_length, size=(batch_size,), dtype=torch.long) # preds shape: [81, 1, num_classes] # target shape: [1, 30] # input_length: [1] # target_length: [1] loss = ctc_loss(preds, target, input_length, target_length) print(preds.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。