赞
踩
从零开始用语义分割模型PIDNet训练自己的数据集。
PIDNet论文地址:https://arxiv.org/pdf/2206.02066.pdf
PIDNet项目地址:GitHub - XuJiacong/PIDNet: This is the official repository for our recent work: PIDNet
首先说明下需要什么样的数据集:PIDNet需要的语义标签图像是8位的灰度图(和我们之前写过的BiSeNet需要的训练格式一样),语义分割的标签就是用的灰度值表示的。这里先展示下整体的一个数据集文件夹格式,如下图:
PV是我们数据集的名字,我们需要准备的就是list文件夹下的四个lst映射文件,以及PV文件夹下的image(原图)和label(语义分割图)文件。
1、 首先通过labelme标注图片,以及将json转换为分割后的图片,并且分割后的图片转换为8位的灰度图,这些操作已经在之前的博客介绍过,详细我的另一篇博客参照:
教程--从零开始使用BiSeNet(语义分割)网络训练自己的数据集_计算机幻觉的博客-CSDN博客
按照上面的方法取得8位的灰度图就行。
2、将第一步操作得到的原图和8位灰度图按照上图那个文件夹格式放就行,放完之后我们通过以下代码来获取lst(映射文件),注意修改自己数据集路径:
- import os
- def op_file():
- # train
- train_image_root = 'image/train/'
- train_label_root = 'label/train/'
- train_image_path = 'data/PV/image/train'
- train_label_path = 'data/PV/label/train'
-
- trainImageList = os.listdir(train_image_path)
- trainLabelList = os.listdir(train_label_path)
-
- train_image_list = []
- for image in trainImageList:
- train_image_list.append(train_image_root + image)
-
- train_label_list = []
- for label in trainLabelList:
- train_label_list.append(train_label_root + label)
-
- train_list_path = 'data/list/PV/train.lst'
- file = open(train_list_path, 'w').close()
- with open(train_list_path, 'w', encoding='utf-8') as f:
- for i1,i2 in zip(train_image_list, train_label_list):
- print(i1, i2)
- f.write(i1 + " " + i2 + "\n")
- f.close()
-
- # test
- test_image_root = 'image/test/'
- test_label_root = 'label/test/'
- test_image_path = 'data/PV/image/test'
-
- testImageList = os.listdir(test_image_path)
-
- test_image_list = []
- for image in testImageList:
- test_image_list.append(test_image_root + image)
-
- test_list_path = 'data/list/PV/test.lst'
- file = open(test_list_path, 'w').close()
- with open(test_list_path, 'w', encoding='utf-8') as f:
- for i1 in test_image_list:
- f.write(i1 + "\n")
- f.close()
-
- # val
- val_image_root = 'image/val/'
- val_label_root = 'label/val/'
- val_image_path = 'data/PV/image/val'
- val_label_path = 'data/PV/label/val'
-
- valImageList = os.listdir(val_image_path)
- valLabelList = os.listdir(val_label_path)
-
- val_image_list = []
- for image in valImageList:
- val_image_list.append(val_image_root + image)
-
- val_label_list = []
- for label in valLabelList:
- val_label_list.append(val_label_root + label)
-
- val_list_path = 'data/list/PV/val.lst'
- file = open(val_list_path, 'w').close()
- with open(val_list_path, 'w', encoding='utf-8') as f:
- for (i1,i2) in zip(val_image_list, val_label_list):
- f.write(i1 + " " + i2 + "\n")
- f.close()
-
- # trainval
- trainval_list_path = 'data/list/PV/trainval.lst'
- file = open(trainval_list_path, 'w').close()
- with open(trainval_list_path, 'w', encoding='utf-8') as f:
- for (i1,i2) in zip(train_image_list, train_label_list):
- f.write(i1 + " " + i2 + "\n")
- f.close()
-
- with open(trainval_list_path, 'a', encoding='utf-8') as f:
- for (i1,i2) in zip(val_image_list, val_label_list):
- f.write(i1 + " " + i2 + "\n")
- f.close()
-
- if __name__ == '__main__':
- op_file()
1、在datasets文件夹下复制同级目录的cityscapes.py,并且重命名为我们数据集的名称PV.py,如下图:
打开PV.py,将其中的Cityscapes全都修改为PV(你数据集的名称);修改num_classes=3(你的类别数,包含了背景,博主这里是三类);修改mean和std;修改label_mapping(几个类就写几个),修改class_weights(详细计算方法如下)
修改上面的需要计算自己数据集的mean、std和class_weights,运行下面代码即可:
- from random import shuffle
- import numpy as np
- import os
- import cv2
-
-
- def get_weight(class_num, pixel_count):
- W = 1 / np.log(pixel_count)
- W = class_num * W / np.sum(W)
- return W
-
-
- def get_MeanStdWeight(class_num=3, size=(1080, 700)):
- image_path = "data/PV/image/train/"
- label_path = "data/PV/label/train/"
-
- namelist = os.listdir(image_path)
- """========如果提供的是txt文本,保存的训练集中的namelist=============="""
- # file_name = "../datasets/train.txt"
- # with open(file_name,"r") as f:
- # namelist = f.readlines()
- # namelist = [file[:-1].split(",") for file in namelist]
- """==============================================================="""
-
- MEAN = []
- STD = []
- pixel_count = np.zeros((class_num, 1))
-
- for i in range(len(namelist)):
- print(i, os.path.join(image_path, namelist[i]))
-
- image = cv2.imread(os.path.join(image_path, namelist[i]))[:, :, ::-1]
- image = cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
- print(image.shape)
-
- mean = np.mean(image, axis=(0, 1))
- std = np.std(image, axis=(0, 1))
- MEAN.append(mean)
- STD.append(std)
-
- label = cv2.imread(os.path.join(label_path, namelist[i]), 0)
- label = cv2.resize(label, size, cv2.INTER_LINEAR)
-
- label_uni = np.unique(label)
- for m in label_uni:
- pixel_count[m] += np.sum(label == m)
-
- MEAN = np.mean(MEAN, axis=0) / 255.0
- STD = np.mean(STD, axis=0) / 255.0
-
- weight = get_weight(class_num, pixel_count.T)
- print(MEAN)
- print(STD)
- print(weight)
-
- return MEAN, STD, weight
-
- if __name__ == '__main__':
- get_MeanStdWeight()
-
2、在datasets/__init_.py文件下导入我们刚才建立的数据集:
3、打开configs/cityscapes/pidnet_small_cityscapes.yaml文件(博主这里选择最小的模型,你们随意),修改训练集名称、数据集路径、类别数以及训练模型地址:
4、打开models/pidnet.py,修改PIDNet的num_classes为你的类别数:
博主采用单GPU训练,你们记得修改yaml文件中的GPU数量,执行下面代码开始训练:
python tools/train.py --cfg configs/cityscapes/pidnet_small_cityscapes.yaml
博主没遇到错误,你们要是遇到Error可以在评论区留言,博主都会一一解答。
需要注意的是,PIDNet网络会用到多次下采样,所以说对训练图片的尺寸大小是有一定要求的,不然会出现demoions不匹配的问题,博主的训练图片大小是1080x640的。可以通过裁剪的方式改变训练集大小,裁剪的代码在我之前的博客中也有(教程--从零开始使用BiSeNet(语义分割)网络训练自己的数据集_计算机幻觉的博客-CSDN博客)。
1、图片测试:
测试之前,需要指定好加载的训练模型,在yaml文件中修改,如下图:
执行代码,开始测试:
python tools/eval.py --cfg experiments/cityscapes/pidnet_small_cityscapes.yaml
测试的结果会在output文件夹下,如下图:
注意:这个时候测试会发现,得到的图片是黑色的,也就是说最终保存的结果是8位的灰度图,而我们需要的是24位的RGB图片,解决方法:
再次打开datasets/PV.py文件(就是我们定义自己数据集的文件),增加color_list属性,如下:
我这里是三类,所以颜色就随便写了三种(包含背景),根据你们自己需求就行。再增加label2color函数,如下图:
代码:
- def label2color(self, label):
- color_map = np.zeros(label.shape + (3,))
- for i, v in enumerate(self.color_list):
- color_map[label == i] = self.color_list[i]
-
- return color_map.astype(np.uint8)
-
-
- def save_pred(self, preds, sv_path, name):
- preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
- for i in range(preds.shape[0]):
- pred = self.label2color(preds[i])
- save_img = Image.fromarray(pred)
- save_img.save(os.path.join(sv_path, name[i]+'.png'))
再重新测试,输出的就是RGB图像了。
2、视频测试:
源代码没有提供视频测试,博主这里提供一个,代码如下:
- import os
- import pprint
- import sys
- sys.path.insert(0, '.')
- import argparse
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.multiprocessing as mp
- import time
- from PIL import Image
- import numpy as np
- import cv2
- import logging
- import lib.data.transform_cv2 as T
- from utils.utils import create_logger
- from configs import config
- from configs import update_config
- torch.set_grad_enabled(False)
- import torch.backends.cudnn as cudnn
- import models
-
- # args
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--cfg', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml')
- parser.add_argument('--weight-path', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/output/PV/pidnet_small_cityscapes/best.pt')
- parser.add_argument('--input', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/video.avi')
- parser.add_argument('--output', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/PIDNet.mp4')
- parser.add_argument('opts',
- help="Modify config options using the command-line",
- default=None,
- nargs=argparse.REMAINDER)
- args = parser.parse_args()
- update_config(config, args)
-
-
- # fetch frames
- def get_func(inpth, in_q, done):
-
- cap = cv2.VideoCapture(args.input)
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
- fps = cap.get(cv2.CAP_PROP_FPS)
-
- to_tensor = T.ToTensor(
- mean=(0.3257, 0.3690, 0.3223), # city, rgb
- std=(0.2112, 0.2148, 0.2115),
- )
-
- while cap.isOpened():
- ret, frame = cap.read()
- if not ret: break
- frame = frame[:, :, ::-1]
- frame = to_tensor(dict(im=frame, lb=None))['im'].unsqueeze(0)
- in_q.put(frame)
-
- in_q.put('quit')
- done.wait()
-
- cap.release()
- time.sleep(1)
- print('input queue done')
-
-
- # save to video
- def save_func(inpth, outpth, out_q):
- np.random.seed(123)
- palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
-
- cap = cv2.VideoCapture(args.input)
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
- fps = cap.get(cv2.CAP_PROP_FPS)
- cap.release()
-
- video_writer = cv2.VideoWriter(outpth,
- cv2.VideoWriter_fourcc(*"mp4v"),
- fps, (int(width), int(height)))
-
- while True:
- out = out_q.get()
- if out == 'quit': break
- out = out.numpy()
- preds = palette[out]
- for pred in preds:
- video_writer.write(pred)
- video_writer.release()
- print('output queue done')
-
-
- # inference a list of frames
- def infer_batch(frames):
- frames = torch.cat(frames, dim=0).cuda()
- H, W = frames.size()[2:]
- frames = F.interpolate(frames, size=(768, 768), mode='bilinear',
- align_corners=False) # must be divisible by 32
- out = model(frames)[0]
- out = F.interpolate(out, size=(H, W), mode='bilinear',
- align_corners=False).argmax(dim=1).detach().cpu()
- out_q.put(out)
-
-
- if __name__ == '__main__':
-
- # args = parse_args()
-
- logger, final_output_dir, _ = create_logger(
- config, args.cfg, 'test')
-
- logger.info(pprint.pformat(args))
- logger.info(pprint.pformat(config))
-
- # cudnn related setting
- cudnn.benchmark = config.CUDNN.BENCHMARK
- cudnn.deterministic = config.CUDNN.DETERMINISTIC
- cudnn.enabled = config.CUDNN.ENABLED
-
- # build model
- model = model = models.pidnet.get_seg_model(config, imgnet_pretrained=True)
-
- if config.TEST.MODEL_FILE:
- model_state_file = config.TEST.MODEL_FILE
- else:
- model_state_file = os.path.join(final_output_dir, 'best.pt')
-
- logger.info('=> loading model from {}'.format(model_state_file))
-
- pretrained_dict = torch.load(model_state_file)
- if 'state_dict' in pretrained_dict:
- pretrained_dict = pretrained_dict['state_dict']
- model_dict = model.state_dict()
- pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
- if k[6:] in model_dict.keys()}
- for k, _ in pretrained_dict.items():
- logger.info(
- '=> loading {} from pretrained model'.format(k))
- model_dict.update(pretrained_dict)
- model.load_state_dict(model_dict)
-
- mp.set_start_method('spawn')
-
- in_q = mp.Queue(1024)
- out_q = mp.Queue(1024)
- done = mp.Event()
-
- in_worker = mp.Process(target=get_func,
- args=(args.input, in_q, done))
- out_worker = mp.Process(target=save_func,
- args=(args.input, args.output, out_q))
-
- in_worker.start()
- out_worker.start()
- model.eval()
- model = model.cuda()
-
- frames = []
- while True:
- frame = in_q.get()
- if frame == 'quit': break
-
- frames.append(frame)
- if len(frames) == 8:
- infer_batch(frames)
- frames = []
- if len(frames) > 0:
- infer_batch(frames)
-
- out_q.put('quit')
- done.set()
-
- out_worker.join()
- in_worker.join()
-
-
-
-
-
修改好自己的各个文件路径,执行代码:
python demo_video.py --cfg /home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml
稍微等一会儿即可,在test_dataset目录下可以看到生成的mp4文件,打开即可。
至此,PIDNet的训练教程结束,如果有什么问题可以留言,博主都会一一解答。
在测试demo_video遇到头文件找不到的,可以在下面链接下载(缺少transformer_cv2):
链接:https://pan.baidu.com/s/1vBHFSiGb_vXANGGX1NTI3A
提取码:i10j
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。