当前位置:   article > 正文

Vision Transformer_mlp head是什么

mlp head是什么

Vision Transformer模型详解

模型由三个模块组成

Linear Projection of Flattened Patches(Embeddind层)

Transformer Encoder(编码层)

MLP Head(最终用于分类的层结构)


Embedding层结构详解

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768.

对于图像数据而言,其数据格式为[H,W,C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一个图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224*224)按照16*16大小的patch进行划分,划分后会得到(224/16)^2=196个patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patches数据shape为[16,16,3]通过映射得到一个长度为768的向量(后面都直接称为Token)。[16,16,3]->[768]

在代码实现中,直接通过一个卷积层来实现。以ViT-B/16为例,直接使用一个卷积核大小为16*16,步距为16,卷积核个数为768的卷积实现。通过卷积[224,224,3]->[14,14,768],然后把H以及W两个维度展平即可[14,14,768]->[196,768],此时正好变成一个二维矩阵,正是Transformer想要的。

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding。在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1,768],[196,768])->[197,768]。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数1D Pos. Emb.,是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197,768],那么这里的Position Embedding的shape也是[197,768]

对于Position Embedding作者也有一系列对比实验,在源码中默认使用的是1D Pos. Emb.,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.比起来没太大差别。


Transformer Encoder详解

其实就是重复堆叠Encoder Block L次,主要由以下几部分组成:

Layer Norm,这种Normalization方法主要针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm。

Multi-Head Attention

Dropout/DropPath,在原论文的代码是直接使用的Dropout层

MLP Block,如图右侧,就是全连接+GELU激活函数+Dropout组成,需要注意的是第一个全连接层会把输入节点的个数翻4倍[197,768]->[197,3072],第二个全连接层会还原回原节点个数[197,3072]->[197,768]


MLP Head详解

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197,768],输出的还是[197,768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面画由ViT模型的详细结构。这里只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197,768]中抽取出[class]token对应的[1,768]

接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K或者自己的数据集上时,只用一个Linear即可。


ViT Transformer网络结构


Hybrid模型详解

将传统CNN特征提取和Transformer进行结合。下面绘制的是以ResNet50作为特征提取器的混合模型,但这里的Resnet与之前的Resnet不同。首先这里的R50的卷积采用的StdConv2d不是传统的Conv2d,然后将所有的BatchNorm层换成了GroupNorm层。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3共重复堆叠9次。

通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14,14,1024],接着再输入patch Embedding层中,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面的ViT中讲的完全一样。

下表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比发现,在训练epoch较少时Hybrid优于ViT,但当epoch增大后ViT优于Hybrid。


ViT模型的搭建参数

在论文Table1中给出是哪个模型(base/large/huge)的参数,在源码中除了有Patch Size为16*16的还有32*32的。其中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP Size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。

ModelPatch SizeLayersHidden Size DMLP sizeHeadsParams
ViT-Base16x161276830721286M
ViT-Large16x16241024409616307M
ViT-Huge14x14321280512016632M

ViT模型

  1. """
  2. original code from rwightman:
  3. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  4. """
  5. from functools import partial
  6. from collections import OrderedDict
  7. import torch
  8. import torch.nn as nn
  9. # 随机深度
  10. def drop_path(x, drop_prob: float = 0., training: bool = False):
  11. """
  12. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  13. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  14. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  15. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  16. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  17. 'survival rate' as the argument.
  18. """
  19. if drop_prob == 0. or not training:
  20. return x
  21. keep_prob = 1 - drop_prob
  22. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  23. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  24. random_tensor.floor_() # binarize
  25. output = x.div(keep_prob) * random_tensor
  26. return output
  27. class DropPath(nn.Module):
  28. """
  29. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  30. """
  31. def __init__(self, drop_prob=None):
  32. super(DropPath, self).__init__()
  33. self.drop_prob = drop_prob
  34. def forward(self, x):
  35. return drop_path(x, self.drop_prob, self.training)
  36. class PatchEmbed(nn.Module):
  37. """
  38. 2D Image to Patch Embedding
  39. """
  40. # 将输入图像转化为self-attention的token向量格式
  41. # 输入图片224*224*3(RGB 3通道),按照16*16的patch划分
  42. # 224,224,3-->14,14,768
  43. def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
  44. super().__init__()
  45. img_size = (img_size, img_size)
  46. patch_size = (patch_size, patch_size)
  47. self.img_size = img_size
  48. self.patch_size = patch_size
  49. # 224 // 16 ,224 // 16
  50. self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
  51. # 计算patches的数目 14 * 14
  52. self.num_patches = self.grid_size[0] * self.grid_size[1]
  53. # 卷积层 conv16*16,s=16,o = [(i+2p-k)/s] + 1=[(256-16)/16]+1=14,14*14
  54. self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  55. # nn.Identity()不做任何操作
  56. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  57. # 正向传播过程:传入图片数据 VIT模型中图像输入大小是固定的,无法更改
  58. def forward(self, x):
  59. B, C, H, W = x.shape
  60. # 检查传入图片大小是否符合预先设定
  61. assert H == self.img_size[0] and W == self.img_size[1], \
  62. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  63. # flatten: [B, C, H, W] -> [B, C, HW] [B, 768, 14*14]
  64. # transpose: [B, C, HW] -> [B, HW, C] [B, 196, 768]
  65. x = self.proj(x).flatten(2).transpose(1, 2)
  66. x = self.norm(x)
  67. return x
  68. # multihead -self attention
  69. class Attention(nn.Module):
  70. def __init__(self,
  71. dim, # 输入token的dim = 768
  72. num_heads=8,
  73. qkv_bias=False,
  74. qk_scale=None, # 根号下d
  75. attn_drop_ratio=0.,
  76. proj_drop_ratio=0.):
  77. super(Attention, self).__init__()
  78. self.num_heads = num_heads
  79. # 计算每个head qkv分得的头个数
  80. head_dim = dim // num_heads
  81. # 对应根号下d
  82. self.scale = qk_scale or head_dim ** -0.5
  83. # 全连接层实现qkv,严格来说是三个分开,这里提高并行化
  84. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  85. self.attn_drop = nn.Dropout(attn_drop_ratio)
  86. self.proj = nn.Linear(dim, dim) # 要将计算完的分头全连接起来,wo映射通过linear实现
  87. self.proj_drop = nn.Dropout(proj_drop_ratio)
  88. # 经典
  89. def forward(self, x):
  90. # [batch_size, num_patches + 1, total_embed_dim]
  91. B, N, C = x.shape
  92. # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
  93. # reshape进行拆分: -> [batch_size, num_patches + 1, 3(代表qkv), num_heads, embed_dim_per_head]
  94. # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] 方便运算
  95. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  96. # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  97. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  98. # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
  99. # @:矩阵乘法
  100. # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
  101. attn = (q @ k.transpose(-2, -1)) * self.scale # scale进行norm处理
  102. # dim=-1 针对每一行进行处理
  103. attn = attn.softmax(dim=-1)
  104. attn = self.attn_drop(attn)
  105. # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
  106. # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
  107. # reshape: -> [batch_size, num_patches + 1, total_embed_dim] 把最后两个信息拼接在一起
  108. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  109. x = self.proj(x)
  110. x = self.proj_drop(x)
  111. return x
  112. class Mlp(nn.Module):
  113. """
  114. MLP as used in Vision Transformer, MLP-Mixer and related networks
  115. """
  116. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  117. super().__init__()
  118. out_features = out_features or in_features
  119. hidden_features = hidden_features or in_features
  120. self.fc1 = nn.Linear(in_features, hidden_features) # 增加通道数
  121. self.act = act_layer()
  122. self.fc2 = nn.Linear(hidden_features, out_features) # 还原通道数
  123. self.drop = nn.Dropout(drop)
  124. def forward(self, x):
  125. x = self.fc1(x)
  126. x = self.act(x)
  127. x = self.drop(x)
  128. x = self.fc2(x)
  129. x = self.drop(x)
  130. return x
  131. # transformer中就是将block重复堆叠L次
  132. class Block(nn.Module):
  133. def __init__(self,
  134. dim,
  135. num_heads,
  136. mlp_ratio=4., # 第一个全连接层是输入节点个数的4倍
  137. qkv_bias=False,
  138. qk_scale=None,
  139. drop_ratio=0.,
  140. attn_drop_ratio=0.,
  141. drop_path_ratio=0.,
  142. act_layer=nn.GELU,
  143. norm_layer=nn.LayerNorm):
  144. super(Block, self).__init__()
  145. self.norm1 = norm_layer(dim)
  146. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  147. attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
  148. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  149. self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
  150. self.norm2 = norm_layer(dim)
  151. mlp_hidden_dim = int(dim * mlp_ratio)
  152. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  153. def forward(self, x):
  154. x = x + self.drop_path(self.attn(self.norm1(x)))
  155. x = x + self.drop_path(self.mlp(self.norm2(x)))
  156. return x
  157. class VisionTransformer(nn.Module):
  158. def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
  159. embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
  160. qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
  161. attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
  162. act_layer=None):
  163. """
  164. Args:
  165. img_size (int, tuple): input image size
  166. patch_size (int, tuple): patch size
  167. in_c (int): number of input channels
  168. num_classes (int): number of classes for classification head
  169. embed_dim (int): embedding dimension
  170. depth (int): depth of transformer 重复堆叠encoder的次数
  171. num_heads (int): number of attention heads
  172. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  173. qkv_bias (bool): enable bias for qkv if True
  174. qk_scale (float): override default qk scale of head_dim ** -0.5 if set
  175. representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
  176. distilled (bool): model includes a distillation token and head as in DeiT models
  177. drop_ratio (float): dropout rate
  178. attn_drop_ratio (float): attention dropout rate
  179. drop_path_ratio (float): stochastic depth rate
  180. embed_layer (nn.Module): patch embedding layer
  181. norm_layer: (nn.Module): normalization layer
  182. """
  183. super(VisionTransformer, self).__init__()
  184. self.num_classes = num_classes
  185. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  186. self.num_tokens = 2 if distilled else 1
  187. norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  188. act_layer = act_layer or nn.GELU
  189. self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
  190. num_patches = self.patch_embed.num_patches
  191. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  192. self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
  193. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
  194. self.pos_drop = nn.Dropout(p=drop_ratio)
  195. dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
  196. self.blocks = nn.Sequential(*[
  197. Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  198. drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
  199. norm_layer=norm_layer, act_layer=act_layer)
  200. for i in range(depth)
  201. ])
  202. self.norm = norm_layer(embed_dim)
  203. # Representation layer
  204. if representation_size and not distilled:
  205. self.has_logits = True
  206. self.num_features = representation_size
  207. self.pre_logits = nn.Sequential(OrderedDict([
  208. ("fc", nn.Linear(embed_dim, representation_size)),
  209. ("act", nn.Tanh())
  210. ]))
  211. else:
  212. self.has_logits = False
  213. self.pre_logits = nn.Identity()
  214. # Classifier head(s)
  215. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  216. self.head_dist = None
  217. if distilled:
  218. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  219. # Weight init
  220. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  221. if self.dist_token is not None:
  222. nn.init.trunc_normal_(self.dist_token, std=0.02)
  223. nn.init.trunc_normal_(self.cls_token, std=0.02)
  224. self.apply(_init_vit_weights)
  225. # 正向传播过程
  226. def forward_features(self, x):
  227. # [B, C, H, W] -> [B, num_patches, embed_dim]
  228. x = self.patch_embed(x) # [B, 196, 768]
  229. # [1, 1, 768] -> [B, 1, 768]
  230. cls_token = self.cls_token.expand(x.shape[0], -1, -1)
  231. if self.dist_token is None:
  232. x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
  233. else:
  234. x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
  235. x = self.pos_drop(x + self.pos_embed)
  236. x = self.blocks(x)
  237. x = self.norm(x)
  238. if self.dist_token is None:
  239. return self.pre_logits(x[:, 0])
  240. else:
  241. return x[:, 0], x[:, 1]
  242. def forward(self, x):
  243. x = self.forward_features(x)
  244. if self.head_dist is not None:
  245. x, x_dist = self.head(x[0]), self.head_dist(x[1])
  246. if self.training and not torch.jit.is_scripting():
  247. # during inference, return the average of both classifier predictions
  248. return x, x_dist
  249. else:
  250. return (x + x_dist) / 2
  251. else:
  252. x = self.head(x)
  253. return x
  254. def _init_vit_weights(m):
  255. """
  256. ViT weight initialization
  257. :param m: module
  258. """
  259. if isinstance(m, nn.Linear):
  260. nn.init.trunc_normal_(m.weight, std=.01)
  261. if m.bias is not None:
  262. nn.init.zeros_(m.bias)
  263. elif isinstance(m, nn.Conv2d):
  264. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  265. if m.bias is not None:
  266. nn.init.zeros_(m.bias)
  267. elif isinstance(m, nn.LayerNorm):
  268. nn.init.zeros_(m.bias)
  269. nn.init.ones_(m.weight)
  270. def vit_base_patch16_224(num_classes: int = 1000):
  271. """
  272. ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  273. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  274. weights ported from official Google JAX impl:
  275. 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
  276. """
  277. model = VisionTransformer(img_size=224,
  278. patch_size=16,
  279. embed_dim=768,
  280. depth=12,
  281. num_heads=12,
  282. representation_size=None,
  283. num_classes=num_classes)
  284. return model
  285. def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  286. """
  287. ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  288. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  289. weights ported from official Google JAX impl:
  290. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
  291. """
  292. model = VisionTransformer(img_size=224,
  293. patch_size=16,
  294. embed_dim=768,
  295. depth=12,
  296. num_heads=12,
  297. representation_size=768 if has_logits else None,
  298. num_classes=num_classes)
  299. return model
  300. def vit_base_patch32_224(num_classes: int = 1000):
  301. """
  302. ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  303. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  304. weights ported from official Google JAX impl:
  305. 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
  306. """
  307. model = VisionTransformer(img_size=224,
  308. patch_size=32,
  309. embed_dim=768,
  310. depth=12,
  311. num_heads=12,
  312. representation_size=None,
  313. num_classes=num_classes)
  314. return model
  315. def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  316. """
  317. ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  318. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  319. weights ported from official Google JAX impl:
  320. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
  321. """
  322. model = VisionTransformer(img_size=224,
  323. patch_size=32,
  324. embed_dim=768,
  325. depth=12,
  326. num_heads=12,
  327. representation_size=768 if has_logits else None,
  328. num_classes=num_classes)
  329. return model
  330. def vit_large_patch16_224(num_classes: int = 1000):
  331. """
  332. ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  333. ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  334. weights ported from official Google JAX impl:
  335. 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
  336. """
  337. model = VisionTransformer(img_size=224,
  338. patch_size=16,
  339. embed_dim=1024,
  340. depth=24,
  341. num_heads=16,
  342. representation_size=None,
  343. num_classes=num_classes)
  344. return model
  345. def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  346. """
  347. ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  348. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  349. weights ported from official Google JAX impl:
  350. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
  351. """
  352. model = VisionTransformer(img_size=224,
  353. patch_size=16,
  354. embed_dim=1024,
  355. depth=24,
  356. num_heads=16,
  357. representation_size=1024 if has_logits else None,
  358. num_classes=num_classes)
  359. return model
  360. def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  361. """
  362. ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
  363. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  364. weights ported from official Google JAX impl:
  365. https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
  366. """
  367. model = VisionTransformer(img_size=224,
  368. patch_size=32,
  369. embed_dim=1024,
  370. depth=24,
  371. num_heads=16,
  372. representation_size=1024 if has_logits else None,
  373. num_classes=num_classes)
  374. return model
  375. def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
  376. """
  377. ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
  378. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
  379. NOTE: converted weights not currently available, too large for github release hosting.
  380. """
  381. model = VisionTransformer(img_size=224,
  382. patch_size=14,
  383. embed_dim=1280,
  384. depth=32,
  385. num_heads=16,
  386. representation_size=1280 if has_logits else None,
  387. num_classes=num_classes)
  388. return model

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号