赞
踩
1. 原理图
2. 代码
- import torch
- import torch.nn as nn
-
-
- class Multi_Head_Self_Attention(nn.Module):
- def __init__(self, embed_size, heads):
- super(Multi_Head_Self_Attention, self).__init__()
- self.embed_size = embed_size
- self.heads = heads
- self.head_dim = embed_size // heads
-
- self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
- self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
- self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
- self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)
-
- def forward(self,queries, keys, values, mask):
- N = queries.shape[0] # batch_size
- query_len = queries.shape[1] # sequence_length
- key_len = keys.shape[1] # sequence_length
- value_len = values.shape[1] # sequence_length
-
- queries = self.queries(queries)
- keys = self.keys(keys)
- values = self.values(values)
-
- # Split the embedding into self.heads pieces
- # batch_size, sequence_length, embed_size(512) -->
- # batch_size, sequence_length, heads(8), head_dim(64)
- queries = queries.reshape(N, query_len, self.heads, self.head_dim)
- keys = keys.reshape(N, key_len, self.heads, self.head_dim)
- values = values.reshape(N, value_len, self.heads, self.head_dim)
-
- # batch_size, sequence_length, heads(8), head_dim(64) -->
- # batch_size, heads(8), sequence_length, head_dim(64)
- queries = queries.transpose(1, 2)
- keys = keys.transpose(1, 2)
- values = values.transpose(1, 2)
-
- # Scaled dot-product attention
- score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))
-
- if mask is not None:
- score = score.masked_fill(mask == 0, float("-inf"))
- # batch_size, heads(8), sequence_length, sequence_length
- attention = torch.softmax(score, dim=-1)
-
- out = torch.matmul(attention, values)
- # batch_size, heads(8), sequence_length, head_dim(64) -->
- # batch_size, sequence_length, heads(8), head_dim(64) -->
- # batch_size, sequence_length, embed_size(512)
- # 为了方便送入后面的网络
- out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
- out = self.fc_out(out)
-
- return out
-
-
- batch_size = 64
- sequence_length = 10
- embed_size = 512
- heads = 8
- mask = None
-
- Q = torch.randn(batch_size, sequence_length, embed_size)
- K = torch.randn(batch_size, sequence_length, embed_size)
- V = torch.randn(batch_size, sequence_length, embed_size)
-
- model = Multi_Head_Self_Attention(embed_size, heads)
- output = model(Q, K, V, mask)
- print(output.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。