当前位置:   article > 正文

CenterNet:Objects as Points训练相关技巧_drop parameter base.level2.tree2.spm.conv3.bias.

drop parameter base.level2.tree2.spm.conv3.bias.

目录

1.训练自己数据集【以VOC数据集为样例】

1.准备数据:准备好自己要用的JFPGImages+Annotations

2.xml转json:【我比较熟悉的是.xml文件,但作者用的是coco类似的.json文件】

3.把文件夹放置的与作者生成的VOC数据集相同:

4.修改一些文件:

5.试着训练一下叭:

6.开始训练:【能够正常开始训练】

7.测试前修改文件:

7.mAP测试结果:

8.单张图片检测结果:

2.关于预训练模型

1.这里是如何将自己生成的.pth模型,转换为可以用来训练的模型

2.检验下生成的预训练模型是否有效:


1.训练自己数据集【以VOC数据集为样例】

【因为更熟悉voc,所以还是以vocal数据集为例】之前的环境配置和VOC训练:https://blog.csdn.net/weixin_38715903/article/details/98039181

1.准备数据:准备好自己要用的JFPGImages+Annotations

2.xml转json:【我比较熟悉的是.xml文件,但作者用的是coco类似的.json文件】

感恩博主:https://blog.csdn.net/u010397980/article/details/90341223

更改了一下代码,最终代码如下:

  1. #coding:utf-8
  2. # pip install lxml
  3. import os
  4. import glob
  5. import json
  6. import shutil
  7. import numpy as np
  8. import xml.etree.ElementTree as ET
  9. path2 = "./INF"
  10. START_BOUNDING_BOX_ID = 1
  11. def get(root, name):
  12. return root.findall(name)
  13. def get_and_check(root, name, length):
  14. vars = root.findall(name)
  15. if len(vars) == 0:
  16. raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
  17. if length > 0 and len(vars) != length:
  18. raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
  19. if length == 1:
  20. vars = vars[0]
  21. return vars
  22. def convert(xml_list, json_file):
  23. json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
  24. categories = pre_define_categories.copy()
  25. bnd_id = START_BOUNDING_BOX_ID
  26. all_categories = {}
  27. for index, line in enumerate(xml_list):
  28. # print("Processing %s"%(line))
  29. xml_f = line
  30. tree = ET.parse(xml_f)
  31. root = tree.getroot()
  32. filename = os.path.basename(xml_f)[:-4] + ".jpg"
  33. image_id = 1 + index
  34. size = get_and_check(root, 'size', 1)
  35. width = int(get_and_check(size, 'width', 1).text)
  36. height = int(get_and_check(size, 'height', 1).text)
  37. image = {'file_name': filename, 'height': height, 'width': width, 'id':image_id}
  38. json_dict['images'].append(image)
  39. ## Cruuently we do not support segmentation
  40. # segmented = get_and_check(root, 'segmented', 1).text
  41. # assert segmented == '0'
  42. for obj in get(root, 'object'):
  43. category = get_and_check(obj, 'name', 1).text
  44. if category in all_categories:
  45. all_categories[category] += 1
  46. else:
  47. all_categories[category] = 1
  48. if category not in categories:
  49. if only_care_pre_define_categories:
  50. continue
  51. new_id = len(categories) + 1
  52. print("[warning] category '{}' not in 'pre_define_categories'({}), create new id: {} automatically".format(category, pre_define_categories, new_id))
  53. categories[category] = new_id
  54. category_id = categories[category]
  55. bndbox = get_and_check(obj, 'bndbox', 1)
  56. xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
  57. ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
  58. xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
  59. ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
  60. assert(xmax > xmin), "xmax <= xmin, {}".format(line)
  61. assert(ymax > ymin), "ymax <= ymin, {}".format(line)
  62. o_width = abs(xmax - xmin)
  63. o_height = abs(ymax - ymin)
  64. ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
  65. image_id, 'bbox':[xmin, ymin, o_width, o_height],
  66. 'category_id': category_id, 'id': bnd_id, 'ignore': 0,
  67. 'segmentation': []}
  68. json_dict['annotations'].append(ann)
  69. bnd_id = bnd_id + 1
  70. for cate, cid in categories.items():
  71. cat = {'supercategory': 'none', 'id': cid, 'name': cate}
  72. json_dict['categories'].append(cat)
  73. json_fp = open(json_file, 'w')
  74. json_str = json.dumps(json_dict)
  75. json_fp.write(json_str)
  76. json_fp.close()
  77. print("------------create {} done--------------".format(json_file))
  78. print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories), all_categories.keys(), len(pre_define_categories), pre_define_categories.keys()))
  79. print("category: id --> {}".format(categories))
  80. print(categories.keys())
  81. print(categories.values())
  82. if __name__ == '__main__':
  83. classes = ['car', 'person', 'bicycle']
  84. pre_define_categories = {}
  85. for i, cls in enumerate(classes):
  86. pre_define_categories[cls] = i + 1
  87. # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
  88. only_care_pre_define_categories = True
  89. # only_care_pre_define_categories = False
  90. train_ratio = 0.9
  91. save_json_train = './INF/annotations/INF_train.json'
  92. save_json_val = './INF/annotations/INF_test.json'
  93. xml_dir = "./INF/Annotations"
  94. img_dir="./INF/JFPGImages"
  95. xml_list = glob.glob(xml_dir + "/*.xml")#返回所有匹配的.xml文件路径列表。
  96. xml_list = np.sort(xml_list)
  97. np.random.seed(100)
  98. np.random.shuffle(xml_list)
  99. #print(xml_list[:100])
  100. train_num = int(len(xml_list)*train_ratio)
  101. xml_list_train = xml_list[:train_num]
  102. xml_list_val = xml_list[train_num:]
  103. if os.path.exists(path2 + "/annotations"):
  104. shutil.rmtree(path2 + "/annotations")
  105. os.makedirs(path2 + "/annotations")
  106. if os.path.exists(path2 + "/images/train2019"):
  107. shutil.rmtree(path2 + "/images/train2019")
  108. os.makedirs(path2 + "/images/train2019")
  109. if os.path.exists(path2 + "/images/val2019"):
  110. shutil.rmtree(path2 +"/images/val2019")
  111. os.makedirs(path2 + "/images/val2019")
  112. convert(xml_list_train, save_json_train)
  113. convert(xml_list_val, save_json_val)
  114. f1 = open("./INF/train.txt", "w")
  115. for xml in xml_list_train:
  116. img1 = img_dir+xml[17:-4] + ".jpg"#'这里的17其实是'./INF/Annotations'的长度'
  117. #print(img1)
  118. f1.write(os.path.basename(xml)[:-4] + "\n")
  119. shutil.copyfile(img1, path2 + "/images/train2019/" + os.path.basename(img1))
  120. f2 = open("./INF/test.txt", "w")
  121. for xml in xml_list_val:
  122. img2 = img_dir+xml[17:-4] + ".jpg"#'这里的17其实是'./INF/Annotations'的长度'
  123. f2.write(os.path.basename(xml)[:-4] + "\n")
  124. shutil.copyfile(img2, path2 + "/images/val2019/" + os.path.basename(img2))
  125. f1.close()
  126. f2.close()
  127. print("-------------------------------")
  128. print("train number:", len(xml_list_train))
  129. print("val number:", len(xml_list_val))

3.把文件夹放置的与作者生成的VOC数据集相同:

  1. voc_INF|--annotations|--INF_test.json
  2. | |--INF_train.json
  3. |--images
  4. |--VOCdevkit
  5. PS:
  6. 1.'annotations'存放.json文件,如果你分了train,val,test三个部分,还要运行merge_pascal_json.py将train和val放在一个.json文件里
  7. 2.'image'存放所有的图片
  8. 3.'VOCdevkit'存放普通的VOC数据集包括JFPGImages、Annotations、ImageSets
  9. 4.上述存放的文件与后续修改路径有关

4.修改一些文件:

  • ~/CenterNet/src/lib/datasets/dataset文件夹中

将pascal.py复制为pascal_INF.py,修改部分路径代码如下:

  1. #从13行开始
  2. """类名与文件名一致"""
  3. class PascalINF(data.Dataset):
  4. """类别数目:20"""
  5. num_classes = 3
  6. default_resolution = [384, 384]
  7. mean = np.array([0.485, 0.456, 0.406],
  8. dtype=np.float32).reshape(1, 1, 3)
  9. std = np.array([0.229, 0.224, 0.225],
  10. dtype=np.float32).reshape(1, 1, 3)
  11. def __init__(self, opt, split):
  12. super(PascalINF, self).__init__()
  13. """data_dir:是存放你数据的文件名,我的是~/CenterNet/data/voc_INF/"""
  14. self.data_dir = os.path.join(opt.data_dir, 'voc_INF')
  15. self.img_dir = os.path.join(self.data_dir, 'images')
  16. """这里照着annotations中的文件名修改,我只有train和test"""
  17. _ann_name = {'train': 'train', 'val': 'test'}
  18. """这里照着annotations中的文件名修改,我的json文件命名规则是INF_test和INF_train"""
  19. self.annot_path = os.path.join(
  20. self.data_dir, 'annotations',
  21. 'INF_{}.json').format(_ann_name[split])
  22. self.max_objs = 50
  23. """修改你的类别,记得与生成json文件时顺序一致,涉及到class_id匹配问题"""
  24. self.class_name = ['__background__', 'car', 'person', 'bicycle']
  25. """4=class_number+1(blackground)"""
  26. self._valid_ids = np.arange(1, 4, dtype=np.int32)
  27. self.cat_ids = {v: i for i, v in enumerate(self._valid_ids)}
  28. self._data_rng = np.random.RandomState(123)
  29. self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571],
  30. dtype=np.float32)
  31. self._eig_vec = np.array([
  32. [-0.58752847, -0.69563484, 0.41340352],
  33. [-0.5832747, 0.00994535, -0.81221408],
  34. [-0.56089297, 0.71832671, 0.41158938]
  35. ], dtype=np.float32)
  36. self.split = split
  37. self.opt = opt
  • ~/CenterNet/src/lib/datasets文件夹中,修改dataset_factory.py
  1. #第10行添加
  2. from .dataset.coco import COCO
  3. from .dataset.pascal import PascalVOC
  4. """添加自己的数据集"""
  5. from .dataset.pascal_INF import PascalINF
  6. from .dataset.kitti import KITTI
  7. from .dataset.coco_hp import COCOHP
  8. #第17行添加
  9. dataset_factory = {
  10. 'coco': COCO,
  11. 'pascal': PascalVOC,
  12. 'kitti': KITTI,
  13. 'coco_hp': COCOHP,
  14. """添加自己的数据集"""
  15. 'inf': PascalINF
  16. }

5.试着训练一下叭:

  1. cd src
  2. # train
  3. sudo python3.6 main.py ctdet --exp_id INF --dataset inf --num_epochs 70 --lr_step 45,60 --batch_size 32 --master_batch 16 --lr 1.25e-4 --gpus 0,1
  4. """
  5. PS:
  6. 1.--exp_id INF 存放日志的文件名
  7. 2.--dataset inf 你的数据类型:刚刚更改的部分
  8. """

6.开始训练:【能够正常开始训练】

明天测试一下模型结果,就酱

7.测试前修改文件:

  • 需要更改的文件:\CenterNet\src\lib\debugger.py
  1. '第44-50行左右,添加elif dataset == 'inf',如下:'
  2. elif num_classes == 80 or dataset == 'coco':
  3. self.names = coco_class_name
  4. elif num_classes == 20 or dataset == 'pascal':
  5. self.names = pascal_class_name
  6. elif dataset == 'inf':
  7. self.names = inf_class_name
  8. '第440-447行左右,添加inf_class_name,如下:'
  9. gta_class_name = [
  10. 'p', 'v'
  11. ]
  12. inf_class_name = ["car", "person", "bicycle"]
  13. pascal_class_name = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
  14. "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike",
  15. "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
  • \CenterNet\src\tools\reval.py-->\CenterNet\src\tools\reval_inf.py
  1. 1.复制reval.py,并改名为reval_inf.py
  2. 2.根据需求修改代码:【3处】
  3. '修改引入的库'
  4. from datasets.pascal_inf import pascal_inf
  5. '#第33到37行左右,修改imdb的默认名称'
  6. parser.add_argument('--imdb', dest='imdb_name',
  7. help='dataset to re-evaluate',
  8. default='INF_test', type=str)
  9. def from_dets(imdb_name, detection_file, args):
  10. '修改函数名称'
  11. imdb = pascal_inf('test')
  12. imdb.competition_mode(args.comp_mode)
  13. imdb.config['matlab_eval'] = args.matlab_eval
  14. with open(os.path.join(detection_file), 'rb') as f:
  15. if 'json' in detection_file:
  16. dets = json.load(f)
  17. else:
  18. dets = pickle.load(f, encoding='latin1')
  19. # import pdb; pdb.set_trace()
  20. if args.apply_nms:
  21. print('Applying NMS to all detections')
  22. test_nms = 0.3
  23. nms_dets = apply_nms(dets, test_nms)
  24. else:
  25. nms_dets = dets
  26. print('Evaluating detections')
  27. imdb.evaluate_detections(nms_dets)
  • \CenterNet\src\tools\voc_eval_lib\datasets\pascal_voc.py-->\CenterNet\src\tools\voc_eval_lib\datasets\pascal_inf.py
  1. 1.将pascal_voc.py复制,并改名为pascal_inf.py
  2. 2.按需求修改代码:
  3. 1). #-*-coding:utf-8-*-
  4. 2).'初始化部分'5处】
  5. '类名修改'
  6. class pascal_inf(imdb):
  7. '输入参数修改'
  8. def __init__(self, image_set, use_diff=False):
  9. 'name命名修改:INF_test'
  10. name = 'INF_'+ image_set
  11. if use_diff:
  12. name += '_diff'
  13. imdb.__init__(self, name)
  14. self._image_set = image_set
  15. self._devkit_path = self._get_default_path()
  16. 'Data地址修改:cfg.DATA_DIR+'voc_INF'+ 'VOCdevkit'[==self._devkit_path]'
  17. self._data_path = os.path.join(self._devkit_path)
  18. '按自己类别修改:'
  19. self._classes = ('__background__', # always index 0
  20. 'car', 'person', 'bicycle')
  21. self._class_to_ind = dict(list(zip(self.classes,
  22. list(range(self.num_classes)))))
  23. self._image_ext = '.jpg'
  24. self._image_index = self._load_image_set_index()
  25. # Default to roidb handler
  26. self._roidb_handler = self.gt_roidb
  27. self._salt = str(uuid.uuid4())
  28. self._comp_id = 'comp4'
  29. 3). '数据地址修改:~/CenterNet/data/voc_INF/VOCdekit/'
  30. def _get_default_path(self):【1处】
  31. """
  32. Return the default path where PASCAL VOC is expected to be installed.
  33. """
  34. return os.path.join(cfg.DATA_DIR, 'voc_INF', 'VOCdevkit')
  35. 4).修改模式,因为我不需要年份【1处】
  36. def rpn_roidb(self):
  37. '这里修改:删去年份判断'
  38. if self._image_set != 'test':
  39. gt_roidb = self.gt_roidb()
  40. rpn_roidb = self._load_rpn_roidb(gt_roidb)
  41. roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
  42. else:
  43. roidb = self._load_rpn_roidb(None)
  44. return roidb
  45. 5).result的存储地址修改【1处】
  46. def _get_voc_results_file_template(self):
  47. # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
  48. filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
  49. 'result的存储地址:~/CenterNet/data/voc_INF/VOCdekit/results/'
  50. path = os.path.join(
  51. self._devkit_path,
  52. 'results',
  53. filename)
  54. return path
  55. 6)._do_python_eval参数修改【1处】
  56. def _do_python_eval(self, output_dir=None):
  57. annopath = os.path.join(
  58. self._devkit_path,
  59. 'Annotations',
  60. '{:s}.xml')
  61. imagesetfile = os.path.join(
  62. self._devkit_path,
  63. 'ImageSets',
  64. 'Main',
  65. self._image_set + '.txt')
  66. cachedir = os.path.join(self._devkit_path, 'annotations_cache')
  67. aps = []
  68. # The PASCAL VOC metric changed in 2010
  69. '我的数据集中没有年份判断,所以做了删改,这里选择直接使用07_metric的方式计算AP,也可以选择其他模式'
  70. use_07_metric = True
  71. if output_dir is not None and not os.path.isdir(output_dir):
  72. os.mkdir(output_dir)
  73. for i, cls in enumerate(self._classes):
  74. if cls == '__background__':
  75. continue
  76. filename = self._get_voc_results_file_template().format(cls)
  77. rec, prec, ap = voc_eval(
  78. filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
  79. use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
  80. aps += [ap]
  81. print(('AP for {} = {:.4f}'.format(cls, ap)))
  82. if output_dir is not None:
  83. with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
  84. pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
  85. print(('Mean AP = {:.4f}'.format(np.mean(aps))))
  86. print('~~~~~~~~')
  87. 7).更改库名啦:【1处】
  88. if __name__ == '__main__':
  89. from datasets.pascal_inf import pascal_inf
  90. d = pascal_voc('trainval')
  91. res = d.roidb
  92. from IPython import embed;
  93. embed()
  • /CenterNet/src/lib/datasets/dataset/pascal_INF.py【训练时我们更改的文件】
  1. 拉到文件末尾,修改使用的reval文件名:
  2. os.system('python tools/reval_inf.py ' + \
  3. '{}/results.json'.format(save_dir))

7.mAP测试结果:

sudo python3.6 test.py ctdet --exp_id INF --dataset inf --load_model ../exp/ctdet/INF/model_last.pth --flip_test

8.单张图片检测结果:

同样地:修改demo.py和opts.py文件【将其复制,并改名】

  1. 1).opts_inf:
  2. default_dataset_info = {
  3. 'ctdet': {'default_resolution': [512, 512], 'num_classes': 80,
  4. 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
  5. 'dataset': 'coco'},
  6. '改为'
  7. default_dataset_info = {
  8. 'ctdet': {'default_resolution': [512, 512], 'num_classes': 3,
  9. 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225],
  10. 'dataset': 'inf'},
  11. 2).demo_inf.py
  12. from opts import opts
  13. '改为'
  14. from opts_inf import opts

【我的数据集要保密,所以随便从百度上拿了几张图片给一下结果,以下图片来源于百度】

2.关于预训练模型

1.这里是如何将自己生成的.pth模型,转换为可以用来训练的模型

  • 先查看一下作者提供的预训练模型结构:
  1. model_pre=torch.load('dla34-ba72cf86.pth')
  2. for name in model_pre:
  3. print(name)

输出:

  1. base_layer.0.weight
  2. base_layer.1.weight
  3. base_layer.1.bias
  4. base_layer.1.running_mean
  5. base_layer.1.running_var
  6. level0.0.weight
  7. level0.1.weight
  8. level0.1.bias
  9. level0.1.running_mean
  10. level0.1.running_var
  11. level1.0.weight
  12. level1.1.weight
  13. level1.1.bias
  14. level1.1.running_mean
  15. level1.1.running_var
  16. level2.tree1.conv1.weight
  17. level2.tree1.bn1.weight
  18. level2.tree1.bn1.bias
  19. level2.tree1.bn1.running_mean
  20. level2.tree1.bn1.running_var
  21. level2.tree1.conv2.weight
  22. level2.tree1.bn2.weight
  23. level2.tree1.bn2.bias
  24. level2.tree1.bn2.running_mean
  25. level2.tree1.bn2.running_var
  26. level2.tree2.conv1.weight
  27. level2.tree2.bn1.weight
  28. level2.tree2.bn1.bias
  29. level2.tree2.bn1.running_mean
  30. level2.tree2.bn1.running_var
  31. level2.tree2.conv2.weight
  32. level2.tree2.bn2.weight
  33. level2.tree2.bn2.bias
  34. level2.tree2.bn2.running_mean
  35. level2.tree2.bn2.running_var
  36. level2.root.conv.weight
  37. level2.root.bn.weight
  38. level2.root.bn.bias
  39. level2.root.bn.running_mean
  40. level2.root.bn.running_var
  41. level2.project.0.weight
  42. level2.project.1.weight
  43. level2.project.1.bias
  44. level2.project.1.running_mean
  45. level2.project.1.running_var
  46. level3.tree1.tree1.conv1.weight
  47. level3.tree1.tree1.bn1.weight
  48. level3.tree1.tree1.bn1.bias
  49. level3.tree1.tree1.bn1.running_mean
  50. level3.tree1.tree1.bn1.running_var
  51. level3.tree1.tree1.conv2.weight
  52. level3.tree1.tree1.bn2.weight
  53. level3.tree1.tree1.bn2.bias
  54. level3.tree1.tree1.bn2.running_mean
  55. level3.tree1.tree1.bn2.running_var
  56. level3.tree1.tree2.conv1.weight
  57. level3.tree1.tree2.bn1.weight
  58. level3.tree1.tree2.bn1.bias
  59. level3.tree1.tree2.bn1.running_mean
  60. level3.tree1.tree2.bn1.running_var
  61. level3.tree1.tree2.conv2.weight
  62. level3.tree1.tree2.bn2.weight
  63. level3.tree1.tree2.bn2.bias
  64. level3.tree1.tree2.bn2.running_mean
  65. level3.tree1.tree2.bn2.running_var
  66. level3.tree1.root.conv.weight
  67. level3.tree1.root.bn.weight
  68. level3.tree1.root.bn.bias
  69. level3.tree1.root.bn.running_mean
  70. level3.tree1.root.bn.running_var
  71. level3.tree1.project.0.weight
  72. level3.tree1.project.1.weight
  73. level3.tree1.project.1.bias
  74. level3.tree1.project.1.running_mean
  75. level3.tree1.project.1.running_var
  76. level3.tree2.tree1.conv1.weight
  77. level3.tree2.tree1.bn1.weight
  78. level3.tree2.tree1.bn1.bias
  79. level3.tree2.tree1.bn1.running_mean
  80. level3.tree2.tree1.bn1.running_var
  81. level3.tree2.tree1.conv2.weight
  82. level3.tree2.tree1.bn2.weight
  83. level3.tree2.tree1.bn2.bias
  84. level3.tree2.tree1.bn2.running_mean
  85. level3.tree2.tree1.bn2.running_var
  86. level3.tree2.tree2.conv1.weight
  87. level3.tree2.tree2.bn1.weight
  88. level3.tree2.tree2.bn1.bias
  89. level3.tree2.tree2.bn1.running_mean
  90. level3.tree2.tree2.bn1.running_var
  91. level3.tree2.tree2.conv2.weight
  92. level3.tree2.tree2.bn2.weight
  93. level3.tree2.tree2.bn2.bias
  94. level3.tree2.tree2.bn2.running_mean
  95. level3.tree2.tree2.bn2.running_var
  96. level3.tree2.root.conv.weight
  97. level3.tree2.root.bn.weight
  98. level3.tree2.root.bn.bias
  99. level3.tree2.root.bn.running_mean
  100. level3.tree2.root.bn.running_var
  101. level3.project.0.weight
  102. level3.project.1.weight
  103. level3.project.1.bias
  104. level3.project.1.running_mean
  105. level3.project.1.running_var
  106. level4.tree1.tree1.conv1.weight
  107. level4.tree1.tree1.bn1.weight
  108. level4.tree1.tree1.bn1.bias
  109. level4.tree1.tree1.bn1.running_mean
  110. level4.tree1.tree1.bn1.running_var
  111. level4.tree1.tree1.conv2.weight
  112. level4.tree1.tree1.bn2.weight
  113. level4.tree1.tree1.bn2.bias
  114. level4.tree1.tree1.bn2.running_mean
  115. level4.tree1.tree1.bn2.running_var
  116. level4.tree1.tree2.conv1.weight
  117. level4.tree1.tree2.bn1.weight
  118. level4.tree1.tree2.bn1.bias
  119. level4.tree1.tree2.bn1.running_mean
  120. level4.tree1.tree2.bn1.running_var
  121. level4.tree1.tree2.conv2.weight
  122. level4.tree1.tree2.bn2.weight
  123. level4.tree1.tree2.bn2.bias
  124. level4.tree1.tree2.bn2.running_mean
  125. level4.tree1.tree2.bn2.running_var
  126. level4.tree1.root.conv.weight
  127. level4.tree1.root.bn.weight
  128. level4.tree1.root.bn.bias
  129. level4.tree1.root.bn.running_mean
  130. level4.tree1.root.bn.running_var
  131. level4.tree1.project.0.weight
  132. level4.tree1.project.1.weight
  133. level4.tree1.project.1.bias
  134. level4.tree1.project.1.running_mean
  135. level4.tree1.project.1.running_var
  136. level4.tree2.tree1.conv1.weight
  137. level4.tree2.tree1.bn1.weight
  138. level4.tree2.tree1.bn1.bias
  139. level4.tree2.tree1.bn1.running_mean
  140. level4.tree2.tree1.bn1.running_var
  141. level4.tree2.tree1.conv2.weight
  142. level4.tree2.tree1.bn2.weight
  143. level4.tree2.tree1.bn2.bias
  144. level4.tree2.tree1.bn2.running_mean
  145. level4.tree2.tree1.bn2.running_var
  146. level4.tree2.tree2.conv1.weight
  147. level4.tree2.tree2.bn1.weight
  148. level4.tree2.tree2.bn1.bias
  149. level4.tree2.tree2.bn1.running_mean
  150. level4.tree2.tree2.bn1.running_var
  151. level4.tree2.tree2.conv2.weight
  152. level4.tree2.tree2.bn2.weight
  153. level4.tree2.tree2.bn2.bias
  154. level4.tree2.tree2.bn2.running_mean
  155. level4.tree2.tree2.bn2.running_var
  156. level4.tree2.root.conv.weight
  157. level4.tree2.root.bn.weight
  158. level4.tree2.root.bn.bias
  159. level4.tree2.root.bn.running_mean
  160. level4.tree2.root.bn.running_var
  161. level4.project.0.weight
  162. level4.project.1.weight
  163. level4.project.1.bias
  164. level4.project.1.running_mean
  165. level4.project.1.running_var
  166. level5.tree1.conv1.weight
  167. level5.tree1.bn1.weight
  168. level5.tree1.bn1.bias
  169. level5.tree1.bn1.running_mean
  170. level5.tree1.bn1.running_var
  171. level5.tree1.conv2.weight
  172. level5.tree1.bn2.weight
  173. level5.tree1.bn2.bias
  174. level5.tree1.bn2.running_mean
  175. level5.tree1.bn2.running_var
  176. level5.tree2.conv1.weight
  177. level5.tree2.bn1.weight
  178. level5.tree2.bn1.bias
  179. level5.tree2.bn1.running_mean
  180. level5.tree2.bn1.running_var
  181. level5.tree2.conv2.weight
  182. level5.tree2.bn2.weight
  183. level5.tree2.bn2.bias
  184. level5.tree2.bn2.running_mean
  185. level5.tree2.bn2.running_var
  186. level5.root.conv.weight
  187. level5.root.bn.weight
  188. level5.root.bn.bias
  189. level5.root.bn.running_mean
  190. level5.root.bn.running_var
  191. level5.project.0.weight
  192. level5.project.1.weight
  193. level5.project.1.bias
  194. level5.project.1.running_mean
  195. level5.project.1.running_var
  196. fc.weight
  197. fc.bias
  • 我们训练后的模型如下结构:
  1. models_weights|--epoch
  2. |--state_dict|--'与预训练模型类似的参数'

model_weights['state_dict']的内容如下:

  1. model_weights=torch.load('ctdet_pascal_dla_384.pth')
  2. for name in model_weights['state_dict']:
  3. print(name)

可以看到输出的name与预训练模型相比,多了'base.',且多了后续层的参数比如'dla_up.'、'ida_up.'等开头以及'tracked'结尾的name参数:

  1. base.base_layer.0.weight
  2. base.base_layer.1.weight
  3. base.base_layer.1.bias
  4. base.base_layer.1.running_mean
  5. base.base_layer.1.running_var
  6. base.base_layer.1.num_batches_tracked
  7. base.level0.0.weight
  8. base.level0.1.weight
  9. base.level0.1.bias
  10. base.level0.1.running_mean
  11. base.level0.1.running_var
  12. base.level0.1.num_batches_tracked
  13. base.level1.0.weight
  14. base.level1.1.weight
  15. base.level1.1.bias
  16. base.level1.1.running_mean
  17. base.level1.1.running_var
  18. base.level1.1.num_batches_tracked
  19. base.level2.tree1.conv1.weight
  20. base.level2.tree1.bn1.weight
  21. base.level2.tree1.bn1.bias
  22. base.level2.tree1.bn1.running_mean
  23. base.level2.tree1.bn1.running_var
  24. base.level2.tree1.bn1.num_batches_tracked
  25. base.level2.tree1.conv2.weight
  26. base.level2.tree1.bn2.weight
  27. base.level2.tree1.bn2.bias
  28. base.level2.tree1.bn2.running_mean
  29. base.level2.tree1.bn2.running_var
  30. base.level2.tree1.bn2.num_batches_tracked
  31. base.level2.tree2.conv1.weight
  32. base.level2.tree2.bn1.weight
  33. base.level2.tree2.bn1.bias
  34. base.level2.tree2.bn1.running_mean
  35. base.level2.tree2.bn1.running_var
  36. base.level2.tree2.bn1.num_batches_tracked
  37. base.level2.tree2.conv2.weight
  38. base.level2.tree2.bn2.weight
  39. base.level2.tree2.bn2.bias
  40. base.level2.tree2.bn2.running_mean
  41. base.level2.tree2.bn2.running_var
  42. base.level2.tree2.bn2.num_batches_tracked
  43. base.level2.root.conv.weight
  44. base.level2.root.bn.weight
  45. base.level2.root.bn.bias
  46. base.level2.root.bn.running_mean
  47. base.level2.root.bn.running_var
  48. base.level2.root.bn.num_batches_tracked
  49. base.level2.project.0.weight
  50. base.level2.project.1.weight
  51. base.level2.project.1.bias
  52. base.level2.project.1.running_mean
  53. base.level2.project.1.running_var
  54. base.level2.project.1.num_batches_tracked
  55. base.level3.tree1.tree1.conv1.weight
  56. base.level3.tree1.tree1.bn1.weight
  57. base.level3.tree1.tree1.bn1.bias
  58. base.level3.tree1.tree1.bn1.running_mean
  59. base.level3.tree1.tree1.bn1.running_var
  60. base.level3.tree1.tree1.bn1.num_batches_tracked
  61. base.level3.tree1.tree1.conv2.weight
  62. base.level3.tree1.tree1.bn2.weight
  63. base.level3.tree1.tree1.bn2.bias
  64. base.level3.tree1.tree1.bn2.running_mean
  65. base.level3.tree1.tree1.bn2.running_var
  66. base.level3.tree1.tree1.bn2.num_batches_tracked
  67. base.level3.tree1.tree2.conv1.weight
  68. base.level3.tree1.tree2.bn1.weight
  69. base.level3.tree1.tree2.bn1.bias
  70. base.level3.tree1.tree2.bn1.running_mean
  71. base.level3.tree1.tree2.bn1.running_var
  72. base.level3.tree1.tree2.bn1.num_batches_tracked
  73. base.level3.tree1.tree2.conv2.weight
  74. base.level3.tree1.tree2.bn2.weight
  75. base.level3.tree1.tree2.bn2.bias
  76. base.level3.tree1.tree2.bn2.running_mean
  77. base.level3.tree1.tree2.bn2.running_var
  78. base.level3.tree1.tree2.bn2.num_batches_tracked
  79. base.level3.tree1.root.conv.weight
  80. base.level3.tree1.root.bn.weight
  81. base.level3.tree1.root.bn.bias
  82. base.level3.tree1.root.bn.running_mean
  83. base.level3.tree1.root.bn.running_var
  84. base.level3.tree1.root.bn.num_batches_tracked
  85. base.level3.tree1.project.0.weight
  86. base.level3.tree1.project.1.weight
  87. base.level3.tree1.project.1.bias
  88. base.level3.tree1.project.1.running_mean
  89. base.level3.tree1.project.1.running_var
  90. base.level3.tree1.project.1.num_batches_tracked
  91. base.level3.tree2.tree1.conv1.weight
  92. base.level3.tree2.tree1.bn1.weight
  93. base.level3.tree2.tree1.bn1.bias
  94. base.level3.tree2.tree1.bn1.running_mean
  95. base.level3.tree2.tree1.bn1.running_var
  96. base.level3.tree2.tree1.bn1.num_batches_tracked
  97. base.level3.tree2.tree1.conv2.weight
  98. base.level3.tree2.tree1.bn2.weight
  99. base.level3.tree2.tree1.bn2.bias
  100. base.level3.tree2.tree1.bn2.running_mean
  101. base.level3.tree2.tree1.bn2.running_var
  102. base.level3.tree2.tree1.bn2.num_batches_tracked
  103. base.level3.tree2.tree2.conv1.weight
  104. base.level3.tree2.tree2.bn1.weight
  105. base.level3.tree2.tree2.bn1.bias
  106. base.level3.tree2.tree2.bn1.running_mean
  107. base.level3.tree2.tree2.bn1.running_var
  108. base.level3.tree2.tree2.bn1.num_batches_tracked
  109. base.level3.tree2.tree2.conv2.weight
  110. base.level3.tree2.tree2.bn2.weight
  111. base.level3.tree2.tree2.bn2.bias
  112. base.level3.tree2.tree2.bn2.running_mean
  113. base.level3.tree2.tree2.bn2.running_var
  114. base.level3.tree2.tree2.bn2.num_batches_tracked
  115. base.level3.tree2.root.conv.weight
  116. base.level3.tree2.root.bn.weight
  117. base.level3.tree2.root.bn.bias
  118. base.level3.tree2.root.bn.running_mean
  119. base.level3.tree2.root.bn.running_var
  120. base.level3.tree2.root.bn.num_batches_tracked
  121. base.level3.project.0.weight
  122. base.level3.project.1.weight
  123. base.level3.project.1.bias
  124. base.level3.project.1.running_mean
  125. base.level3.project.1.running_var
  126. base.level3.project.1.num_batches_tracked
  127. base.level4.tree1.tree1.conv1.weight
  128. base.level4.tree1.tree1.bn1.weight
  129. base.level4.tree1.tree1.bn1.bias
  130. base.level4.tree1.tree1.bn1.running_mean
  131. base.level4.tree1.tree1.bn1.running_var
  132. base.level4.tree1.tree1.bn1.num_batches_tracked
  133. base.level4.tree1.tree1.conv2.weight
  134. base.level4.tree1.tree1.bn2.weight
  135. base.level4.tree1.tree1.bn2.bias
  136. base.level4.tree1.tree1.bn2.running_mean
  137. base.level4.tree1.tree1.bn2.running_var
  138. base.level4.tree1.tree1.bn2.num_batches_tracked
  139. base.level4.tree1.tree2.conv1.weight
  140. base.level4.tree1.tree2.bn1.weight
  141. base.level4.tree1.tree2.bn1.bias
  142. base.level4.tree1.tree2.bn1.running_mean
  143. base.level4.tree1.tree2.bn1.running_var
  144. base.level4.tree1.tree2.bn1.num_batches_tracked
  145. base.level4.tree1.tree2.conv2.weight
  146. base.level4.tree1.tree2.bn2.weight
  147. base.level4.tree1.tree2.bn2.bias
  148. base.level4.tree1.tree2.bn2.running_mean
  149. base.level4.tree1.tree2.bn2.running_var
  150. base.level4.tree1.tree2.bn2.num_batches_tracked
  151. base.level4.tree1.root.conv.weight
  152. base.level4.tree1.root.bn.weight
  153. base.level4.tree1.root.bn.bias
  154. base.level4.tree1.root.bn.running_mean
  155. base.level4.tree1.root.bn.running_var
  156. base.level4.tree1.root.bn.num_batches_tracked
  157. base.level4.tree1.project.0.weight
  158. base.level4.tree1.project.1.weight
  159. base.level4.tree1.project.1.bias
  160. base.level4.tree1.project.1.running_mean
  161. base.level4.tree1.project.1.running_var
  162. base.level4.tree1.project.1.num_batches_tracked
  163. base.level4.tree2.tree1.conv1.weight
  164. base.level4.tree2.tree1.bn1.weight
  165. base.level4.tree2.tree1.bn1.bias
  166. base.level4.tree2.tree1.bn1.running_mean
  167. base.level4.tree2.tree1.bn1.running_var
  168. base.level4.tree2.tree1.bn1.num_batches_tracked
  169. base.level4.tree2.tree1.conv2.weight
  170. base.level4.tree2.tree1.bn2.weight
  171. base.level4.tree2.tree1.bn2.bias
  172. base.level4.tree2.tree1.bn2.running_mean
  173. base.level4.tree2.tree1.bn2.running_var
  174. base.level4.tree2.tree1.bn2.num_batches_tracked
  175. base.level4.tree2.tree2.conv1.weight
  176. base.level4.tree2.tree2.bn1.weight
  177. base.level4.tree2.tree2.bn1.bias
  178. base.level4.tree2.tree2.bn1.running_mean
  179. base.level4.tree2.tree2.bn1.running_var
  180. base.level4.tree2.tree2.bn1.num_batches_tracked
  181. base.level4.tree2.tree2.conv2.weight
  182. base.level4.tree2.tree2.bn2.weight
  183. base.level4.tree2.tree2.bn2.bias
  184. base.level4.tree2.tree2.bn2.running_mean
  185. base.level4.tree2.tree2.bn2.running_var
  186. base.level4.tree2.tree2.bn2.num_batches_tracked
  187. base.level4.tree2.root.conv.weight
  188. base.level4.tree2.root.bn.weight
  189. base.level4.tree2.root.bn.bias
  190. base.level4.tree2.root.bn.running_mean
  191. base.level4.tree2.root.bn.running_var
  192. base.level4.tree2.root.bn.num_batches_tracked
  193. base.level4.project.0.weight
  194. base.level4.project.1.weight
  195. base.level4.project.1.bias
  196. base.level4.project.1.running_mean
  197. base.level4.project.1.running_var
  198. base.level4.project.1.num_batches_tracked
  199. base.level5.tree1.conv1.weight
  200. base.level5.tree1.bn1.weight
  201. base.level5.tree1.bn1.bias
  202. base.level5.tree1.bn1.running_mean
  203. base.level5.tree1.bn1.running_var
  204. base.level5.tree1.bn1.num_batches_tracked
  205. base.level5.tree1.conv2.weight
  206. base.level5.tree1.bn2.weight
  207. base.level5.tree1.bn2.bias
  208. base.level5.tree1.bn2.running_mean
  209. base.level5.tree1.bn2.running_var
  210. base.level5.tree1.bn2.num_batches_tracked
  211. base.level5.tree2.conv1.weight
  212. base.level5.tree2.bn1.weight
  213. base.level5.tree2.bn1.bias
  214. base.level5.tree2.bn1.running_mean
  215. base.level5.tree2.bn1.running_var
  216. base.level5.tree2.bn1.num_batches_tracked
  217. base.level5.tree2.conv2.weight
  218. base.level5.tree2.bn2.weight
  219. base.level5.tree2.bn2.bias
  220. base.level5.tree2.bn2.running_mean
  221. base.level5.tree2.bn2.running_var
  222. base.level5.tree2.bn2.num_batches_tracked
  223. base.level5.root.conv.weight
  224. base.level5.root.bn.weight
  225. base.level5.root.bn.bias
  226. base.level5.root.bn.running_mean
  227. base.level5.root.bn.running_var
  228. base.level5.root.bn.num_batches_tracked
  229. base.level5.project.0.weight
  230. base.level5.project.1.weight
  231. base.level5.project.1.bias
  232. base.level5.project.1.running_mean
  233. base.level5.project.1.running_var
  234. base.level5.project.1.num_batches_tracked
  235. base.fc.weight
  236. base.fc.bias
  237. dla_up.ida_0.proj_1.actf.0.weight
  238. dla_up.ida_0.proj_1.actf.0.bias
  239. dla_up.ida_0.proj_1.actf.0.running_mean
  240. dla_up.ida_0.proj_1.actf.0.running_var
  241. dla_up.ida_0.proj_1.actf.0.num_batches_tracked
  242. dla_up.ida_0.proj_1.conv.weight
  243. dla_up.ida_0.proj_1.conv.bias
  244. dla_up.ida_0.proj_1.conv.conv_offset_mask.weight
  245. dla_up.ida_0.proj_1.conv.conv_offset_mask.bias
  246. dla_up.ida_0.up_1.weight
  247. dla_up.ida_0.node_1.actf.0.weight
  248. dla_up.ida_0.node_1.actf.0.bias
  249. dla_up.ida_0.node_1.actf.0.running_mean
  250. dla_up.ida_0.node_1.actf.0.running_var
  251. dla_up.ida_0.node_1.actf.0.num_batches_tracked
  252. dla_up.ida_0.node_1.conv.weight
  253. dla_up.ida_0.node_1.conv.bias
  254. dla_up.ida_0.node_1.conv.conv_offset_mask.weight
  255. dla_up.ida_0.node_1.conv.conv_offset_mask.bias
  256. dla_up.ida_1.proj_1.actf.0.weight
  257. dla_up.ida_1.proj_1.actf.0.bias
  258. dla_up.ida_1.proj_1.actf.0.running_mean
  259. dla_up.ida_1.proj_1.actf.0.running_var
  260. dla_up.ida_1.proj_1.actf.0.num_batches_tracked
  261. dla_up.ida_1.proj_1.conv.weight
  262. dla_up.ida_1.proj_1.conv.bias
  263. dla_up.ida_1.proj_1.conv.conv_offset_mask.weight
  264. dla_up.ida_1.proj_1.conv.conv_offset_mask.bias
  265. dla_up.ida_1.up_1.weight
  266. dla_up.ida_1.node_1.actf.0.weight
  267. dla_up.ida_1.node_1.actf.0.bias
  268. dla_up.ida_1.node_1.actf.0.running_mean
  269. dla_up.ida_1.node_1.actf.0.running_var
  270. dla_up.ida_1.node_1.actf.0.num_batches_tracked
  271. dla_up.ida_1.node_1.conv.weight
  272. dla_up.ida_1.node_1.conv.bias
  273. dla_up.ida_1.node_1.conv.conv_offset_mask.weight
  274. dla_up.ida_1.node_1.conv.conv_offset_mask.bias
  275. dla_up.ida_1.proj_2.actf.0.weight
  276. dla_up.ida_1.proj_2.actf.0.bias
  277. dla_up.ida_1.proj_2.actf.0.running_mean
  278. dla_up.ida_1.proj_2.actf.0.running_var
  279. dla_up.ida_1.proj_2.actf.0.num_batches_tracked
  280. dla_up.ida_1.proj_2.conv.weight
  281. dla_up.ida_1.proj_2.conv.bias
  282. dla_up.ida_1.proj_2.conv.conv_offset_mask.weight
  283. dla_up.ida_1.proj_2.conv.conv_offset_mask.bias
  284. dla_up.ida_1.up_2.weight
  285. dla_up.ida_1.node_2.actf.0.weight
  286. dla_up.ida_1.node_2.actf.0.bias
  287. dla_up.ida_1.node_2.actf.0.running_mean
  288. dla_up.ida_1.node_2.actf.0.running_var
  289. dla_up.ida_1.node_2.actf.0.num_batches_tracked
  290. dla_up.ida_1.node_2.conv.weight
  291. dla_up.ida_1.node_2.conv.bias
  292. dla_up.ida_1.node_2.conv.conv_offset_mask.weight
  293. dla_up.ida_1.node_2.conv.conv_offset_mask.bias
  294. dla_up.ida_2.proj_1.actf.0.weight
  295. dla_up.ida_2.proj_1.actf.0.bias
  296. dla_up.ida_2.proj_1.actf.0.running_mean
  297. dla_up.ida_2.proj_1.actf.0.running_var
  298. dla_up.ida_2.proj_1.actf.0.num_batches_tracked
  299. dla_up.ida_2.proj_1.conv.weight
  300. dla_up.ida_2.proj_1.conv.bias
  301. dla_up.ida_2.proj_1.conv.conv_offset_mask.weight
  302. dla_up.ida_2.proj_1.conv.conv_offset_mask.bias
  303. dla_up.ida_2.up_1.weight
  304. dla_up.ida_2.node_1.actf.0.weight
  305. dla_up.ida_2.node_1.actf.0.bias
  306. dla_up.ida_2.node_1.actf.0.running_mean
  307. dla_up.ida_2.node_1.actf.0.running_var
  308. dla_up.ida_2.node_1.actf.0.num_batches_tracked
  309. dla_up.ida_2.node_1.conv.weight
  310. dla_up.ida_2.node_1.conv.bias
  311. dla_up.ida_2.node_1.conv.conv_offset_mask.weight
  312. dla_up.ida_2.node_1.conv.conv_offset_mask.bias
  313. dla_up.ida_2.proj_2.actf.0.weight
  314. dla_up.ida_2.proj_2.actf.0.bias
  315. dla_up.ida_2.proj_2.actf.0.running_mean
  316. dla_up.ida_2.proj_2.actf.0.running_var
  317. dla_up.ida_2.proj_2.actf.0.num_batches_tracked
  318. dla_up.ida_2.proj_2.conv.weight
  319. dla_up.ida_2.proj_2.conv.bias
  320. dla_up.ida_2.proj_2.conv.conv_offset_mask.weight
  321. dla_up.ida_2.proj_2.conv.conv_offset_mask.bias
  322. dla_up.ida_2.up_2.weight
  323. dla_up.ida_2.node_2.actf.0.weight
  324. dla_up.ida_2.node_2.actf.0.bias
  325. dla_up.ida_2.node_2.actf.0.running_mean
  326. dla_up.ida_2.node_2.actf.0.running_var
  327. dla_up.ida_2.node_2.actf.0.num_batches_tracked
  328. dla_up.ida_2.node_2.conv.weight
  329. dla_up.ida_2.node_2.conv.bias
  330. dla_up.ida_2.node_2.conv.conv_offset_mask.weight
  331. dla_up.ida_2.node_2.conv.conv_offset_mask.bias
  332. dla_up.ida_2.proj_3.actf.0.weight
  333. dla_up.ida_2.proj_3.actf.0.bias
  334. dla_up.ida_2.proj_3.actf.0.running_mean
  335. dla_up.ida_2.proj_3.actf.0.running_var
  336. dla_up.ida_2.proj_3.actf.0.num_batches_tracked
  337. dla_up.ida_2.proj_3.conv.weight
  338. dla_up.ida_2.proj_3.conv.bias
  339. dla_up.ida_2.proj_3.conv.conv_offset_mask.weight
  340. dla_up.ida_2.proj_3.conv.conv_offset_mask.bias
  341. dla_up.ida_2.up_3.weight
  342. dla_up.ida_2.node_3.actf.0.weight
  343. dla_up.ida_2.node_3.actf.0.bias
  344. dla_up.ida_2.node_3.actf.0.running_mean
  345. dla_up.ida_2.node_3.actf.0.running_var
  346. dla_up.ida_2.node_3.actf.0.num_batches_tracked
  347. dla_up.ida_2.node_3.conv.weight
  348. dla_up.ida_2.node_3.conv.bias
  349. dla_up.ida_2.node_3.conv.conv_offset_mask.weight
  350. dla_up.ida_2.node_3.conv.conv_offset_mask.bias
  351. ida_up.proj_1.actf.0.weight
  352. ida_up.proj_1.actf.0.bias
  353. ida_up.proj_1.actf.0.running_mean
  354. ida_up.proj_1.actf.0.running_var
  355. ida_up.proj_1.actf.0.num_batches_tracked
  356. ida_up.proj_1.conv.weight
  357. ida_up.proj_1.conv.bias
  358. ida_up.proj_1.conv.conv_offset_mask.weight
  359. ida_up.proj_1.conv.conv_offset_mask.bias
  360. ida_up.up_1.weight
  361. ida_up.node_1.actf.0.weight
  362. ida_up.node_1.actf.0.bias
  363. ida_up.node_1.actf.0.running_mean
  364. ida_up.node_1.actf.0.running_var
  365. ida_up.node_1.actf.0.num_batches_tracked
  366. ida_up.node_1.conv.weight
  367. ida_up.node_1.conv.bias
  368. ida_up.node_1.conv.conv_offset_mask.weight
  369. ida_up.node_1.conv.conv_offset_mask.bias
  370. ida_up.proj_2.actf.0.weight
  371. ida_up.proj_2.actf.0.bias
  372. ida_up.proj_2.actf.0.running_mean
  373. ida_up.proj_2.actf.0.running_var
  374. ida_up.proj_2.actf.0.num_batches_tracked
  375. ida_up.proj_2.conv.weight
  376. ida_up.proj_2.conv.bias
  377. ida_up.proj_2.conv.conv_offset_mask.weight
  378. ida_up.proj_2.conv.conv_offset_mask.bias
  379. ida_up.up_2.weight
  380. ida_up.node_2.actf.0.weight
  381. ida_up.node_2.actf.0.bias
  382. ida_up.node_2.actf.0.running_mean
  383. ida_up.node_2.actf.0.running_var
  384. ida_up.node_2.actf.0.num_batches_tracked
  385. ida_up.node_2.conv.weight
  386. ida_up.node_2.conv.bias
  387. ida_up.node_2.conv.conv_offset_mask.weight
  388. ida_up.node_2.conv.conv_offset_mask.bias
  389. hm.0.weight
  390. hm.0.bias
  391. hm.2.weight
  392. hm.2.bias
  393. wh.0.weight
  394. wh.0.bias
  395. wh.2.weight
  396. wh.2.bias
  397. reg.0.weight
  398. reg.0.bias
  399. reg.2.weight
  400. reg.2.bias
  1. '所以要删除的名字有如下特点:'
  2. name[-7:]=='tracked'
  3. name[:6]=='dla_up'
  4. name[:6]=='ida_up'
  5. name[:2]=='hm'
  6. name[:2]=='wh'
  7. name[:3]=='reg'

最终生成与预训练模型结构相似的结构模型代码如下:

  1. import torch
  2. from torch import nn
  3. new_model={}
  4. model_weights=torch.load('ctdet_pascal_dla_384.pth')
  5. for name in model_weights['state_dict']:
  6. print(name)
  7. n=0
  8. if(name[-7:]=='tracked'):
  9. n=1
  10. elif(name[:6]=='dla_up' or name[:6]=='ida_up'):
  11. n=1
  12. elif(name[:2]=='hm' or name[:2]=='wh' or name[:3]=='reg'):
  13. n=1
  14. if(n==0):
  15. new_model[name[5:]]=model_weights['state_dict'][name]
  16. for name in new_model:
  17. print(name)
  18. torch.save(new_model,'pre_pascal_dla_384.pth')

2.检验下生成的预训练模型是否有效:

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/308670
推荐阅读
相关标签
  

闽ICP备14008679号