赞
踩
包 | 版本 |
---|---|
mmcv-full | 1.4.2 |
mmdet | 2.19.1 |
torch | 1.10.0+cu113 |
torchvision | 0.11.1+cu113 |
打开mmdetection-master/tools
下面执行train.py
文件
其中配置文件--config
在 mmdetection-master/configs/deformable_detr/deformable_detr_r50_16x2_50e_coco.py
python train.py {path}/mmdetection-master/configs/deformable_detr/deformable_detr_r50_16x2_50e_coco.py
会报错,不用管(在work_dirs/deformable_detr_r50_16x2_50e_coco
生成需要的配置文件)
my_deformable_detr_r50_16x2_50e_coco.py文件修改
mmdetection-master/configs/deformable_detr/
并改名为my_deformable_detr_r50_16x2_50e_coco.py
修改mmdet源码
修改{path}/mmdet/core/evaluation/class_names.py
下面的coco_classes()
修改{path}/mmdet/datasets/coco.py
下面的CLASSES和PALETTE
有的源码里面没有PALETTE可不添加
再次进入mmdetection-master/tools
下面,执行下面代码
其中的my_deformable_detr_r50_16x2_50e_coco.py
是上面刚刚修改的文件名称
python train.py {path}/mmdetection-master/configs/deformable_detr/my_deformable_detr_r50_16x2_50e_coco.py
训练结束之后会生成相应的权重文件
打开{path}/mmdetection-master/demo/
文件夹执行image_demo.py
python image_demo.py 1.jpg {path}/mmdetection-master/configs/deformable_detr/my_deformable_detr_r50_16x2_50e_coco.py {path}/mmdetection-master/tools/work_dirs/deformable_detr_r50_16x2_50e_coco/latest.pth
由于我在ubuntu虚拟机上面进行的代码测试,无法使用 show_result_pyplot()
函数,稍作修改存储到对应的目录中
( 其中的第一个参数 img 修改成一个目录,可以直接进行对一个目录里面的文件读取并且处理后保存)
from argparse import ArgumentParser
from mmdet.apis import (inference_detector, init_detector)
import cv2
import os
def parse_args():
parser = ArgumentParser()
parser.add_argument('--img', default='img2', help='Image file')
parser.add_argument('--config', default='../configs/deformable_detr/my_deformable_detr_r50_16x2_50e_coco.py',help='Config file')
parser.add_argument('--checkpoint',default='../tools/work_dirs/deformable_detr_r50_16x2_50e_coco/latest.pth', help='Checkpoint file')
parser.add_argument('--device', default='cpu', help='Device used for inference')
parser.add_argument(
'--palette',
default='coco',
choices=['coco', 'voc', 'citys', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
parser.add_argument(
'--async-test',
action='store_true',
help='whether to set async options for async inference.')
args = parser.parse_args()
return args
def getfiles(file):
path_list = []
filenames = os.listdir(file)
print(filenames)
for filename in filenames:
a = os.path.join(file, filename)
# print(a)
path_list.append(a)
# print(path_list)
return path_list,filenames
def main(args):
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single image
path_list,filenames = getfiles(args.img)
for path,filename in zip(path_list,filenames):
result = inference_detector(model, path)
img = show_result_pyplot2(model, path, result, score_thr=0.8)
cv2.imwrite(args.img+"/out/out_"+filename, img)
def show_result_pyplot2(model, img, result, score_thr=0.3, fig_size=(15, 10)):
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
return img
if __name__ == '__main__':
args = parse_args()
main(args)
预测结果还是比较准确的
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。