当前位置:   article > 正文

【yolov8】修改训练模型记录_yolov8修改网络模型

yolov8修改网络模型

写在前面,因为yolov8的集成度相较于yolov5来说更高,封装成了新的库,使用起来就没有那么方便,所以记录一下,使用和修改的过程。
知识来源:YOLOV8训练教程
这个up主更新了很多yolo相关的修改视频,但是博客发得不全,好像比较忙吧,但是很值得学习,然后自己记录一下,方便后续的使用。
但是现在yolov8仍处于一个改动的阶段,没有稳定版本,所以可能过一段时间以后,这些修改也会发生变动,这也是up主没有大规模更新v8的视频吧。

  1. 卸载ultralytics库,不然的话,运行程序都使用的是这个库里面的文件,修改项目文件就没有作用了。
  2. yolo/engine/model.py
    载入方式
    _new:通过读取配置文件来读取模型(我们修改文件需要用到的)
    _load:通过预训练权重来导入模型

所以这里主要修改的是_new函数里面的部分代码

# 新增的导入库
import torch
from ultralytics.yolo.utils.torch_utils import intersect_dicts

# _new函数部分新增

    def _new(self, cfg: str, verbose=True):
        """
        Initializes a new model and infers the task type from the model definitions.

        Args:
            cfg (str): model configuration file
            verbose (bool): display model info on load
        """
        cfg = check_yaml(cfg)  # check YAML
        cfg_dict = yaml_load(cfg, append_filename=True)  # model dict
        self.task = guess_model_task(cfg_dict)
        self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
            self._assign_ops_from_task(self.task)
        self.model = self.ModelClass(cfg_dict, verbose=verbose)  # initialize
        self.cfg = cfg

        ## newly appended
        ckpt = torch.load('yolov8n.pt')
        csd = ckpt['model'].float().state_dict()
        csd = intersect_dicts(csd, self.model.state_dict())
        self.model.load_state_dict(csd, strict = False)
        print(f'Transferred{len(csd)}/{len(self.model.state_dict())} items')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

train函数 将append_filename = False

因为没有这个库了,所以不能用在命令行用yolo去运行了
所以 需要新建一个main.py 来运行相关的代码

from ultralytics import YOLO

# load a model
# 第一个参数一定要传配置文件,不能传pt
model = YOLO("ultralytics/models/v8/yolov8n.yaml")

model.train(**{'cfg':'ultralytics/yolo/cfg/default.yaml'}) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

运行结果如下,需要特别关注红框里面的分子和分母的数值需要一致,如果不一致的话,可能是因为用的模型不匹配
例如我尝试了用yolov5su.pt 和v5/yolov5s.yaml的文件里面的分子分母是不一样的(不知道是改版问题还是什么)
而且现在yolov8改版太大了,发现前两天和今天又是一个新的版本。现在可能已经不太适用这个方法了。
在这里插入图片描述
在63行的地方已经调用了一次,然后后面再212-214又重新调用了一次,所以把后面的注释掉
不然就会导致再一次初始化,没有达成修改的目的

        # if not overrides.get("resume"):  # manually set model only if not resuming
        #     self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
        #     self.model = self.trainer.model
        self.trainer.model = self.model
  • 1
  • 2
  • 3
  • 4

例如添加注意力机制,需要注意,有些是需要添加通道数的,有些是不需要的。
如果需要接收通道数的话,就在task.py里面导入
具个例子
导入MHSA
先在nn目录下新建MHSA.py文件夹
然后因为需要通道数n_dims
所以需要在task.py里面导入

from ultralytics.nn.MHSA import MHSA
  • 1

然后在解析模型的代码里面加入

        elif m in {MHSA}:
            args = [ch[f], *args]
  • 1
  • 2

需要注意的是 在修改yaml文件的时候
每一层会有对应的层数,比如在从backbone数下来,加多了一层,那么下面的concat之类的层数也要发生改变,随着层数走,不然网络结构会发生很大的变化,concat的地方会放错的等等。

建议先对照着结构图来理解yaml文件,然后再进行修改

无参的

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/379068?site
推荐阅读
相关标签
  

闽ICP备14008679号