当前位置:   article > 正文

config配置文件原理及使用_config文件

config文件

今天看完了EEGNet的论文准备搭建一下EEGNet的网络,然后想到之前看过网络配置文件的内容,然后想着以后开发自己的网络的能够规范和方便,所以就学习一下,并在这里记录一下,方便以后查阅。

config配置文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

config代码

a.yaml

DATA:
  BATCH_SIZE: 512
MODEL:
  TRANS:
    EMBED_DIM: 768
  • 1
  • 2
  • 3
  • 4
  • 5

config.py

from yacs.config import CfgNode as CN
import yaml


# 设置默认参数
_C = CN()

_C.DATA = CN()
_C.DATA.DATASET = 'cifar10'
_C.DATA.BATCH_SIZE = 128

_C.MODEL = CN()
_C.MODEL.NUM_CLASSES = 10

_C.MODEL.TRANS = CN()
_C.MODEL.TRANS.EMBED_DIM = 96
_C.MODEL.TRANS.DEPTHS = [2, 2, 6, 2]
_C.MODEL.TRANS.QKV_BIAS = False


# 通过yaml更新参数
def _update_config_from_file(config, cfg_file):
    config.defrost()
    config.merge_from_file(cfg_file) # .yaml

# 通过argparser.ArgumentParser更新参数
def update_config(config, args):
    if args.cfg:
        _update(config, args.cfg)
    if args.dataset:
        config.DATA.DATASET = args.datasert
    if args.batch_size:
        config.DATA.BATCH_SIZE = args.batch_size
    return config

def get_config(cfg_file=None):
    config = _C.clone()
    if cfg_file:
        _update_config_from_file(config, cfg_file)
    return config

def main():
    cfg  = get_config('./a.yaml')
    print(cfg)

if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

输出:

在这里插入图片描述

argparse.py

import argparse
from config import get_config
from config import update_config

def get_argument():
    parser = argparse.ArgumentParser('ViT')
    parser.add_argument('-cfg', type=str, default=None)
    parser.add_argument('-dataset', type=str, default=None)
    parser.add_argument('-batch_size', type=str, default=None)

    arguments = parser.parse_args()
    return arguments

def main():
    cfg = get_config()
    print(cfg)
    print('-----------------')
    cfg = get_config('./a.yaml')
    print(cfg)
    print('-----------------')
    args = get_argument()
    cfg = update_config(cfg, args)
    print(cfg)

if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

输出:

在这里插入图片描述

config配置文件的使用

PaddleViTSwinTransformer为例:

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考资料

自监督ViT算法:BeiT和MAE

PaddleViT

SwinTransformer

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/83951
推荐阅读
相关标签
  

闽ICP备14008679号