赞
踩
今天看完了EEGNet的论文准备搭建一下EEGNet的网络,然后想到之前看过网络配置文件的内容,然后想着以后开发自己的网络的能够规范和方便,所以就学习一下,并在这里记录一下,方便以后查阅。
a.yaml
DATA:
BATCH_SIZE: 512
MODEL:
TRANS:
EMBED_DIM: 768
config.py
from yacs.config import CfgNode as CN import yaml # 设置默认参数 _C = CN() _C.DATA = CN() _C.DATA.DATASET = 'cifar10' _C.DATA.BATCH_SIZE = 128 _C.MODEL = CN() _C.MODEL.NUM_CLASSES = 10 _C.MODEL.TRANS = CN() _C.MODEL.TRANS.EMBED_DIM = 96 _C.MODEL.TRANS.DEPTHS = [2, 2, 6, 2] _C.MODEL.TRANS.QKV_BIAS = False # 通过yaml更新参数 def _update_config_from_file(config, cfg_file): config.defrost() config.merge_from_file(cfg_file) # .yaml # 通过argparser.ArgumentParser更新参数 def update_config(config, args): if args.cfg: _update(config, args.cfg) if args.dataset: config.DATA.DATASET = args.datasert if args.batch_size: config.DATA.BATCH_SIZE = args.batch_size return config def get_config(cfg_file=None): config = _C.clone() if cfg_file: _update_config_from_file(config, cfg_file) return config def main(): cfg = get_config('./a.yaml') print(cfg) if __name__ == "__main__": main()
输出:
argparse.py
import argparse from config import get_config from config import update_config def get_argument(): parser = argparse.ArgumentParser('ViT') parser.add_argument('-cfg', type=str, default=None) parser.add_argument('-dataset', type=str, default=None) parser.add_argument('-batch_size', type=str, default=None) arguments = parser.parse_args() return arguments def main(): cfg = get_config() print(cfg) print('-----------------') cfg = get_config('./a.yaml') print(cfg) print('-----------------') args = get_argument() cfg = update_config(cfg, args) print(cfg) if __name__ == "__main__": main()
输出:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。