当前位置:   article > 正文

基于pytorch的vit(transformer)自编代码_vit transformer代码

vit transformer代码
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2023/5/29 19:50
  4. # @Author : Denxun
  5. # @FileName: pixel_atten.py
  6. # @Software: PyCharm
  7. import torch.nn as nn
  8. import torch
  9. import random
  10. import skimage
  11. import math
  12. import os
  13. import numpy as np
  14. from torchsummary import summary
  15. seed=888
  16. random.seed(seed)
  17. os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
  18. np.random.seed(seed)
  19. torch.manual_seed(seed)
  20. torch.cuda.manual_seed(seed)
  21. torch.cuda.manual_seed_all(seed)
  22. class pixel_att(nn.Module):
  23. def __init__(self,in_channels, out_channel,imgesize,heads,ker_size,stride):
  24. super(pixel_att,self).__init__()
  25. self.pactch_conv=nn.Sequential(nn.Conv2d(in_channels,out_channel,kernel_size=ker_size,stride=stride),
  26. nn.BatchNorm2d(out_channel),nn.ReLU())#特定卷积核选取patch
  27. self.token_size=(imgesize//ker_size)**2
  28. self.layer_norm=nn.LayerNorm(self.token_size)#对隐藏层layernorm
  29. self.flatten=nn.Flatten(2)#从第二维开始展品
  30. self.class_token=nn.Parameter(torch.zeros(1,1,self.token_size),requires_grad=True)#添加分类token
  31. self.pos_embedding = nn.Parameter(torch.randn(1, out_channel + 1,self.token_size),requires_grad=True)#添加位置编码矩阵
  32. self.dropout=nn.Dropout(0.5)
  33. self.att=self_att(self.token_size,heads)
  34. self.mlp=mlp_block(self.token_size)
  35. self.layer_norm1=nn.LayerNorm(self.token_size)
  36. def forward(self,x):
  37. batch,chanel,w,h=x.size()#取batch
  38. patch_x=self.pactch_conv(x)#进行像素patch
  39. token=self.flatten(patch_x)#展平后两个维度例如输入为4,1,64,64变为4,1,4096
  40. token_layer_norm=self.layer_norm(token)#对4096进行layer_norm
  41. clas_token=torch.repeat_interleave(self.class_token,dim=0,repeats=batch)#复制变为batch,1,imgesize*imgesize
  42. token_cat_class=torch.cat([token_layer_norm,clas_token],dim=1)#添加分类token
  43. #print(token_cat_class.shape,(self.pos_embedding[:, :(chanel + 1)]).shape)
  44. token_cat_class+= self.pos_embedding#[:, :(chanel + 1)]#添加位置编码
  45. last_token=self.dropout(token_cat_class)
  46. atten_token=self.att(last_token)
  47. atten_token=atten_token+last_token#残差连接
  48. atten_tok
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/345766
推荐阅读
相关标签
  

闽ICP备14008679号