当前位置:   article > 正文

YOLOv9改进策略 | 低照度图像篇 | 轻量级的低照度图像增强网络IAT改进YOLOv9暗光检测(全网独家首发)_yolov9弱光照

yolov9弱光照

  一、本文介绍

本文给大家带来的改进机制是轻量级的变换器模型:Illumination Adaptive Transformer (IAT),用于图像增强和曝光校正。其基本原理是通过分解图像信号处理器(ISP)管道到局部和全局图像组件,从而恢复在低光或过/欠曝光条件下的正常光照sRGB图像。具体来说,IAT使用注意力查询来表示和调整ISP相关参数,例如颜色校正、伽马校正。模型具有约90k参数和约0.004s的处理速度,能够在低光增强和曝光校正的基准数据集上持续实现优于最新技术(State-of-The-Art, SOTA)的性能,我们将其用于YOLOv5上来改进我们模型的暗光检测能力,同时本文的内容不影响其它的模块改进。

 欢迎大家订阅我的专栏一起学习YOLO! 

 专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

目录

 一、本文介绍

二、基本原理

2.1 IAT原理

2.2 IAT的核心模块

三、核心代码 

四、手把手教你添加IAT低照度图像增强网络

  4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四 

五、yaml文件和运行记录

5.1 yaml文件

5.2 训练过程截图 

五、本文总结


二、基本原理

论文地址:官方论文地址点击此处即可跳转

代码地址:官方代码地址点击此处即可跳转


2.1 IAT原理

本文提出了一个轻量级的变换器模型:Illumination Adaptive Transformer (IAT),用于图像增强和曝光校正。其基本原理是通过分解图像信号处理器(ISP)管道到局部和全局图像组件,从而恢复在低光或过/欠曝光条件下的正常光照sRGB图像。具体来说,IAT使用注意力查询来表示和调整ISP相关参数,例如颜色校正、伽马校正。模型具有约90k参数和约0.004s的处理速度,能够在低光增强和曝光校正的基准数据集上持续实现优于最新技术(State-of-The-Art, SOTA)的性能。

 Illumination Adaptive Transformer (IAT)的基本原理如下:

1. 轻量级变换器架构:IAT设计为一个轻量级的模型,具有大约90,000个参数,专注于图像增强和曝光校正任务。这使得它在处理速度和资源消耗上非常高效,适用于实时或资源受限的应用场景。

2. 图像信号处理器(ISP)管道分解:IAT的核心原理是模拟并改进传统的ISP管道。通过分解ISP处理过程中的局部和全局图像成分,IAT能够针对特定的光照条件调整图像的视觉表现。

3. 适应性光照调整:IAT能够根据输入图像的光照条件动态调整处理策略,有效地处理低光、过曝光和欠曝光等情况,恢复正常光照下的sRGB图像。

下面为大家展示Illumination Adaptive Transformer (IAT)的结构分为两个主要部分:局部分支和全局分支。

1. 局部分支 (Local Branch):处理图像的局部特征。这一分支通过多次使用参数增强模块(PEM)来提取局部特征,并通过卷积层来进一步处理这些特征。

2. 全局分支 (Global Branch):处理图像的全局信息。它同样包含多个PEM和卷积层,不过处理的是全局图像内容。

3. 参数生成 (黑色线条):黑色线条表示参数生成路径,即如何通过网络生成ISP管道中需要的参数,如颜色矩阵和伽马值。

4. 图像处理 (黄色线条):黄色线条表示实际的图像处理路径。图像经过局部和全局分支的处理后,获得的特征会被用于调整图像的颜色和曝光。

5. 交叉注意力 (Cross Attention):这一组件在全局分支中,负责整合局部和全局分支的信息,以更准确地调整颜色矩阵和伽马值。

6. 最终输出:处理过的图像特征通过一个重塑操作和卷积层的处理,将局部和全局的调整应用到原始输入图像上,最终输出增强后的图像。


2.2 IAT的核心模块

下面这张图为大家直观地展示了Illumination Adaptive Transformer (IAT)中的两个核心模块:像素级增强模块(Pixel-wise Enhancement Module, PEM)全局预测模块(Global Prediction Module, GPM)。

(a)像素级增强模块(PEM):
 输入: 大小为 B×C×H×W 的特征图,其中 B 表示批次大小,C 表示通道数,H×W 表示特征图的高和宽。
 流程:
    1. 通过一系列的1x1卷积层,对特征图进行逐点的线性变换,以增强或调整特定像素点的特性。
    2. 每个1x1卷积层之后,进行元素级的相乘(表示为黄色的圆圈和相乘符号)。
    3. 操作结束后,特征图被重塑成原始的 B×C×H×W 形状。

(b)全局预测模块(GPM):
 流程:
    1. 特征图首先经过一个全连接层(FC),产生 V,代表全局信息的值向量。
    2. 另一个全连接层生成K,代表键向量。
    3. KV 通过交叉注意力机制与查询Q 相结合,查询 Q 通常来自于局部特征。
    4. 结果通过重塑操作,形成颜色校正矩阵和伽马校正值。

两个模块共同工作,PEM负责增强局部特征细节,而GPM则负责生成全局调整参数,两者合作为图像增强提供更精细的控制。通过这种方法,IAT能够在处理不同光照条件下的图像时提供细腻的调整,实现出色的图像增强效果。


三、核心代码 

核心代码的使用方式看章节四!

  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from timm.models.layers import trunc_normal_, DropPath, to_2tuple
  5. __all__ = ['IAT']
  6. class query_Attention(nn.Module):
  7. def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  8. super().__init__()
  9. self.num_heads = num_heads
  10. head_dim = dim // num_heads
  11. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  12. self.scale = qk_scale or head_dim ** -0.5
  13. self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
  14. self.k = nn.Linear(dim, dim, bias=qkv_bias)
  15. self.v = nn.Linear(dim, dim, bias=qkv_bias)
  16. self.attn_drop = nn.Dropout(attn_drop)
  17. self.proj = nn.Linear(dim, dim)
  18. self.proj_drop = nn.Dropout(proj_drop)
  19. def forward(self, x):
  20. B, N, C = x.shape
  21. k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  22. v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  23. q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
  24. attn = (q @ k.transpose(-2, -1)) * self.scale
  25. attn = attn.softmax(dim=-1)
  26. attn = self.attn_drop(attn)
  27. x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
  28. x = self.proj(x)
  29. x = self.proj_drop(x)
  30. return x
  31. class query_SABlock(nn.Module):
  32. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  33. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  34. super().__init__()
  35. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  36. self.norm1 = norm_layer(dim)
  37. self.attn = query_Attention(
  38. dim,
  39. num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  40. attn_drop=attn_drop, proj_drop=drop)
  41. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  42. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  43. self.norm2 = norm_layer(dim)
  44. mlp_hidden_dim = int(dim * mlp_ratio)
  45. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  46. def forward(self, x):
  47. x = x + self.pos_embed(x)
  48. x = x.flatten(2).transpose(1, 2)
  49. x = self.drop_path(self.attn(self.norm1(x)))
  50. x = x + self.drop_path(self.mlp(self.norm2(x)))
  51. return x
  52. class conv_embedding(nn.Module):
  53. def __init__(self, in_channels, out_channels):
  54. super(conv_embedding, self).__init__()
  55. self.proj = nn.Sequential(
  56. nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  57. nn.BatchNorm2d(out_channels // 2),
  58. nn.GELU(),
  59. # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  60. # nn.BatchNorm2d(out_channels // 2),
  61. # nn.GELU(),
  62. nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
  63. nn.BatchNorm2d(out_channels),
  64. )
  65. def forward(self, x):
  66. x = self.proj(x)
  67. return x
  68. class Global_pred(nn.Module):
  69. def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
  70. super(Global_pred, self).__init__()
  71. if type == 'exp':
  72. self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
  73. else:
  74. self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)
  75. self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True) # basic color matrix
  76. # main blocks
  77. self.conv_large = conv_embedding(in_channels, out_channels)
  78. self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
  79. self.gamma_linear = nn.Linear(out_channels, 1)
  80. self.color_linear = nn.Linear(out_channels, 1)
  81. self.apply(self._init_weights)
  82. for name, p in self.named_parameters():
  83. if name == 'generator.attn.v.weight':
  84. nn.init.constant_(p, 0)
  85. def _init_weights(self, m):
  86. if isinstance(m, nn.Linear):
  87. trunc_normal_(m.weight, std=.02)
  88. if isinstance(m, nn.Linear) and m.bias is not None:
  89. nn.init.constant_(m.bias, 0)
  90. elif isinstance(m, nn.LayerNorm):
  91. nn.init.constant_(m.bias, 0)
  92. nn.init.constant_(m.weight, 1.0)
  93. def forward(self, x):
  94. #print(self.gamma_base)
  95. x = self.conv_large(x)
  96. x = self.generator(x)
  97. gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
  98. gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
  99. #print(self.gamma_base, self.gamma_linear(gamma))
  100. color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
  101. return gamma, color
  102. # ResMLP's normalization
  103. class Aff(nn.Module):
  104. def __init__(self, dim):
  105. super().__init__()
  106. # learnable
  107. self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
  108. self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
  109. def forward(self, x):
  110. x = x * self.alpha + self.beta
  111. return x
  112. # Color Normalization
  113. class Aff_channel(nn.Module):
  114. def __init__(self, dim, channel_first = True):
  115. super().__init__()
  116. # learnable
  117. self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
  118. self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
  119. self.color = nn.Parameter(torch.eye(dim))
  120. self.channel_first = channel_first
  121. def forward(self, x):
  122. if self.channel_first:
  123. x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
  124. x2 = x1 * self.alpha + self.beta
  125. else:
  126. x1 = x * self.alpha + self.beta
  127. x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
  128. return x2
  129. class Mlp(nn.Module):
  130. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  131. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  132. super().__init__()
  133. out_features = out_features or in_features
  134. hidden_features = hidden_features or in_features
  135. self.fc1 = nn.Linear(in_features, hidden_features)
  136. self.act = act_layer()
  137. self.fc2 = nn.Linear(hidden_features, out_features)
  138. self.drop = nn.Dropout(drop)
  139. def forward(self, x):
  140. x = self.fc1(x)
  141. x = self.act(x)
  142. x = self.drop(x)
  143. x = self.fc2(x)
  144. x = self.drop(x)
  145. return x
  146. class CMlp(nn.Module):
  147. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  148. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  149. super().__init__()
  150. out_features = out_features or in_features
  151. hidden_features = hidden_features or in_features
  152. self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  153. self.act = act_layer()
  154. self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  155. self.drop = nn.Dropout(drop)
  156. def forward(self, x):
  157. x = self.fc1(x)
  158. x = self.act(x)
  159. x = self.drop(x)
  160. x = self.fc2(x)
  161. x = self.drop(x)
  162. return x
  163. class CBlock_ln(nn.Module):
  164. def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  165. drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
  166. super().__init__()
  167. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  168. #self.norm1 = Aff_channel(dim)
  169. self.norm1 = norm_layer(dim)
  170. self.conv1 = nn.Conv2d(dim, dim, 1)
  171. self.conv2 = nn.Conv2d(dim, dim, 1)
  172. self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  173. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  174. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  175. #self.norm2 = Aff_channel(dim)
  176. self.norm2 = norm_layer(dim)
  177. mlp_hidden_dim = int(dim * mlp_ratio)
  178. self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
  179. self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
  180. self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  181. def forward(self, x):
  182. x = x + self.pos_embed(x)
  183. B, C, H, W = x.shape
  184. #print(x.shape)
  185. norm_x = x.flatten(2).transpose(1, 2)
  186. #print(norm_x.shape)
  187. norm_x = self.norm1(norm_x)
  188. norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
  189. x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
  190. norm_x = x.flatten(2).transpose(1, 2)
  191. norm_x = self.norm2(norm_x)
  192. norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
  193. x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
  194. return x
  195. def window_partition(x, window_size):
  196. """
  197. Args:
  198. x: (B, H, W, C)
  199. window_size (int): window size
  200. Returns:
  201. windows: (num_windows*B, window_size, window_size, C)
  202. """
  203. B, H, W, C = x.shape
  204. #print(x.shape)
  205. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  206. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  207. return windows
  208. def window_reverse(windows, window_size, H, W):
  209. """
  210. Args:
  211. windows: (num_windows*B, window_size, window_size, C)
  212. window_size (int): Window size
  213. H (int): Height of image
  214. W (int): Width of image
  215. Returns:
  216. x: (B, H, W, C)
  217. """
  218. B = int(windows.shape[0] / (H * W / window_size / window_size))
  219. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  220. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  221. return x
  222. class WindowAttention(nn.Module):
  223. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  224. It supports both of shifted and non-shifted window.
  225. Args:
  226. dim (int): Number of input channels.
  227. window_size (tuple[int]): The height and width of the window.
  228. num_heads (int): Number of attention heads.
  229. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  230. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  231. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  232. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  233. """
  234. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  235. super().__init__()
  236. self.dim = dim
  237. self.window_size = window_size # Wh, Ww
  238. self.num_heads = num_heads
  239. head_dim = dim // num_heads
  240. self.scale = qk_scale or head_dim ** -0.5
  241. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  242. self.attn_drop = nn.Dropout(attn_drop)
  243. self.proj = nn.Linear(dim, dim)
  244. self.proj_drop = nn.Dropout(proj_drop)
  245. self.softmax = nn.Softmax(dim=-1)
  246. def forward(self, x):
  247. B_, N, C = x.shape
  248. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  249. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  250. q = q * self.scale
  251. attn = (q @ k.transpose(-2, -1))
  252. attn = self.softmax(attn)
  253. attn = self.attn_drop(attn)
  254. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  255. x = self.proj(x)
  256. x = self.proj_drop(x)
  257. return x
  258. ## Layer_norm, Aff_norm, Aff_channel_norm
  259. class SwinTransformerBlock(nn.Module):
  260. r""" Swin Transformer Block.
  261. Args:
  262. dim (int): Number of input channels.
  263. input_resolution (tuple[int]): Input resulotion.
  264. num_heads (int): Number of attention heads.
  265. window_size (int): Window size.
  266. shift_size (int): Shift size for SW-MSA.
  267. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  268. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  269. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  270. drop (float, optional): Dropout rate. Default: 0.0
  271. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  272. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  273. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  274. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  275. """
  276. def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
  277. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  278. act_layer=nn.GELU, norm_layer=Aff_channel):
  279. super().__init__()
  280. self.dim = dim
  281. self.num_heads = num_heads
  282. self.window_size = window_size
  283. self.shift_size = shift_size
  284. self.mlp_ratio = mlp_ratio
  285. self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  286. #self.norm1 = norm_layer(dim)
  287. self.norm1 = norm_layer(dim)
  288. self.attn = WindowAttention(
  289. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  290. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  291. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  292. #self.norm2 = norm_layer(dim)
  293. self.norm2 = norm_layer(dim)
  294. mlp_hidden_dim = int(dim * mlp_ratio)
  295. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  296. def forward(self, x):
  297. x = x + self.pos_embed(x)
  298. B, C, H, W = x.shape
  299. x = x.flatten(2).transpose(1, 2)
  300. shortcut = x
  301. x = self.norm1(x)
  302. x = x.view(B, H, W, C)
  303. # cyclic shift
  304. if self.shift_size > 0:
  305. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  306. else:
  307. shifted_x = x
  308. # partition windows
  309. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  310. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  311. # W-MSA/SW-MSA
  312. attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
  313. # merge windows
  314. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  315. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  316. x = shifted_x
  317. x = x.view(B, H * W, C)
  318. # FFN
  319. x = shortcut + self.drop_path(x)
  320. x = x + self.drop_path(self.mlp(self.norm2(x)))
  321. x = x.transpose(1, 2).reshape(B, C, H, W)
  322. return x
  323. class Local_pred(nn.Module):
  324. def __init__(self, dim=16, number=4, type='ccc'):
  325. super(Local_pred, self).__init__()
  326. # initial convolution
  327. self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
  328. self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  329. # main blocks
  330. block = CBlock_ln(dim)
  331. block_t = SwinTransformerBlock(dim) # head number
  332. if type == 'ccc':
  333. # blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
  334. blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  335. blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  336. elif type == 'ttt':
  337. blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
  338. elif type == 'cct':
  339. blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
  340. # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
  341. self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
  342. self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
  343. def forward(self, img):
  344. img1 = self.relu(self.conv1(img))
  345. mul = self.mul_blocks(img1)
  346. add = self.add_blocks(img1)
  347. return mul, add
  348. # Short Cut Connection on Final Layer
  349. class Local_pred_S(nn.Module):
  350. def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
  351. super(Local_pred_S, self).__init__()
  352. # initial convolution
  353. self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
  354. self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  355. # main blocks
  356. block = CBlock_ln(dim)
  357. block_t = SwinTransformerBlock(dim) # head number
  358. if type == 'ccc':
  359. blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  360. blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
  361. elif type == 'ttt':
  362. blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
  363. elif type == 'cct':
  364. blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
  365. # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
  366. self.mul_blocks = nn.Sequential(*blocks1)
  367. self.add_blocks = nn.Sequential(*blocks2)
  368. self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
  369. self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
  370. self.apply(self._init_weights)
  371. def _init_weights(self, m):
  372. if isinstance(m, nn.Linear):
  373. trunc_normal_(m.weight, std=.02)
  374. if isinstance(m, nn.Linear) and m.bias is not None:
  375. nn.init.constant_(m.bias, 0)
  376. elif isinstance(m, nn.LayerNorm):
  377. nn.init.constant_(m.bias, 0)
  378. nn.init.constant_(m.weight, 1.0)
  379. elif isinstance(m, nn.Conv2d):
  380. fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  381. fan_out //= m.groups
  382. m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  383. if m.bias is not None:
  384. m.bias.data.zero_()
  385. def forward(self, img):
  386. img1 = self.relu(self.conv1(img))
  387. # short cut connection
  388. mul = self.mul_blocks(img1) + img1
  389. add = self.add_blocks(img1) + img1
  390. mul = self.mul_end(mul)
  391. add = self.add_end(add)
  392. return mul, add
  393. class IAT(nn.Module):
  394. def __init__(self, in_dim=3, with_global=True, type='lol'):
  395. super(IAT, self).__init__()
  396. # self.local_net = Local_pred()
  397. self.local_net = Local_pred_S(in_dim=in_dim)
  398. self.with_global = with_global
  399. if self.with_global:
  400. self.global_net = Global_pred(in_channels=in_dim, type=type)
  401. def apply_color(self, image, ccm):
  402. shape = image.shape
  403. image = image.view(-1, 3)
  404. image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
  405. image = image.view(shape)
  406. return torch.clamp(image, 1e-8, 1.0)
  407. def forward(self, img_low):
  408. # print(self.with_global)
  409. mul, add = self.local_net(img_low)
  410. img_high = (img_low.mul(mul)).add(add)
  411. if not self.with_global:
  412. return img_high
  413. else:
  414. gamma, color = self.global_net(img_low)
  415. b = img_high.shape[0]
  416. img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
  417. img_high = torch.stack(
  418. [self.apply_color(img_high[i, :, :, :], color[i, :, :]) ** gamma[i, :] for i in range(b)], dim=0)
  419. img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
  420. return img_high
  421. if __name__ == "__main__":
  422. img = torch.Tensor(1, 3, 640, 640)
  423. net = IAT()
  424. imghigh = net(img)
  425. print(imghigh.size())
  426. print('total parameters:', sum(param.numel() for param in net.parameters()))
  427. _, _, high = net(img)


四、手把手教你添加IAT低照度图像增强网络

  4.1 修改一

第一还是建立文件,我们找到如下yolov9-main/models文件夹下建立一个目录名字呢就是'modules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

 


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'yolov5-master/models/yolo.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四 

按照我的添加在parse_model里添加即可。

到此就修改完成了,大家可以复制下面的yaml文件运行。


五、yaml文件和运行记录

5.1 yaml文件

主干和Neck全部用上该卷积轻量化到机制的yaml文件。

  1. # YOLOv9
  2. # parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 1 # model depth multiple
  5. width_multiple: 1 # layer channel multiple
  6. #activation: nn.LeakyReLU(0.1)
  7. #activation: nn.ReLU()
  8. # anchors
  9. anchors: 3
  10. # YOLOv9 backbone
  11. backbone:
  12. [
  13. [-1, 1, IAT, []],
  14. [-1, 1, Silence, []],
  15. # conv down
  16. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
  17. # conv down
  18. [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
  19. # elan-1 block
  20. [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
  21. # conv down
  22. [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
  23. # elan-2 block
  24. [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
  25. # conv down
  26. [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
  27. # elan-2 block
  28. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
  29. # conv down
  30. [-1, 1, Conv, [512, 3, 2]], # 8-P5/32
  31. # elan-2 block
  32. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
  33. ]
  34. # YOLOv9 head
  35. head:
  36. [
  37. # elan-spp block
  38. [-1, 1, SPPELAN, [512, 256]], # 10
  39. # up-concat merge
  40. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  41. [[-1, 8], 1, Concat, [1]], # cat backbone P4
  42. # elan-2 block
  43. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
  44. # up-concat merge
  45. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  46. [[-1, 6], 1, Concat, [1]], # cat backbone P3
  47. # elan-2 block
  48. [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
  49. # conv-down merge
  50. [-1, 1, Conv, [256, 3, 2]],
  51. [[-1, 14], 1, Concat, [1]], # cat head P4
  52. # elan-2 block
  53. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
  54. # conv-down merge
  55. [-1, 1, Conv, [512, 3, 2]],
  56. [[-1, 11], 1, Concat, [1]], # cat head P5
  57. # elan-2 block
  58. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
  59. # routing
  60. [6, 1, CBLinear, [[256]]], # 23
  61. [8, 1, CBLinear, [[256, 512]]], # 24
  62. [10, 1, CBLinear, [[256, 512, 512]]], # 25
  63. # conv down
  64. [0, 1, Conv, [64, 3, 2]], # 26-P1/2
  65. # conv down
  66. [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
  67. # elan-1 block
  68. [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
  69. # conv down fuse
  70. [-1, 1, Conv, [256, 3, 2]], # 29-P3/8
  71. [[24, 25, 26, -1], 1, CBFuse, [[0, 0, 0]]], # 30
  72. # elan-2 block
  73. [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
  74. # conv down fuse
  75. [-1, 1, Conv, [512, 3, 2]], # 32-P4/16
  76. [[25, 26, -1], 1, CBFuse, [[1, 1]]], # 33
  77. # elan-2 block
  78. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
  79. # conv down fuse
  80. [-1, 1, Conv, [512, 3, 2]], # 35-P5/32
  81. [[26, -1], 1, CBFuse, [[2]]], # 36
  82. # elan-2 block
  83. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
  84. # detect
  85. [[32, 35, 38, 17, 20, 23], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
  86. ]


5.2 训练过程截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

 专栏地址:YOLOv9有效涨点专栏-持续复现各种顶会内容-有效涨点-全网改进最全的专栏 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/813975
推荐阅读
相关标签
  

闽ICP备14008679号