当前位置:   article > 正文

深度学习之交通标志识别_深度学习交通标志识别

深度学习交通标志识别

前言

交通标志识别是一项重要的任务。YOLOv8是前沿的目标检测技术,它基于先前 YOLO 版本在目标检测任务上的成功,进一步提升性能和灵活性。

我们将使用YOLOv8训练中国交通标志数据集,完成一个多目标检测实战项目。可实时检测图像标志,并提供可视化演示界面 。

YOLOv8具体改进如下:

  1. Backbone:使用的依旧是CSP的思想,不过YOLOv5中的C3模块被替换成了C2f模块,实现了进一步的轻量化,同时YOLOv8依旧使用了YOLOv5等架构中使用的SPPF模块;
  2. PAN-FPN:毫无疑问YOLOv8依旧使用了PAN的思想,不过通过对比YOLOv5与YOLOv8的结构图可以看到,YOLOv8将YOLOv5中PAN-FPN上采样阶段中的卷积结构删除了,同时也将C3模块替换为了C2f模块;
  3. Decoupled-Head:是不是嗅到了不一样的味道?是的,YOLOv8走向了Decoupled-Head;
  4. Anchor-Free:YOLOv8抛弃了以往的Anchor-Base,使用了Anchor-Free的思想;
  5. 损失函数:YOLOv8使用VFL Loss作为分类损失,使用DFL Loss+CIOU Loss作为分类损失;
  6. 样本匹配:YOLOv8抛弃了以往的IOU匹配或者单边比例的分配方式,而是使用了Task-Aligned Assigner匹配方式

数据集来源:https://github.com/csust7zhangjm/CCTSDB

在这里插入图片描述

加载模型

from ultralytics import YOLO

# model = YOLO('yolov8n.yaml')  # 从YAML创建一个新模型
model = YOLO('weights/yolov8s.pt')  # 加载预训练模型(推荐用于训练)
# model = YOLO('yolov8s.yaml').load('weights/yolov8s.pt')  # 从YAML构建并转移权重
# 训练模型
model.train(data="../datasets-train/sign.yaml",
            imgsz=640,     # 输入图像的大小为整数或 w,h
            epochs=50,    # 要训练的次数
            batch=16,      # 每批次的图像数量(AutoBatch-1)
            device=0,      # 要运行的设备,即 cuda device=0 或 device=0,1,2,3 或 device=cpu
            workers=0,     # 用于数据加载的工作线程数(如果是 DDP,则为每个 RANK)
            resume=False)  # True的时候则从上一个检查点恢复训练
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

主窗口向yolo实例发送执行信号

class MainWindow(QMainWindow, Ui_MainWindow):
    main2yolo_begin_sgl = Signal()  

    def __init__(self, parent=None):
        super(MainWindow, self).__init__()

        self.setupUi(self)
        self.setAttribute(Qt.WA_TranslucentBackground)
        self.setWindowFlags(Qt.FramelessWindowHint)
        UIFuncitons.uiDefinitions(self)

        UIFuncitons.shadow_style(self, self.Class_QF, QColor(0, 205, 102))
        UIFuncitons.shadow_style(self, self.Target_QF, QColor(123, 104, 238))
        UIFuncitons.shadow_style(self, self.Fps_QF, QColor(0, 205, 102))
        UIFuncitons.shadow_style(self, self.Model_QF, QColor(123, 104, 238))

        self.model_box.clear()
        self.pt_list = os.listdir('./models')
        self.pt_list = [file for file in self.pt_list if file.endswith('.pt') or file.endswith('.engine')]
        self.pt_list.sort(key=lambda x: os.path.getsize('./models/' + x))  # 按文件大小排序
        self.model_box.clear()
        self.model_box.addItems(self.pt_list)
        self.Qtimer_ModelBox = QTimer(self)  # 计时器:每2秒监视模型文件更改一次
        self.Qtimer_ModelBox.timeout.connect(self.ModelBoxRefre)
        self.Qtimer_ModelBox.start(2000)

        # Yolo-v8 thread
        self.yolo_predict = YoloPredictor()  # 实例化yolo检测
        self.select_model = self.model_box.currentText()
        self.yolo_predict.new_model_name = "./models/%s" % self.select_model
        self.yolo_thread = QThread()
        self.yolo_predict.yolo2main_trail_img.connect(lambda x: self.show_image(x, self.pre_video))
        self.yolo_predict.yolo2main_box_img.connect(lambda x: self.show_image(x, self.res_video))
        self.yolo_predict.yolo2main_status_msg.connect(lambda x: self.show_status(x))
        self.yolo_predict.yolo2main_fps.connect(lambda x: self.fps_label.setText(x))

        self.yolo_predict.yolo2main_class_num.connect(lambda x: self.Class_num.setText(str(x)))
        self.yolo_predict.yolo2main_target_num.connect(lambda x: self.Target_num.setText(str(x)))
        self.yolo_predict.yolo2main_progress.connect(lambda x: self.progress_bar.setValue(x))
        self.main2yolo_begin_sgl.connect(self.yolo_predict.run)
        self.yolo_predict.moveToThread(self.yolo_thread)

        # 模型参数
        self.model_box.currentTextChanged.connect(self.change_model)
        self.iou_spinbox.valueChanged.connect(lambda x: self.change_val(x, 'iou_spinbox'))  # iou box
        self.iou_slider.valueChanged.connect(lambda x: self.change_val(x, 'iou_slider'))  # iou scroll bar
        self.conf_spinbox.valueChanged.connect(lambda x: self.change_val(x, 'conf_spinbox'))  # conf box
        self.conf_slider.valueChanged.connect(lambda x: self.change_val(x, 'conf_slider'))  # conf scroll bar
        self.speed_spinbox.valueChanged.connect(lambda x: self.change_val(x, 'speed_spinbox'))  # speed box
        self.speed_slider.valueChanged.connect(lambda x: self.change_val(x, 'speed_slider'))  # speed scroll bar

        self.Class_num.setText('--')
        self.Target_num.setText('--')
        self.fps_label.setText('--')
        self.Model_name.setText(self.select_model)

        self.src_file_button.clicked.connect(self.open_src_file)
        self.src_cam_button.clicked.connect(self.camera_select)
        self.src_rtsp_button.clicked.connect(self.rtsp_seletction)

        self.run_button.clicked.connect(self.run_or_continue)
        self.stop_button.clicked.connect(self.stop)

        self.save_res_button.toggled.connect(self.is_save_res)
        self.save_txt_button.toggled.connect(self.is_save_txt)
        self.ToggleBotton.clicked.connect(lambda: UIFuncitons.toggleMenu(self, True))
        self.settings_button.clicked.connect(lambda: UIFuncitons.settingBox(self, True))

        self.load_config()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

主窗口显示轨迹图像和检测图像

常见的两阶段检测首先是使用候选区域生成器生成的候选区集合,并从每个候选区中提取特征,然后使用区域分类器预测候选区域的类别。而YOLO作为单阶段检测器,则不用生成候选区域,直接对特征图的每个位置上的对象进行分类预测,效率更高。

    @staticmethod
    def show_image(img_src, label):
        try:
            if len(img_src.shape) == 3:
                ih, iw, _ = img_src.shape
            if len(img_src.shape) == 2:
                ih, iw = img_src.shape
            w = label.geometry().width()
            h = label.geometry().height()

            if iw / w > ih / h:
                scal = w / iw
                nw = w
                nh = int(scal * ih)
                img_src_ = cv2.resize(img_src, (nw, nh))

            else:
                scal = h / ih
                nw = int(scal * iw)
                nh = h
                img_src_ = cv2.resize(img_src, (nw, nh))

            frame = cv2.cvtColor(img_src_, cv2.COLOR_BGR2RGB)
            img = QImage(frame.data, frame.shape[1], frame.shape[0], frame.shape[2] * frame.shape[1],
                         QImage.Format_RGB888)
            label.setPixmap(QPixmap.fromImage(img))

        except Exception as e:
            print(repr(e))

    def set_lock_id(self, lock_id):
        self.yolo_predict.lock_id = None
        self.yolo_predict.lock_id = lock_id
        new_config = {"id": lock_id}
        new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
        with open('config/id.json', 'w', encoding='utf-8') as f:
            f.write(new_json)
        self.show_status('加载ID:{}'.format(lock_id))
        self.id_window.close()

    # 控制开始|暂停
    def run_or_continue(self):
        if self.yolo_predict.source == '' or self.yolo_predict.source == None:
            self.show_status('请在检测前选择输入源...')
            self.run_button.setChecked(False)
        else:
            self.yolo_predict.stop_dtc = False
            if self.run_button.isChecked():
                self.run_button.setChecked(True)
                self.save_txt_button.setEnabled(False)
                self.save_res_button.setEnabled(False)
                self.conf_slider.setEnabled(False)
                self.iou_slider.setEnabled(False)
                self.speed_slider.setEnabled(False)

                self.show_status('检测中...')
                if '0' in self.yolo_predict.source or 'rtsp' in self.yolo_predict.source:
                    self.progress_bar.setFormat('实时视频流检测中...')
                if 'avi' in self.yolo_predict.source or 'mp4' in self.yolo_predict.source:
                    self.progress_bar.setFormat("当前检测进度:%p%")
                self.yolo_predict.continue_dtc = True
                if not self.yolo_thread.isRunning():
                    self.yolo_thread.start()
                    self.main2yolo_begin_sgl.emit()
            else:
                self.yolo_predict.continue_dtc = False
                self.show_status("暂停...")
                self.run_button.setChecked(False)

    def show_status(self, msg):
        self.status_bar.setText(msg)
        if msg == 'Detection completed' or msg == '检测完成':
            self.save_res_button.setEnabled(True)
            self.save_txt_button.setEnabled(True)
            self.run_button.setChecked(False)
            self.progress_bar.setValue(0)
            if self.yolo_thread.isRunning():
                self.yolo_thread.quit()  # 终止线程
        elif msg == 'Detection terminated!' or msg == '检测终止':
            self.save_res_button.setEnabled(True)
            self.save_txt_button.setEnabled(True)
            self.run_button.setChecked(False)
            self.progress_bar.setValue(0)
            if self.yolo_thread.isRunning():
                self.yolo_thread.quit()  # 终止线程
            self.pre_video.clear()
            self.res_video.clear()
            self.Class_num.setText('--')
            self.Target_num.setText('--')
            self.fps_label.setText('--')

    def open_src_file(self):
        config_file = 'config/fold.json'
        config = json.load(open(config_file, 'r', encoding='utf-8'))
        open_fold = config['open_fold']
        if not os.path.exists(open_fold):
            open_fold = os.getcwd()
        name, _ = QFileDialog.getOpenFileName(self, 'Video/image', open_fold,
                                              "Pic File(*.mp4 *.mkv *.avi *.flv *.jpg *.png)")
        if name:
            self.yolo_predict.source = name
            self.show_status('加载文件:{}'.format(os.path.basename(name)))
            config['open_fold'] = os.path.dirname(name)
            config_json = json.dumps(config, ensure_ascii=False, indent=2)
            with open(config_file, 'w', encoding='utf-8') as f:
                f.write(config_json)
            self.stop()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

选择rtsp

RTSP(Real-Time Stream Protocol)协议是一个基于文本的多媒体播放控制协议,属于应用层。RTSP以客户端方式工作,对流媒体提供播放、暂停、后退、前进等操作

    def rtsp_seletction(self):
        self.rtsp_window = Window()
        config_file = 'config/ip.json'
        if not os.path.exists(config_file):
            ip = "rtsp://admin:admin@10.98.43.107:8554/live"
            new_config = {"ip": ip}
            new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
            with open(config_file, 'w', encoding='utf-8') as f:
                f.write(new_json)
        else:
            config = json.load(open(config_file, 'r', encoding='utf-8'))
            ip = config['ip']
        self.rtsp_window.rtspEdit.setText(ip)
        self.rtsp_window.show()
        self.rtsp_window.rtspButton.clicked.connect(lambda: self.load_rtsp(self.rtsp_window.rtspEdit.text()))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

加载RTSP

    def load_rtsp(self, ip):
        # try:
        self.stop()
        MessageBox(
            self.close_button, title='提示', text='加载 rtsp...', time=1000, auto=True).exec()
        self.yolo_predict.source = ip
        new_config = {"ip": ip}
        new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
        with open('config/ip.json', 'w', encoding='utf-8') as f:
            f.write(new_json)
        self.show_status('加载rtsp地址:{}'.format(ip))
        self.rtsp_window.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

保存提示和JSON配置文件初始化

# 保存提示(txt)
    def is_save_txt(self):
        if self.save_txt_button.checkState() == Qt.CheckState.Unchecked:
            self.show_status('提示: 标签信息不会被保存')
            self.yolo_predict.save_txt = False
        elif self.save_txt_button.checkState() == Qt.CheckState.Checked:
            self.show_status('提示: 标签信息将会被保存')
            self.yolo_predict.save_txt = True

    # JSON配置文件初始化
    def load_config(self):
        config_file = 'config/setting.json'
        if not os.path.exists(config_file):
            iou = 0.26
            conf = 0.33
            rate = 10
            save_res = 0
            save_txt = 0
            new_config = {"iou": iou,
                          "conf": conf,
                          "rate": rate,
                          "save_res": save_res,
                          "save_txt": save_txt
                          }
            new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
            with open(config_file, 'w', encoding='utf-8') as f:
                f.write(new_json)
        else:
            config = json.load(open(config_file, 'r', encoding='utf-8'))
            if len(config) != 5:
                iou = 0.26
                conf = 0.33
                rate = 10
                save_res = 0
                save_txt = 0
            else:
                iou = config['iou']
                conf = config['conf']
                rate = config['rate']
                save_res = config['save_res']
                save_txt = config['save_txt']
        self.save_res_button.setCheckState(Qt.CheckState(save_res))
        self.yolo_predict.save_res = (False if save_res == 0 else True)
        self.save_txt_button.setCheckState(Qt.CheckState(save_txt))
        self.yolo_predict.save_txt = (False if save_txt == 0 else True)
        self.run_button.setChecked(False)
        self.show_status("欢迎使用YOLOv8目标检测系统")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

训练效果

在这里插入图片描述

停止事件(按下停止按钮)

    def stop(self):
        try:
            self.yolo_thread.quit()  # 结束线程
        except:
            pass
        self.yolo_predict.stop_dtc = True
        self.run_button.setChecked(False)  # 恢复按钮状态
        self.save_res_button.setEnabled(True)  # 把保存按钮设置为可用
        self.save_txt_button.setEnabled(True)  # 把保存按钮设置为可用
        self.iou_slider.setEnabled(True)  # 把滑块设置为可用
        self.conf_slider.setEnabled(True)  # 把滑块设置为可用
        self.speed_slider.setEnabled(True)  # 把滑块设置为可用
        self.pre_video.clear()  # 清空视频显示
        self.res_video.clear()  # 清空视频显示
        self.progress_bar.setValue(0)  # 进度条清零
        self.Class_num.setText('--')
        self.Target_num.setText('--')
        self.fps_label.setText('--')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

检测参数设置

    def change_val(self, x, flag):
        if flag == 'iou_spinbox':
            self.iou_slider.setValue(int(x * 100))
        elif flag == 'iou_slider':
            self.iou_spinbox.setValue(x / 100)
            self.show_status('IOU Threshold: %s' % str(x / 100))
            self.yolo_predict.iou_thres = x / 100
        elif flag == 'conf_spinbox':
            self.conf_slider.setValue(int(x * 100))
        elif flag == 'conf_slider':
            self.conf_spinbox.setValue(x / 100)
            self.show_status('Conf Threshold: %s' % str(x / 100))
            self.yolo_predict.conf_thres = x / 100
        elif flag == 'speed_spinbox':
            self.speed_slider.setValue(x)
        elif flag == 'speed_slider':
            self.speed_spinbox.setValue(x)
            self.show_status('Delay: %s ms' % str(x))
            self.yolo_predict.speed_thres = x  # ms
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

模型更换

    def change_model(self, x):
        self.select_model = self.model_box.currentText()
        self.yolo_predict.new_model_name = "./models/%s" % self.select_model
        self.show_status('更改模型:%s' % self.select_model)
        self.Model_name.setText(self.select_model)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

训练结果分析

confusion_matrix.png :列代表预测的类别,行代表实际的类别。其对角线上的值表示预测正确的数量比例,非对角线元素则是预测错误的部分。混淆矩阵的对角线值越高越好,这表明许多预测是正确的。

在这里插入图片描述
上图是道路破损检测训练,有图可以看出 ,分别是破损和background FP。该图在每列上进行归一化处理。则可以看出破损检测预测正确的概率为91%。

F1_curve.png:F1分数与置信度(x轴)之间的关系。F1分数是分类的一个衡量标准,是精确率和召回率的调和平均函数,介于0,1之间。越大越好。

TP:真实为真,预测为真;

FN:真实为真,预测为假;

FP:真实为假,预测为真;

TN:真实为假,预测为假;

精确率(precision)=TP/(TP+FP)

召回率(Recall)=TP/(TP+FN)

F1=2*(精确率*召回率)/(精确率+召回率)
在这里插入图片描述

循环监测文件夹的文件变化

    def ModelBoxRefre(self):
        pt_list = os.listdir('./models')
        pt_list = [file for file in pt_list if file.endswith('.pt') or file.endswith('.engine')]
        pt_list.sort(key=lambda x: os.path.getsize('./models/' + x))
        # 必须排序后再比较,否则列表会一直刷新
        if pt_list != self.pt_list:
            self.pt_list = pt_list
            self.model_box.clear()
            self.model_box.addItems(self.pt_list)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

窗口优化

 # 获取鼠标位置(用于按住标题栏拖动窗口)
    def mousePressEvent(self, event):
        p = event.globalPosition()
        globalPos = p.toPoint()
        self.dragPos = globalPos

    # 拖动窗口大小时优化调整
    def resizeEvent(self, event):
        # Update Size Grips
        UIFuncitons.resize_grips(self)

    # 退出时退出线程,保存设置
    def closeEvent(self, event):
        config_file = 'config/setting.json'
        config = dict()
        config['iou'] = self.iou_spinbox.value()
        config['conf'] = self.conf_spinbox.value()
        config['rate'] = self.speed_spinbox.value()
        config['save_res'] = (0 if self.save_res_button.checkState() == Qt.Unchecked else 2)
        config['save_txt'] = (0 if self.save_txt_button.checkState() == Qt.Unchecked else 2)
        config_json = json.dumps(config, ensure_ascii=False, indent=2)
        with open(config_file, 'w', encoding='utf-8') as f:
            f.write(config_json)

        if self.yolo_thread.isRunning():
            self.yolo_predict.stop_dtc = True
            self.yolo_thread.quit()
            MessageBox(
                self.close_button, title='Note', text='退出中,请等待...', time=2000, auto=True).exec()
            sys.exit(0)
        else:
            sys.exit(0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

最后

if __name__ == "__main__":
    if not os.path.exists('models'):
        os.mkdir('models')
    app = QApplication(sys.argv)
    Home = MainWindow()
    Home.show()
    sys.exit(app.exec())

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

图片检查结果

在这里插入图片描述
在这里插入图片描述

总结

基于深度学习的交通标志识别在近年来取得了显著的进展,本文通过YOLOV8框架来训练交通标志数据集,通过python构建一个支持图像的交通标志识别系统。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/253925
推荐阅读
相关标签
  

闽ICP备14008679号