当前位置:   article > 正文

Transformer——patch embedding代码

patch embedding

简单版ViT(无attention部分)

主要记录一下Patch Embedding怎么处理和了解一下vit的简单基本框架,下一节写完整的ViT框架


图像上的Transformer怎么处理?如图
图片—>分块patch---->映射(可学习)---->特征
在这里插入图片描述
整体网络结构:
在这里插入图片描述

在这里插入图片描述

实践部分:

Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings
Patch Embedding部分代码:

class PatchEmbedding(nn.Module):
    def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):
        super(PatchEmbedding, self).__init__()
        #patch_embed相当于做了一个卷积
        self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)
        self.drop=nn.Dropout(dropout)

    def forward(self,x):
        # x[4, 3, 224, 224]
        x=self.patch_embed(x)
        # x [4, 16, 32, 32]
        # x:[n,embed_dim,h',w']
        x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]
        x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]
        x = self.drop(x)
        print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

ViT部分代码:
省略了attention部分

class Vit(nn.Module):
    def __init__(self):
        super(Vit, self).__init__()
        self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokens
        layer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16
        self.encoders=nn.Sequential(*layer_list)
        self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10
        self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均

    def forward(self,x):
        x=self.patch_embed(x)      # #x [4, 1024, 16]
        for i in self.encoders:
            x=i(x)
        # [n,h*w,c]
        x=x.permute((0,2,1))  # [4, 16, 1024]
        # [n,c,h*w]
        x=self.avg(x)  # [n,c,1]  [4, 16, 1]
        x=x.flatten(1)  # [n,c]  [4,16]
        x=self.head(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

完整代码:

from PIL import Image
import numpy as np
import torch
import torch.nn as nn

# Identity  什么都不做
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

 #在Mlp中,其实就是两层全连接层,该mlp一般接在attention层后面。首先将16的通道膨胀4倍到64,然后再缩小4倍,最终保持通道数不变。
class Mlp(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.):       #  mlp_ratio就是膨胀参数
        super(Mlp, self).__init__()
        self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))       # 膨胀
        self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)      # 尺寸变回去
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class PatchEmbedding(nn.Module):
    def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):
        super(PatchEmbedding, self).__init__()
        #patch_embed相当于做了一个卷积
        self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)
        self.drop=nn.Dropout(dropout)

    def forward(self,x):
        # x[4, 3, 224, 224]
        x=self.patch_embed(x)
        # x [4, 16, 32, 32]
        # x:[n,embed_dim,h',w']
        x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]
        x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]
        x = self.drop(x)
        print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]
        return x

class Encoder(nn.Module):
    def __init__(self,embed_dim):
        super(Encoder, self).__init__()
        self.atten = Identity()      # self-attention部分先不去实现
        self.layer_nomer = nn.LayerNorm(embed_dim)   # LN层
        self.mlp = Mlp(embed_dim)
        self.mlp_nomer = nn.LayerNorm(embed_dim)


    def forward(self,x):
        # 参差结构
        h = x
        x = self.atten(x)  # 先做self-attention
        x = self.layer_nomer(x)  # 再做LN层
        x = h+x

        h = x
        x = self.mlp(x)  #先做FC层
        x = self.layer_nomer(x)  # 再做LN层
        x = h + x

        return x



class Vit(nn.Module):
    def __init__(self):
        super(Vit, self).__init__()
        self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokens
        layer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16
        self.encoders=nn.Sequential(*layer_list)
        self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10
        self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均

    def forward(self,x):
        x=self.patch_embed(x)      # #x [4, 1024, 16]
        for i in self.encoders:
            x=i(x)
        # [n,h*w,c]
        x=x.permute((0,2,1))  # [4, 16, 1024]
        # [n,c,h*w]
        x=self.avg(x)  # [n,c,1]  [4, 16, 1]
        x=x.flatten(1)  # [n,c]  [4,16]
        x=self.head(x)
        return x


def test():
    # 1. create a image
    img=np.array(Image.open('test.jpg'))   # 224x224
    t = torch.tensor(img, dtype=torch.float32)
    print(t.shape)                # [224, 224, 3]
    sample = t.reshape([4,3,224,224])      # 将[224, 224, 3]reshape成一行
    print(sample)
    #print(t.transpose(1,0))

    # 2. patch embedding--------Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings
    # patch_size是切分的大小,原始224 ∗ 224 ∗ 3 的图片会首先变成32 ∗ 32 ∗ 16
    # in_channel rgb图是3
    # embed_dim是需要映射的dim

    patch_embedding = PatchEmbedding(image_size=224, patch_size=7, in_channels=3, embed_dim=1)
    # 做前向操作
    out = patch_embedding(sample)
    print(out)
    #print(out.shape)

    mlp=Mlp(embed_dim=1)
    out = mlp(out)
    print(out.shape)

def main():
    t = torch.randn([4,3,224,224])
    model=Vit()
    out=model(t)
    print(out.shape)


if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128

最后输出[4,10]
下一节写完整的ViT代码

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/92600
推荐阅读
相关标签
  

闽ICP备14008679号