当前位置:   article > 正文

语义分割快速入门教程(mmsegmentation平台)_mmsegmentation测试教程

mmsegmentation测试教程

引言

最近项目用到语义分割相关的技术,跟目标检测一样,离不开标注数据集、制作训练数据、模型训练、模型推理测试几个环节,找到了一个比较好的平台mmsegmentation,是香港中文大学-商汤科技联合实验室(MMLab)推出的一个集齐目标检测、语义分割等深度学习的框架和平台,让小白也能快速将论文中的算法模型、网络结构复现落地应用,工欲善其事,必先利其器,那就从搭建环境开始吧!

目录

引言

一、搭建环境

1.  安装cuda、cudnn

2. 安装pytorch-gpu

3. 安装MMCV

4.安装mmsegmentation

5.测试验证

二、制作数据集

1. 搭建环境

2.数据转码和划分数据集

三、模型训练

1.配置文件

2.注意事项

四、模型推理预测

1.单张图片推理

2.视频流推理

 3.摄像头推理预测

致谢


一、搭建环境

参考链接:Get started: Install and Run MMSeg

1.  安装cuda、cudnn

我这边电脑是有两台笔记本,一个是RTX3060(6G独显)的天选2笔记本,另一个是GTX 1050Ti 的戴尔笔记本,安装好显卡驱动,用nvidia-smi可以查看电脑支持CUDA的最高版本,这里以我的电脑为例,最高支持到CUDA 12.2,而目前pytorch-gpu版本是只兼容到12.1,所以为了适配性,建议安装不大于12.1的CUDA版本

这里有显卡驱动与支持CUDA版本对应关系表 ,如果是要安装CUDA 12.X版本,显卡驱动不低于于525.60.13

到官网CUDA ToolkitcuDNN Archive 安装对应的版本包

我这边安装的是CUDA 12.1 和 CUDNN 8.9.2 版本,因为之前天选2电脑是安装了CUDA 12.0版本,在安装mmsegmentation平台环境报错,当时一直无法安装通过,最后把之前的CUDA版本卸载了,重新安装了CUDA才行,希望大家避开这个坑!!!

  1. wget https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
  2. sudo sh cuda_12.1.0_530.30.02_linux.run

 除了驱动无须安装,其他可以正常安装,安装完后在 ~/.zshrc 和 或 ~/.bashrc 文件末尾添加CUDA环境变量,并且生效即可

  1. $ sudo gedit ~/.zshrc
  2. $ source ~/.zshrc
  1. export PATH=/usr/local/cuda-12.1/bin:$PATH
  2. export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
  3. export CUDA_HOME=/usr/local/cuda

解压下载的cudnn压缩包,并且将对应的CUDNN的软件链库拷贝到对应CUDA目录下并赋予权限,可参考这个安装教程

  1. tar -xf cudnn-linux-x86_64-8.9.2.26_cuda12-archive.tar.xz
  2. # dell@wu in ~/cudnn-linux-x86_64-8.9.2.26_cuda12-archive [22:49:11]
  3. $ sudo cp -d lib/* /usr/local/cuda-12.1/lib64/
  4. $ sudo cp include/* /usr/local/cuda-12.1/include/
  5. $ sudo chmod a+r /usr/local/cuda-12.1/include/cudnn.h /usr/local/cuda-12.1/lib64/libcudnn*

安装完毕之后,可以输入以下命令查看安装的版本

  1. cat /usr/local/cuda-12.1/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
  2. nvcc --version

2. 安装pytorch-gpu

创建虚拟环境之后,根据自己电脑配置情况选择对应的pytorch-gpu版本,我这里两台电脑环境有些差异,天选2安装了pytorch-gpu 2.2.0,但戴尔电脑安装了pytorch-gpu 2.1.0,因为我后面在戴尔电脑按照天选2的一样环境安装,发现错误,具体可以往下看,所有只能把pytorch版本往后退。

  1. conda create --name mmsegmentation python=3.8
  2. conda activate mmsegmentation

RTX 3060 天选2电脑环境:

  1. torch 2.2.2
  2. torchaudio 2.2.2
  3. torchvision 0.17.2
  4. mmcv 2.1.0
  5. mmengine 0.10.3
  6. mmsegmentation 1.2.2 /home/mmsegmentation
  7. numpy 1.24.4
  8. onnxruntime 1.15.1
  9. opencv-python 4.9.0.80
  10. openmim 0.3.9

GTX 1050Ti 戴尔电脑环境:

  1. torch 2.1.0+cu121
  2. torchaudio 2.1.0+cu121
  3. torchvision 0.16.0+cu121
  4. mmcv 2.1.0
  5. mmengine 0.10.3
  6. mmsegmentation 1.2.2 /home/mmsegmentation
  7. numpy 1.23.5
  8. onnx 1.4.1
  9. onnxruntime 1.18.1
  10. opencv-python 4.7.0.72
  11. openmim 0.3.9

3. 安装MMCV

MMCV官网教程

  1. pip install -U openmim
  2. mim install mmcv==2.1.0

记住这里安装mmcv,最好是指定版本2.1.0,切勿直接执行 mim install mmcv (避坑)!!!,否则它是默认安装最新版本,有可能出现环境不兼容的问题,我4月底在天选2电脑安装的版本是 mmcv 2.1.0,当时是没有任何问题。但最近在戴尔电脑安装的时候,发现mmcv 更新到最新版本 2.2.0,而我的numpy 默认装了1.24.x版本,结果导致出现 module 'numpy' has no attribute 'object' [closed] 错误,后来尝试把numpy版本降到1.23.5版本,但运行的时候,仍有如下错误,貌似是mmcv 2.2.0版本不兼容,我就尝试把conda的虚拟环境重新卸载和安装,折腾了好几次还是失败。最后只能把pytorch版本降到 2.1.0,重新走一遍流程,参考这个教程才算成功安装成功。

pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html

  1. Traceback (most recent call last):
  2. File "demo/image_demo.py", line 6, in <module>
  3. from mmseg.apis import inference_model, init_model, show_result_pyplot
  4. File "/root/mmsegmentation/mmseg/__init__.py", line 61, in <module>
  5. assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
  6. AssertionError: MMCV==2.2.0 is used but incompatible. Please install mmcv>=2.0.0rc4.

4.安装mmsegmentation

  1. git clone -b main https://github.com/open-mmlab/mmsegmentation.git
  2. cd mmsegmentation
  3. pip install -v -e .

5.测试验证

  1. mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest .
  2. python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cuda:0 --out-file result.jpg

二、制作数据集

1. 搭建环境

这里我推荐用X-AnyLabeling,这个本身集合目标检测、语义分割算法模型,快速进行图像数据标注,建议大家用conda单独创建一个虚拟环境,按照下述步骤安装配置环境

  1. git clone https://github.com/CVHub520/X-AnyLabeling.git
  2. # upgrade pip to its latest version
  3. pip install -U pip
  4. pip install -r requirements-gpu-dev.txt
  5. python anylabeling/app.py

大家可以在X-AnyLabeling v0.2.0 或者X-AnyLabeling 模型库 这里找到对应的算法模型权重包,提前下载,参考加载内置模型 教程来配置相关文件,就可以用SAM(Segment Anything Model )模型(SAM是Meta 公司提出的分割一切模型)完成大部分场景的自动标注

下载权重文件和对应的.yaml配置文件,放在model路径下,把对应的encoder_model_path 和 decoder_model_path 替换成自己本地的模型权重路径,最后选择加载自定义模型,即可使用

执行下面命令即可运行界面

python3 anylabeling/app.py 

标注数据时,几点提醒:

1.关闭 Save With Image Data(不会把图片信息记录在.json文件里)

2.选择 Save Automatically,自动保存

3.标注生成的.json文件保存到跟图片同一个路径下

生成的json文件内容如下:

  1. {
  2. "version": "2.3.5",
  3. "flags": {},
  4. "shapes": [
  5. {
  6. "label": "watermelon",
  7. "points": [
  8. [
  9. 329.0,
  10. 12.0
  11. ],
  12. [
  13. 329.0,
  14. 31.0
  15. ],
  16. [
  17. 330.0,
  18. 32.0
  19. ],
  20. [
  21. 330.0,
  22. 33.0
  23. ],
  24. [
  25. 329.0,
  26. 34.0
  27. ],
  28. [
  29. 329.0,
  30. 36.0
  31. ],
  32. [
  33. 330.0,
  34. 37.0
  35. ],
  36. [
  37. 330.0,
  38. 58.0
  39. ],
  40. [
  41. 331.0,
  42. 59.0
  43. ],
  44. [
  45. 331.0,
  46. 64.0
  47. ],
  48. [
  49. 332.0,
  50. 65.0
  51. ],
  52. [
  53. 348.0,
  54. 65.0
  55. ],
  56. [
  57. 349.0,
  58. 64.0
  59. ],
  60. [
  61. 350.0,
  62. 64.0
  63. ],
  64. [
  65. 351.0,
  66. 65.0
  67. ],
  68. [
  69. 359.0,
  70. 65.0
  71. ],
  72. [
  73. 360.0,
  74. 64.0
  75. ],
  76. [
  77. 363.0,
  78. 64.0
  79. ],
  80. [
  81. 364.0,
  82. 65.0
  83. ],
  84. [
  85. 370.0,
  86. 65.0
  87. ],
  88. [
  89. 371.0,
  90. 64.0
  91. ],
  92. [
  93. 373.0,
  94. 64.0
  95. ],
  96. [
  97. 374.0,
  98. 65.0
  99. ],
  100. [
  101. 376.0,
  102. 65.0
  103. ],
  104. [
  105. 377.0,
  106. 64.0
  107. ],
  108. [
  109. 378.0,
  110. 65.0
  111. ],
  112. [
  113. 392.0,
  114. 65.0
  115. ],
  116. [
  117. 393.0,
  118. 66.0
  119. ],
  120. [
  121. 394.0,
  122. 66.0
  123. ],
  124. [
  125. 396.0,
  126. 64.0
  127. ],
  128. [
  129. 396.0,
  130. 62.0
  131. ],
  132. [
  133. 397.0,
  134. 61.0
  135. ],
  136. [
  137. 397.0,
  138. 54.0
  139. ],
  140. [
  141. 398.0,
  142. 53.0
  143. ],
  144. [
  145. 398.0,
  146. 48.0
  147. ],
  148. [
  149. 399.0,
  150. 47.0
  151. ],
  152. [
  153. 399.0,
  154. 43.0
  155. ],
  156. [
  157. 400.0,
  158. 42.0
  159. ],
  160. [
  161. 400.0,
  162. 38.0
  163. ],
  164. [
  165. 401.0,
  166. 37.0
  167. ],
  168. [
  169. 401.0,
  170. 29.0
  171. ],
  172. [
  173. 404.0,
  174. 26.0
  175. ],
  176. [
  177. 404.0,
  178. 25.0
  179. ],
  180. [
  181. 405.0,
  182. 24.0
  183. ],
  184. [
  185. 405.0,
  186. 19.0
  187. ],
  188. [
  189. 404.0,
  190. 18.0
  191. ],
  192. [
  193. 404.0,
  194. 17.0
  195. ],
  196. [
  197. 403.0,
  198. 16.0
  199. ],
  200. [
  201. 403.0,
  202. 15.0
  203. ],
  204. [
  205. 402.0,
  206. 14.0
  207. ],
  208. [
  209. 402.0,
  210. 13.0
  211. ],
  212. [
  213. 400.0,
  214. 11.0
  215. ],
  216. [
  217. 400.0,
  218. 10.0
  219. ],
  220. [
  221. 399.0,
  222. 10.0
  223. ],
  224. [
  225. 398.0,
  226. 9.0
  227. ],
  228. [
  229. 391.0,
  230. 9.0
  231. ],
  232. [
  233. 390.0,
  234. 8.0
  235. ],
  236. [
  237. 382.0,
  238. 8.0
  239. ],
  240. [
  241. 381.0,
  242. 9.0
  243. ],
  244. [
  245. 379.0,
  246. 9.0
  247. ],
  248. [
  249. 378.0,
  250. 8.0
  251. ],
  252. [
  253. 376.0,
  254. 8.0
  255. ],
  256. [
  257. 375.0,
  258. 9.0
  259. ],
  260. [
  261. 374.0,
  262. 9.0
  263. ],
  264. [
  265. 373.0,
  266. 8.0
  267. ],
  268. [
  269. 371.0,
  270. 8.0
  271. ],
  272. [
  273. 370.0,
  274. 9.0
  275. ],
  276. [
  277. 368.0,
  278. 9.0
  279. ],
  280. [
  281. 367.0,
  282. 8.0
  283. ],
  284. [
  285. 364.0,
  286. 8.0
  287. ],
  288. [
  289. 363.0,
  290. 9.0
  291. ],
  292. [
  293. 362.0,
  294. 8.0
  295. ],
  296. [
  297. 360.0,
  298. 8.0
  299. ],
  300. [
  301. 359.0,
  302. 9.0
  303. ],
  304. [
  305. 356.0,
  306. 9.0
  307. ],
  308. [
  309. 355.0,
  310. 8.0
  311. ],
  312. [
  313. 354.0,
  314. 9.0
  315. ],
  316. [
  317. 348.0,
  318. 9.0
  319. ],
  320. [
  321. 347.0,
  322. 10.0
  323. ],
  324. [
  325. 345.0,
  326. 10.0
  327. ],
  328. [
  329. 344.0,
  330. 9.0
  331. ],
  332. [
  333. 343.0,
  334. 9.0
  335. ],
  336. [
  337. 342.0,
  338. 10.0
  339. ],
  340. [
  341. 337.0,
  342. 10.0
  343. ],
  344. [
  345. 336.0,
  346. 11.0
  347. ],
  348. [
  349. 334.0,
  350. 11.0
  351. ],
  352. [
  353. 333.0,
  354. 10.0
  355. ],
  356. [
  357. 332.0,
  358. 10.0
  359. ],
  360. [
  361. 330.0,
  362. 12.0
  363. ]
  364. ],
  365. "group_id": null,
  366. "description": "",
  367. "difficult": false,
  368. "shape_type": "polygon",
  369. "flags": {},
  370. "attributes": {}
  371. },
  372. {
  373. "label": "lawn",
  374. "points": [
  375. [
  376. 0.0,
  377. 0.0
  378. ],
  379. [
  380. 0.0,
  381. 115.0
  382. ],
  383. [
  384. 2.0,
  385. 116.0
  386. ],
  387. [
  388. 13.0,
  389. 138.0
  390. ],
  391. [
  392. 24.0,
  393. 150.0
  394. ],
  395. [
  396. 35.0,
  397. 157.0
  398. ],
  399. [
  400. 52.0,
  401. 160.0
  402. ],
  403. [
  404. 76.0,
  405. 159.0
  406. ],
  407. [
  408. 83.0,
  409. 152.0
  410. ],
  411. [
  412. 89.0,
  413. 143.0
  414. ],
  415. [
  416. 93.0,
  417. 130.0
  418. ],
  419. [
  420. 92.0,
  421. 128.0
  422. ],
  423. [
  424. 93.0,
  425. 120.0
  426. ],
  427. [
  428. 95.0,
  429. 118.0
  430. ],
  431. [
  432. 100.0,
  433. 118.0
  434. ],
  435. [
  436. 109.0,
  437. 122.0
  438. ],
  439. [
  440. 123.0,
  441. 122.0
  442. ],
  443. [
  444. 138.0,
  445. 132.0
  446. ],
  447. [
  448. 150.0,
  449. 131.0
  450. ],
  451. [
  452. 161.0,
  453. 124.0
  454. ],
  455. [
  456. 164.0,
  457. 125.0
  458. ],
  459. [
  460. 211.0,
  461. 124.0
  462. ],
  463. [
  464. 218.0,
  465. 126.0
  466. ],
  467. [
  468. 226.0,
  469. 134.0
  470. ],
  471. [
  472. 229.0,
  473. 135.0
  474. ],
  475. [
  476. 232.0,
  477. 139.0
  478. ],
  479. [
  480. 237.0,
  481. 142.0
  482. ],
  483. [
  484. 248.0,
  485. 143.0
  486. ],
  487. [
  488. 256.0,
  489. 140.0
  490. ],
  491. [
  492. 267.0,
  493. 130.0
  494. ],
  495. [
  496. 270.0,
  497. 120.0
  498. ],
  499. [
  500. 274.0,
  501. 115.0
  502. ],
  503. [
  504. 279.0,
  505. 112.0
  506. ],
  507. [
  508. 286.0,
  509. 111.0
  510. ],
  511. [
  512. 288.0,
  513. 109.0
  514. ],
  515. [
  516. 293.0,
  517. 109.0
  518. ],
  519. [
  520. 294.0,
  521. 108.0
  522. ],
  523. [
  524. 292.0,
  525. 104.0
  526. ],
  527. [
  528. 293.0,
  529. 100.0
  530. ],
  531. [
  532. 298.0,
  533. 101.0
  534. ],
  535. [
  536. 297.0,
  537. 105.0
  538. ],
  539. [
  540. 298.0,
  541. 106.0
  542. ],
  543. [
  544. 311.0,
  545. 102.0
  546. ],
  547. [
  548. 311.0,
  549. 101.0
  550. ],
  551. [
  552. 304.0,
  553. 101.0
  554. ],
  555. [
  556. 301.0,
  557. 96.0
  558. ],
  559. [
  560. 293.0,
  561. 98.0
  562. ],
  563. [
  564. 290.0,
  565. 95.0
  566. ],
  567. [
  568. 290.0,
  569. 92.0
  570. ],
  571. [
  572. 288.0,
  573. 89.0
  574. ],
  575. [
  576. 289.0,
  577. 86.0
  578. ],
  579. [
  580. 288.0,
  581. 84.0
  582. ],
  583. [
  584. 289.0,
  585. 81.0
  586. ],
  587. [
  588. 288.0,
  589. 51.0
  590. ],
  591. [
  592. 284.0,
  593. 46.0
  594. ],
  595. [
  596. 232.0,
  597. 22.0
  598. ],
  599. [
  600. 227.0,
  601. 21.0
  602. ],
  603. [
  604. 208.0,
  605. 11.0
  606. ],
  607. [
  608. 203.0,
  609. 10.0
  610. ],
  611. [
  612. 194.0,
  613. 5.0
  614. ],
  615. [
  616. 182.0,
  617. 2.0
  618. ],
  619. [
  620. 180.0,
  621. 0.0
  622. ]
  623. ],
  624. "group_id": null,
  625. "description": "",
  626. "difficult": false,
  627. "shape_type": "polygon",
  628. "flags": {},
  629. "attributes": {}
  630. }
  631. ],
  632. "imagePath": "2.jpg",
  633. "imageData": null,
  634. "imageHeight": 480,
  635. "imageWidth": 640,
  636. "text": ""
  637. }

2.数据转码和划分数据集

把标注的数据转成整数掩码格式数据,可参考子濠师兄的Label2Everything代码 和 B站教程,我自己把代码修改了一下,可以运行以下代码,Dataset_Path是之前标注生成的.json文件和图片的文件夹,修改自己的路径和类别之后,就可以批量转成掩码格式数据并划分训练和测试数据集

  1. import os
  2. import json
  3. import numpy as np
  4. import cv2
  5. import shutil
  6. from tqdm import tqdm
  7. import random
  8. Dataset_Path = '/home/labelme/examples/garden'
  9. # 每个类别的信息及画mask的顺序(按照由大到小,由粗到精的顺序)
  10. class_info = [
  11. {'label': 'dog', 'type': 'polygon', 'color': 1}, # polygon 多段线
  12. {'label': 'person', 'type': 'polygon', 'color': 2},
  13. ]
  14. # 按顺序将mask绘制在空白图上
  15. def labelme2mask_single_img(img_path, labelme_json_path):
  16. '''
  17. 输入原始图像路径和labelme标注路径,输出 mask
  18. '''
  19. img_bgr = cv2.imread(img_path)
  20. img_mask = np.zeros(img_bgr.shape[:2]) # 创建空白图像 0-背景
  21. with open(labelme_json_path, 'r', encoding='utf-8') as f:
  22. labelme = json.load(f)
  23. for one_class in class_info: # 按顺序遍历每一个类别
  24. for each in labelme['shapes']: # 遍历所有标注,找到属于当前类别的标注
  25. if each['label'] == one_class['label']:
  26. if one_class['type'] == 'polygon': # polygon 多段线标注
  27. # 获取点的坐标
  28. points = [np.array(each['points'], dtype=np.int32).reshape((-1, 1, 2))]
  29. # 在空白图上画 mask(闭合区域)
  30. img_mask = cv2.fillPoly(img_mask, points, color=one_class['color'])
  31. elif one_class['type'] == 'line' or one_class['type'] == 'linestrip': # line 或者 linestrip 线段标注
  32. # 获取点的坐标
  33. points = [np.array(each['points'], dtype=np.int32).reshape((-1, 1, 2))]
  34. # 在空白图上画 mask(非闭合区域)
  35. img_mask = cv2.polylines(img_mask, points, isClosed=False, color=one_class['color'],
  36. thickness=one_class['thickness'])
  37. elif one_class['type'] == 'circle': # circle 圆形标注
  38. points = np.array(each['points'], dtype=np.int32)
  39. center_x, center_y = points[0][0], points[0][1] # 圆心点坐标
  40. edge_x, edge_y = points[1][0], points[1][1] # 圆周点坐标
  41. radius = np.linalg.norm(np.array([center_x, center_y] - np.array([edge_x, edge_y]))).astype(
  42. 'uint32') # 半径
  43. img_mask = cv2.circle(img_mask, (center_x, center_y), radius, one_class['color'],
  44. one_class['thickness'])
  45. else:
  46. print('未知标注类型', one_class['type'])
  47. return img_mask
  48. os.chdir(Dataset_Path)
  49. os.mkdir('ann_dir')
  50. os.chdir('img_dir')
  51. for img_path in tqdm(os.listdir()):
  52. try:
  53. labelme_json_path = os.path.join('../', 'labelme_jsons', '.'.join(img_path.split('.')[:-1]) + '.json')
  54. img_mask = labelme2mask_single_img(img_path, labelme_json_path)
  55. mask_path = img_path.split('.')[0] + '.png'
  56. cv2.imwrite(os.path.join('../', 'ann_dir', mask_path), img_mask)
  57. except Exception as E:
  58. print(img_path, '转换失败', E)
  59. # 划分训练-测试集
  60. os.chdir(Dataset_Path)
  61. os.mkdir('train')
  62. os.mkdir('val')
  63. test_frac = 0.2 # 测试集比例
  64. random.seed(123) # 随机数种子,便于复现
  65. folder = 'img_dir'
  66. img_paths = os.listdir(folder)
  67. random.shuffle(img_paths) # 随机打乱
  68. val_number = int(len(img_paths) * test_frac) # 测试集文件个数
  69. train_files = img_paths[val_number:] # 训练集文件名列表
  70. val_files = img_paths[:val_number] # 测试集文件名列表
  71. print('数据集文件总数', len(img_paths))
  72. print('训练集文件个数', len(train_files))
  73. print('测试集文件个数', len(val_files))
  74. for each in tqdm(train_files):
  75. src_path = os.path.join(folder, each)
  76. dst_path = os.path.join('train', each)
  77. shutil.move(src_path, dst_path)
  78. for each in tqdm(val_files):
  79. src_path = os.path.join(folder, each)
  80. dst_path = os.path.join('val', each)
  81. shutil.move(src_path, dst_path)
  82. shutil.move('train', 'img_dir/train')
  83. shutil.move('val', 'img_dir/val')
  84. folder = 'ann_dir'
  85. os.mkdir('train')
  86. os.mkdir('val')
  87. for each in tqdm(train_files):
  88. src_path = os.path.join(folder, each.split('.')[0] + '.png')
  89. dst_path = os.path.join('train', each.split('.')[0] + '.png')
  90. shutil.move(src_path, dst_path)
  91. for each in tqdm(val_files):
  92. src_path = os.path.join(folder, each.split('.')[0] + '.png')
  93. dst_path = os.path.join('val', each.split('.')[0] + '.png')
  94. shutil.move(src_path, dst_path)
  95. shutil.move('train', 'ann_dir/train')
  96. shutil.move('val', 'ann_dir/val')

三、模型训练

在开始训练数据之前,我认真地、反复阅读了几篇文章,按照他们的步骤,配置了参数文件训练

超详细!手把手带你轻松用 MMSegmentation 跑语义分割数据集

【Python】mmSegmentation语义分割框架教程(1.x版本)

mmsegmentation 训练自己的数据集

但是结果还是报一些奇怪的错误(KeyError: 'dataset_type is not in the mmseg::dataset registry),我在github上issue反映了具体的问题 

后来我还是参考了教程同济子豪兄——两天搞定人工智能毕业设计之【语义分割】,才顺利训练,感谢他的无私开源(代码链接),让人少走很多弯路

1.配置文件

自己可根据实际情况选择对应网络结构模型来配置文件,示例如下:

  • mmsegmentation/mmseg/datasets/watermelon_dataset.py
  • mmsegmentation/mmseg/datasets/init.py
  • mmsegmentation/configs/base/datasets/watermelon_segmentation_pipeline.py
  • mmsegmentation/configs/pspnet/pspnet_r50-d8_4xb2-40k_watermelon_segmen-512x1024.py

第一个文件 mmsegmentation/mmseg/datasets/watermelon_dataset.py

大家可以自定义训练数据集名称,以及命名数据集的classes和修改不同类别对应的palette

  1. import mmengine.fileio as fileio
  2. from mmseg.registry import DATASETS
  3. from .basesegdataset import BaseSegDataset
  4. @DATASETS.register_module()
  5. class WatermelonSegmentationDataset(BaseSegDataset):
  6. METAINFO = dict(
  7. classes=('background', 'red', 'green', 'white', 'seed-black', 'seed-white'),
  8. palette=[[127, 127, 127], [200, 0, 0], [0, 200, 0], [144, 238, 144], [30, 30, 30], [251, 189, 8]])
  9. def __init__(self,
  10. img_suffix='.jpg',
  11. seg_map_suffix='.png',
  12. reduce_zero_label=False,
  13. **kwargs) -> None:
  14. super().__init__(
  15. img_suffix=img_suffix,
  16. seg_map_suffix=seg_map_suffix,
  17. reduce_zero_label=reduce_zero_label,
  18. **kwargs)
  19. assert fileio.exists(
  20. self.data_prefix['img_path'], backend_args=self.backend_args)

 在第二个文件 mmsegmentation/mmseg/datasets/init.py 的末尾添加自定义数据集的类名

  1. __all__ = [
  2. 'BaseSegDataset', 'BioMedical3DRandomCrop', 'BioMedical3DRandomFlip',
  3. 'CityscapesDataset', 'PascalVOCDataset', 'ADE20KDataset',
  4. 'PascalContextDataset', 'PascalContextDataset59', 'ChaseDB1Dataset',
  5. 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'DarkZurichDataset',
  6. 'NightDrivingDataset', 'COCOStuffDataset', 'LoveDADataset',
  7. 'MultiImageMixDataset', 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset',
  8. 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
  9. 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
  10. 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
  11. 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
  12. 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
  13. 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
  14. 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
  15. 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
  16. 'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
  17. 'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
  18. 'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
  19. 'ConcatCDInput', 'BaseCDDataset', 'DSDLSegDataset', 'BDD100KDataset',
  20. 'NYUDataset', 'HSIDrive20Dataset', 'WatermelonSegmentationDataset'
  21. ]

第三个文件 mmsegmentation/configs/base/datasets/watermelon_segmentation_pipeline.py 是训练数据和预处理配置文件,大家根据自己的情况修改data_root和crop_size,其他可以默认不改

  1. # dataset settings
  2. dataset_type = 'WatermelonSegmentationDataset'
  3. # 数据集路径(相对于mmsegmentation主目录)
  4. data_root = '/home/deep_learning_collection/mmsegmentation/data/watermelon/'
  5. crop_size = (512, 512) # 输入模型的图像裁剪尺寸,一般是128的倍数,越小显存开销越少
  6. train_pipeline = [
  7. dict(type='LoadImageFromFile'),
  8. dict(type='LoadAnnotations'),
  9. dict(
  10. type='Resize',
  11. # scale=(720, 1280),
  12. scale=(2048, 1024),
  13. ratio_range=(0.5, 2.0),
  14. keep_ratio=True),
  15. dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
  16. dict(type='RandomFlip', prob=0.5),
  17. dict(type='PhotoMetricDistortion'),
  18. dict(type='PackSegInputs')
  19. ]
  20. test_pipeline = [
  21. dict(type='LoadImageFromFile'),
  22. # dict(type='Resize', scale=(720, 1280), keep_ratio=True),
  23. dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
  24. # add loading annotation after ``Resize`` because ground truth
  25. # does not need to do resize data transform
  26. dict(type='LoadAnnotations'),
  27. dict(type='PackSegInputs')
  28. ]
  29. img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
  30. tta_pipeline = [
  31. dict(type='LoadImageFromFile', backend_args=None),
  32. dict(
  33. type='TestTimeAug',
  34. transforms=[[
  35. dict(type='Resize', scale_factor=r, keep_ratio=True)
  36. for r in img_ratios
  37. ],
  38. [
  39. dict(type='RandomFlip', prob=0., direction='horizontal'),
  40. dict(type='RandomFlip', prob=1., direction='horizontal')
  41. ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]])
  42. ]
  43. train_dataloader = dict(
  44. batch_size=4,
  45. num_workers=4,
  46. persistent_workers=True,
  47. sampler=dict(type='InfiniteSampler', shuffle=True),
  48. dataset=dict(
  49. type='dataset_type',
  50. data_root='data_root',
  51. data_prefix=dict(
  52. img_path='img_dir/train', seg_map_path='ann_dir/train'),
  53. pipeline='train_pipeline'))
  54. val_dataloader = dict(
  55. batch_size=1,
  56. num_workers=4,
  57. persistent_workers=True,
  58. sampler=dict(type='DefaultSampler', shuffle=False),
  59. dataset=dict(
  60. type='dataset_type',
  61. data_root='data_root',
  62. data_prefix=dict(
  63. img_path='img_dir/val', seg_map_path='ann_dir/val'),
  64. pipeline='test_pipeline'))
  65. test_dataloader = 'val_dataloader'
  66. # val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'], ignore_index=2)
  67. val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU', 'mDice', 'mFscore'])
  68. test_evaluator = val_evaluator

第四个文件 mmsegmentation/configs/pspnet/pspnet_r50-d8_4xb2-40k_watermelon_segmen-512x1024.py,是调用网络模型和之前配置好的文件

  1. # _base_ = [
  2. # '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/lawn_segmentation_pipeline.py',
  3. # '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
  4. # ]
  5. _base_ = [
  6. '/home/deep_learning_collection/mmsegmentation/configs/_base_/models/pspnet_r50-d8.py',
  7. '/home/deep_learning_collection/mmsegmentation/configs/_base_/datasets/watermelon_segmentation_pipeline.py',
  8. '/home/deep_learning_collection/mmsegmentation/configs/_base_/default_runtime.py',
  9. '/home/deep_learning_collection/mmsegmentation/configs/_base_/schedules/schedule_40k.py'
  10. ]
  11. crop_size = (512, 1024)
  12. data_preprocessor = dict(size=crop_size)
  13. model = dict(data_preprocessor=data_preprocessor)

运行以下代码,就会生成一个包含所有配置好训练参数等信息的文件,其实就会开始愉快地训练

python3 tools/train.py configs/pspnet/pspnet_r50-d8_4xb2-40k_watermenlon_segmen-512x1024.py

2.注意事项

但是我这边不知道为什么训练的时候出现最开始提及的错误,后来我这边把所有之前的配置信息写在一个代码文件,直接运行,就跑通了,只要修改好data_root,crop_size,dataset_type名称,以及train/val/test对应的type类型,训练次数可以根据实际情况调整,train_cfg = dict(max_iters=30000, type='IterBasedTrainLoop', val_interval=1000),其他可以默认不改。

  1. crop_size = (
  2. 512,
  3. 512,
  4. )
  5. data_preprocessor = dict(
  6. bgr_to_rgb=True,
  7. mean=[
  8. 123.675,
  9. 116.28,
  10. 103.53,
  11. ],
  12. pad_val=0,
  13. seg_pad_val=255,
  14. size=(
  15. 512,
  16. 1024,
  17. ),
  18. std=[
  19. 58.395,
  20. 57.12,
  21. 57.375,
  22. ],
  23. type='SegDataPreProcessor')
  24. data_root = '/home/deep_learning_collection/mmsegmentation/data/watermelon/'
  25. dataset_type = 'WatermelonSegmentationDataset'
  26. default_hooks = dict(
  27. checkpoint=dict(
  28. by_epoch=False,
  29. interval=2500,
  30. max_keep_ckpts=2,
  31. save_best='mIoU',
  32. type='CheckpointHook'),
  33. logger=dict(interval=100, log_metric_by_epoch=False, type='LoggerHook'),
  34. param_scheduler=dict(type='ParamSchedulerHook'),
  35. sampler_seed=dict(type='DistSamplerSeedHook'),
  36. timer=dict(type='IterTimerHook'),
  37. visualization=dict(type='SegVisualizationHook'))
  38. default_scope = 'mmseg'
  39. env_cfg = dict(
  40. cudnn_benchmark=True,
  41. dist_cfg=dict(backend='nccl'),
  42. mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
  43. img_ratios = [
  44. 0.5,
  45. 0.75,
  46. 1.0,
  47. 1.25,
  48. 1.5,
  49. 1.75,
  50. ]
  51. load_from = None
  52. log_level = 'INFO'
  53. log_processor = dict(by_epoch=False)
  54. model = dict(
  55. auxiliary_head=[
  56. dict(
  57. align_corners=False,
  58. channels=32,
  59. concat_input=False,
  60. in_channels=128,
  61. in_index=-2,
  62. loss_decode=dict(
  63. loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=True),
  64. norm_cfg=dict(requires_grad=True, type='BN'),
  65. num_classes=2,
  66. num_convs=1,
  67. type='FCNHead'),
  68. dict(
  69. align_corners=False,
  70. channels=32,
  71. concat_input=False,
  72. in_channels=64,
  73. in_index=-3,
  74. loss_decode=dict(
  75. loss_weight=0.4, type='CrossEntropyLoss', use_sigmoid=True),
  76. norm_cfg=dict(requires_grad=True, type='BN'),
  77. num_classes=2,
  78. num_convs=1,
  79. type='FCNHead'),
  80. ],
  81. backbone=dict(
  82. align_corners=False,
  83. downsample_dw_channels=(
  84. 32,
  85. 48,
  86. ),
  87. fusion_out_channels=128,
  88. global_block_channels=(
  89. 64,
  90. 96,
  91. 128,
  92. ),
  93. global_block_strides=(
  94. 2,
  95. 2,
  96. 1,
  97. ),
  98. global_in_channels=64,
  99. global_out_channels=128,
  100. higher_in_channels=64,
  101. lower_in_channels=128,
  102. norm_cfg=dict(requires_grad=True, type='BN'),
  103. out_indices=(
  104. 0,
  105. 1,
  106. 2,
  107. ),
  108. type='FastSCNN'),
  109. data_preprocessor=dict(
  110. bgr_to_rgb=True,
  111. mean=[
  112. 123.675,
  113. 116.28,
  114. 103.53,
  115. ],
  116. pad_val=0,
  117. seg_pad_val=255,
  118. size=(
  119. 512,
  120. 1024,
  121. ),
  122. std=[
  123. 58.395,
  124. 57.12,
  125. 57.375,
  126. ],
  127. type='SegDataPreProcessor'),
  128. decode_head=dict(
  129. align_corners=False,
  130. channels=128,
  131. concat_input=False,
  132. in_channels=128,
  133. in_index=-1,
  134. loss_decode=dict(
  135. loss_weight=1, type='CrossEntropyLoss', use_sigmoid=True),
  136. norm_cfg=dict(requires_grad=True, type='BN'),
  137. num_classes=2,
  138. type='DepthwiseSeparableFCNHead'),
  139. test_cfg=dict(mode='whole'),
  140. train_cfg=dict(),
  141. type='EncoderDecoder')
  142. norm_cfg = dict(requires_grad=True, type='BN')
  143. optim_wrapper = dict(
  144. clip_grad=None,
  145. optimizer=dict(lr=0.12, momentum=0.9, type='SGD', weight_decay=4e-05),
  146. type='OptimWrapper')
  147. optimizer = dict(lr=0.12, momentum=0.9, type='SGD', weight_decay=4e-05)
  148. param_scheduler = [
  149. dict(
  150. begin=0,
  151. by_epoch=False,
  152. end=160000,
  153. eta_min=0.0001,
  154. power=0.9,
  155. type='PolyLR'),
  156. ]
  157. randomness = dict(seed=0)
  158. resume = False
  159. test_cfg = dict(type='TestLoop')
  160. test_dataloader = dict(
  161. batch_size=8,
  162. dataset=dict(
  163. data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
  164. data_root='/home/deep_learning_collection/mmsegmentation/data/watermelon/',
  165. pipeline=[
  166. dict(type='LoadImageFromFile'),
  167. dict(keep_ratio=True, scale=(
  168. 2048,
  169. 1024,
  170. ), type='Resize'),
  171. dict(type='LoadAnnotations'),
  172. dict(type='PackSegInputs'),
  173. ],
  174. type='WatermelonSegmentationDataset'),
  175. num_workers=4,
  176. persistent_workers=True,
  177. sampler=dict(shuffle=False, type='DefaultSampler'))
  178. test_evaluator = dict(
  179. iou_metrics=[
  180. 'mIoU',
  181. 'mDice',
  182. 'mFscore',
  183. ], type='IoUMetric')
  184. test_pipeline = [
  185. dict(type='LoadImageFromFile'),
  186. dict(keep_ratio=True, scale=(
  187. 2048,
  188. 1024,
  189. ), type='Resize'),
  190. dict(type='LoadAnnotations'),
  191. dict(type='PackSegInputs'),
  192. ]
  193. train_cfg = dict(max_iters=30000, type='IterBasedTrainLoop', val_interval=1000)
  194. train_dataloader = dict(
  195. batch_size=16,
  196. dataset=dict(
  197. data_prefix=dict(
  198. img_path='img_dir/train', seg_map_path='ann_dir/train'),
  199. data_root='/home/deep_learning_collection/mmsegmentation/data/watermelon/',
  200. pipeline=[
  201. dict(type='LoadImageFromFile'),
  202. dict(type='LoadAnnotations'),
  203. dict(
  204. keep_ratio=True,
  205. ratio_range=(
  206. 0.5,
  207. 2.0,
  208. ),
  209. scale=(
  210. 2048,
  211. 1024,
  212. ),
  213. type='RandomResize'),
  214. dict(
  215. cat_max_ratio=0.75, crop_size=(
  216. 512,
  217. 512,
  218. ), type='RandomCrop'),
  219. dict(prob=0.5, type='RandomFlip'),
  220. dict(type='PhotoMetricDistortion'),
  221. dict(type='PackSegInputs'),
  222. ],
  223. type='WatermelonSegmentationDataset'),
  224. num_workers=8,
  225. persistent_workers=True,
  226. sampler=dict(shuffle=True, type='InfiniteSampler'))
  227. train_pipeline = [
  228. dict(type='LoadImageFromFile'),
  229. dict(type='LoadAnnotations'),
  230. dict(
  231. keep_ratio=True,
  232. ratio_range=(
  233. 0.5,
  234. 2.0,
  235. ),
  236. scale=(
  237. 2048,
  238. 1024,
  239. ),
  240. type='RandomResize'),
  241. dict(cat_max_ratio=0.75, crop_size=(
  242. 512,
  243. 512,
  244. ), type='RandomCrop'),
  245. dict(prob=0.5, type='RandomFlip'),
  246. dict(type='PhotoMetricDistortion'),
  247. dict(type='PackSegInputs'),
  248. ]
  249. tta_model = dict(type='SegTTAModel')
  250. tta_pipeline = [
  251. dict(file_client_args=dict(backend='disk'), type='LoadImageFromFile'),
  252. dict(
  253. transforms=[
  254. [
  255. dict(keep_ratio=True, scale_factor=0.5, type='Resize'),
  256. dict(keep_ratio=True, scale_factor=0.75, type='Resize'),
  257. dict(keep_ratio=True, scale_factor=1.0, type='Resize'),
  258. dict(keep_ratio=True, scale_factor=1.25, type='Resize'),
  259. dict(keep_ratio=True, scale_factor=1.5, type='Resize'),
  260. dict(keep_ratio=True, scale_factor=1.75, type='Resize'),
  261. ],
  262. [
  263. dict(direction='horizontal', prob=0.0, type='RandomFlip'),
  264. dict(direction='horizontal', prob=1.0, type='RandomFlip'),
  265. ],
  266. [
  267. dict(type='LoadAnnotations'),
  268. ],
  269. [
  270. dict(type='PackSegInputs'),
  271. ],
  272. ],
  273. type='TestTimeAug'),
  274. ]
  275. val_cfg = dict(type='ValLoop')
  276. val_dataloader = dict(
  277. batch_size=8,
  278. dataset=dict(
  279. data_prefix=dict(img_path='img_dir/val', seg_map_path='ann_dir/val'),
  280. data_root='/home/deep_learning_collection/mmsegmentation/data/watermelon/',
  281. pipeline=[
  282. dict(type='LoadImageFromFile'),
  283. dict(keep_ratio=True, scale=(
  284. 2048,
  285. 1024,
  286. ), type='Resize'),
  287. dict(type='LoadAnnotations'),
  288. dict(type='PackSegInputs'),
  289. ],
  290. type='WatermelonSegmentationDataset'),
  291. num_workers=4,
  292. persistent_workers=True,
  293. sampler=dict(shuffle=False, type='DefaultSampler'))
  294. val_evaluator = dict(
  295. iou_metrics=[
  296. 'mIoU',
  297. 'mDice',
  298. 'mFscore',
  299. ], type='IoUMetric')
  300. vis_backends = [
  301. dict(type='LocalVisBackend'),
  302. ]
  303. visualizer = dict(
  304. name='visualizer',
  305. type='SegLocalVisualizer',
  306. vis_backends=[
  307. dict(type='LocalVisBackend'),
  308. ])
  309. work_dir = '/home/deep_learning_collection/mmsegmentation/outputs/watermenlon_FastSCNN'

四、模型推理预测

这里提供单张图片、视频流、接入摄像头的推理预测代码,只要把上述的运行代码文件+训练好权重模型对应放好,就可以正常使用了

1.单张图片推理

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from mmseg.apis import init_model, inference_model, show_result_pyplot
  4. import mmcv
  5. import cv2
  6. # 模型 config 配置文件
  7. config_file = '/home/mmsegmentation/Zihao-Configs/ZihaoDataset_FastSCNN_20230818.py'
  8. # 模型 checkpoint 权重文件
  9. checkpoint_file = '/home/mmsegmentation/outputs/20240425_211259/best_mIoU_iter_30000.pth'
  10. # device = 'cpu'
  11. device = 'cuda:0'
  12. model = init_model(config_file, checkpoint_file, device=device)
  13. img_path = '/home/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask/img_dir/val/watermelon-medium.jpg'
  14. img_bgr = cv2.imread(img_path)
  15. result = inference_model(model, img_bgr)
  16. result.keys()
  17. pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
  18. pred_mask.shape
  19. np.unique(pred_mask)
  20. plt.figure(figsize=(8, 8))
  21. plt.imshow(pred_mask)
  22. plt.savefig('outputs/K1-1.jpg')
  23. plt.show()

2.视频流推理

  1. import time
  2. import numpy as np
  3. from tqdm import tqdm
  4. import cv2
  5. import moviepy.editor as mp
  6. import mmcv
  7. from mmseg.apis import init_model, inference_model
  8. def init():
  9. config_file ='/home/mmsegmentation/Zihao-Configs/WatermelonDataset_FastSCNN.py'
  10. checkpoint_file ='/home/mmsegmentation/checkpoint/WatermelonDataset_FastSCNN.pth'
  11. # 计算硬件
  12. # device = 'cpu'
  13. device = 'cuda:0'
  14. global model
  15. model = init_model(config_file, checkpoint_file, device=device)
  16. palette = [
  17. ['background', [127, 127, 127]],
  18. ['red', [200, 0, 0]],
  19. ['green', [0, 200, 0]],
  20. ['white', [144, 238, 144]],
  21. ['seed-black', [30, 30, 30]],
  22. ['seed-white', [251, 189, 8]]
  23. ]
  24. global palette_dict
  25. palette_dict = {}
  26. for idx, each in enumerate(palette):
  27. palette_dict[idx] = each[1]
  28. global opacity
  29. opacity = 0.4 # 透明度,越大越接近原图
  30. def process_frame(img_bgr):
  31. # 语义分割预测
  32. result = inference_model(model, img_bgr)
  33. pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
  34. # 将预测的整数ID,映射为对应类别的颜色
  35. pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
  36. for idx in palette_dict.keys():
  37. pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
  38. pred_mask_bgr = pred_mask_bgr.astype('uint8')
  39. # 将语义分割预测图和原图叠加显示
  40. pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
  41. return pred_viz
  42. def generate_video(input_path='videos/robot.mp4'):
  43. filehead = input_path.split('/')[-1]
  44. # print("filehead", filehead)
  45. output_path = "/home/Video/watermelon/out-" + filehead
  46. print('视频开始处理', input_path)
  47. # 获取视频总帧数
  48. cap = cv2.VideoCapture(input_path)
  49. frame_count = 0
  50. while (cap.isOpened()):
  51. success, frame = cap.read()
  52. frame_count += 1
  53. if not success:
  54. break
  55. cap.release()
  56. print('视频总帧数为', frame_count)
  57. # cv2.namedWindow('Crack Detection and Measurement Video Processing')
  58. cap = cv2.VideoCapture(input_path)
  59. frame_size = (cap.get(cv2.CAP_PROP_FRAME_WIDTH), cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  60. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  61. fps = cap.get(cv2.CAP_PROP_FPS)
  62. out = cv2.VideoWriter(output_path, fourcc, fps, (int(frame_size[0]), int(frame_size[1])))
  63. # 进度条绑定视频总帧数
  64. with tqdm(total=frame_count - 1) as pbar:
  65. try:
  66. while (cap.isOpened()):
  67. success, frame = cap.read()
  68. if not success:
  69. break
  70. # 处理帧
  71. # frame_path = './temp_frame.png'
  72. # cv2.imwrite(frame_path, frame)
  73. try:
  74. frame = process_frame(frame)
  75. except:
  76. print('报错!')
  77. pass
  78. if success == True:
  79. # cv2.imshow('Video Processing', frame)
  80. out.write(frame)
  81. # 进度条更新一帧
  82. pbar.update(1)
  83. # if cv2.waitKey(1) & 0xFF == ord('q'):
  84. # break
  85. except:
  86. print('中途中断')
  87. pass
  88. cv2.destroyAllWindows()
  89. out.release()
  90. cap.release()
  91. print('视频已保存', output_path)
  92. def main():
  93. init()
  94. generate_video(input_path='/home/Video/watermelon_seg.mp4')
  95. if __name__ == "__main__":
  96. main()

 3.摄像头推理预测

(此处我增加边缘检测,可以提取轮廓)

  1. import time
  2. import numpy as np
  3. import cv2
  4. import os
  5. import matplotlib.pyplot as plt
  6. import mmcv
  7. from mmseg.apis import init_model, inference_model
  8. import serial
  9. import time
  10. import threading
  11. # 载入训练好的模型
  12. # 模型 config 配置文件
  13. def init():
  14. config_file = '/home/mmsegmentation/Zihao-Configs/WatermelonDataset_FastSCNN.py'
  15. checkpoint_file = '/home/mmsegmentation/checkpoint/WatermelonDataset_FastSCNN.pth'
  16. # device = 'cpu'
  17. device = 'cuda:0'
  18. global model
  19. model = init_model(config_file, checkpoint_file, device=device)
  20. palette = [
  21. ['background', [127, 127, 127]],
  22. ['red', [200, 0, 0]],
  23. ['green', [0, 200, 0]],
  24. ['white', [144, 238, 144]],
  25. ['seed-black', [30, 30, 30]],
  26. ['seed-white', [251, 189, 8]]
  27. ]
  28. global palette_dict
  29. palette_dict = {}
  30. for idx, each in enumerate(palette):
  31. palette_dict[idx] = each[1]
  32. global opacity
  33. opacity = 0.4 # 透明度,越大越接近原图
  34. class Canny:
  35. def gaussian_blur(self, image, kernel_size):
  36. blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
  37. return blurred
  38. def erode(self, image, kernel_size, iterations=1):
  39. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
  40. eroded = cv2.erode(image, kernel, iterations=iterations)
  41. return eroded
  42. def dilate(self, image, kernel_size, iterations=1):
  43. kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
  44. dilated = cv2.dilate(image, kernel, iterations=iterations)
  45. return dilated
  46. def opening(self, image, kernel_size):
  47. opened = cv2.morphologyEx(image, cv2.MORPH_OPEN,
  48. cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)))
  49. return opened
  50. def closing(self, image, kernel_size):
  51. closed = cv2.morphologyEx(image, cv2.MORPH_CLOSE,
  52. cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)))
  53. return closed
  54. def canny_edge_detection(self, image, threshold1, threshold2):
  55. edges = cv2.Canny(image, threshold1, threshold2)
  56. return edges
  57. canny = Canny()
  58. def Canny_detect(seg_image):
  59. # 在进行边缘检测前,将图像通道转成RGB
  60. seg_image = cv2.cvtColor(seg_image, cv2.COLOR_BGR2RGB)
  61. blurred = canny.gaussian_blur(seg_image, 9)
  62. eroded = canny.erode(blurred, 9, 2)
  63. dilated = canny.dilate(eroded, 9, 2)
  64. opened = canny.opening(dilated, 9)
  65. closed = canny.closing(opened, 9)
  66. # Canny边缘检测
  67. edges = canny.canny_edge_detection(closed, 100, 200)
  68. return edges
  69. # 逐帧处理函数
  70. def process_frame(img_bgr):
  71. global message
  72. # 记录该帧开始处理的时间
  73. start_time = time.time()
  74. # 语义分割预测
  75. result = inference_model(model, img_bgr)
  76. # 提取了预测的语义分割掩码,并将其转换为 NumPy 数组,pred_mask 是一个二维数组,表示图像中每个像素的预测类别
  77. pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
  78. # 创建了一个与 pred_mask 相同大小的全零数组 pred_mask_bgr,用于存储彩色掩码图像
  79. pred_mask_bgr = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3))
  80. # 将预测的整数ID,映射为对应类别的颜色
  81. for idx in palette_dict.keys():
  82. # 将 pred_mask 中值为 idx 的像素索引位置,对应的 pred_mask_bgr 中的像素值设置为 palette_dict[idx],
  83. # 即根据类别标签将掩码转换为彩色图像
  84. pred_mask_bgr[np.where(pred_mask == idx)] = palette_dict[idx]
  85. # 将 pred_mask_bgr 数组的数据类型转换为无符号8位整数(uint8),以便在后续使用中正确表示图像的像素值范围
  86. # pred_mask_bgr是语义分割预测图像
  87. pred_mask_bgr = pred_mask_bgr.astype('uint8')
  88. # 把语义分割预测图像进行图像滤波处理、边缘检测,canny边缘检测后图像变成了二值化图像
  89. canny_viz = Canny_detect(pred_mask_bgr)
  90. # 将语义分割预测图和原图叠加显示
  91. pred_viz = cv2.addWeighted(img_bgr, opacity, pred_mask_bgr, 1 - opacity, 0)
  92. # 调整尺寸,确保原图和处理后的图像具有相同的尺寸
  93. canny_viz = cv2.resize(canny_viz, (pred_viz.shape[1], pred_viz.shape[0]))
  94. # 转换颜色空间,确保原图和处理后的图像具有相同的通道数
  95. canny_viz = cv2.cvtColor(canny_viz, cv2.COLOR_GRAY2RGB)
  96. # 合并语义分割图像和canny边缘检测图像,横向显示
  97. merged_image_horizontal = cv2.hconcat([pred_viz, canny_viz])
  98. end_time = time.time()
  99. FPS = 1 / (end_time - start_time)
  100. # 在画面上写字:图片,字符串,左上角坐标,字体,字体大小,颜色,字体粗细
  101. scaler = 1 # 文字大小
  102. FPS_string = 'FPS {:.2f}'.format(FPS)
  103. img_bgr = cv2.putText(merged_image_horizontal, FPS_string, (10 * scaler, 20 * scaler), cv2.FONT_HERSHEY_SIMPLEX, 0.75 * scaler,(255, 0, 255), 2 * scaler)
  104. return img_bgr
  105. def main():
  106. init()
  107. # 获取摄像头,传入0表示获取系统默认摄像头
  108. cap = cv2.VideoCapture(0)
  109. try:
  110. # 无限循环,直到break被触发
  111. while cap.isOpened():
  112. # 获取画面
  113. success, frame = cap.read()
  114. if not success: # 如果获取画面不成功,则退出
  115. print('获取画面不成功,退出')
  116. break
  117. frame = process_frame(frame)
  118. cv2.namedWindow('my_window', cv2.WINDOW_NORMAL)
  119. cv2.resizeWindow('my_window', int(frame.shape[1] * 1.4), int(frame.shape[0] * 1.4))
  120. cv2.imshow('my_window', frame)
  121. key_pressed = cv2.waitKey(60) # 每隔多少毫秒毫秒,获取键盘哪个键被按下
  122. if key_pressed in [ord('q'), 27]: # 按键盘上的q或esc退出(在英文输入法下)
  123. break
  124. finally:
  125. ser.close()
  126. # 关闭摄像头
  127. cap.release()
  128. # 关闭图像窗口
  129. cv2.destroyAllWindows()
  130. if __name__ == "__main__":
  131. main()

致谢

最后感谢子濠师兄的开源作品以及 MMLab 实验室开源的框架和平台,让小白也能快速上手,感受深度学习的强大和 AI 带给人们的便利性,如果对大家有帮助,麻烦点个赞,我会不定期更新一些好的文章,与君共勉。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号