当前位置:   article > 正文

【Transformer】Self-Attention with Relative Position Representations及实现pytorch代码

self-attention with relative position representations实现

Transformer中加入可训练的embedding编码,使得output representation可以表征inputs的时序/位置信息。这些embedding vectors在计算输入序列中的任意两个单词i,j之间的key和value是被加入其中。embedding vector用于表示单词i,j之间的距离,因此命名为“相对位置表征”(Relative Postiion Representation)。

 

Self-Attention

输入序列xi经过Self-Attention之后输出为zi。zi是所有经过映射(W^V)的序列的加权和。两个序列之间的权重通过(2)式计算,并通过Softmax归一化。

Relation-aware Self-Attention

考虑输入元素对之间的关系。输入xi和xj之间,增加aij^V和aij^K的表示,这不需要额外的Linear Layer,而是在attention heads中共用。将(1)式修改为(3);(2)式修改为(4)

相关位置,我们仅考虑最远k个元素。即从当前位置出发,左侧最远k个元素:-k,和右侧最远k个元素k,超过k个元素范围的距离截断为k,因此对于w^V和w^K,每一个都包含2k+1个向量(左边k个+右边k个,加自己),每个向量,即aij^K或aij^V都是d_a=d_z维。具体如下图所示:

 简化运算,将(4)拆为两步计算。

实验中,本文使用6个encoder和decoder层,其中d_x=512,d_z=64,8个attention_heads,clipping distance k = 16。训练了100,000 steps在8个K40 GPUs。关于k的取值:

 

代码

参考:GitHub - evelinehong/Transformer_Relative_Position_PyTorch: Implement the paper "Self-Attention with Relative Position Representations"

  1. class RelativePosition(nn.Module):
  2. def __init__(self, num_units, max_relative_position):
  3. super().__init__()
  4. self.num_units = num_units
  5. self.max_relative_position = max_relative_position
  6. self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
  7. nn.init.xavier_uniform_(self.embeddings_table)
  8. def forward(self, length_q, length_k):
  9. range_vec_q = torch.arange(length_q)
  10. range_vec_k = torch.arange(length_k)
  11. distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
  12. distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
  13. final_mat = distance_mat_clipped + self.max_relative_position
  14. final_mat = torch.LongTensor(final_mat).cuda()
  15. embeddings = self.embeddings_table[final_mat].cuda()
  16. return embeddings
  17. class RelativeMultiHeadAttention(nn.Module):
  18. def __init__(self, d_model, n_heads, dropout=0.1, batch_size=6):
  19. "Take in model size and number of heads."
  20. super(RelativeMultiHeadAttention, self).__init__()
  21. self.d_model = d_model
  22. self.n_heads = n_heads
  23. self.batch_size = batch_size
  24. assert d_model % n_heads == 0
  25. self.head_dim = d_model // n_heads
  26. self.linears = _get_clones(nn.Linear(d_model, d_model), 4)
  27. self.dropout = nn.Dropout(p=dropout)
  28. self.relative_position_k = RelativePosition(self.head_dim, max_relative_position=16)
  29. self.relative_position_v = RelativePosition(self.head_dim, max_relative_position=16)
  30. self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).cuda()
  31. def forward(self, query, key, value):
  32. # embedding
  33. # query, key, value = [batch_size, len, hid_dim]
  34. query, key, value = [l(x).view(self.batch_size, -1, self.d_model) for l, x in
  35. zip(self.linears, (query, key, value))]
  36. len_k = query.shape[1]
  37. len_q = query.shape[1]
  38. len_v = value.shape[1]
  39. # Self-Attention
  40. # r_q1, r_k1 = [batch_size, len, n_heads, head_dim]
  41. r_q1 = query.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
  42. r_k1 = key.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
  43. attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))
  44. r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, self.batch_size * self.n_heads, self.head_dim)
  45. r_k2 = self.relative_position_k(len_q, len_k)
  46. attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)
  47. attn2 = attn2.contiguous().view(self.batch_size, self.n_heads, len_q, len_k)
  48. attn = (attn1 + attn2) / self.scale
  49. attn = self.dropout(torch.softmax(attn, dim=-1))
  50. # attn = [batch_size, n_heads, len, len]
  51. r_v1 = value.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
  52. weight1 = torch.matmul(attn, r_v1)
  53. r_v2 = self.relative_position_v(len_q, len_v)
  54. weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, self.batch_size * self.n_heads, len_k)
  55. weight2 = torch.matmul(weight2, r_v2)
  56. weight2 = weight2.transpose(0, 1).contiguous().view(self.batch_size, self.n_heads, len_q, self.head_dim)
  57. x = weight1 + weight2
  58. # x = [batch size, n heads, query len, head dim]
  59. x = x.permute(0, 2, 1, 3).contiguous()
  60. # x = [batch size, query len, n heads, head dim]
  61. x = x.view(self.batch_size * len_q, self.d_model)
  62. # x = [batch size * query len, hid dim]
  63. return self.linears[-1](x)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/92929
推荐阅读
相关标签
  

闽ICP备14008679号