当前位置:   article > 正文

【深度学习实战(28)】训练框架之使用yaml和argparse配置训练参数_args.yaml文件

args.yaml文件

一、yaml和argparse

在训练时,通过有许多参数需要管理,这里推荐一种结合yamlargparse的写法。

二、Yaml

yaml文件中设置训练参数
在这里插入图片描述

对应的解析代码

# 解析yaml配置文件
class LoadYaml:
    def __init__(self, path):
        with open(path, encoding='utf8') as f:
            data = yaml.load(f, Loader=yaml.FullLoader)

        self.val_txt = data["DATASET"]["VAL"]
        self.train_txt = data["DATASET"]["TRAIN"]
        self.names = data["DATASET"]["NAMES"]

        self.learn_rate = data["TRAIN"]["LR"]
        self.batch_size = data["TRAIN"]["BATCH_SIZE"]
        self.milestones = data["TRAIN"]["MILESTIONES"]
        self.end_epoch = data["TRAIN"]["END_EPOCH"]

        self.input_width = data["MODEL"]["INPUT_WIDTH"]
        self.input_height = data["MODEL"]["INPUT_HEIGHT"]

        self.category_num = data["MODEL"]["NC"]

        print("Load yaml sucess...")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

三、完整代码

import yaml
import argparse
import os

# 解析yaml配置文件
class LoadYaml:
    def __init__(self, path):
        with open(path, encoding='utf8') as f:
            data = yaml.load(f, Loader=yaml.FullLoader)

        self.val_txt = data["DATASET"]["VAL"]
        self.train_txt = data["DATASET"]["TRAIN"]
        self.names = data["DATASET"]["NAMES"]

        self.learn_rate = data["TRAIN"]["LR"]
        self.batch_size = data["TRAIN"]["BATCH_SIZE"]
        self.milestones = data["TRAIN"]["MILESTIONES"]
        self.end_epoch = data["TRAIN"]["END_EPOCH"]

        self.input_width = data["MODEL"]["INPUT_WIDTH"]
        self.input_height = data["MODEL"]["INPUT_HEIGHT"]

        self.category_num = data["MODEL"]["NC"]

        print("Load yaml sucess...")



parser = argparse.ArgumentParser()
parser.add_argument('--yaml',type=str,default="./train.yaml",help='.yaml config')
parser.add_argument('--pretain_weight',type=str,default='pretrain.pth',help='.weight config')
parser.add_argument('--saved_weight',type=str,default='best.pth',help='.weight config')
parser.add_argument('--last_weight',type=str,default='last.pth',help='.weight config')

opt = parser.parse_args()
assert os.path.exists(opt.yaml), "请指定正确的配置文件路径"

print('pretain_weight:',opt.pretain_weight)
print('saved_weight:',opt.saved_weight)
print('last_weight:',opt.last_weight)

# 解析yaml配置文件
cfg = LoadYaml(opt.yaml)
print('category_num:',cfg.category_num)
print('names:',cfg.names)
print('batch_size:',cfg.batch_size)
print('end_epoch:',cfg.end_epoch)
print('input_width:',cfg.input_width)
print('input_width:',cfg.input_width)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/551236
推荐阅读
相关标签
  

闽ICP备14008679号