当前位置:   article > 正文

不定长图文模型训练_resnetlstm(nn.module)

resnetlstm(nn.module)

生成数据集

import os
import random
from PIL import Image, ImageDraw, ImageFont, ImageFilter
from io import BytesIO
import time


def main():
    _first_num = random.randint(1, 1000)
    _code_style = ['加', '减', '乘', '+', '-', '*']
    _last_num = random.randint(1, 1000)
    init_chars = [str(_first_num), random.choices(_code_style)[0], str(_last_num)]

    def create_validate_code(size=(150, 50),
                             chars=init_chars,
                             img_type="PNG",
                             mode="RGB",
                             bg_color=(255, 255, 255),
                             fg_color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
                             font_size=18,
                             font_type="./msyh.ttc",
                             draw_lines=True,
                             n_line=(1, 2),
                             draw_points=True,
                             point_chance=1):
        """
        @todo: 生成验证码图片
        @param size: 图片的大小,格式(宽,高),默认为(120, 30)
        @param chars: 允许的字符集合,格式字符串
        @param img_type: 图片保存的格式,默认为GIF,可选的为GIF,JPEG,TIFF,PNG
        @param mode: 图片模式,默认为RGB
        @param bg_color: 背景颜色,默认为白色
        @param fg_color: 前景色,验证码字符颜色,默认为蓝色#0000FF
        @param font_size: 验证码字体大小
        @param font_type: 验证码字体
        @param length: 验证码字符个数
        @param draw_lines: 是否划干扰线
        @param n_lines: 干扰线的条数范围,格式元组,默认为(1, 2),只有draw_lines为True时有效
        @param draw_points: 是否画干扰点
        @param point_chance: 干扰点出现的概率,大小范围[0, 100]
        @return: [0]: PIL Image实例
        @return: [1]: 验证码图片中的字符串
        """

        width, height = size  # 宽高
        # 创建图形
        img = Image.new(mode, size, bg_color)
        draw = ImageDraw.Draw(img)  # 创建画笔

        def get_chars():
            return chars

        def create_lines():
            """绘制干扰线"""
            line_num = random.randint(*n_line)  # 干扰线条数

            for i in range(line_num):
                # 起始点
                begin = (random.randint(0, size[0]), random.randint(0, size[1]))
                # 结束点
                end = (random.randint(0, size[0]), random.randint(0, size[1]))
                draw.line([begin, end], fill=(0, 0, 0))

        def create_points():
            """绘制干扰点"""
            chance = min(100, max(0, int(point_chance)))  # 大小限制在[0, 100]

            for w in range(width):
                for h in range(height):
                    tmp = random.randint(0, 100)
                    if tmp > 100 - chance:
                        draw.point((w, h), fill=(0, 0, 0))

        def create_strs():
            """绘制验证码字符"""
            c_chars = get_chars()
            strs = ' %s ' % ' '.join(c_chars)  # 每个字符前后以空格隔开

            font = ImageFont.truetype(font_type, font_size)
            font_width, font_height = font.getsize(strs)
            font_width /= 0.7
            font_height /= 0.7
            draw.text(((width - font_width) / 3, (height - font_height) / 3),
                      strs, font=font, fill=fg_color)

            return ''.join(c_chars)

        if draw_lines:
            create_lines()
        if draw_points:
            create_points()
        strs = create_strs()

        # 图形扭曲参数
        params = [1 - float(random.randint(1, 2)) / 80,
                  0,
                  0,
                  0,
                  1 - float(random.randint(1, 10)) / 80,
                  float(random.randint(3, 5)) / 450,
                  0.001,
                  float(random.randint(3, 5)) / 450
                  ]
        img = img.transform(size, Image.PERSPECTIVE, params)  # 创建扭曲
        output_buffer = BytesIO()
        img.save(output_buffer, format='PNG')
        img_byte_data = output_buffer.getvalue()
        # img = img.filter(ImageFilter.EDGE_ENHANCE_MORE)  # 滤镜,边界加强(阈值更大)
        return img_byte_data, strs
    res = create_validate_code()
    return res


main()
# try:
#     os.mkdir('./训练图片生成')
# except FileExistsError:
#     print('训练图片生成 文件夹已经存在')
# print('生成存储文件夹成功')
while 1:
    number = input('请输入要生成的验证码数量:')
    try:
        for i in range(int(number)):
            res = main()
            with open('./picture/{0}_{1}.png'.format(res[1].replace('*', '乘'), int(time.time())), 'wb') as f:
            # with open('./test/{0}_{1}.png'.format(res[1].replace('*', '乘'), int(time.time())), 'wb') as f:
                f.write(res[0])
            print('生成第', i+1, '个图片成功')
    except ValueError:
        print('请输入一个数字,不要输入乱七八糟的东西,打你哦')
    except:
        import traceback
        traceback.print_exc()
        break
    input('理论上生成完成了~,QAQ 共生成了' + number + '个验证码')
input('出现未知错误,错误已打印')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137

先创建picture和test目录,picture作为训练数据集目录 test作为测试数据集目录

先用以上程序生成3000张训练数据图片集:

在这里插入图片描述

再生成200张测试数据集:

在这里插入图片描述

模型选择

本次我们采用一种基于长短期记忆网络(LSTM)和残差网络(ResNet)相融合的网络模型,模型代码实现:

import torch
import torch.nn as nn
from torch.nn import functional as F


class RestNetBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        output = self.conv1(x)
        output = F.relu(self.bn1(output))
        output = self.conv2(output)
        output = self.bn2(output)
        return F.relu(x + output)


class RestNetDownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(RestNetDownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride[0], padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.extra = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=0),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        extra_x = self.extra(x)
        output = self.conv1(x)
        out = F.relu(self.bn1(output))

        out = self.conv2(out)
        out = self.bn2(out)
        return F.relu(extra_x + out)

class ResNetLstm_shape(nn.Module):
    def __init__(self):
        super(ResNetLstm_shape, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))


    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)   # [2, 256, 7, 19]  [batch, layer, h, w]   []
        return out.shape


class ResNetLstm(nn.Module):
    def __init__(self, image_shape):
        super(ResNetLstm, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.Sequential(RestNetBasicBlock(64, 64, 1),
                                    RestNetBasicBlock(64, 64, 1))

        self.layer2 = nn.Sequential(RestNetDownBlock(64, 128, [2, 1]),
                                    RestNetBasicBlock(128, 128, 1))

        self.layer3 = nn.Sequential(RestNetDownBlock(128, 256, [2, 1]),
                                    RestNetBasicBlock(256, 256, 1))
        x = torch.zeros((1, 3) + image_shape)
        size = ResNetLstm_shape()(x)
        input_size = size[1] * size[2]
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=input_size, num_layers=1, bidirectional=True)
        self.fc = nn.Linear(input_size * 2, 17)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)   # [2, 256, 7, 19]  [batch, layer, h, w]   []
        out = out.permute(3, 0, 1, 2)    # [19, 2, 256, 7]
        out_shape = out.shape
        out = out.view(out_shape[0], out_shape[1], out_shape[2]*out_shape[3])   # [19, 2, 256*7]
        out, _ = self.lstm(out)
        # print(out.shape)   # [19, 2, 3584] [w, b, h]
        out_shape = out.shape
        out = out.view(out_shape[0]*out_shape[1], out_shape[2])
        out = self.fc(out)
        out = out.view(out_shape[0], out_shape[1], -1)
        return out

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104

计算均值和标准差

from torch.utils.data import Dataset
import os
from PIL import Image
import torch
from tqdm import tqdm
import numpy as np
from torchvision import transforms


class Letter2Dataset(Dataset):
    def __init__(self, root: str, transform=None):
        super(Letter2Dataset, self).__init__()
        self.path = root
        self.transform = transform
        # 可优化
        self.mapping = [i for i in '_0123456789加减乘+-*']

    def load_picture_path(self):
        picture_list = list(os.walk(self.path))[0][-1]
        # 这里可以增加很多的错误判断
        return picture_list

    def __len__(self):
        return len(self.load_picture_path())

    def __getitem__(self, item):
        load_picture = self.load_picture_path()
        image = Image.open(self.path + '/' + load_picture[item])
        if self.transform:
            image = self.transform(image)
        labels = [self.mapping.index(i) for i in load_picture[item].split('_')[0]]
        for i in range(9 - len(labels)):
            labels.insert(0, 0)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        return image, labels, len(labels)

    def slice(self, start, end, step=1):
        load_picture = self.load_picture_path()
        images = []
        for i in range(start, end, step):
            image = Image.open(self.path + '/' + load_picture[i])
            if self.transform:
                image = self.transform(image)
            images.append(image.numpy())
        images = torch.Tensor(images)
        return images


# 获取均值和标准差 效率较低速度较慢建议使用第2种
def get_ms():
    transform = transforms.Compose([transforms.ToTensor(), ])
    my_train = Letter2Dataset(root="./picture", transform=transform)
    total_mean = [[], [], []]
    total_std = [[], [], []]
    res_total = [0, 0, 0]
    res_std = [0, 0, 0]
    for i in tqdm(range(len(my_train))):
        for j in range(len(total_std)):
            total_mean[j].append([np.array(my_train[i][0][j])])
            total_std[j].append([np.array(my_train[i][0][j])])

    for i in range(len(total_std)):
        res_total[i] = np.mean(total_mean[i])
        res_std[i] = np.std(total_std[i])
    return res_total, res_std


# 获取均值和标准差
def get_ms2():
    transform = transforms.Compose([transforms.ToTensor(), ])
    my_train = Letter2Dataset(root="./picture", transform=transform)
    a = my_train.slice(0, len(my_train))
    b = a.mean(dim=(2, 3), keepdim=True)
    b = b.mean(dim=0, keepdim=True)
    c = a.std(dim=(0, 2, 3), keepdim=True)
    res_total = b.reshape(-1).tolist()
    res_std = c.reshape(-1).tolist()
    return res_total, res_std


if __name__ == '__main__':
    res_total, res_std = get_ms()
    print(res_total, res_std)
    print('==' * 2)
    res_total, res_std = get_ms2()
    print(res_total, res_std)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88

在这里插入图片描述

训练代码

from torch import save, load
from test_p2 import test
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from tqdm import tqdm
from model import ResNetLstm
from MyDataset import get_ms,Letter2Dataset
import os
import numpy as np
import torch
res_total, res_std = get_ms2()

print(res_total, res_std)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# 实例化模型
size = (50, 150)
model = ResNetLstm(size)
model = model.to(device)
optimizer = optim.Adam(model.parameters())
batch_size = 16
# 加载已经训练好的模型和优化器继续进行训练
if os.path.exists('./models/model.pkl'):
    model.load_state_dict(load("./models/model.pkl"))
    optimizer.load_state_dict(load("./models/optimizer.pkl"))

loss_function = nn.CTCLoss()
my_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=res_total, std=res_std)
    ]
)
mnist_train = Letter2Dataset(root="./picture", transform=my_transforms)
def train(epoch):
    total_loss = []
    dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
    dataloader = tqdm(dataloader, total=len(dataloader))
    model.train()
    for images, labels, labels_lengths in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        # 梯度置0
        optimizer.zero_grad()
        # 前向传播
        output = model(images)
        # 通过结果计算损失

        input_lengths = torch.IntTensor([output.shape[0]]*output.shape[1])
        # print(input_lengths.shape)
        loss = loss_function(output, labels, input_lengths, labels_lengths)
        total_loss.append(loss.item())
        dataloader.set_description('loss:{}'.format(np.mean(total_loss)))
        # 反向传播
        loss.backward()
        # 优化器更新
        optimizer.step()

    save(model.state_dict(), './models/model.pkl')
    save(optimizer.state_dict(), './models/optimizer.pkl')
    # 打印一下训练成功率, test.test_success()
    print('第{}个epoch,成功率, 损失为{}'.format(epoch, np.mean(total_loss)))

for i in range(15):
    train(i)
    print(test())
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

刚开始跑了14轮,成功率只有80多,后面又跑了十几轮就达到97左右了

在这里插入图片描述

测试集测试

from torch import save, load
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
import os
import numpy as np
from model import ResNetLstm
import itertools
from MyDataset import get_ms,Letter2Dataset
res_total, res_std = get_ms()

mapping = [i for i in '_0123456789加减乘+-*']
def test():
    # 实例化模型
    model = ResNetLstm((50, 150))
    optimizer = optim.Adam(model.parameters())
    batch_size = 16
    # 加载已经训练好的模型和优化器继续进行训练
    if os.path.exists('./models/model.pkl'):
        model.load_state_dict(load("./models/model.pkl"))
        optimizer.load_state_dict(load("./models/optimizer.pkl"))
    my_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=res_total, std=res_std)
        ]
    )
    mnist_train = Letter2Dataset(root="./test", transform=my_transforms)
    success = 0
    total = 0
    dataloader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
    dataloader = tqdm(dataloader, total=len(dataloader))
    model.eval()
    with torch.no_grad():
        for images, labels, _ in dataloader:
            output = model(images)
            # 通过结果计算损失
            output = output.permute(1, 0, 2)  # [2 19 17]
            for i in range(output.shape[0]):
                output_result = output[i, :, :]
                output_result = output_result.max(-1)[-1]
                labels_s = [mapping[i] for i in labels[i].cpu().numpy() if mapping[i] != '_']
                output_s = [mapping[i[0]] for i in itertools.groupby(output_result.cpu().numpy()) if i[0] != 0]
                # print('lab-->',labels_s)
                # print('out-->',output_s)
                if labels_s == output_s:
                    success += 1
                total += 1

    return success/total

if __name__ == '__main__':
    print(test())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

在这里插入图片描述
测试了200张图片,成功率达97%。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/118750?site
推荐阅读
相关标签
  

闽ICP备14008679号