赞
踩
import torch from torch import nn from torch.nn import LSTM, Linear from torchvision import models from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import Dataset, DataLoader import os import cv2 import numpy as np from tqdm import tqdm import torchvision.transforms as T // 图片大小 IMAGE_SHAPE = (28, 135) transform = T.Compose([ T.ToPILImage(), T.Resize(IMAGE_SHAPE), T.ToTensor() ]) // 标签'_'代表占位,不定长必要 LABEL_MAP = [i for i in '_0123456789-+='] Max_label_len = 6 class MyDataset(Dataset): def __init__(self, data_path, label_map, max_label_len): super(MyDataset, self).__init__() self.data = [(os.path.join(data_path, file), file.split('_')[0]) for file in os.listdir(data_path)] self.label_map = [char for char in label_map] self.label_map_len = len(self.label_map) self.max_label_len = max_label_len def __getitem__(self, index): file = self.data[index][0] label = self.data[index][1] raw_len = len(label) im = np.fromfile(file, dtype=np.uint8) im = cv2.imdecode(im, cv2.IMREAD_COLOR) im = transform(im) label = [self.label_map.index(i) for i in label] for i in range(self.max_label_len - len(label)): label.append(0) label = np.asarray(label, dtype='int32').reshape(self.max_label_len) return im, label, raw_len def __len__(self): return len(self.data) class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 需要把后边的AdaptiveAvgPool2d层和全连接层去掉 self.resnet18 = nn.Sequential(*list(models.resnet18().children())[0:-3]) bone_output_shape = self._cal_shape() self.lstm = LSTM(bone_output_shape, bone_output_shape, num_layers=1, bidirectional=True) self.linear = Linear(bone_output_shape * 2, 256) self.lstm1 = LSTM(256, bone_output_shape, num_layers=1, bidirectional=True) self.linear1 = Linear(bone_output_shape * 2, len(LABEL_MAP)) def _cal_shape(self): x = torch.zeros((1, 3) + IMAGE_SHAPE) shape = self.resnet18(x).shape # [1, 256, 4, 10] BATCH, DIM, HEIGHT, WIDTH return shape[1] * shape[2] def forward(self, x): x = self.resnet18(x) x = x.permute(3, 0, 1, 2) w, b, c, h = x.shape x = x.view(w, b, c * h) x, _ = self.lstm(x) time_step, batch_size, h = x.shape x = x.view(time_step * batch_size, h) x = self.linear(x) x = x.view(time_step, batch_size, -1) x, _ = self.lstm1(x) time_step, batch_size, h = x.shape x = x.view(time_step * batch_size, h) x = self.linear1(x) x = x.view(time_step, batch_size, -1) return x def tranfromlabel(label): t_label = [] for i in label: t_label.append(LABEL_MAP[i]) return ''.join(t_label) def ctc_to_str(data): """ CTC 解码 :param data: 编码后的文本 :param label_map: 码表 :return: 解码后文本 """ # print('in',data) result = [] last = -1 for i in list(data): if i == 0: last = -1 elif i != last: result.append(i) last = i return tranfromlabel(result) train = DataLoader( dataset=MyDataset(r'./train', label_map=LABEL_MAP, max_label_len=Max_label_len), batch_size=32, shuffle=True, num_workers=3) test = DataLoader( dataset=MyDataset(r'./test', label_map=LABEL_MAP, max_label_len=Max_label_len), batch_size=4, shuffle=True, num_workers=0) if __name__ == '__main__': DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net() model.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_func = nn.CTCLoss() scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3) for epoch in range(0, 100): bar = tqdm(train, 'Training') for images, labels, target_lengths in bar: images = images.to(DEVICE) predict = model(images) predict_lengths = torch.IntTensor([[int(predict.shape[0])] * labels.shape[0]]) loss = loss_func(predict, labels, predict_lengths, target_lengths) optimizer.zero_grad() loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] bar.set_description("Train epoch %d, loss %.4f, lr %.6f" % ( epoch, loss.detach().cpu().numpy(), lr )) bar = tqdm(test, 'Validating') correct = count = 0 for images, labels, target_lengths in bar: images = images.to(DEVICE) predicts = model(images) for i in range(predicts.shape[1]): predict = predicts[:, i, :] predict = predict.argmax(1) predict = predict.contiguous() count += 1 label_text = tranfromlabel(labels[i])[:target_lengths[i]] predict_text = ctc_to_str(predict) # print(label_text, predict_text) if label_text == predict_text: correct += 1 predict_lengths = torch.IntTensor([[int(predicts.shape[0])] * labels.shape[0]]) loss = loss_func(predicts, labels, predict_lengths, target_lengths) lr = optimizer.param_groups[0]['lr'] bar.set_description("Valid epoch %d, acc %.4f, loss %.4f, lr %.6f" % ( epoch, correct / count, loss.detach().cpu().numpy(), lr )) scheduler.step(correct / count) torch.save(model.state_dict(), "models/save_%d.model" % epoch)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。