当前位置:   article > 正文

利用神经网络识别12306验证码——(五)训练模型_ocr resnet lstm

ocr resnet lstm

需要训练的有两个模型,一个是文本识别模型,一个是图像识别模型。在训练的时候,尝试了ResNet50、ResNet101、MobileNetV2,三种模型,前两个残差神经网络模型的参数比较大,训练比较耗时,精度上也逊色于第三个模型。尝试了RTX 2080、RTX 2070、Tesla K80三种GPU,三者的训练速度依次递减,但差距不是很大。训练平台在篇尾再进行介绍。
两个模型的训练基本上可以共用一套代码,只有一些细微的差异。
首先介绍共用的代码部分,其中加载数据的函数,数据的根目录根据自己的情况,需要自行修改。

#导入必要的包
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
import tensorflow as tf
from tensorflow import keras
#from tensorflow.keras.applications import ResNet101
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from keras.callbacks import EarlyStopping
import numpy as np
import os
import shutil
import collections
import math
import random
import pathlib

#加载数据
def dataloader():
    data_root = '/mnt/Train_Valid_text'  #根目录自行更改
    train_data_root = pathlib.Path(data_root + "/Train")
    valid_data_root = pathlib.Path(data_root + "/Valid")
    #train_valid_data_root = pathlib.Path(data_root + "/Train_valid")
    # test_data_root = pathlib.Path(data_root+"/test")
    label_names = sorted(item.name for item in train_data_root.glob('*/') if item.is_dir())
    print(label_names)
    label_to_index = dict((name, index) for index, name in enumerate(label_names))
    print(label_to_index)
    
    train_all_image_paths = [str(path) for path in list(train_data_root.glob('*/*'))]
    valid_all_image_paths = [str(path) for path in list(valid_data_root.glob('*/*'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/118713
推荐阅读
相关标签
  

闽ICP备14008679号