赞
踩
from math import sqrt
import torch
import torch.nn as nn
class SA(nn.Module):
def __init__(self, dimQ, dimK, dimV):
super(SA, self).__init__()
self.dimQ = dimQ
self.dimK = dimK
self.dimV = dimV
# self.mid = 10
self.linerQ = nn.Linear(self.dimQ, self.dimV, bias = False)
self.linerK = nn.Linear(self.dimK, self.dimV, bias = False)
self.linerV = nn.Linear(self.dimV, self.dimV, bias = False)
self.sqrtD = 1 / sqrt(dimQ)
def forward(self, x):
batch, n, dim = x.shape
assert(dim == self.dimQ)
Q = self.linerQ(x)
K = self.linerK(x)
V = self.linerV(x)
dist = torch.bmm(Q, K.transpose(1, 2)) * self.sqrtD
W = torch.softmax(dist, dim = -1)
Output = torch.bmm(W, V)
return Output
if __name__ == "__main__":
x = torch.tensor([[[1,2,3],[2,3,4],[3,4,5],[4,5,6]],
[[1,2,3],[2,3,4],[3,4,5],[4,5,6]]], dtype = torch.float)
print(x.shape)
saModel = SA(3, 3, 3)
Output = saModel(x)
print(Output)
https://zhuanlan.zhihu.com/p/338817680
https://blog.csdn.net/weixin_44750512/article/details/124244915
https://blog.csdn.net/qq_40178291/article/details/100302375
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。