当前位置:   article > 正文

【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存_vision transformer 训练自己的数据

vision transformer 训练自己的数据

一、Vision Transformer介绍

Transformer的核心是 “自注意力” 机制。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

自注意力(self-attention)相比 卷积神经网络循环神经网络 同时具有并行计算和最短的最大路径⻓度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模型 [Cheng et al., 2016,Lin et al., 2017b, Paulus et al., 2017],transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 [Vaswani et al., 2017]。尽管transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

17年发布时主要应用于不同语言之间翻译功能的实现。而在后来,有关研究发现Transformer应用于计算机视觉CV方面有着不输于卷积神经网络的强劲性能,一定程度上甚至比卷积神经网络更强。于是,初代Vision Transformer诞生了, 简称Vit。

Vision Transformer和Transformer区别是什么?用最最最简单的理解方式来看,Transformer的工作就是把一句话从一种语言翻译成另一种语言。主要是通过是将待翻译的一句话拆分为 多个单词 或者 多个模块,进行编码和解码训练,再评估那个单词对应的意思得分高就是相应的翻译结果。

而Vision Transformer则是将一个图片抽象地看做翻译中一个句子,通过图像分割将其拆分为多个模块,再进行编码和解码训练,评估中得分高的选项便是预测的结果。(纯属个人理解,如有错误,欢迎批评指正)

二、数据集

我的数据集为植物叶片病害的无标注数据集,共有三种类型。

  1. {
  2. "0": "Huanglong_disease",
  3. "1": "Magnesium_deficiency",
  4. "2": "Normal"
  5. }

其中train : val : test  =  8 : 1 : 1,种类都是三种,只是数量不一样。

  1. train
  2. ├── Huanglong_disease
  3. │ ├── 000000.jpg
  4. │ ├── 000001.jpg
  5. │ ├── 000002.jpg
  6. │ ├── .............
  7. │ ├── 000607.jpg
  8. ├── Magnesium_deficiency
  9. └── Normal

大概长这样:

三、实战代码

1.vit_model.py

  1. """
  2. original code from rwightman:
  3. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  4. """
  5. from functools import partial
  6. from collections import OrderedDict
  7. import torch
  8. import torch.nn as nn
  9. def drop_path(x, drop_prob: float = 0., training: bool = False):
  10. """
  11. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  12. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  13. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  14. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  15. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  16. 'survival rate' as the argument.
  17. """
  18. if drop_prob == 0. or not training:
  19. return x
  20. keep_prob = 1 - drop_prob
  21. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  22. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  23. random_tensor.floor_() # binarize
  24. output = x.div(keep_prob) * random_tensor
  25. return output
  26. class DropPath(nn.Module):
  27. """
  28. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  29. """
  30. def __init__(self, drop_prob=None):
  31. super(DropPath, self).__init__()
  32. self.drop_prob = drop_prob
  33. def forward(self, x):
  34. return drop_path(x, self.drop_prob, self.training)
  35. class PatchEmbed(nn.Module):
  36. """
  37. 2D Image to Patch Embedding
  38. """
  39. def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
  40. super().__init__()
  41. img_size = (img_size, img_size)
  42. patch_size = (patch_size, patch_size)
  43. self.img_size = img_size
  44. self.patch_size = patch_size
  45. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  46. self.num_patches = self.grid_size[0] * self.grid_size[1]
  47. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  48. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  49. def forward(self, x):
  50. B, C, H, W = x.shape
  51. assert H == self.img_size[0] and W == self.img_size[1], \
  52. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  53. # flatten: [B, C, H, W] -> [B, C, HW]
  54. # transpose: [B, C, HW] -> [B, HW, C]
  55. x = self.proj(x).flatten(2).transpose(1, 2)
  56. x = self.norm(x)
  57. return x
  58. class Attention(nn.Module):
  59. def __init__(self,
  60. dim, # 输入token的dim
  61. num_heads=8,
  62. qkv_bias=False,
  63. qk_scale=None,
  64. attn_drop_ratio=0.,
  65. proj_drop_ratio=0.):
  66. super(Attention, self).__init__()
  67. self.num_heads = num_heads
  68. head_dim = dim // num_heads
  69. self.scale = qk_scale or head_dim ** -0.5
  70. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  71. self.attn_drop = nn.Dropout(attn_drop_ratio)
  72. self.proj = nn.Linear(dim, dim)
  73. self.proj_drop = nn.Dropout(proj_drop_ratio)
  74. def forward(self, x):
  75. # [batch_size, num_patches + 1, total_embed_dim]
  76. B, N, C = x.shape
  77. # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
  78. # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
  79. # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  80. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  81. # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  82. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  83. # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
  84. # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
  85. attn = (q @ k.transpose(-2, -1)) * self.scale
  86. attn = attn.softmax(dim=-1)
  87. attn = self.attn_drop(attn)
  88. # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  89. # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
  90. # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
  91. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  92. x = self.proj(x)
  93. x = self.proj_drop(x)
  94. return x
  95. class Mlp(nn.Module):
  96. """
  97. MLP as used in Vision Transformer, MLP-Mixer and related networks
  98. """
  99. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  100. super().__init__()
  101. out_features = out_features or in_features
  102. hidden_features = hidden_features or in_features
  103. self.fc1 = nn.Linear(in_features, hidden_features)
  104. self.act = act_layer()
  105. self.fc2 = nn.Linear(hidden_features, out_features)
  106. self.drop = nn.Dropout(drop)
  107. def forward(self, x):
  108. x = self.fc1(x)
  109. x = self.act(x)
  110. x = self.drop(x)
  111. x = self.fc2(x)
  112. x = self.drop(x)
  113. return x
  114. class Block(nn.Module):
  115. def __init__(self,
  116. dim,
  117. num_heads,
  118. mlp_ratio=4.,
  119. qkv_bias=False,
  120. qk_scale=None,
  121. drop_ratio=0.,
  122. attn_drop_ratio=0.,
  123. drop_path_ratio=0.,
  124. act_layer=nn.GELU,
  125. norm_layer=nn.LayerNorm):
  126. super(Block, self).__init__()
  127. self.norm1 = norm_layer(dim)
  128. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  129. attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
  130. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  131. self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
  132. self.norm2 = norm_layer(dim)
  133. mlp_hidden_dim = int(dim * mlp_ratio)
  134. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  135. def forward(self, x):
  136. x = x + self.drop_path(self.attn(self.norm1(x)))
  137. x = x + self.drop_path(self.mlp(self.norm2(x)))
  138. return x
  139. class VisionTransformer(nn.Module):
  140. def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
  141. embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
  142. qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
  143. attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
  144. act_layer=None):
  145. """
  146. Args:
  147. img_size (int, tuple): input image size
  148. patch_size (int, tuple): patch size
  149. in_c (int): number of input channels
  150. num_classes (int): number of classes for classification head
  151. embed_dim (int): embedding dimension
  152. depth (int): depth of transformer
  153. num_heads (int): number of attention heads
  154. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  155. qkv_bias (bool): enable bias for qkv if True
  156. qk_scale (float): override default qk scale of head_dim ** -0.5 if set
  157. representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
  158. distilled (bool): model includes a distillation token and head as in DeiT models
  159. drop_ratio (float): dropout rate
  160. attn_drop_ratio (float): attention dropout rate
  161. drop_path_ratio (float): stochastic depth rate
  162. embed_layer (nn.Module): patch embedding layer
  163. norm_layer: (nn.Module): normalization layer
  164. """
  165. super(VisionTransformer, self).__init__()
  166. self.num_classes = num_classes
  167. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  168. self.num_tokens = 2 if distilled else 1
  169. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  170. act_layer = act_layer or nn.GELU
  171. self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
  172. num_patches = self.patch_embed.num_patches
  173. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  174. self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
  175. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
  176. self.pos_drop = nn.Dropout(p=drop_ratio)
  177. dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
  178. self.blocks = nn.Sequential(*[
  179. Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  180. drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
  181. norm_layer=norm_layer, act_layer=act_layer)
  182. for i in range(depth)
  183. ])
  184. self.norm = norm_layer(embed_dim)
  185. # Representation layer
  186. if representation_size and not distilled:
  187. self.has_logits = True
  188. self.num_features = representation_size
  189. self.pre_logits = nn.Sequential(OrderedDict([
  190. ("fc", nn.Linear(embed_dim, representation_size)),
  191. ("act", nn.Tanh())
  192. ]))
  193. else:
  194. self.has_logits = False
  195. self.pre_logits = nn.Identity()
  196. # Classifier head(s)
  197. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  198. self.head_dist = None
  199. if distilled:
  200. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  201. # Weight init
  202. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  203. if self.dist_token is not None:
  204. nn.init.trunc_normal_(self.dist_token, std=0.02)
  205. nn.init.trunc_normal_(self.cls_token, std=0.02)
  206. self.apply(_init_vit_weights)
  207. def forward_features(self, x):
  208. # [B, C, H, W] -> [B, num_patches, embed_dim]
  209. x = self.patch_embed(x) # [B, 196, 768]
  210. # [1, 1, 768] -> [B, 1, 768]
  211. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  212. if self.dist_token is None:
  213. x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
  214. else:
  215. x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
  216. x = self.pos_drop(x + self.pos_embed)
  217. x = self.blocks(x)
  218. x = self.norm(x)
  219. if self.dist_token is None:
  220. return self.pre_logits(x[:, 0])
  221. else:
  222. return x[:, 0], x[:, 1]
  223. def forward(self, x):
  224. x = self.forward_features(x)
  225. if self.head_dist is not None:
  226. x, x_dist = self.head(x[0]), self.head_dist(x[1])
  227. if self.training and not torch.jit.is_scripting():
  228. # during inference, return the average of both classifier predictions
  229. return x, x_dist
  230. else:
  231. return (x + x_dist) / 2
  232. else:
  233. x = self.head(x)
  234. return x
  235. def _init_vit_weights(m):
  236. """
  237. ViT weight initialization
  238. :param m: module
  239. """
  240. if isinstance(m, nn.Linear):
  241. nn.init.trunc_normal_(m.weight, std=.01)
  242. if m.bias is not None:
  243. nn.init.zeros_(m.bias)
  244. elif isinstance(m, nn.Conv2d):
  245. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  246. if m.bias is not None:
  247. nn.init.zeros_(m.bias)
  248. elif isinstance(m, nn.LayerNorm):
  249. nn.init.zeros_(m.bias)
  250. nn.init.ones_(m.weight)
  251. def vit_base_patch16_224(num_classes: int = 1000):
  252. """
  253. ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  254. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  255. weights ported from official Google JAX impl:
  256. 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
  257. """
  258. model = VisionTransformer(img_size=224,
  259. patch_size=16,
  260. embed_dim=768,
  261. depth=12,
  262. num_heads=12,
  263. representation_size=None,
  264. num_classes=num_classes)
  265. return model
  266. def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  267. """
  268. ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  269. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  270. weights ported from official Google JAX impl:
  271. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
  272. """
  273. model = VisionTransformer(img_size=224,
  274. patch_size=16,
  275. embed_dim=768,
  276. depth=12,
  277. num_heads=12,
  278. representation_size=768 if has_logits else None,
  279. num_classes=num_classes)
  280. return model
  281. def vit_base_patch32_224(num_classes: int = 1000):
  282. """
  283. ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  284. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  285. weights ported from official Google JAX impl:
  286. 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
  287. """
  288. model = VisionTransformer(img_size=224,
  289. patch_size=32,
  290. embed_dim=768,
  291. depth=12,
  292. num_heads=12,
  293. representation_size=None,
  294. num_classes=num_classes)
  295. return model
  296. def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  297. """
  298. ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  299. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  300. weights ported from official Google JAX impl:
  301. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
  302. """
  303. model = VisionTransformer(img_size=224,
  304. patch_size=32,
  305. embed_dim=768,
  306. depth=12,
  307. num_heads=12,
  308. representation_size=768 if has_logits else None,
  309. num_classes=num_classes)
  310. return model
  311. def vit_large_patch16_224(num_classes: int = 1000):
  312. """
  313. ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  314. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  315. weights ported from official Google JAX impl:
  316. 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
  317. """
  318. model = VisionTransformer(img_size=224,
  319. patch_size=16,
  320. embed_dim=1024,
  321. depth=24,
  322. num_heads=16,
  323. representation_size=None,
  324. num_classes=num_classes)
  325. return model
  326. def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  327. """
  328. ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  329. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  330. weights ported from official Google JAX impl:
  331. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
  332. """
  333. model = VisionTransformer(img_size=224,
  334. patch_size=16,
  335. embed_dim=1024,
  336. depth=24,
  337. num_heads=16,
  338. representation_size=1024 if has_logits else None,
  339. num_classes=num_classes)
  340. return model
  341. def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  342. """
  343. ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
  344. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  345. weights ported from official Google JAX impl:
  346. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
  347. """
  348. model = VisionTransformer(img_size=224,
  349. patch_size=32,
  350. embed_dim=1024,
  351. depth=24,
  352. num_heads=16,
  353. representation_size=1024 if has_logits else None,
  354. num_classes=num_classes)
  355. return model
  356. def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  357. """
  358. ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
  359. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  360. NOTE: converted weights not currently available, too large for github release hosting.
  361. """
  362. model = VisionTransformer(img_size=224,
  363. patch_size=14,
  364. embed_dim=1280,
  365. depth=32,
  366. num_heads=16,
  367. representation_size=1280 if has_logits else None,
  368. num_classes=num_classes)
  369. return model

2.utils.py

  1. import os
  2. import sys
  3. import json
  4. import pickle
  5. import random
  6. import torch
  7. from tqdm import tqdm
  8. import matplotlib.pyplot as plt
  9. def read_split_data(root: str, val_rate: float = 0.2):
  10. random.seed(0) # 保证随机结果可复现
  11. assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
  12. # 遍历文件夹,一个文件夹对应一个类别
  13. flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
  14. # 排序,保证顺序一致
  15. flower_class.sort()
  16. # 生成类别名称以及对应的数字索引
  17. class_indices = dict((k, v) for v, k in enumerate(flower_class))
  18. json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
  19. with open('class_indices.json', 'w') as json_file:
  20. json_file.write(json_str)
  21. train_images_path = [] # 存储训练集的所有图片路径
  22. train_images_label = [] # 存储训练集图片对应索引信息
  23. val_images_path = [] # 存储验证集的所有图片路径
  24. val_images_label = [] # 存储验证集图片对应索引信息
  25. every_class_num = [] # 存储每个类别的样本总数
  26. supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
  27. # 遍历每个文件夹下的文件
  28. for cla in flower_class:
  29. cla_path = os.path.join(root, cla)
  30. # 遍历获取supported支持的所有文件路径
  31. images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
  32. if os.path.splitext(i)[-1] in supported]
  33. # 获取该类别对应的索引
  34. image_class = class_indices[cla]
  35. # 记录该类别的样本数量
  36. every_class_num.append(len(images))
  37. # 按比例随机采样验证样本
  38. val_path = random.sample(images, k=int(len(images) * val_rate))
  39. for img_path in images:
  40. if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
  41. val_images_path.append(img_path)
  42. val_images_label.append(image_class)
  43. else: # 否则存入训练集
  44. train_images_path.append(img_path)
  45. train_images_label.append(image_class)
  46. print("{} images were found in the dataset.".format(sum(every_class_num)))
  47. print("{} images for training.".format(len(train_images_path)))
  48. print("{} images for validation.".format(len(val_images_path)))
  49. plot_image = False
  50. if plot_image:
  51. # 绘制每种类别个数柱状图
  52. plt.bar(range(len(flower_class)), every_class_num, align='center')
  53. # 将横坐标0,1,2,3,4替换为相应的类别名称
  54. plt.xticks(range(len(flower_class)), flower_class)
  55. # 在柱状图上添加数值标签
  56. for i, v in enumerate(every_class_num):
  57. plt.text(x=i, y=v + 5, s=str(v), ha='center')
  58. # 设置x坐标
  59. plt.xlabel('image class')
  60. # 设置y坐标
  61. plt.ylabel('number of images')
  62. # 设置柱状图的标题
  63. plt.title('flower class distribution')
  64. plt.show()
  65. return train_images_path, train_images_label, val_images_path, val_images_label
  66. def plot_data_loader_image(data_loader):
  67. batch_size = data_loader.batch_size
  68. plot_num = min(batch_size, 4)
  69. json_path = './class_indices.json'
  70. assert os.path.exists(json_path), json_path + " does not exist."
  71. json_file = open(json_path, 'r')
  72. class_indices = json.load(json_file)
  73. for data in data_loader:
  74. images, labels = data
  75. for i in range(plot_num):
  76. # [C, H, W] -> [H, W, C]
  77. img = images[i].numpy().transpose(1, 2, 0)
  78. # 反Normalize操作
  79. img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
  80. label = labels[i].item()
  81. plt.subplot(1, plot_num, i+1)
  82. plt.xlabel(class_indices[str(label)])
  83. plt.xticks([]) # 去掉x轴的刻度
  84. plt.yticks([]) # 去掉y轴的刻度
  85. plt.imshow(img.astype('uint8'))
  86. plt.show()
  87. def write_pickle(list_info: list, file_name: str):
  88. with open(file_name, 'wb') as f:
  89. pickle.dump(list_info, f)
  90. def read_pickle(file_name: str) -> list:
  91. with open(file_name, 'rb') as f:
  92. info_list = pickle.load(f)
  93. return info_list
  94. def train_one_epoch(model, optimizer, data_loader, device, epoch):
  95. model.train()
  96. loss_function = torch.nn.CrossEntropyLoss()
  97. accu_loss = torch.zeros(1).to(device) # 累计损失
  98. accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
  99. optimizer.zero_grad()
  100. sample_num = 0
  101. data_loader = tqdm(data_loader, file=sys.stdout)
  102. for step, data in enumerate(data_loader):
  103. images, labels = data
  104. sample_num += images.shape[0]
  105. pred = model(images.to(device))
  106. pred_classes = torch.max(pred, dim=1)[1]
  107. accu_num += torch.eq(pred_classes, labels.to(device)).sum()
  108. loss = loss_function(pred, labels.to(device))
  109. loss.backward()
  110. accu_loss += loss.detach()
  111. data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
  112. accu_loss.item() / (step + 1),
  113. accu_num.item() / sample_num)
  114. if not torch.isfinite(loss):
  115. print('WARNING: non-finite loss, ending training ', loss)
  116. sys.exit(1)
  117. optimizer.step()
  118. optimizer.zero_grad()
  119. return accu_loss.item() / (step + 1), accu_num.item() / sample_num
  120. @torch.no_grad()
  121. def evaluate(model, data_loader, device, epoch):
  122. loss_function = torch.nn.CrossEntropyLoss()
  123. model.eval()
  124. accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
  125. accu_loss = torch.zeros(1).to(device) # 累计损失
  126. sample_num = 0
  127. data_loader = tqdm(data_loader, file=sys.stdout)
  128. for step, data in enumerate(data_loader):
  129. images, labels = data
  130. sample_num += images.shape[0]
  131. pred = model(images.to(device))
  132. pred_classes = torch.max(pred, dim=1)[1]
  133. accu_num += torch.eq(pred_classes, labels.to(device)).sum()
  134. loss = loss_function(pred, labels.to(device))
  135. accu_loss += loss
  136. data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
  137. accu_loss.item() / (step + 1),
  138. accu_num.item() / sample_num)
  139. return accu_loss.item() / (step + 1), accu_num.item() / sample_num

3.my_dataset.py

  1. from PIL import Image
  2. import torch
  3. from torch.utils.data import Dataset
  4. class MyDataSet(Dataset):
  5. """自定义数据集"""
  6. def __init__(self, images_path: list, images_class: list, transform=None):
  7. self.images_path = images_path
  8. self.images_class = images_class
  9. self.transform = transform
  10. def __len__(self):
  11. return len(self.images_path)
  12. def __getitem__(self, item):
  13. img = Image.open(self.images_path[item])
  14. # RGB为彩色图片,L为灰度图片
  15. if img.mode != 'RGB':
  16. raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
  17. label = self.images_class[item]
  18. if self.transform is not None:
  19. img = self.transform(img)
  20. return img, label
  21. @staticmethod
  22. def collate_fn(batch):
  23. # 官方实现的default_collate可以参考
  24. # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
  25. images, labels = tuple(zip(*batch))
  26. images = torch.stack(images, dim=0)
  27. labels = torch.as_tensor(labels)
  28. return images, labels

4.train.py

其中若使用预训练模型需要提前下载,下载地址在 utils.py 处有标明,代码默认是使用预训练模型的。下载后,预训练模型放入项目的根目录即可。我训练的数据集种类有三种,于是我将网络的全连接层的输出改成了 3 ,各位需要依据自己数据集不同来进行调整。

若下载不方便,也可以下载我上传的资源:

vit_base_patch16_224_in21k.zip-深度学习文档类资源-CSDN下载

  1. import os
  2. import math
  3. import argparse
  4. import torch
  5. import torch.optim as optim
  6. import torch.optim.lr_scheduler as lr_scheduler
  7. from torch.utils.tensorboard import SummaryWriter
  8. from torchvision import transforms
  9. from my_dataset import MyDataSet
  10. from vit_model import vit_base_patch16_224_in21k as create_model
  11. from utils import read_split_data, train_one_epoch, evaluate
  12. import xlwt
  13. book = xlwt.Workbook(encoding='utf-8') #创建Workbook,相当于创建Excel
  14. # 创建sheet,Sheet1为表的名字,cell_overwrite_ok为是否覆盖单元格
  15. sheet1 = book.add_sheet(u'Train_data', cell_overwrite_ok=True)
  16. # 向表中添加数据
  17. sheet1.write(0, 0, 'epoch')
  18. sheet1.write(0, 1, 'Train_Loss')
  19. sheet1.write(0, 2, 'Train_Acc')
  20. sheet1.write(0, 3, 'Val_Loss')
  21. sheet1.write(0, 4, 'Val_Acc')
  22. sheet1.write(0, 5, 'lr')
  23. sheet1.write(0, 6, 'Best val Acc')
  24. def main(args):
  25. best_acc = 0
  26. device = torch.device(args.device if torch.cuda.is_available() else "cpu")
  27. if os.path.exists("./weights") is False:
  28. os.makedirs("./weights")
  29. tb_writer = SummaryWriter()
  30. train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
  31. data_transform = {
  32. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  33. transforms.RandomHorizontalFlip(),
  34. transforms.ToTensor(),
  35. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
  36. "val": transforms.Compose([transforms.Resize(256),
  37. transforms.CenterCrop(224),
  38. transforms.ToTensor(),
  39. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
  40. # 实例化训练数据集
  41. train_dataset = MyDataSet(images_path=train_images_path,
  42. images_class=train_images_label,
  43. transform=data_transform["train"])
  44. # 实例化验证数据集
  45. val_dataset = MyDataSet(images_path=val_images_path,
  46. images_class=val_images_label,
  47. transform=data_transform["val"])
  48. batch_size = args.batch_size
  49. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  50. print('Using {} dataloader workers every process'.format(nw))
  51. train_loader = torch.utils.data.DataLoader(train_dataset,
  52. batch_size=batch_size,
  53. shuffle=True,
  54. pin_memory=True,
  55. num_workers=nw,
  56. collate_fn=train_dataset.collate_fn)
  57. val_loader = torch.utils.data.DataLoader(val_dataset,
  58. batch_size=batch_size,
  59. shuffle=False,
  60. pin_memory=True,
  61. num_workers=nw,
  62. collate_fn=val_dataset.collate_fn)
  63. model = create_model(num_classes=3, has_logits=False).to(device)
  64. images = torch.zeros(1, 3, 224, 224).to(device)#要求大小与输入图片的大小一致
  65. tb_writer.add_graph(model, images, verbose=False)
  66. if args.weights != "":
  67. assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
  68. weights_dict = torch.load(args.weights, map_location=device)
  69. # 删除不需要的权重
  70. del_keys = ['head.weight', 'head.bias'] if model.has_logits \
  71. else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
  72. for k in del_keys:
  73. del weights_dict[k]
  74. print(model.load_state_dict(weights_dict, strict=False))
  75. if args.freeze_layers:
  76. for name, para in model.named_parameters():
  77. # 除head, pre_logits外,其他权重全部冻结
  78. if "head" not in name and "pre_logits" not in name:
  79. para.requires_grad_(False)
  80. else:
  81. print("training {}".format(name))
  82. pg = [p for p in model.parameters() if p.requires_grad]
  83. optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
  84. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  85. lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
  86. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  87. for epoch in range(args.epochs):
  88. sheet1.write(epoch+1, 0, epoch+1)
  89. sheet1.write(epoch + 1, 5, str(optimizer.state_dict()['param_groups'][0]['lr']))
  90. # train
  91. train_loss, train_acc = train_one_epoch(model=model,
  92. optimizer=optimizer,
  93. data_loader=train_loader,
  94. device=device,
  95. epoch=epoch)
  96. scheduler.step()
  97. sheet1.write(epoch + 1, 1, str(train_loss))
  98. sheet1.write(epoch + 1, 2, str(train_acc))
  99. # validate
  100. val_loss, val_acc = evaluate(model=model,
  101. data_loader=val_loader,
  102. device=device,
  103. epoch=epoch)
  104. sheet1.write(epoch + 1, 3, str(val_loss))
  105. sheet1.write(epoch + 1, 4, str(val_acc))
  106. tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
  107. tb_writer.add_scalar(tags[0], train_loss, epoch)
  108. tb_writer.add_scalar(tags[1], train_acc, epoch)
  109. tb_writer.add_scalar(tags[2], val_loss, epoch)
  110. tb_writer.add_scalar(tags[3], val_acc, epoch)
  111. tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
  112. if val_acc > best_acc:
  113. best_acc = val_acc
  114. torch.save(model.state_dict(), "./weights/best_model.pth")
  115. #torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
  116. sheet1.write(1, 6, str(best_acc))
  117. book.save('.\Train_data.xlsx')
  118. print("The Best Acc = : {:.4f}".format(best_acc))
  119. if __name__ == '__main__':
  120. parser = argparse.ArgumentParser()
  121. parser.add_argument('--num_classes', type=int, default=3)
  122. parser.add_argument('--epochs', type=int, default=100)
  123. parser.add_argument('--batch-size', type=int, default=8)
  124. parser.add_argument('--lr', type=float, default=0.001)
  125. parser.add_argument('--lrf', type=float, default=0.01)
  126. # 数据集所在根目录
  127. parser.add_argument('--data-path', type=str,
  128. default=r"D:\pyCharmdata\resnet50_plant_3\datasets\train")
  129. parser.add_argument('--model-name', default='', help='create model name')
  130. # 预训练权重路径,如果不想载入就设置为空字符
  131. parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',
  132. help='initial weights path')
  133. # 是否冻结权重
  134. parser.add_argument('--freeze-layers', type=bool, default=False)
  135. parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
  136. opt = parser.parse_args()
  137. main(opt)

5.predict.py

可以实现单张图片的种类预测,得分最高的便是模型预测种类。

  1. import os
  2. import json
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from vit_model import vit_base_patch16_224_in21k as create_model
  8. def main():
  9. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  10. data_transform = transforms.Compose(
  11. [transforms.Resize(256),
  12. transforms.CenterCrop(224),
  13. transforms.ToTensor(),
  14. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
  15. # load image
  16. img_path = r"D:\pyCharmdata\resnet50_plant_3\datasets\test\Huanglong_disease\000000.jpg"
  17. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  18. img = Image.open(img_path)
  19. plt.imshow(img)
  20. # [N, C, H, W]
  21. img = data_transform(img)
  22. # expand batch dimension
  23. img = torch.unsqueeze(img, dim=0)
  24. # read class_indict
  25. json_path = './class_indices.json'
  26. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  27. with open(json_path, "r") as f:
  28. class_indict = json.load(f)
  29. # create model
  30. model = create_model(num_classes=3, has_logits=False).to(device)
  31. # load model weights
  32. model_weight_path = "./weights/best_model.pth"
  33. model.load_state_dict(torch.load(model_weight_path, map_location=device))
  34. model.eval()
  35. with torch.no_grad():
  36. # predict class
  37. output = torch.squeeze(model(img.to(device))).cpu()
  38. predict = torch.softmax(output, dim=0)
  39. predict_cla = torch.argmax(predict).numpy()
  40. print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
  41. predict[predict_cla].numpy())
  42. plt.title(print_res)
  43. for i in range(len(predict)):
  44. print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
  45. predict[i].numpy()))
  46. plt.show()
  47. if __name__ == '__main__':
  48. main()

预测结果展示:

四、训练数据

在配置好环境和数据集、预训练模型的路径后,即可运行 train.py 开始训练,默认是训练100轮。

训练使用的是SGDM优化器,初始学习率为0.001,使用LambdaLR自定义学习率调整策略,导入预训练模型但不冻结网络层和参数。

 训练过程中可以在项目路径下的终端 输入:

tensorboard --logdir=runs/

进行实时监控训练进程,也可以查看 Vision Transformer 的网络可视化结构。

Vision Transformer 的网络可视化 :

我简单训练了100轮后,最高 val_acc 准确率为 0.9976。

 训练结束后,会在项目根目录生成一个Excel文件,里面记载了训练全过程的数据,你也可以在通过 Matlab 来获得高度自定义化的可视化对比图片,堪称 论文人 的福音。

我这里只展示前10轮的训练数据。

我的完整项目框架,有需要的自取:

Vit_myself.zip-深度学习文档类资源-CSDN下载

 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

如果本文对你有帮助,欢迎一键三连!!!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/359149
推荐阅读
相关标签
  

闽ICP备14008679号