当前位置:   article > 正文

【目标检测】YOLOv8算法实现(一):模型搭建_yolov8模型

yolov8模型

  本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。
  本篇文章在YOLOv5算法实现的基础上,进一步完成YOLOv8算法的实现。YOLOv8相比于YOLOv5,最主要的不同之处如下:

  • 模型结构:将YOLOv5中的CSP模块替换为C2f模块,将Detect(耦合头 + Anchor-based)模块替换为Detect模块(解耦头 + Anchor-free + DFL)
  • 正样本匹配:采用TaskAlignedAssigner分配策略
  • 损失计算
    • 类别损失:二值交叉熵损失
    • 位置损失:Distribution Focal Loss(DFL) + CIOU Loss
    • 置信度损失:YOLOv8不预测模型的目标置信度,不再使用该损失

文章地址:
YOLOv8算法实现(一):模型搭建
YOLOv8算法实现(二):正样本匹配(TaskAlignedAssigner)与损失计算

1 模型结构

  YOLOv8的模型结构如图1所示,其包含以下几个模块:

  • CBS:卷积层、批标准化(BN)和SiLU激活函数
  • C2f:多梯度融合特征提取模块
  • SPPF:快速金字塔池化特征层
  • Detect:检测头(解耦头 + Anchor-free + Distribution)

在这里插入图片描述

2 模型模块实现(common.py)

2.1 C2f模块

class Bottleneck(nn.Module):
    '''
    残差连接瓶颈层, Residual block
    '''
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, k=1):
        '''
        :param c1: 输入通道
        :param c2: 输出通道
        :param shortcut: 为True时采用残差连接
        :param g: groups 在输出通道上分组, c2 // g 分组后不同组之间的卷积核参数不同
        :param e: 中间层的通道数
        '''
        super(Bottleneck, self).__init__()
        c_ = int(c2 * e)  # 中间层的通道
        self.cv1 = Conv(c1, c_, k, 1)  # ch_in, ch_out, kereal_size, stride
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
    def forward(self, x):
        out = self.cv2(self.cv1(x))
        return x + out if self.add else out
        
class C2f(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, e=1.0, k=3) for _ in range(n))

    def forward(self, x):
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

2.2 Detect模块(解耦头 + Anchor-free + Distribution)

  YOLOv8对结果的预测有如下特征:

  • 基于不同分辨率的特征图实现对不同大小的目标预测;
  • 每张特征图以像素为单位为单位,对中心点落在该像素单位的目标进行预测,每个单位负责得到一个预测结果;

  假设特征图数量为 n l nl nl,特征图中的分辨率为 ( g r i d _ x i , g r i d _ y i ) (grid\_xi,grid\_yi) (grid_xi,grid_yi),则一张图片可得到的预测结果数量 n p np np为:
n p = ∑ i = 1 n l ( g r i d _ x i × g r i d _ y i ) np = \sum\limits_{i = 1}^{nl} {( grid\_xi \times grid\_yi} ) np=i=1nl(grid_xi×grid_yi)

  模型预测的边框信息最终表示为 ( l e f t , t o p , r i g h t , b o t t o m ) (left,top,right,bottom) (left,top,right,bottom)

  • l e f t left left:中心点距离边框左侧距离
  • t o p top top:中心点距离边框上侧距离
  • r i g h t right right:中心点距离边框右侧距离
  • b o t t o m bottom bottom:中心点距离边框下侧距离

  模型的边框信息输出形式为一序列,如图2所示。假设某目标 l e f t left left预测结果为序列 { y 0 , y 1 , y 2 , . . . , y n − 1 } , y i ⊆ [ 0 , 1.0 ] \{y_0,y_1,y_2,...,y_{n-1}\},y_i\subseteq [0,1.0] {y0,y1,y2,...,yn1},yi[0,1.0],满足:
l e f t = ∑ i = 0 n − 1 i y i left = \sum\limits_{i = 0}^{n-1} {i{y_i}} left=i=0n1iyi
在这里插入图片描述

图2 模型预测边框序列
class Detect(nn.Module):
    # YOLOv8 Detect head for detection models
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):  # detection layer
        super().__init__()
        self.nc = nc  # 类别数
        self.nl = len(ch)  # 检测层数(feature_map)
        self.reg_max = 16  # DFL channels(通过卷积实现预测序列面积的计算)
        self.no = nc + self.reg_max * 4  # 每一个预测单元点的输出通道 
        self.stride = torch.zeros(self.nl)  # strides computed during build

        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc)  # 中间层通道
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def forward(self, x):
        shape = x[0].shape  # BCHW
        for i in range(self.nl):
            # shape->(bs, 4*reg_max+num_cls, H, W)
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
        if self.training:
            return x
        elif self.shape != shape:
			# anchors:所有预测单元中心点坐标; strides:所有预测单元相对于输入图像大小的尺度
            self.anchors, self.strides = (x.transpose(0, 1) for x in self.make_anchors(x, self.stride, 0.5))
            self.shape = shape
        # [bs, no, ny, nx] -> box:[bs, 4 * reg_max, (20^2+40^2+80^2))] cls:[bs, num_cls, 20^2+40^2+80^]
        box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
        # 将预测结果(l,t,r,b)(不同特征图上)转换为(x,y,x,y)(原图绝对坐标)
        dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
        y = torch.cat((dbox, cls.sigmoid()), 1)  # shape [1, 4+num_cls, (20^2+40^2+80^2)] 4->(x,y,x,y)输入图绝对坐标
        return y, x

    def make_anchors(self, feats, strides, grid_cell_offset=0.5):
        """Generate anchors from features."""
        anchor_points, stride_tensor = [], []
        assert feats is not None
        dtype, device = feats[0].dtype, feats[0].device
        for i, stride in enumerate(strides):
            _, _, h, w = feats[i].shape  # bs, channel, h, w
            sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # x方向网格中心点
            sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # y方向网格中心点
            sy, sx = torch.meshgrid(sy, sx)
            anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))  # 所有网格中心点
            stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
        return torch.cat(anchor_points), torch.cat(stride_tensor)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

2.3 其他

其余模块的实现方式与YOLOv5中一致,具体可参考文章YOLOv5算法实现(二):模型搭建

3 模型配置文件构建(model.yaml)

  基于图1所示的模型结构和模型模块所需的参数,构建模型配置文件。其中结构解析包含四个参数[from,number,module,args]:

  • from:当前层的输入来自于哪一层
  • number:当前层数量
  • module:当前层所有模块(在common.py中实现,需与类名对应)
  • args:第一个参数为当前层输出通道数,其余参数为模块特有参数;当前层的输入通道数由“from”参数指向的层决定,在结构解析时加入该参数。
# Parameters
nc: 80  # number of classes
depth_multiple: 1.00  # 模型深度(模块个数系数)
width_multiple: 1.00  # 模型宽度(模块通道数系数)

# YOLOv8.0l backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [512, True]]
  - [-1, 1, SPPF, [512, 5]]  # 9

# YOLOv8.0l head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 17 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 20 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [512]]  # 23 (P5/32-large)

  - [[15, 18, 21], 1, Detect_v8, [nc]]  # Detect(P3, P4, P5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

4 模型搭建(yolo.py)

模型搭建的具体实现方法可见文章YOLOv5算法实现(二):模型搭建

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/177706
推荐阅读
相关标签
  

闽ICP备14008679号