赞
踩
import tensorflow as tf
import os
import argparse
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
tf.get_logger().setLevel('ERROR')
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True, help='Path to the directory which contains the model configuration and checkpoints.')
parser.add_argument('--pipeline_config_path', type=str, required=True, help='Path to the pipeline configuration file.')
parser.add_argument('--train_data_path', type=str, required=True, help='Path to the training data directory.')
parser.add_argument('--eval_data_path', type=str, required=True, help='Path to the evaluation data directory.')
parser.add_argument('--num_train_steps', type=int, required=True, help='Number of training steps.')
parser.add_argument('--num_eval_steps', type=int, required=True, help='Number of evaluation steps.')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size.')
parser.add_argument('--num_classes', type=int, default=1, help='Number of classes.')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Learning rate.')
parser.add_argument('--checkpoint_every_n', type=int, default=1000, help='Save a checkpoint every n steps.')
parser.add_argument('--eval_every_n', type=int, default=1000, help='Run evaluation every n steps.')
args = parser.parse_args()
# Load the pipeline configuration
configs = config_util.get_configs_from_pipeline_file(args.pipeline_config_path)
model_config = configs['model']
train_config = configs['train_config']
eval_config = configs['eval_config']
input_config = configs['train_input_config']
# Load the label map
label_map_path = os.path.join(args.train_data_path, 'label_map.pbtxt')
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=args.num_classes, use_display_name=True)
# Define the model
detection_model = model_builder.build(model_config=model_config, is_training=True)
# Define the data input pipeline
dataset = input_config.tf_record_input_reader
files_train = tf.io.gfile.glob(os.path.join(args.train_data_path, '*.record'))
dataset_train = tf.data.TFRecordDataset(files_train)
dataset_train = dataset_train.shuffle(input_config.shuffle_buffer_size)
dataset_train = dataset_train.map(dataset.parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_train = dataset_train.batch(args.batch_size)
dataset_train = dataset_train.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
files_eval = tf.io.gfile.glob(os.path.join(args.eval_data_path, '*.record'))
dataset_eval = tf.data.TFRecordDataset(files_eval)
dataset_eval = dataset_eval.map(dataset.parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_eval = dataset_eval.batch(args.batch_size)
# Define the loss function
@tf.function
def train_step(data):
images, boxes, classes = data
with tf.GradientTape() as tape:
pred = detection_model(images, training=True)
loss = detection_model.loss(pred, boxes, classes)
total_loss = tf.reduce_mean(loss['Loss/total_loss'])
gradients = tape.gradient(total_loss, detection_model.trainable_variables)
detection_model.optimizer.apply_gradients(zip(gradients, detection_model.trainable_variables))
return loss
# Define the evaluation function
@tf.function
def eval_step(data):
images, boxes, classes = data
pred = detection_model(images, training=False)
loss = detection_model.loss(pred, boxes, classes)
return loss
# Define the training loop
ckpt = tf.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(ckpt, args.model_dir, max_to_keep=10)
global_step = tf.Variable(1, dtype=tf.int64)
train_loss = tf.keras.metrics.Mean(name='train_loss')
eval_loss = tf.keras.metrics.Mean(name='eval_loss')
for epoch in range(1, train_config.num_epochs + 1):
train_loss.reset_states()
for step, data in enumerate(dataset_train):
loss = train_step(data)
train_loss.update_state(loss['Loss/total_loss'])
global_step.assign_add(1)
if global_step % args.checkpoint_every_n == 0:
ckpt_manager.save()
if global_step % args.eval_every_n == 0:
eval_loss.reset_states()
for data_eval in dataset_eval.take(args.num_eval_steps):
loss_eval = eval_step(data_eval)
eval_loss.update_state(loss_eval['Loss/total_loss'])
print('Epoch {} - Step {} - Train Loss {:.4f} - Eval Loss {:.4f}'.format(epoch,
global_step.numpy(),
train_loss.result(),
eval_loss.result()))
其中,数据预处理和可视化方法略。上述代码中,通过配置文件获取模型参数和训练参数,加载数据集并定义数据输入管道。通过模型构建器定义目标检测模型,并利用训练集进行训练,同时通过验证集进行模型的评估,最后保存模型。
该博文为原创文章,未经博主同意不得转。本文章博客地址:https://cplusplus.blog.csdn.net/article/details/133968719
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。