当前位置:   article > 正文

【教程】从零开始-PIDNet(语义分割)模型训练自己的数据集_pidnet 自己训练数据集

pidnet 自己训练数据集

引言

从零开始用语义分割模型PIDNet训练自己的数据集。

PIDNet论文地址:https://arxiv.org/pdf/2206.02066.pdf

PIDNet项目地址:GitHub - XuJiacong/PIDNet: This is the official repository for our recent work: PIDNet

一、数据集的准备

首先说明下需要什么样的数据集:PIDNet需要的语义标签图像是8位的灰度图(和我们之前写过的BiSeNet需要的训练格式一样),语义分割的标签就是用的灰度值表示的。这里先展示下整体的一个数据集文件夹格式,如下图:

PV是我们数据集的名字,我们需要准备的就是list文件夹下的四个lst映射文件,以及PV文件夹下的image(原图)和label(语义分割图)文件。

1、 首先通过labelme标注图片,以及将json转换为分割后的图片,并且分割后的图片转换为8位的灰度图,这些操作已经在之前的博客介绍过,详细我的另一篇博客参照:

教程--从零开始使用BiSeNet(语义分割)网络训练自己的数据集_计算机幻觉的博客-CSDN博客

 按照上面的方法取得8位的灰度图就行。

2、将第一步操作得到的原图和8位灰度图按照上图那个文件夹格式放就行,放完之后我们通过以下代码来获取lst(映射文件),注意修改自己数据集路径:

  1. import os
  2. def op_file():
  3. # train
  4. train_image_root = 'image/train/'
  5. train_label_root = 'label/train/'
  6. train_image_path = 'data/PV/image/train'
  7. train_label_path = 'data/PV/label/train'
  8. trainImageList = os.listdir(train_image_path)
  9. trainLabelList = os.listdir(train_label_path)
  10. train_image_list = []
  11. for image in trainImageList:
  12. train_image_list.append(train_image_root + image)
  13. train_label_list = []
  14. for label in trainLabelList:
  15. train_label_list.append(train_label_root + label)
  16. train_list_path = 'data/list/PV/train.lst'
  17. file = open(train_list_path, 'w').close()
  18. with open(train_list_path, 'w', encoding='utf-8') as f:
  19. for i1,i2 in zip(train_image_list, train_label_list):
  20. print(i1, i2)
  21. f.write(i1 + " " + i2 + "\n")
  22. f.close()
  23. # test
  24. test_image_root = 'image/test/'
  25. test_label_root = 'label/test/'
  26. test_image_path = 'data/PV/image/test'
  27. testImageList = os.listdir(test_image_path)
  28. test_image_list = []
  29. for image in testImageList:
  30. test_image_list.append(test_image_root + image)
  31. test_list_path = 'data/list/PV/test.lst'
  32. file = open(test_list_path, 'w').close()
  33. with open(test_list_path, 'w', encoding='utf-8') as f:
  34. for i1 in test_image_list:
  35. f.write(i1 + "\n")
  36. f.close()
  37. # val
  38. val_image_root = 'image/val/'
  39. val_label_root = 'label/val/'
  40. val_image_path = 'data/PV/image/val'
  41. val_label_path = 'data/PV/label/val'
  42. valImageList = os.listdir(val_image_path)
  43. valLabelList = os.listdir(val_label_path)
  44. val_image_list = []
  45. for image in valImageList:
  46. val_image_list.append(val_image_root + image)
  47. val_label_list = []
  48. for label in valLabelList:
  49. val_label_list.append(val_label_root + label)
  50. val_list_path = 'data/list/PV/val.lst'
  51. file = open(val_list_path, 'w').close()
  52. with open(val_list_path, 'w', encoding='utf-8') as f:
  53. for (i1,i2) in zip(val_image_list, val_label_list):
  54. f.write(i1 + " " + i2 + "\n")
  55. f.close()
  56. # trainval
  57. trainval_list_path = 'data/list/PV/trainval.lst'
  58. file = open(trainval_list_path, 'w').close()
  59. with open(trainval_list_path, 'w', encoding='utf-8') as f:
  60. for (i1,i2) in zip(train_image_list, train_label_list):
  61. f.write(i1 + " " + i2 + "\n")
  62. f.close()
  63. with open(trainval_list_path, 'a', encoding='utf-8') as f:
  64. for (i1,i2) in zip(val_image_list, val_label_list):
  65. f.write(i1 + " " + i2 + "\n")
  66. f.close()
  67. if __name__ == '__main__':
  68. op_file()

二、相关代码修改

1、在datasets文件夹下复制同级目录的cityscapes.py,并且重命名为我们数据集的名称PV.py,如下图:

 打开PV.py,将其中的Cityscapes全都修改为PV(你数据集的名称);修改num_classes=3(你的类别数,包含了背景,博主这里是三类);修改mean和std;修改label_mapping(几个类就写几个),修改class_weights(详细计算方法如下)

 修改上面的需要计算自己数据集的mean、std和class_weights,运行下面代码即可:

  1. from random import shuffle
  2. import numpy as np
  3. import os
  4. import cv2
  5. def get_weight(class_num, pixel_count):
  6. W = 1 / np.log(pixel_count)
  7. W = class_num * W / np.sum(W)
  8. return W
  9. def get_MeanStdWeight(class_num=3, size=(1080, 700)):
  10. image_path = "data/PV/image/train/"
  11. label_path = "data/PV/label/train/"
  12. namelist = os.listdir(image_path)
  13. """========如果提供的是txt文本,保存的训练集中的namelist=============="""
  14. # file_name = "../datasets/train.txt"
  15. # with open(file_name,"r") as f:
  16. # namelist = f.readlines()
  17. # namelist = [file[:-1].split(",") for file in namelist]
  18. """==============================================================="""
  19. MEAN = []
  20. STD = []
  21. pixel_count = np.zeros((class_num, 1))
  22. for i in range(len(namelist)):
  23. print(i, os.path.join(image_path, namelist[i]))
  24. image = cv2.imread(os.path.join(image_path, namelist[i]))[:, :, ::-1]
  25. image = cv2.resize(image, size, interpolation=cv2.INTER_NEAREST)
  26. print(image.shape)
  27. mean = np.mean(image, axis=(0, 1))
  28. std = np.std(image, axis=(0, 1))
  29. MEAN.append(mean)
  30. STD.append(std)
  31. label = cv2.imread(os.path.join(label_path, namelist[i]), 0)
  32. label = cv2.resize(label, size, cv2.INTER_LINEAR)
  33. label_uni = np.unique(label)
  34. for m in label_uni:
  35. pixel_count[m] += np.sum(label == m)
  36. MEAN = np.mean(MEAN, axis=0) / 255.0
  37. STD = np.mean(STD, axis=0) / 255.0
  38. weight = get_weight(class_num, pixel_count.T)
  39. print(MEAN)
  40. print(STD)
  41. print(weight)
  42. return MEAN, STD, weight
  43. if __name__ == '__main__':
  44. get_MeanStdWeight()

2、在datasets/__init_.py文件下导入我们刚才建立的数据集:

 3、打开configs/cityscapes/pidnet_small_cityscapes.yaml文件(博主这里选择最小的模型,你们随意),修改训练集名称、数据集路径、类别数以及训练模型地址:

 4、打开models/pidnet.py,修改PIDNet的num_classes为你的类别数:

三、开始训练

 博主采用单GPU训练,你们记得修改yaml文件中的GPU数量,执行下面代码开始训练:

python tools/train.py --cfg configs/cityscapes/pidnet_small_cityscapes.yaml

 博主没遇到错误,你们要是遇到Error可以在评论区留言,博主都会一一解答。

需要注意的是,PIDNet网络会用到多次下采样,所以说对训练图片的尺寸大小是有一定要求的,不然会出现demoions不匹配的问题,博主的训练图片大小是1080x640的。可以通过裁剪的方式改变训练集大小,裁剪的代码在我之前的博客中也有(教程--从零开始使用BiSeNet(语义分割)网络训练自己的数据集_计算机幻觉的博客-CSDN博客)。

四、测试

1、图片测试:

测试之前,需要指定好加载的训练模型,在yaml文件中修改,如下图:

执行代码,开始测试:

python tools/eval.py --cfg experiments/cityscapes/pidnet_small_cityscapes.yaml

测试的结果会在output文件夹下,如下图:

 注意:这个时候测试会发现,得到的图片是黑色的,也就是说最终保存的结果是8位的灰度图,而我们需要的是24位的RGB图片,解决方法:

再次打开datasets/PV.py文件(就是我们定义自己数据集的文件),增加color_list属性,如下:

 我这里是三类,所以颜色就随便写了三种(包含背景),根据你们自己需求就行。再增加label2color函数,如下图:

 代码:

  1. def label2color(self, label):
  2. color_map = np.zeros(label.shape + (3,))
  3. for i, v in enumerate(self.color_list):
  4. color_map[label == i] = self.color_list[i]
  5. return color_map.astype(np.uint8)
  6. def save_pred(self, preds, sv_path, name):
  7. preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
  8. for i in range(preds.shape[0]):
  9. pred = self.label2color(preds[i])
  10. save_img = Image.fromarray(pred)
  11. save_img.save(os.path.join(sv_path, name[i]+'.png'))

 再重新测试,输出的就是RGB图像了。

2、视频测试:

源代码没有提供视频测试,博主这里提供一个,代码如下:

  1. import os
  2. import pprint
  3. import sys
  4. sys.path.insert(0, '.')
  5. import argparse
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.multiprocessing as mp
  10. import time
  11. from PIL import Image
  12. import numpy as np
  13. import cv2
  14. import logging
  15. import lib.data.transform_cv2 as T
  16. from utils.utils import create_logger
  17. from configs import config
  18. from configs import update_config
  19. torch.set_grad_enabled(False)
  20. import torch.backends.cudnn as cudnn
  21. import models
  22. # args
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument('--cfg', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml')
  25. parser.add_argument('--weight-path', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/output/PV/pidnet_small_cityscapes/best.pt')
  26. parser.add_argument('--input', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/video.avi')
  27. parser.add_argument('--output', type=str, default='/home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/test_dataset/PIDNet.mp4')
  28. parser.add_argument('opts',
  29. help="Modify config options using the command-line",
  30. default=None,
  31. nargs=argparse.REMAINDER)
  32. args = parser.parse_args()
  33. update_config(config, args)
  34. # fetch frames
  35. def get_func(inpth, in_q, done):
  36. cap = cv2.VideoCapture(args.input)
  37. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
  38. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
  39. fps = cap.get(cv2.CAP_PROP_FPS)
  40. to_tensor = T.ToTensor(
  41. mean=(0.3257, 0.3690, 0.3223), # city, rgb
  42. std=(0.2112, 0.2148, 0.2115),
  43. )
  44. while cap.isOpened():
  45. ret, frame = cap.read()
  46. if not ret: break
  47. frame = frame[:, :, ::-1]
  48. frame = to_tensor(dict(im=frame, lb=None))['im'].unsqueeze(0)
  49. in_q.put(frame)
  50. in_q.put('quit')
  51. done.wait()
  52. cap.release()
  53. time.sleep(1)
  54. print('input queue done')
  55. # save to video
  56. def save_func(inpth, outpth, out_q):
  57. np.random.seed(123)
  58. palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
  59. cap = cv2.VideoCapture(args.input)
  60. width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
  61. height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
  62. fps = cap.get(cv2.CAP_PROP_FPS)
  63. cap.release()
  64. video_writer = cv2.VideoWriter(outpth,
  65. cv2.VideoWriter_fourcc(*"mp4v"),
  66. fps, (int(width), int(height)))
  67. while True:
  68. out = out_q.get()
  69. if out == 'quit': break
  70. out = out.numpy()
  71. preds = palette[out]
  72. for pred in preds:
  73. video_writer.write(pred)
  74. video_writer.release()
  75. print('output queue done')
  76. # inference a list of frames
  77. def infer_batch(frames):
  78. frames = torch.cat(frames, dim=0).cuda()
  79. H, W = frames.size()[2:]
  80. frames = F.interpolate(frames, size=(768, 768), mode='bilinear',
  81. align_corners=False) # must be divisible by 32
  82. out = model(frames)[0]
  83. out = F.interpolate(out, size=(H, W), mode='bilinear',
  84. align_corners=False).argmax(dim=1).detach().cpu()
  85. out_q.put(out)
  86. if __name__ == '__main__':
  87. # args = parse_args()
  88. logger, final_output_dir, _ = create_logger(
  89. config, args.cfg, 'test')
  90. logger.info(pprint.pformat(args))
  91. logger.info(pprint.pformat(config))
  92. # cudnn related setting
  93. cudnn.benchmark = config.CUDNN.BENCHMARK
  94. cudnn.deterministic = config.CUDNN.DETERMINISTIC
  95. cudnn.enabled = config.CUDNN.ENABLED
  96. # build model
  97. model = model = models.pidnet.get_seg_model(config, imgnet_pretrained=True)
  98. if config.TEST.MODEL_FILE:
  99. model_state_file = config.TEST.MODEL_FILE
  100. else:
  101. model_state_file = os.path.join(final_output_dir, 'best.pt')
  102. logger.info('=> loading model from {}'.format(model_state_file))
  103. pretrained_dict = torch.load(model_state_file)
  104. if 'state_dict' in pretrained_dict:
  105. pretrained_dict = pretrained_dict['state_dict']
  106. model_dict = model.state_dict()
  107. pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
  108. if k[6:] in model_dict.keys()}
  109. for k, _ in pretrained_dict.items():
  110. logger.info(
  111. '=> loading {} from pretrained model'.format(k))
  112. model_dict.update(pretrained_dict)
  113. model.load_state_dict(model_dict)
  114. mp.set_start_method('spawn')
  115. in_q = mp.Queue(1024)
  116. out_q = mp.Queue(1024)
  117. done = mp.Event()
  118. in_worker = mp.Process(target=get_func,
  119. args=(args.input, in_q, done))
  120. out_worker = mp.Process(target=save_func,
  121. args=(args.input, args.output, out_q))
  122. in_worker.start()
  123. out_worker.start()
  124. model.eval()
  125. model = model.cuda()
  126. frames = []
  127. while True:
  128. frame = in_q.get()
  129. if frame == 'quit': break
  130. frames.append(frame)
  131. if len(frames) == 8:
  132. infer_batch(frames)
  133. frames = []
  134. if len(frames) > 0:
  135. infer_batch(frames)
  136. out_q.put('quit')
  137. done.set()
  138. out_worker.join()
  139. in_worker.join()

修改好自己的各个文件路径,执行代码:

 python demo_video.py --cfg /home/zhangwei/Prg/Pycharm/FIles/PIDNet-main/configs/cityscapes/pidnet_small_cityscapes.yaml

稍微等一会儿即可,在test_dataset目录下可以看到生成的mp4文件,打开即可。 

至此,PIDNet的训练教程结束,如果有什么问题可以留言,博主都会一一解答。

更新

在测试demo_video遇到头文件找不到的,可以在下面链接下载(缺少transformer_cv2):

链接:https://pan.baidu.com/s/1vBHFSiGb_vXANGGX1NTI3A
提取码:i10j

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/122375
推荐阅读
相关标签
  

闽ICP备14008679号