当前位置:   article > 正文

self_attention python代码

self_attention python代码

self_attention面试code

from math import sqrt
import torch
import torch.nn as nn

class SA(nn.Module):
    def __init__(self, dimQ, dimK, dimV):
        super(SA, self).__init__()

        self.dimQ = dimQ
        self.dimK = dimK
        self.dimV = dimV

        # self.mid = 10
        
        self.linerQ = nn.Linear(self.dimQ, self.dimV, bias = False)
        self.linerK = nn.Linear(self.dimK, self.dimV, bias = False)
        self.linerV = nn.Linear(self.dimV, self.dimV, bias = False)

        self.sqrtD = 1 / sqrt(dimQ)
    
    def forward(self, x):
        batch, n, dim = x.shape

        assert(dim == self.dimQ)

        Q = self.linerQ(x)
        K = self.linerK(x)
        V = self.linerV(x)

        dist = torch.bmm(Q, K.transpose(1, 2)) * self.sqrtD
        W = torch.softmax(dist, dim = -1)

        Output = torch.bmm(W, V)
        return Output

if __name__ == "__main__":
    x = torch.tensor([[[1,2,3],[2,3,4],[3,4,5],[4,5,6]],
                     [[1,2,3],[2,3,4],[3,4,5],[4,5,6]]], dtype = torch.float)
    print(x.shape)

    saModel = SA(3, 3, 3)
    Output = saModel(x)
    print(Output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

https://zhuanlan.zhihu.com/p/338817680
https://blog.csdn.net/weixin_44750512/article/details/124244915
https://blog.csdn.net/qq_40178291/article/details/100302375

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/816628
推荐阅读
相关标签
  

闽ICP备14008679号