赞
踩
目录
1.准备数据:准备好自己要用的JFPGImages+Annotations
2.xml转json:【我比较熟悉的是.xml文件,但作者用的是coco类似的.json文件】
1.这里是如何将自己生成的.pth模型,转换为可以用来训练的模型
【因为更熟悉voc,所以还是以vocal数据集为例】之前的环境配置和VOC训练:https://blog.csdn.net/weixin_38715903/article/details/98039181
感恩博主:https://blog.csdn.net/u010397980/article/details/90341223
更改了一下代码,最终代码如下:
- #coding:utf-8
- # pip install lxml
-
- import os
- import glob
- import json
- import shutil
- import numpy as np
- import xml.etree.ElementTree as ET
-
- path2 = "./INF"
-
- START_BOUNDING_BOX_ID = 1
-
- def get(root, name):
- return root.findall(name)
-
- def get_and_check(root, name, length):
- vars = root.findall(name)
- if len(vars) == 0:
- raise NotImplementedError('Can not find %s in %s.'%(name, root.tag))
- if length > 0 and len(vars) != length:
- raise NotImplementedError('The size of %s is supposed to be %d, but is %d.'%(name, length, len(vars)))
- if length == 1:
- vars = vars[0]
- return vars
-
- def convert(xml_list, json_file):
- json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
- categories = pre_define_categories.copy()
- bnd_id = START_BOUNDING_BOX_ID
- all_categories = {}
- for index, line in enumerate(xml_list):
- # print("Processing %s"%(line))
- xml_f = line
- tree = ET.parse(xml_f)
- root = tree.getroot()
-
- filename = os.path.basename(xml_f)[:-4] + ".jpg"
- image_id = 1 + index
- size = get_and_check(root, 'size', 1)
- width = int(get_and_check(size, 'width', 1).text)
- height = int(get_and_check(size, 'height', 1).text)
- image = {'file_name': filename, 'height': height, 'width': width, 'id':image_id}
- json_dict['images'].append(image)
- ## Cruuently we do not support segmentation
- # segmented = get_and_check(root, 'segmented', 1).text
- # assert segmented == '0'
- for obj in get(root, 'object'):
- category = get_and_check(obj, 'name', 1).text
- if category in all_categories:
- all_categories[category] += 1
- else:
- all_categories[category] = 1
- if category not in categories:
- if only_care_pre_define_categories:
- continue
- new_id = len(categories) + 1
- print("[warning] category '{}' not in 'pre_define_categories'({}), create new id: {} automatically".format(category, pre_define_categories, new_id))
- categories[category] = new_id
- category_id = categories[category]
- bndbox = get_and_check(obj, 'bndbox', 1)
- xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
- ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
- xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
- ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
- assert(xmax > xmin), "xmax <= xmin, {}".format(line)
- assert(ymax > ymin), "ymax <= ymin, {}".format(line)
- o_width = abs(xmax - xmin)
- o_height = abs(ymax - ymin)
- ann = {'area': o_width*o_height, 'iscrowd': 0, 'image_id':
- image_id, 'bbox':[xmin, ymin, o_width, o_height],
- 'category_id': category_id, 'id': bnd_id, 'ignore': 0,
- 'segmentation': []}
- json_dict['annotations'].append(ann)
- bnd_id = bnd_id + 1
-
- for cate, cid in categories.items():
- cat = {'supercategory': 'none', 'id': cid, 'name': cate}
- json_dict['categories'].append(cat)
- json_fp = open(json_file, 'w')
- json_str = json.dumps(json_dict)
- json_fp.write(json_str)
- json_fp.close()
- print("------------create {} done--------------".format(json_file))
- print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories), all_categories.keys(), len(pre_define_categories), pre_define_categories.keys()))
- print("category: id --> {}".format(categories))
- print(categories.keys())
- print(categories.values())
-
- if __name__ == '__main__':
- classes = ['car', 'person', 'bicycle']
- pre_define_categories = {}
- for i, cls in enumerate(classes):
- pre_define_categories[cls] = i + 1
- # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
- only_care_pre_define_categories = True
- # only_care_pre_define_categories = False
-
- train_ratio = 0.9
- save_json_train = './INF/annotations/INF_train.json'
- save_json_val = './INF/annotations/INF_test.json'
- xml_dir = "./INF/Annotations"
- img_dir="./INF/JFPGImages"
-
- xml_list = glob.glob(xml_dir + "/*.xml")#返回所有匹配的.xml文件路径列表。
- xml_list = np.sort(xml_list)
- np.random.seed(100)
- np.random.shuffle(xml_list)
- #print(xml_list[:100])
- train_num = int(len(xml_list)*train_ratio)
- xml_list_train = xml_list[:train_num]
- xml_list_val = xml_list[train_num:]
-
- if os.path.exists(path2 + "/annotations"):
- shutil.rmtree(path2 + "/annotations")
- os.makedirs(path2 + "/annotations")
- if os.path.exists(path2 + "/images/train2019"):
- shutil.rmtree(path2 + "/images/train2019")
- os.makedirs(path2 + "/images/train2019")
- if os.path.exists(path2 + "/images/val2019"):
- shutil.rmtree(path2 +"/images/val2019")
- os.makedirs(path2 + "/images/val2019")
-
- convert(xml_list_train, save_json_train)
- convert(xml_list_val, save_json_val)
-
- f1 = open("./INF/train.txt", "w")
- for xml in xml_list_train:
- img1 = img_dir+xml[17:-4] + ".jpg"#'这里的17其实是'./INF/Annotations'的长度'
- #print(img1)
- f1.write(os.path.basename(xml)[:-4] + "\n")
- shutil.copyfile(img1, path2 + "/images/train2019/" + os.path.basename(img1))
-
- f2 = open("./INF/test.txt", "w")
- for xml in xml_list_val:
- img2 = img_dir+xml[17:-4] + ".jpg"#'这里的17其实是'./INF/Annotations'的长度'
- f2.write(os.path.basename(xml)[:-4] + "\n")
- shutil.copyfile(img2, path2 + "/images/val2019/" + os.path.basename(img2))
-
- f1.close()
- f2.close()
- print("-------------------------------")
- print("train number:", len(xml_list_train))
- print("val number:", len(xml_list_val))
- voc_INF|--annotations|--INF_test.json
- | |--INF_train.json
- |--images
- |--VOCdevkit
-
- PS:
- 1.'annotations'存放.json文件,如果你分了train,val,test三个部分,还要运行merge_pascal_json.py将train和val放在一个.json文件里
- 2.'image'存放所有的图片
- 3.'VOCdevkit'存放普通的VOC数据集包括JFPGImages、Annotations、ImageSets
- 4.上述存放的文件与后续修改路径有关
将pascal.py复制为pascal_INF.py,修改部分路径代码如下:
- #从13行开始
- """类名与文件名一致"""
- class PascalINF(data.Dataset):
- """类别数目:20"""
- num_classes = 3
- default_resolution = [384, 384]
- mean = np.array([0.485, 0.456, 0.406],
- dtype=np.float32).reshape(1, 1, 3)
- std = np.array([0.229, 0.224, 0.225],
- dtype=np.float32).reshape(1, 1, 3)
-
- def __init__(self, opt, split):
- super(PascalINF, self).__init__()
- """data_dir:是存放你数据的文件名,我的是~/CenterNet/data/voc_INF/"""
- self.data_dir = os.path.join(opt.data_dir, 'voc_INF')
- self.img_dir = os.path.join(self.data_dir, 'images')
- """这里照着annotations中的文件名修改,我只有train和test"""
- _ann_name = {'train': 'train', 'val': 'test'}
- """这里照着annotations中的文件名修改,我的json文件命名规则是INF_test和INF_train"""
- self.annot_path = os.path.join(
- self.data_dir, 'annotations',
- 'INF_{}.json').format(_ann_name[split])
- self.max_objs = 50
- """修改你的类别,记得与生成json文件时顺序一致,涉及到class_id匹配问题"""
- self.class_name = ['__background__', 'car', 'person', 'bicycle']
- """4=class_number+1(blackground)"""
- self._valid_ids = np.arange(1, 4, dtype=np.int32)
- self.cat_ids = {v: i for i, v in enumerate(self._valid_ids)}
- self._data_rng = np.random.RandomState(123)
- self._eig_val = np.array([0.2141788, 0.01817699, 0.00341571],
- dtype=np.float32)
- self._eig_vec = np.array([
- [-0.58752847, -0.69563484, 0.41340352],
- [-0.5832747, 0.00994535, -0.81221408],
- [-0.56089297, 0.71832671, 0.41158938]
- ], dtype=np.float32)
- self.split = split
- self.opt = opt
- #第10行添加
- from .dataset.coco import COCO
- from .dataset.pascal import PascalVOC
- """添加自己的数据集"""
- from .dataset.pascal_INF import PascalINF
- from .dataset.kitti import KITTI
- from .dataset.coco_hp import COCOHP
-
- #第17行添加
- dataset_factory = {
- 'coco': COCO,
- 'pascal': PascalVOC,
- 'kitti': KITTI,
- 'coco_hp': COCOHP,
- """添加自己的数据集"""
- 'inf': PascalINF
- }
- cd src
- # train
- sudo python3.6 main.py ctdet --exp_id INF --dataset inf --num_epochs 70 --lr_step 45,60 --batch_size 32 --master_batch 16 --lr 1.25e-4 --gpus 0,1
-
- """
- PS:
- 1.--exp_id INF 存放日志的文件名
- 2.--dataset inf 你的数据类型:刚刚更改的部分
- """
明天测试一下模型结果,就酱
- '第44-50行左右,添加elif dataset == 'inf',如下:'
- elif num_classes == 80 or dataset == 'coco':
- self.names = coco_class_name
- elif num_classes == 20 or dataset == 'pascal':
- self.names = pascal_class_name
- elif dataset == 'inf':
- self.names = inf_class_name
-
- '第440-447行左右,添加inf_class_name,如下:'
- gta_class_name = [
- 'p', 'v'
- ]
- inf_class_name = ["car", "person", "bicycle"]
-
- pascal_class_name = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus",
- "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike",
- "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
- 1.复制reval.py,并改名为reval_inf.py
- 2.根据需求修改代码:【3处】
- '修改引入的库'
- from datasets.pascal_inf import pascal_inf
-
- '#第33到37行左右,修改imdb的默认名称'
- parser.add_argument('--imdb', dest='imdb_name',
- help='dataset to re-evaluate',
- default='INF_test', type=str)
-
- def from_dets(imdb_name, detection_file, args):
- '修改函数名称'
- imdb = pascal_inf('test')
- imdb.competition_mode(args.comp_mode)
- imdb.config['matlab_eval'] = args.matlab_eval
- with open(os.path.join(detection_file), 'rb') as f:
- if 'json' in detection_file:
- dets = json.load(f)
- else:
- dets = pickle.load(f, encoding='latin1')
- # import pdb; pdb.set_trace()
- if args.apply_nms:
- print('Applying NMS to all detections')
- test_nms = 0.3
- nms_dets = apply_nms(dets, test_nms)
- else:
- nms_dets = dets
-
- print('Evaluating detections')
- imdb.evaluate_detections(nms_dets)
- 1.将pascal_voc.py复制,并改名为pascal_inf.py
- 2.按需求修改代码:
- 1). #-*-coding:utf-8-*-
- 2).'初始化部分'【5处】
- '类名修改'
- class pascal_inf(imdb):
- '输入参数修改'
- def __init__(self, image_set, use_diff=False):
- 'name命名修改:INF_test'
- name = 'INF_'+ image_set
- if use_diff:
- name += '_diff'
- imdb.__init__(self, name)
- self._image_set = image_set
- self._devkit_path = self._get_default_path()
- 'Data地址修改:cfg.DATA_DIR+'voc_INF'+ 'VOCdevkit'[==self._devkit_path]'
- self._data_path = os.path.join(self._devkit_path)
- '按自己类别修改:'
- self._classes = ('__background__', # always index 0
- 'car', 'person', 'bicycle')
- self._class_to_ind = dict(list(zip(self.classes,
- list(range(self.num_classes)))))
- self._image_ext = '.jpg'
- self._image_index = self._load_image_set_index()
- # Default to roidb handler
- self._roidb_handler = self.gt_roidb
- self._salt = str(uuid.uuid4())
- self._comp_id = 'comp4'
-
- 3). '数据地址修改:~/CenterNet/data/voc_INF/VOCdekit/'
- def _get_default_path(self):【1处】
- """
- Return the default path where PASCAL VOC is expected to be installed.
- """
- return os.path.join(cfg.DATA_DIR, 'voc_INF', 'VOCdevkit')
-
- 4).修改模式,因为我不需要年份【1处】
- def rpn_roidb(self):
- '这里修改:删去年份判断'
- if self._image_set != 'test':
- gt_roidb = self.gt_roidb()
- rpn_roidb = self._load_rpn_roidb(gt_roidb)
- roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
- else:
- roidb = self._load_rpn_roidb(None)
- return roidb
-
- 5).result的存储地址修改【1处】
- def _get_voc_results_file_template(self):
- # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
- filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
- 'result的存储地址:~/CenterNet/data/voc_INF/VOCdekit/results/'
- path = os.path.join(
- self._devkit_path,
- 'results',
- filename)
- return path
-
- 6)._do_python_eval参数修改【1处】
- def _do_python_eval(self, output_dir=None):
- annopath = os.path.join(
- self._devkit_path,
- 'Annotations',
- '{:s}.xml')
- imagesetfile = os.path.join(
- self._devkit_path,
- 'ImageSets',
- 'Main',
- self._image_set + '.txt')
- cachedir = os.path.join(self._devkit_path, 'annotations_cache')
- aps = []
- # The PASCAL VOC metric changed in 2010
- '我的数据集中没有年份判断,所以做了删改,这里选择直接使用07_metric的方式计算AP,也可以选择其他模式'
- use_07_metric = True
- if output_dir is not None and not os.path.isdir(output_dir):
- os.mkdir(output_dir)
- for i, cls in enumerate(self._classes):
- if cls == '__background__':
- continue
- filename = self._get_voc_results_file_template().format(cls)
- rec, prec, ap = voc_eval(
- filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
- use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
- aps += [ap]
- print(('AP for {} = {:.4f}'.format(cls, ap)))
- if output_dir is not None:
- with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
- pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
- print(('Mean AP = {:.4f}'.format(np.mean(aps))))
- print('~~~~~~~~')
- 7).更改库名啦:【1处】
- if __name__ == '__main__':
- from datasets.pascal_inf import pascal_inf
-
- d = pascal_voc('trainval')
- res = d.roidb
- from IPython import embed;
-
- embed()
- 拉到文件末尾,修改使用的reval文件名:
- os.system('python tools/reval_inf.py ' + \
- '{}/results.json'.format(save_dir))
sudo python3.6 test.py ctdet --exp_id INF --dataset inf --load_model ../exp/ctdet/INF/model_last.pth --flip_test
同样地:修改demo.py和opts.py文件【将其复制,并改名】
- 1).opts_inf:
- default_dataset_info = {
- 'ctdet': {'default_resolution': [512, 512], 'num_classes': 80,
- 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
- 'dataset': 'coco'},
-
- '改为'
-
- default_dataset_info = {
- 'ctdet': {'default_resolution': [512, 512], 'num_classes': 3,
- 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225],
- 'dataset': 'inf'},
- 2).demo_inf.py
- from opts import opts
- '改为'
- from opts_inf import opts
【我的数据集要保密,所以随便从百度上拿了几张图片给一下结果,以下图片来源于百度】
- model_pre=torch.load('dla34-ba72cf86.pth')
- for name in model_pre:
- print(name)
输出:
base_layer.0.weight base_layer.1.weight base_layer.1.bias base_layer.1.running_mean base_layer.1.running_var level0.0.weight level0.1.weight level0.1.bias level0.1.running_mean level0.1.running_var level1.0.weight level1.1.weight level1.1.bias level1.1.running_mean level1.1.running_var level2.tree1.conv1.weight level2.tree1.bn1.weight level2.tree1.bn1.bias level2.tree1.bn1.running_mean level2.tree1.bn1.running_var level2.tree1.conv2.weight level2.tree1.bn2.weight level2.tree1.bn2.bias level2.tree1.bn2.running_mean level2.tree1.bn2.running_var level2.tree2.conv1.weight level2.tree2.bn1.weight level2.tree2.bn1.bias level2.tree2.bn1.running_mean level2.tree2.bn1.running_var level2.tree2.conv2.weight level2.tree2.bn2.weight level2.tree2.bn2.bias level2.tree2.bn2.running_mean level2.tree2.bn2.running_var level2.root.conv.weight level2.root.bn.weight level2.root.bn.bias level2.root.bn.running_mean level2.root.bn.running_var level2.project.0.weight level2.project.1.weight level2.project.1.bias level2.project.1.running_mean level2.project.1.running_var level3.tree1.tree1.conv1.weight level3.tree1.tree1.bn1.weight level3.tree1.tree1.bn1.bias level3.tree1.tree1.bn1.running_mean level3.tree1.tree1.bn1.running_var level3.tree1.tree1.conv2.weight level3.tree1.tree1.bn2.weight level3.tree1.tree1.bn2.bias level3.tree1.tree1.bn2.running_mean level3.tree1.tree1.bn2.running_var level3.tree1.tree2.conv1.weight level3.tree1.tree2.bn1.weight level3.tree1.tree2.bn1.bias level3.tree1.tree2.bn1.running_mean level3.tree1.tree2.bn1.running_var level3.tree1.tree2.conv2.weight level3.tree1.tree2.bn2.weight level3.tree1.tree2.bn2.bias level3.tree1.tree2.bn2.running_mean level3.tree1.tree2.bn2.running_var level3.tree1.root.conv.weight level3.tree1.root.bn.weight level3.tree1.root.bn.bias level3.tree1.root.bn.running_mean level3.tree1.root.bn.running_var level3.tree1.project.0.weight level3.tree1.project.1.weight level3.tree1.project.1.bias level3.tree1.project.1.running_mean level3.tree1.project.1.running_var level3.tree2.tree1.conv1.weight level3.tree2.tree1.bn1.weight level3.tree2.tree1.bn1.bias level3.tree2.tree1.bn1.running_mean level3.tree2.tree1.bn1.running_var level3.tree2.tree1.conv2.weight level3.tree2.tree1.bn2.weight level3.tree2.tree1.bn2.bias level3.tree2.tree1.bn2.running_mean level3.tree2.tree1.bn2.running_var level3.tree2.tree2.conv1.weight level3.tree2.tree2.bn1.weight level3.tree2.tree2.bn1.bias level3.tree2.tree2.bn1.running_mean level3.tree2.tree2.bn1.running_var level3.tree2.tree2.conv2.weight level3.tree2.tree2.bn2.weight level3.tree2.tree2.bn2.bias level3.tree2.tree2.bn2.running_mean level3.tree2.tree2.bn2.running_var level3.tree2.root.conv.weight level3.tree2.root.bn.weight level3.tree2.root.bn.bias level3.tree2.root.bn.running_mean level3.tree2.root.bn.running_var level3.project.0.weight level3.project.1.weight level3.project.1.bias level3.project.1.running_mean level3.project.1.running_var level4.tree1.tree1.conv1.weight level4.tree1.tree1.bn1.weight level4.tree1.tree1.bn1.bias level4.tree1.tree1.bn1.running_mean level4.tree1.tree1.bn1.running_var level4.tree1.tree1.conv2.weight level4.tree1.tree1.bn2.weight level4.tree1.tree1.bn2.bias level4.tree1.tree1.bn2.running_mean level4.tree1.tree1.bn2.running_var level4.tree1.tree2.conv1.weight level4.tree1.tree2.bn1.weight level4.tree1.tree2.bn1.bias level4.tree1.tree2.bn1.running_mean level4.tree1.tree2.bn1.running_var level4.tree1.tree2.conv2.weight level4.tree1.tree2.bn2.weight level4.tree1.tree2.bn2.bias level4.tree1.tree2.bn2.running_mean level4.tree1.tree2.bn2.running_var level4.tree1.root.conv.weight level4.tree1.root.bn.weight level4.tree1.root.bn.bias level4.tree1.root.bn.running_mean level4.tree1.root.bn.running_var level4.tree1.project.0.weight level4.tree1.project.1.weight level4.tree1.project.1.bias level4.tree1.project.1.running_mean level4.tree1.project.1.running_var level4.tree2.tree1.conv1.weight level4.tree2.tree1.bn1.weight level4.tree2.tree1.bn1.bias level4.tree2.tree1.bn1.running_mean level4.tree2.tree1.bn1.running_var level4.tree2.tree1.conv2.weight level4.tree2.tree1.bn2.weight level4.tree2.tree1.bn2.bias level4.tree2.tree1.bn2.running_mean level4.tree2.tree1.bn2.running_var level4.tree2.tree2.conv1.weight level4.tree2.tree2.bn1.weight level4.tree2.tree2.bn1.bias level4.tree2.tree2.bn1.running_mean level4.tree2.tree2.bn1.running_var level4.tree2.tree2.conv2.weight level4.tree2.tree2.bn2.weight level4.tree2.tree2.bn2.bias level4.tree2.tree2.bn2.running_mean level4.tree2.tree2.bn2.running_var level4.tree2.root.conv.weight level4.tree2.root.bn.weight level4.tree2.root.bn.bias level4.tree2.root.bn.running_mean level4.tree2.root.bn.running_var level4.project.0.weight level4.project.1.weight level4.project.1.bias level4.project.1.running_mean level4.project.1.running_var level5.tree1.conv1.weight level5.tree1.bn1.weight level5.tree1.bn1.bias level5.tree1.bn1.running_mean level5.tree1.bn1.running_var level5.tree1.conv2.weight level5.tree1.bn2.weight level5.tree1.bn2.bias level5.tree1.bn2.running_mean level5.tree1.bn2.running_var level5.tree2.conv1.weight level5.tree2.bn1.weight level5.tree2.bn1.bias level5.tree2.bn1.running_mean level5.tree2.bn1.running_var level5.tree2.conv2.weight level5.tree2.bn2.weight level5.tree2.bn2.bias level5.tree2.bn2.running_mean level5.tree2.bn2.running_var level5.root.conv.weight level5.root.bn.weight level5.root.bn.bias level5.root.bn.running_mean level5.root.bn.running_var level5.project.0.weight level5.project.1.weight level5.project.1.bias level5.project.1.running_mean level5.project.1.running_var fc.weight fc.bias
- models_weights|--epoch
- |--state_dict|--'与预训练模型类似的参数'
model_weights['state_dict']的内容如下:
- model_weights=torch.load('ctdet_pascal_dla_384.pth')
- for name in model_weights['state_dict']:
- print(name)
可以看到输出的name与预训练模型相比,多了'base.',且多了后续层的参数比如'dla_up.'、'ida_up.'等开头以及'tracked'结尾的name参数:
base.base_layer.0.weight base.base_layer.1.weight base.base_layer.1.bias base.base_layer.1.running_mean base.base_layer.1.running_var base.base_layer.1.num_batches_tracked base.level0.0.weight base.level0.1.weight base.level0.1.bias base.level0.1.running_mean base.level0.1.running_var base.level0.1.num_batches_tracked base.level1.0.weight base.level1.1.weight base.level1.1.bias base.level1.1.running_mean base.level1.1.running_var base.level1.1.num_batches_tracked base.level2.tree1.conv1.weight base.level2.tree1.bn1.weight base.level2.tree1.bn1.bias base.level2.tree1.bn1.running_mean base.level2.tree1.bn1.running_var base.level2.tree1.bn1.num_batches_tracked base.level2.tree1.conv2.weight base.level2.tree1.bn2.weight base.level2.tree1.bn2.bias base.level2.tree1.bn2.running_mean base.level2.tree1.bn2.running_var base.level2.tree1.bn2.num_batches_tracked base.level2.tree2.conv1.weight base.level2.tree2.bn1.weight base.level2.tree2.bn1.bias base.level2.tree2.bn1.running_mean base.level2.tree2.bn1.running_var base.level2.tree2.bn1.num_batches_tracked base.level2.tree2.conv2.weight base.level2.tree2.bn2.weight base.level2.tree2.bn2.bias base.level2.tree2.bn2.running_mean base.level2.tree2.bn2.running_var base.level2.tree2.bn2.num_batches_tracked base.level2.root.conv.weight base.level2.root.bn.weight base.level2.root.bn.bias base.level2.root.bn.running_mean base.level2.root.bn.running_var base.level2.root.bn.num_batches_tracked base.level2.project.0.weight base.level2.project.1.weight base.level2.project.1.bias base.level2.project.1.running_mean base.level2.project.1.running_var base.level2.project.1.num_batches_tracked base.level3.tree1.tree1.conv1.weight base.level3.tree1.tree1.bn1.weight base.level3.tree1.tree1.bn1.bias base.level3.tree1.tree1.bn1.running_mean base.level3.tree1.tree1.bn1.running_var base.level3.tree1.tree1.bn1.num_batches_tracked base.level3.tree1.tree1.conv2.weight base.level3.tree1.tree1.bn2.weight base.level3.tree1.tree1.bn2.bias base.level3.tree1.tree1.bn2.running_mean base.level3.tree1.tree1.bn2.running_var base.level3.tree1.tree1.bn2.num_batches_tracked base.level3.tree1.tree2.conv1.weight base.level3.tree1.tree2.bn1.weight base.level3.tree1.tree2.bn1.bias base.level3.tree1.tree2.bn1.running_mean base.level3.tree1.tree2.bn1.running_var base.level3.tree1.tree2.bn1.num_batches_tracked base.level3.tree1.tree2.conv2.weight base.level3.tree1.tree2.bn2.weight base.level3.tree1.tree2.bn2.bias base.level3.tree1.tree2.bn2.running_mean base.level3.tree1.tree2.bn2.running_var base.level3.tree1.tree2.bn2.num_batches_tracked base.level3.tree1.root.conv.weight base.level3.tree1.root.bn.weight base.level3.tree1.root.bn.bias base.level3.tree1.root.bn.running_mean base.level3.tree1.root.bn.running_var base.level3.tree1.root.bn.num_batches_tracked base.level3.tree1.project.0.weight base.level3.tree1.project.1.weight base.level3.tree1.project.1.bias base.level3.tree1.project.1.running_mean base.level3.tree1.project.1.running_var base.level3.tree1.project.1.num_batches_tracked base.level3.tree2.tree1.conv1.weight base.level3.tree2.tree1.bn1.weight base.level3.tree2.tree1.bn1.bias base.level3.tree2.tree1.bn1.running_mean base.level3.tree2.tree1.bn1.running_var base.level3.tree2.tree1.bn1.num_batches_tracked base.level3.tree2.tree1.conv2.weight base.level3.tree2.tree1.bn2.weight base.level3.tree2.tree1.bn2.bias base.level3.tree2.tree1.bn2.running_mean base.level3.tree2.tree1.bn2.running_var base.level3.tree2.tree1.bn2.num_batches_tracked base.level3.tree2.tree2.conv1.weight base.level3.tree2.tree2.bn1.weight base.level3.tree2.tree2.bn1.bias base.level3.tree2.tree2.bn1.running_mean base.level3.tree2.tree2.bn1.running_var base.level3.tree2.tree2.bn1.num_batches_tracked base.level3.tree2.tree2.conv2.weight base.level3.tree2.tree2.bn2.weight base.level3.tree2.tree2.bn2.bias base.level3.tree2.tree2.bn2.running_mean base.level3.tree2.tree2.bn2.running_var base.level3.tree2.tree2.bn2.num_batches_tracked base.level3.tree2.root.conv.weight base.level3.tree2.root.bn.weight base.level3.tree2.root.bn.bias base.level3.tree2.root.bn.running_mean base.level3.tree2.root.bn.running_var base.level3.tree2.root.bn.num_batches_tracked base.level3.project.0.weight base.level3.project.1.weight base.level3.project.1.bias base.level3.project.1.running_mean base.level3.project.1.running_var base.level3.project.1.num_batches_tracked base.level4.tree1.tree1.conv1.weight base.level4.tree1.tree1.bn1.weight base.level4.tree1.tree1.bn1.bias base.level4.tree1.tree1.bn1.running_mean base.level4.tree1.tree1.bn1.running_var base.level4.tree1.tree1.bn1.num_batches_tracked base.level4.tree1.tree1.conv2.weight base.level4.tree1.tree1.bn2.weight base.level4.tree1.tree1.bn2.bias base.level4.tree1.tree1.bn2.running_mean base.level4.tree1.tree1.bn2.running_var base.level4.tree1.tree1.bn2.num_batches_tracked base.level4.tree1.tree2.conv1.weight base.level4.tree1.tree2.bn1.weight base.level4.tree1.tree2.bn1.bias base.level4.tree1.tree2.bn1.running_mean base.level4.tree1.tree2.bn1.running_var base.level4.tree1.tree2.bn1.num_batches_tracked base.level4.tree1.tree2.conv2.weight base.level4.tree1.tree2.bn2.weight base.level4.tree1.tree2.bn2.bias base.level4.tree1.tree2.bn2.running_mean base.level4.tree1.tree2.bn2.running_var base.level4.tree1.tree2.bn2.num_batches_tracked base.level4.tree1.root.conv.weight base.level4.tree1.root.bn.weight base.level4.tree1.root.bn.bias base.level4.tree1.root.bn.running_mean base.level4.tree1.root.bn.running_var base.level4.tree1.root.bn.num_batches_tracked base.level4.tree1.project.0.weight base.level4.tree1.project.1.weight base.level4.tree1.project.1.bias base.level4.tree1.project.1.running_mean base.level4.tree1.project.1.running_var base.level4.tree1.project.1.num_batches_tracked base.level4.tree2.tree1.conv1.weight base.level4.tree2.tree1.bn1.weight base.level4.tree2.tree1.bn1.bias base.level4.tree2.tree1.bn1.running_mean base.level4.tree2.tree1.bn1.running_var base.level4.tree2.tree1.bn1.num_batches_tracked base.level4.tree2.tree1.conv2.weight base.level4.tree2.tree1.bn2.weight base.level4.tree2.tree1.bn2.bias base.level4.tree2.tree1.bn2.running_mean base.level4.tree2.tree1.bn2.running_var base.level4.tree2.tree1.bn2.num_batches_tracked base.level4.tree2.tree2.conv1.weight base.level4.tree2.tree2.bn1.weight base.level4.tree2.tree2.bn1.bias base.level4.tree2.tree2.bn1.running_mean base.level4.tree2.tree2.bn1.running_var base.level4.tree2.tree2.bn1.num_batches_tracked base.level4.tree2.tree2.conv2.weight base.level4.tree2.tree2.bn2.weight base.level4.tree2.tree2.bn2.bias base.level4.tree2.tree2.bn2.running_mean base.level4.tree2.tree2.bn2.running_var base.level4.tree2.tree2.bn2.num_batches_tracked base.level4.tree2.root.conv.weight base.level4.tree2.root.bn.weight base.level4.tree2.root.bn.bias base.level4.tree2.root.bn.running_mean base.level4.tree2.root.bn.running_var base.level4.tree2.root.bn.num_batches_tracked base.level4.project.0.weight base.level4.project.1.weight base.level4.project.1.bias base.level4.project.1.running_mean base.level4.project.1.running_var base.level4.project.1.num_batches_tracked base.level5.tree1.conv1.weight base.level5.tree1.bn1.weight base.level5.tree1.bn1.bias base.level5.tree1.bn1.running_mean base.level5.tree1.bn1.running_var base.level5.tree1.bn1.num_batches_tracked base.level5.tree1.conv2.weight base.level5.tree1.bn2.weight base.level5.tree1.bn2.bias base.level5.tree1.bn2.running_mean base.level5.tree1.bn2.running_var base.level5.tree1.bn2.num_batches_tracked base.level5.tree2.conv1.weight base.level5.tree2.bn1.weight base.level5.tree2.bn1.bias base.level5.tree2.bn1.running_mean base.level5.tree2.bn1.running_var base.level5.tree2.bn1.num_batches_tracked base.level5.tree2.conv2.weight base.level5.tree2.bn2.weight base.level5.tree2.bn2.bias base.level5.tree2.bn2.running_mean base.level5.tree2.bn2.running_var base.level5.tree2.bn2.num_batches_tracked base.level5.root.conv.weight base.level5.root.bn.weight base.level5.root.bn.bias base.level5.root.bn.running_mean base.level5.root.bn.running_var base.level5.root.bn.num_batches_tracked base.level5.project.0.weight base.level5.project.1.weight base.level5.project.1.bias base.level5.project.1.running_mean base.level5.project.1.running_var base.level5.project.1.num_batches_tracked base.fc.weight base.fc.bias dla_up.ida_0.proj_1.actf.0.weight dla_up.ida_0.proj_1.actf.0.bias dla_up.ida_0.proj_1.actf.0.running_mean dla_up.ida_0.proj_1.actf.0.running_var dla_up.ida_0.proj_1.actf.0.num_batches_tracked dla_up.ida_0.proj_1.conv.weight dla_up.ida_0.proj_1.conv.bias dla_up.ida_0.proj_1.conv.conv_offset_mask.weight dla_up.ida_0.proj_1.conv.conv_offset_mask.bias dla_up.ida_0.up_1.weight dla_up.ida_0.node_1.actf.0.weight dla_up.ida_0.node_1.actf.0.bias dla_up.ida_0.node_1.actf.0.running_mean dla_up.ida_0.node_1.actf.0.running_var dla_up.ida_0.node_1.actf.0.num_batches_tracked dla_up.ida_0.node_1.conv.weight dla_up.ida_0.node_1.conv.bias dla_up.ida_0.node_1.conv.conv_offset_mask.weight dla_up.ida_0.node_1.conv.conv_offset_mask.bias dla_up.ida_1.proj_1.actf.0.weight dla_up.ida_1.proj_1.actf.0.bias dla_up.ida_1.proj_1.actf.0.running_mean dla_up.ida_1.proj_1.actf.0.running_var dla_up.ida_1.proj_1.actf.0.num_batches_tracked dla_up.ida_1.proj_1.conv.weight dla_up.ida_1.proj_1.conv.bias dla_up.ida_1.proj_1.conv.conv_offset_mask.weight dla_up.ida_1.proj_1.conv.conv_offset_mask.bias dla_up.ida_1.up_1.weight dla_up.ida_1.node_1.actf.0.weight dla_up.ida_1.node_1.actf.0.bias dla_up.ida_1.node_1.actf.0.running_mean dla_up.ida_1.node_1.actf.0.running_var dla_up.ida_1.node_1.actf.0.num_batches_tracked dla_up.ida_1.node_1.conv.weight dla_up.ida_1.node_1.conv.bias dla_up.ida_1.node_1.conv.conv_offset_mask.weight dla_up.ida_1.node_1.conv.conv_offset_mask.bias dla_up.ida_1.proj_2.actf.0.weight dla_up.ida_1.proj_2.actf.0.bias dla_up.ida_1.proj_2.actf.0.running_mean dla_up.ida_1.proj_2.actf.0.running_var dla_up.ida_1.proj_2.actf.0.num_batches_tracked dla_up.ida_1.proj_2.conv.weight dla_up.ida_1.proj_2.conv.bias dla_up.ida_1.proj_2.conv.conv_offset_mask.weight dla_up.ida_1.proj_2.conv.conv_offset_mask.bias dla_up.ida_1.up_2.weight dla_up.ida_1.node_2.actf.0.weight dla_up.ida_1.node_2.actf.0.bias dla_up.ida_1.node_2.actf.0.running_mean dla_up.ida_1.node_2.actf.0.running_var dla_up.ida_1.node_2.actf.0.num_batches_tracked dla_up.ida_1.node_2.conv.weight dla_up.ida_1.node_2.conv.bias dla_up.ida_1.node_2.conv.conv_offset_mask.weight dla_up.ida_1.node_2.conv.conv_offset_mask.bias dla_up.ida_2.proj_1.actf.0.weight dla_up.ida_2.proj_1.actf.0.bias dla_up.ida_2.proj_1.actf.0.running_mean dla_up.ida_2.proj_1.actf.0.running_var dla_up.ida_2.proj_1.actf.0.num_batches_tracked dla_up.ida_2.proj_1.conv.weight dla_up.ida_2.proj_1.conv.bias dla_up.ida_2.proj_1.conv.conv_offset_mask.weight dla_up.ida_2.proj_1.conv.conv_offset_mask.bias dla_up.ida_2.up_1.weight dla_up.ida_2.node_1.actf.0.weight dla_up.ida_2.node_1.actf.0.bias dla_up.ida_2.node_1.actf.0.running_mean dla_up.ida_2.node_1.actf.0.running_var dla_up.ida_2.node_1.actf.0.num_batches_tracked dla_up.ida_2.node_1.conv.weight dla_up.ida_2.node_1.conv.bias dla_up.ida_2.node_1.conv.conv_offset_mask.weight dla_up.ida_2.node_1.conv.conv_offset_mask.bias dla_up.ida_2.proj_2.actf.0.weight dla_up.ida_2.proj_2.actf.0.bias dla_up.ida_2.proj_2.actf.0.running_mean dla_up.ida_2.proj_2.actf.0.running_var dla_up.ida_2.proj_2.actf.0.num_batches_tracked dla_up.ida_2.proj_2.conv.weight dla_up.ida_2.proj_2.conv.bias dla_up.ida_2.proj_2.conv.conv_offset_mask.weight dla_up.ida_2.proj_2.conv.conv_offset_mask.bias dla_up.ida_2.up_2.weight dla_up.ida_2.node_2.actf.0.weight dla_up.ida_2.node_2.actf.0.bias dla_up.ida_2.node_2.actf.0.running_mean dla_up.ida_2.node_2.actf.0.running_var dla_up.ida_2.node_2.actf.0.num_batches_tracked dla_up.ida_2.node_2.conv.weight dla_up.ida_2.node_2.conv.bias dla_up.ida_2.node_2.conv.conv_offset_mask.weight dla_up.ida_2.node_2.conv.conv_offset_mask.bias dla_up.ida_2.proj_3.actf.0.weight dla_up.ida_2.proj_3.actf.0.bias dla_up.ida_2.proj_3.actf.0.running_mean dla_up.ida_2.proj_3.actf.0.running_var dla_up.ida_2.proj_3.actf.0.num_batches_tracked dla_up.ida_2.proj_3.conv.weight dla_up.ida_2.proj_3.conv.bias dla_up.ida_2.proj_3.conv.conv_offset_mask.weight dla_up.ida_2.proj_3.conv.conv_offset_mask.bias dla_up.ida_2.up_3.weight dla_up.ida_2.node_3.actf.0.weight dla_up.ida_2.node_3.actf.0.bias dla_up.ida_2.node_3.actf.0.running_mean dla_up.ida_2.node_3.actf.0.running_var dla_up.ida_2.node_3.actf.0.num_batches_tracked dla_up.ida_2.node_3.conv.weight dla_up.ida_2.node_3.conv.bias dla_up.ida_2.node_3.conv.conv_offset_mask.weight dla_up.ida_2.node_3.conv.conv_offset_mask.bias ida_up.proj_1.actf.0.weight ida_up.proj_1.actf.0.bias ida_up.proj_1.actf.0.running_mean ida_up.proj_1.actf.0.running_var ida_up.proj_1.actf.0.num_batches_tracked ida_up.proj_1.conv.weight ida_up.proj_1.conv.bias ida_up.proj_1.conv.conv_offset_mask.weight ida_up.proj_1.conv.conv_offset_mask.bias ida_up.up_1.weight ida_up.node_1.actf.0.weight ida_up.node_1.actf.0.bias ida_up.node_1.actf.0.running_mean ida_up.node_1.actf.0.running_var ida_up.node_1.actf.0.num_batches_tracked ida_up.node_1.conv.weight ida_up.node_1.conv.bias ida_up.node_1.conv.conv_offset_mask.weight ida_up.node_1.conv.conv_offset_mask.bias ida_up.proj_2.actf.0.weight ida_up.proj_2.actf.0.bias ida_up.proj_2.actf.0.running_mean ida_up.proj_2.actf.0.running_var ida_up.proj_2.actf.0.num_batches_tracked ida_up.proj_2.conv.weight ida_up.proj_2.conv.bias ida_up.proj_2.conv.conv_offset_mask.weight ida_up.proj_2.conv.conv_offset_mask.bias ida_up.up_2.weight ida_up.node_2.actf.0.weight ida_up.node_2.actf.0.bias ida_up.node_2.actf.0.running_mean ida_up.node_2.actf.0.running_var ida_up.node_2.actf.0.num_batches_tracked ida_up.node_2.conv.weight ida_up.node_2.conv.bias ida_up.node_2.conv.conv_offset_mask.weight ida_up.node_2.conv.conv_offset_mask.bias hm.0.weight hm.0.bias hm.2.weight hm.2.bias wh.0.weight wh.0.bias wh.2.weight wh.2.bias reg.0.weight reg.0.bias reg.2.weight reg.2.bias
- '所以要删除的名字有如下特点:'
- name[-7:]=='tracked'
- name[:6]=='dla_up'
- name[:6]=='ida_up'
- name[:2]=='hm'
- name[:2]=='wh'
- name[:3]=='reg'
最终生成与预训练模型结构相似的结构模型代码如下:
- import torch
- from torch import nn
- new_model={}
- model_weights=torch.load('ctdet_pascal_dla_384.pth')
- for name in model_weights['state_dict']:
- print(name)
- n=0
- if(name[-7:]=='tracked'):
- n=1
- elif(name[:6]=='dla_up' or name[:6]=='ida_up'):
- n=1
- elif(name[:2]=='hm' or name[:2]=='wh' or name[:3]=='reg'):
- n=1
- if(n==0):
- new_model[name[5:]]=model_weights['state_dict'][name]
- for name in new_model:
- print(name)
- torch.save(new_model,'pre_pascal_dla_384.pth')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。