赞
踩
在训练时,通过有许多参数需要管理,这里推荐一种结合yaml
和argparse
的写法。
在yaml
文件中设置训练参数
对应的解析代码
# 解析yaml配置文件
class LoadYaml:
def __init__(self, path):
with open(path, encoding='utf8') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
self.val_txt = data["DATASET"]["VAL"]
self.train_txt = data["DATASET"]["TRAIN"]
self.names = data["DATASET"]["NAMES"]
self.learn_rate = data["TRAIN"]["LR"]
self.batch_size = data["TRAIN"]["BATCH_SIZE"]
self.milestones = data["TRAIN"]["MILESTIONES"]
self.end_epoch = data["TRAIN"]["END_EPOCH"]
self.input_width = data["MODEL"]["INPUT_WIDTH"]
self.input_height = data["MODEL"]["INPUT_HEIGHT"]
self.category_num = data["MODEL"]["NC"]
print("Load yaml sucess...")
import yaml
import argparse
import os
# 解析yaml配置文件
class LoadYaml:
def __init__(self, path):
with open(path, encoding='utf8') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
self.val_txt = data["DATASET"]["VAL"]
self.train_txt = data["DATASET"]["TRAIN"]
self.names = data["DATASET"]["NAMES"]
self.learn_rate = data["TRAIN"]["LR"]
self.batch_size = data["TRAIN"]["BATCH_SIZE"]
self.milestones = data["TRAIN"]["MILESTIONES"]
self.end_epoch = data["TRAIN"]["END_EPOCH"]
self.input_width = data["MODEL"]["INPUT_WIDTH"]
self.input_height = data["MODEL"]["INPUT_HEIGHT"]
self.category_num = data["MODEL"]["NC"]
print("Load yaml sucess...")
parser = argparse.ArgumentParser()
parser.add_argument('--yaml',type=str,default="./train.yaml",help='.yaml config')
parser.add_argument('--pretain_weight',type=str,default='pretrain.pth',help='.weight config')
parser.add_argument('--saved_weight',type=str,default='best.pth',help='.weight config')
parser.add_argument('--last_weight',type=str,default='last.pth',help='.weight config')
opt = parser.parse_args()
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"
print('pretain_weight:',opt.pretain_weight)
print('saved_weight:',opt.saved_weight)
print('last_weight:',opt.last_weight)
# 解析yaml配置文件
cfg = LoadYaml(opt.yaml)
print('category_num:',cfg.category_num)
print('names:',cfg.names)
print('batch_size:',cfg.batch_size)
print('end_epoch:',cfg.end_epoch)
print('input_width:',cfg.input_width)
print('input_width:',cfg.input_width)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。