赞
踩
Transformer的核心是 “自注意力” 机制。
论文地址:https://arxiv.org/pdf/2010.11929.pdf
自注意力(self-attention)相比 卷积神经网络 和 循环神经网络 同时具有并行计算和最短的最大路径⻓度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模型 [Cheng et al., 2016,Lin et al., 2017b, Paulus et al., 2017],transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 [Vaswani et al., 2017]。尽管transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。
17年发布时主要应用于不同语言之间翻译功能的实现。而在后来,有关研究发现Transformer应用于计算机视觉CV方面有着不输于卷积神经网络的强劲性能,一定程度上甚至比卷积神经网络更强。于是,初代Vision Transformer诞生了, 简称Vit。
Vision Transformer和Transformer区别是什么?用最最最简单的理解方式来看,Transformer的工作就是把一句话从一种语言翻译成另一种语言。主要是通过是将待翻译的一句话拆分为 多个单词 或者 多个模块,进行编码和解码训练,再评估那个单词对应的意思得分高就是相应的翻译结果。
而Vision Transformer则是将一个图片抽象地看做翻译中一个句子,通过图像分割将其拆分为多个模块,再进行编码和解码训练,评估中得分高的选项便是预测的结果。(纯属个人理解,如有错误,欢迎批评指正)
我的数据集为植物叶片病害的无标注数据集,共有三种类型。
- {
- "0": "Huanglong_disease",
- "1": "Magnesium_deficiency",
- "2": "Normal"
- }
其中train : val : test = 8 : 1 : 1,种类都是三种,只是数量不一样。
- train
- ├── Huanglong_disease
- │ ├── 000000.jpg
- │ ├── 000001.jpg
- │ ├── 000002.jpg
- │ ├── .............
- │ ├── 000607.jpg
- ├── Magnesium_deficiency
- └── Normal
大概长这样:
- """
- original code from rwightman:
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- """
- from functools import partial
- from collections import OrderedDict
-
- import torch
- import torch.nn as nn
-
-
- def drop_path(x, drop_prob: float = 0., training: bool = False):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
- 'survival rate' as the argument.
- """
- if drop_prob == 0. or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
- return output
-
-
- class DropPath(nn.Module):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
-
-
- class PatchEmbed(nn.Module):
- """
- 2D Image to Patch Embedding
- """
- def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
- super().__init__()
- img_size = (img_size, img_size)
- patch_size = (patch_size, patch_size)
- self.img_size = img_size
- self.patch_size = patch_size
- self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
- self.num_patches = self.grid_size[0] * self.grid_size[1]
-
- self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
-
- def forward(self, x):
- B, C, H, W = x.shape
- assert H == self.img_size[0] and W == self.img_size[1], \
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
-
- # flatten: [B, C, H, W] -> [B, C, HW]
- # transpose: [B, C, HW] -> [B, HW, C]
- x = self.proj(x).flatten(2).transpose(1, 2)
- x = self.norm(x)
- return x
-
-
- class Attention(nn.Module):
- def __init__(self,
- dim, # 输入token的dim
- num_heads=8,
- qkv_bias=False,
- qk_scale=None,
- attn_drop_ratio=0.,
- proj_drop_ratio=0.):
- super(Attention, self).__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop_ratio)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop_ratio)
-
- def forward(self, x):
- # [batch_size, num_patches + 1, total_embed_dim]
- B, N, C = x.shape
-
- # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
- # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
- # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
- # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
- # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
- # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
- class Mlp(nn.Module):
- """
- MLP as used in Vision Transformer, MLP-Mixer and related networks
- """
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
- class Block(nn.Module):
- def __init__(self,
- dim,
- num_heads,
- mlp_ratio=4.,
- qkv_bias=False,
- qk_scale=None,
- drop_ratio=0.,
- attn_drop_ratio=0.,
- drop_path_ratio=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm):
- super(Block, self).__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
- attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
-
- def forward(self, x):
- x = x + self.drop_path(self.attn(self.norm1(x)))
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
-
-
- class VisionTransformer(nn.Module):
- def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
- embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
- qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
- attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
- act_layer=None):
- """
- Args:
- img_size (int, tuple): input image size
- patch_size (int, tuple): patch size
- in_c (int): number of input channels
- num_classes (int): number of classes for classification head
- embed_dim (int): embedding dimension
- depth (int): depth of transformer
- num_heads (int): number of attention heads
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
- qkv_bias (bool): enable bias for qkv if True
- qk_scale (float): override default qk scale of head_dim ** -0.5 if set
- representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
- distilled (bool): model includes a distillation token and head as in DeiT models
- drop_ratio (float): dropout rate
- attn_drop_ratio (float): attention dropout rate
- drop_path_ratio (float): stochastic depth rate
- embed_layer (nn.Module): patch embedding layer
- norm_layer: (nn.Module): normalization layer
- """
- super(VisionTransformer, self).__init__()
- self.num_classes = num_classes
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
- self.num_tokens = 2 if distilled else 1
- norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
- act_layer = act_layer or nn.GELU
-
- self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
- num_patches = self.patch_embed.num_patches
-
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
- self.pos_drop = nn.Dropout(p=drop_ratio)
-
- dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
- self.blocks = nn.Sequential(*[
- Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
- norm_layer=norm_layer, act_layer=act_layer)
- for i in range(depth)
- ])
- self.norm = norm_layer(embed_dim)
-
- # Representation layer
- if representation_size and not distilled:
- self.has_logits = True
- self.num_features = representation_size
- self.pre_logits = nn.Sequential(OrderedDict([
- ("fc", nn.Linear(embed_dim, representation_size)),
- ("act", nn.Tanh())
- ]))
- else:
- self.has_logits = False
- self.pre_logits = nn.Identity()
-
- # Classifier head(s)
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- self.head_dist = None
- if distilled:
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
-
- # Weight init
- nn.init.trunc_normal_(self.pos_embed, std=0.02)
- if self.dist_token is not None:
- nn.init.trunc_normal_(self.dist_token, std=0.02)
-
- nn.init.trunc_normal_(self.cls_token, std=0.02)
- self.apply(_init_vit_weights)
-
- def forward_features(self, x):
- # [B, C, H, W] -> [B, num_patches, embed_dim]
- x = self.patch_embed(x) # [B, 196, 768]
- # [1, 1, 768] -> [B, 1, 768]
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
- if self.dist_token is None:
- x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
- else:
- x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
-
- x = self.pos_drop(x + self.pos_embed)
- x = self.blocks(x)
- x = self.norm(x)
- if self.dist_token is None:
- return self.pre_logits(x[:, 0])
- else:
- return x[:, 0], x[:, 1]
-
- def forward(self, x):
- x = self.forward_features(x)
- if self.head_dist is not None:
- x, x_dist = self.head(x[0]), self.head_dist(x[1])
- if self.training and not torch.jit.is_scripting():
- # during inference, return the average of both classifier predictions
- return x, x_dist
- else:
- return (x + x_dist) / 2
- else:
- x = self.head(x)
- return x
-
-
- def _init_vit_weights(m):
- """
- ViT weight initialization
- :param m: module
- """
- if isinstance(m, nn.Linear):
- nn.init.trunc_normal_(m.weight, std=.01)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out")
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.LayerNorm):
- nn.init.zeros_(m.bias)
- nn.init.ones_(m.weight)
-
-
- def vit_base_patch16_224(num_classes: int = 1000):
- """
- ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
- """
- model = VisionTransformer(img_size=224,
- patch_size=16,
- embed_dim=768,
- depth=12,
- num_heads=12,
- representation_size=None,
- num_classes=num_classes)
- return model
-
-
- def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
- """
- ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
- """
- model = VisionTransformer(img_size=224,
- patch_size=16,
- embed_dim=768,
- depth=12,
- num_heads=12,
- representation_size=768 if has_logits else None,
- num_classes=num_classes)
- return model
-
-
- def vit_base_patch32_224(num_classes: int = 1000):
- """
- ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
- """
- model = VisionTransformer(img_size=224,
- patch_size=32,
- embed_dim=768,
- depth=12,
- num_heads=12,
- representation_size=None,
- num_classes=num_classes)
- return model
-
-
- def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
- """
- ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
- """
- model = VisionTransformer(img_size=224,
- patch_size=32,
- embed_dim=768,
- depth=12,
- num_heads=12,
- representation_size=768 if has_logits else None,
- num_classes=num_classes)
- return model
-
-
- def vit_large_patch16_224(num_classes: int = 1000):
- """
- ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
- """
- model = VisionTransformer(img_size=224,
- patch_size=16,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- representation_size=None,
- num_classes=num_classes)
- return model
-
-
- def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
- """
- ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
- """
- model = VisionTransformer(img_size=224,
- patch_size=16,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- representation_size=1024 if has_logits else None,
- num_classes=num_classes)
- return model
-
-
- def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
- """
- ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- weights ported from official Google JAX impl:
- https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
- """
- model = VisionTransformer(img_size=224,
- patch_size=32,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- representation_size=1024 if has_logits else None,
- num_classes=num_classes)
- return model
-
-
- def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
- """
- ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
- NOTE: converted weights not currently available, too large for github release hosting.
- """
- model = VisionTransformer(img_size=224,
- patch_size=14,
- embed_dim=1280,
- depth=32,
- num_heads=16,
- representation_size=1280 if has_logits else None,
- num_classes=num_classes)
- return model
- import os
- import sys
- import json
- import pickle
- import random
-
- import torch
- from tqdm import tqdm
-
- import matplotlib.pyplot as plt
-
-
- def read_split_data(root: str, val_rate: float = 0.2):
- random.seed(0) # 保证随机结果可复现
- assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
-
- # 遍历文件夹,一个文件夹对应一个类别
- flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
- flower_class.sort()
- # 生成类别名称以及对应的数字索引
- class_indices = dict((k, v) for v, k in enumerate(flower_class))
- json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
- with open('class_indices.json', 'w') as json_file:
- json_file.write(json_str)
-
- train_images_path = [] # 存储训练集的所有图片路径
- train_images_label = [] # 存储训练集图片对应索引信息
- val_images_path = [] # 存储验证集的所有图片路径
- val_images_label = [] # 存储验证集图片对应索引信息
- every_class_num = [] # 存储每个类别的样本总数
- supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
- # 遍历每个文件夹下的文件
- for cla in flower_class:
- cla_path = os.path.join(root, cla)
- # 遍历获取supported支持的所有文件路径
- images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
- if os.path.splitext(i)[-1] in supported]
- # 获取该类别对应的索引
- image_class = class_indices[cla]
- # 记录该类别的样本数量
- every_class_num.append(len(images))
- # 按比例随机采样验证样本
- val_path = random.sample(images, k=int(len(images) * val_rate))
-
- for img_path in images:
- if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
- val_images_path.append(img_path)
- val_images_label.append(image_class)
- else: # 否则存入训练集
- train_images_path.append(img_path)
- train_images_label.append(image_class)
-
- print("{} images were found in the dataset.".format(sum(every_class_num)))
- print("{} images for training.".format(len(train_images_path)))
- print("{} images for validation.".format(len(val_images_path)))
-
- plot_image = False
- if plot_image:
- # 绘制每种类别个数柱状图
- plt.bar(range(len(flower_class)), every_class_num, align='center')
- # 将横坐标0,1,2,3,4替换为相应的类别名称
- plt.xticks(range(len(flower_class)), flower_class)
- # 在柱状图上添加数值标签
- for i, v in enumerate(every_class_num):
- plt.text(x=i, y=v + 5, s=str(v), ha='center')
- # 设置x坐标
- plt.xlabel('image class')
- # 设置y坐标
- plt.ylabel('number of images')
- # 设置柱状图的标题
- plt.title('flower class distribution')
- plt.show()
-
- return train_images_path, train_images_label, val_images_path, val_images_label
-
-
- def plot_data_loader_image(data_loader):
- batch_size = data_loader.batch_size
- plot_num = min(batch_size, 4)
-
- json_path = './class_indices.json'
- assert os.path.exists(json_path), json_path + " does not exist."
- json_file = open(json_path, 'r')
- class_indices = json.load(json_file)
-
- for data in data_loader:
- images, labels = data
- for i in range(plot_num):
- # [C, H, W] -> [H, W, C]
- img = images[i].numpy().transpose(1, 2, 0)
- # 反Normalize操作
- img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
- label = labels[i].item()
- plt.subplot(1, plot_num, i+1)
- plt.xlabel(class_indices[str(label)])
- plt.xticks([]) # 去掉x轴的刻度
- plt.yticks([]) # 去掉y轴的刻度
- plt.imshow(img.astype('uint8'))
- plt.show()
-
-
- def write_pickle(list_info: list, file_name: str):
- with open(file_name, 'wb') as f:
- pickle.dump(list_info, f)
-
-
- def read_pickle(file_name: str) -> list:
- with open(file_name, 'rb') as f:
- info_list = pickle.load(f)
- return info_list
-
-
- def train_one_epoch(model, optimizer, data_loader, device, epoch):
- model.train()
- loss_function = torch.nn.CrossEntropyLoss()
- accu_loss = torch.zeros(1).to(device) # 累计损失
- accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
- optimizer.zero_grad()
-
- sample_num = 0
- data_loader = tqdm(data_loader, file=sys.stdout)
- for step, data in enumerate(data_loader):
- images, labels = data
- sample_num += images.shape[0]
-
- pred = model(images.to(device))
- pred_classes = torch.max(pred, dim=1)[1]
- accu_num += torch.eq(pred_classes, labels.to(device)).sum()
-
- loss = loss_function(pred, labels.to(device))
- loss.backward()
- accu_loss += loss.detach()
-
- data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
- accu_loss.item() / (step + 1),
- accu_num.item() / sample_num)
-
- if not torch.isfinite(loss):
- print('WARNING: non-finite loss, ending training ', loss)
- sys.exit(1)
-
- optimizer.step()
- optimizer.zero_grad()
-
- return accu_loss.item() / (step + 1), accu_num.item() / sample_num
-
-
- @torch.no_grad()
- def evaluate(model, data_loader, device, epoch):
- loss_function = torch.nn.CrossEntropyLoss()
-
- model.eval()
-
- accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
- accu_loss = torch.zeros(1).to(device) # 累计损失
-
- sample_num = 0
- data_loader = tqdm(data_loader, file=sys.stdout)
- for step, data in enumerate(data_loader):
- images, labels = data
- sample_num += images.shape[0]
-
- pred = model(images.to(device))
- pred_classes = torch.max(pred, dim=1)[1]
- accu_num += torch.eq(pred_classes, labels.to(device)).sum()
-
- loss = loss_function(pred, labels.to(device))
- accu_loss += loss
-
- data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
- accu_loss.item() / (step + 1),
- accu_num.item() / sample_num)
-
- return accu_loss.item() / (step + 1), accu_num.item() / sample_num
- from PIL import Image
- import torch
- from torch.utils.data import Dataset
-
-
- class MyDataSet(Dataset):
- """自定义数据集"""
-
- def __init__(self, images_path: list, images_class: list, transform=None):
- self.images_path = images_path
- self.images_class = images_class
- self.transform = transform
-
- def __len__(self):
- return len(self.images_path)
-
- def __getitem__(self, item):
- img = Image.open(self.images_path[item])
- # RGB为彩色图片,L为灰度图片
- if img.mode != 'RGB':
- raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
- label = self.images_class[item]
-
- if self.transform is not None:
- img = self.transform(img)
-
- return img, label
-
- @staticmethod
- def collate_fn(batch):
- # 官方实现的default_collate可以参考
- # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
- images, labels = tuple(zip(*batch))
-
- images = torch.stack(images, dim=0)
- labels = torch.as_tensor(labels)
- return images, labels
其中若使用预训练模型需要提前下载,下载地址在 utils.py 处有标明,代码默认是使用预训练模型的。下载后,预训练模型放入项目的根目录即可。我训练的数据集种类有三种,于是我将网络的全连接层的输出改成了 3 ,各位需要依据自己数据集不同来进行调整。
若下载不方便,也可以下载我上传的资源:
vit_base_patch16_224_in21k.zip-深度学习文档类资源-CSDN下载
- import os
- import math
- import argparse
-
- import torch
- import torch.optim as optim
- import torch.optim.lr_scheduler as lr_scheduler
- from torch.utils.tensorboard import SummaryWriter
- from torchvision import transforms
-
-
- from my_dataset import MyDataSet
- from vit_model import vit_base_patch16_224_in21k as create_model
- from utils import read_split_data, train_one_epoch, evaluate
-
- import xlwt
-
- book = xlwt.Workbook(encoding='utf-8') #创建Workbook,相当于创建Excel
- # 创建sheet,Sheet1为表的名字,cell_overwrite_ok为是否覆盖单元格
- sheet1 = book.add_sheet(u'Train_data', cell_overwrite_ok=True)
- # 向表中添加数据
- sheet1.write(0, 0, 'epoch')
- sheet1.write(0, 1, 'Train_Loss')
- sheet1.write(0, 2, 'Train_Acc')
- sheet1.write(0, 3, 'Val_Loss')
- sheet1.write(0, 4, 'Val_Acc')
- sheet1.write(0, 5, 'lr')
- sheet1.write(0, 6, 'Best val Acc')
-
- def main(args):
- best_acc = 0
-
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
-
- if os.path.exists("./weights") is False:
- os.makedirs("./weights")
-
- tb_writer = SummaryWriter()
-
- train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
-
- data_transform = {
- "train": transforms.Compose([transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
- "val": transforms.Compose([transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
-
- # 实例化训练数据集
- train_dataset = MyDataSet(images_path=train_images_path,
- images_class=train_images_label,
- transform=data_transform["train"])
-
- # 实例化验证数据集
- val_dataset = MyDataSet(images_path=val_images_path,
- images_class=val_images_label,
- transform=data_transform["val"])
-
- batch_size = args.batch_size
- nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
- print('Using {} dataloader workers every process'.format(nw))
- train_loader = torch.utils.data.DataLoader(train_dataset,
- batch_size=batch_size,
- shuffle=True,
- pin_memory=True,
- num_workers=nw,
- collate_fn=train_dataset.collate_fn)
-
- val_loader = torch.utils.data.DataLoader(val_dataset,
- batch_size=batch_size,
- shuffle=False,
- pin_memory=True,
- num_workers=nw,
- collate_fn=val_dataset.collate_fn)
-
- model = create_model(num_classes=3, has_logits=False).to(device)
-
- images = torch.zeros(1, 3, 224, 224).to(device)#要求大小与输入图片的大小一致
- tb_writer.add_graph(model, images, verbose=False)
-
- if args.weights != "":
- assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
- weights_dict = torch.load(args.weights, map_location=device)
- # 删除不需要的权重
- del_keys = ['head.weight', 'head.bias'] if model.has_logits \
- else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
- for k in del_keys:
- del weights_dict[k]
- print(model.load_state_dict(weights_dict, strict=False))
-
- if args.freeze_layers:
- for name, para in model.named_parameters():
- # 除head, pre_logits外,其他权重全部冻结
- if "head" not in name and "pre_logits" not in name:
- para.requires_grad_(False)
- else:
- print("training {}".format(name))
-
- pg = [p for p in model.parameters() if p.requires_grad]
-
- optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
- # Scheduler https://arxiv.org/pdf/1812.01187.pdf
- lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
-
- for epoch in range(args.epochs):
-
- sheet1.write(epoch+1, 0, epoch+1)
- sheet1.write(epoch + 1, 5, str(optimizer.state_dict()['param_groups'][0]['lr']))
-
- # train
- train_loss, train_acc = train_one_epoch(model=model,
- optimizer=optimizer,
- data_loader=train_loader,
- device=device,
- epoch=epoch)
-
- scheduler.step()
-
- sheet1.write(epoch + 1, 1, str(train_loss))
- sheet1.write(epoch + 1, 2, str(train_acc))
-
- # validate
- val_loss, val_acc = evaluate(model=model,
- data_loader=val_loader,
- device=device,
- epoch=epoch)
-
- sheet1.write(epoch + 1, 3, str(val_loss))
- sheet1.write(epoch + 1, 4, str(val_acc))
-
- tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
- tb_writer.add_scalar(tags[0], train_loss, epoch)
- tb_writer.add_scalar(tags[1], train_acc, epoch)
- tb_writer.add_scalar(tags[2], val_loss, epoch)
- tb_writer.add_scalar(tags[3], val_acc, epoch)
- tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
-
-
- if val_acc > best_acc:
- best_acc = val_acc
- torch.save(model.state_dict(), "./weights/best_model.pth")
- #torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
-
- sheet1.write(1, 6, str(best_acc))
- book.save('.\Train_data.xlsx')
- print("The Best Acc = : {:.4f}".format(best_acc))
-
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--num_classes', type=int, default=3)
- parser.add_argument('--epochs', type=int, default=100)
- parser.add_argument('--batch-size', type=int, default=8)
- parser.add_argument('--lr', type=float, default=0.001)
- parser.add_argument('--lrf', type=float, default=0.01)
-
- # 数据集所在根目录
- parser.add_argument('--data-path', type=str,
- default=r"D:\pyCharmdata\resnet50_plant_3\datasets\train")
- parser.add_argument('--model-name', default='', help='create model name')
-
- # 预训练权重路径,如果不想载入就设置为空字符
- parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',
- help='initial weights path')
-
- # 是否冻结权重
- parser.add_argument('--freeze-layers', type=bool, default=False)
- parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
-
- opt = parser.parse_args()
-
- main(opt)
可以实现单张图片的种类预测,得分最高的便是模型预测种类。
- import os
- import json
-
- import torch
- from PIL import Image
- from torchvision import transforms
- import matplotlib.pyplot as plt
-
- from vit_model import vit_base_patch16_224_in21k as create_model
-
-
- def main():
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- data_transform = transforms.Compose(
- [transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
-
- # load image
- img_path = r"D:\pyCharmdata\resnet50_plant_3\datasets\test\Huanglong_disease\000000.jpg"
- assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
- img = Image.open(img_path)
- plt.imshow(img)
- # [N, C, H, W]
- img = data_transform(img)
- # expand batch dimension
- img = torch.unsqueeze(img, dim=0)
-
- # read class_indict
- json_path = './class_indices.json'
- assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
-
- with open(json_path, "r") as f:
- class_indict = json.load(f)
-
- # create model
- model = create_model(num_classes=3, has_logits=False).to(device)
- # load model weights
- model_weight_path = "./weights/best_model.pth"
- model.load_state_dict(torch.load(model_weight_path, map_location=device))
- model.eval()
- with torch.no_grad():
- # predict class
- output = torch.squeeze(model(img.to(device))).cpu()
- predict = torch.softmax(output, dim=0)
- predict_cla = torch.argmax(predict).numpy()
-
- print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
- predict[predict_cla].numpy())
- plt.title(print_res)
- for i in range(len(predict)):
- print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- predict[i].numpy()))
- plt.show()
-
-
- if __name__ == '__main__':
- main()
预测结果展示:
在配置好环境和数据集、预训练模型的路径后,即可运行 train.py 开始训练,默认是训练100轮。
训练使用的是SGDM优化器,初始学习率为0.001,使用LambdaLR自定义学习率调整策略,导入预训练模型但不冻结网络层和参数。
训练过程中可以在项目路径下的终端 输入:
tensorboard --logdir=runs/
进行实时监控训练进程,也可以查看 Vision Transformer 的网络可视化结构。
Vision Transformer 的网络可视化 :
我简单训练了100轮后,最高 val_acc 准确率为 0.9976。
训练结束后,会在项目根目录生成一个Excel文件,里面记载了训练全过程的数据,你也可以在通过 Matlab 来获得高度自定义化的可视化对比图片,堪称 论文人 的福音。
我这里只展示前10轮的训练数据。
我的完整项目框架,有需要的自取:
Vit_myself.zip-深度学习文档类资源-CSDN下载
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。