赞
踩
需要训练的有两个模型,一个是文本识别模型,一个是图像识别模型。在训练的时候,尝试了ResNet50、ResNet101、MobileNetV2,三种模型,前两个残差神经网络模型的参数比较大,训练比较耗时,精度上也逊色于第三个模型。尝试了RTX 2080、RTX 2070、Tesla K80三种GPU,三者的训练速度依次递减,但差距不是很大。训练平台在篇尾再进行介绍。
两个模型的训练基本上可以共用一套代码,只有一些细微的差异。
首先介绍共用的代码部分,其中加载数据的函数,数据的根目录根据自己的情况,需要自行修改。
#导入必要的包 from tensorflow.compat.v1 import ConfigProto from tensorflow.compat.v1 import InteractiveSession config = ConfigProto() config.gpu_options.allow_growth = True session = InteractiveSession(config=config) import tensorflow as tf from tensorflow import keras #from tensorflow.keras.applications import ResNet101 from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 from keras.callbacks import EarlyStopping import numpy as np import os import shutil import collections import math import random import pathlib #加载数据 def dataloader(): data_root = '/mnt/Train_Valid_text' #根目录自行更改 train_data_root = pathlib.Path(data_root + "/Train") valid_data_root = pathlib.Path(data_root + "/Valid") #train_valid_data_root = pathlib.Path(data_root + "/Train_valid") # test_data_root = pathlib.Path(data_root+"/test") label_names = sorted(item.name for item in train_data_root.glob('*/') if item.is_dir()) print(label_names) label_to_index = dict((name, index) for index, name in enumerate(label_names)) print(label_to_index) train_all_image_paths = [str(path) for path in list(train_data_root.glob('*/*'))] valid_all_image_paths = [str(path) for path in list(valid_data_root.glob('*/*'
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。