当前位置:   article > 正文

实习日志8.2——transformer结构理解_transformer日志

transformer日志

注意力机制理解及代码:

理解:

因此,每个q向量都会对应一个v向量,同时,不同的评分函数对于注意力权重有不同的影响。

缩放点积注意力:

参考链接:10.3. 注意力评分函数 — 动手学深度学习 2.0.0 documentation 

注意,这种评分函数要求q和k拥有相同的长度d,对于小批量(有了批量,向量dim=3),给出公式如上。

对于点积可以衡量向量相似度的理解:向量点积越大,方向越接近,一定程度上刻画了两个向量的相似度。

代码:
  1. import math
  2. import torch
  3. from torch import nn
  4. from d2l import torch as d2l
  5. #掩码softmax,用于指定有效序列长度
  6. def masked_softmax(X, valid_lens):
  7. """通过在最后一个轴上掩蔽元素来执行softmax操作"""
  8. # X:3D张量(batch_size,sequence_length,embedding_dim),valid_lens:1D或2D张量
  9. if valid_lens is None:
  10. return nn.functional.softmax(X, dim=-1)
  11. else:
  12. shape = X.shape
  13. if valid_lens.dim() == 1:
  14. valid_lens = torch.repeat_interleave(valid_lens, shape[1])
  15. '''repeat_interleave(input:tensor,repeats:int or tensor, dim)用于指定维度上重复张量中的元素,将有效长度扩充到句子长度'''
  16. else:
  17. valid_lens = valid_lens.reshape(-1)
  18. # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
  19. X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
  20. value=-1e6)
  21. return nn.functional.softmax(X.reshape(shape), dim=-1)
  22. class DotProductAttention(nn.Module):
  23. """缩放点积注意力"""
  24. def __init__(self, dropout, **kwargs):
  25. super(DotProductAttention, self).__init__(**kwargs)
  26. self.dropout = nn.Dropout(dropout)
  27. # queries的形状:(batch_size,查询的个数,d)
  28. # keys的形状:(batch_size,“键-值”对的个数,d)
  29. # values的形状:(batch_size,“键-值”对的个数,值的维度)
  30. # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
  31. def forward(self, queries, keys, values, valid_lens=None):
  32. d = queries.shape[-1]
  33. # 设置transpose_b=True为了交换keys的最后两个维度
  34. #torch.bmm()用于批量矩阵相乘,transpose对矩阵K转置
  35. scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) #注意力分数
  36. self.attention_weights = masked_softmax(scores, valid_lens)
  37. # 由于矩阵乘法的性质,可以用torch.bmm()进行注意力公式最后的加权求和
  38. return torch.bmm(self.dropout(self.attention_weights), values)

 自注意力与位置编码

自注意力:

对self-attention而言,q和k的维度相同,事实上,在transformer中,Q,K,V都相同

代码: 
  1. from d2l import torch as d2l
  2. num_hiddens, num_heads = 10, 5
  3. '''class MultiHeadAttention(nn.Module): # 多头注意力
  4. def __init__(self, key_size, query_size, value_size, num_hiddens,
  5. num_heads, dropout, bias=False, **kwargs):
  6. super(MultiHeadAttention, self).__init__(**kwargs)
  7. self.num_heads = num_heads
  8. self.attention = d2l.DotProductAttention(dropout)
  9. self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
  10. self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
  11. self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
  12. self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  13. '''主要看forward函数'''
  14. def forward(self, queries, keys, values, valid_lens):
  15. # Shape of `queries`, `keys`, or `values`:
  16. # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
  17. # Shape of `valid_lens`:
  18. # (`batch_size`,) or (`batch_size`, no. of queries)
  19. # After transposing, shape of output `queries`, `keys`, or `values`:
  20. # (`batch_size` * `num_heads`, no. of queries or key-value pairs,
  21. # `num_hiddens` / `num_heads`)
  22. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  23. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  24. values = transpose_qkv(self.W_v(values), self.num_heads)
  25. if valid_lens is not None:
  26. # On axis 0, copy the first item (scalar or vector) for
  27. # `num_heads` times, then copy the next item, and so on
  28. valid_lens = torch.repeat_interleave(
  29. valid_lens, repeats=self.num_heads, dim=0)
  30. # Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
  31. # `num_hiddens` / `num_heads`)
  32. output = self.attention(queries, keys, values, valid_lens)
  33. # Shape of `output_concat`:
  34. # (`batch_size`, no. of queries, `num_hiddens`)
  35. output_concat = transpose_output(output, self.num_heads)
  36. return self.W_o(output_concat)
  37. '''
  38. attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
  39. num_hiddens, num_heads, 0.5)
  40. attention.eval()
  41. batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
  42. X = torch.ones((batch_size, num_queries, num_hiddens))
  43. print(attention(X, X, X, valid_lens).shape) #torch.size([2,4,10])
  44. print(X)
  45. print(X.shape) # torch.size([2,4,10])

位置编码:

在自然语言处理(NLP)中,绝对位置(Absolute Position)和相对位置(Relative Position)通常用来描述序列中不同元素之间的位置关系。

绝对位置是指序列中每个元素的具体位置,通常通过元素在序列中的索引或位置来表示。

相对位置是指序列中不同元素之间的相对距离或位置关系。相对位置通常以相对于某个参考元素的距离来表示。例如,在一个句子中,一个单词相对于另一个单词的位置可以用距离来表示,例如距离为1表示紧邻相邻,距离为2表示间隔一个单词,依此类推。相对位置可以更好地捕捉序列中元素之间的上下文关系和依赖关系。

以下为transformer中所使用的位置编码方法:

 其中所蕴含的绝对位置信息:这种编码方式使得编码维度频率逐渐降低,而编码维度单调降低的频率与绝对位置信息的关系是,维度越高,频率越低,与这种三角函数编码方式相吻合。

相对位置信息:利用三角函数周期性的性质,如下公式可以证明其中蕴含相对位置信息。

其中,\delta为任意位置i的偏移量。可以看到,投影矩阵与i无关,只与 \delta有关。

代码:
  1. class PositionalEncoding(nn.Module):
  2. """位置编码"""
  3. def __init__(self, num_hiddens, dropout, max_len=1000):
  4. super(PositionalEncoding, self).__init__()
  5. self.dropout = nn.Dropout(dropout)
  6. # 创建一个足够长的P,用来存放编码矩阵
  7. self.P = torch.zeros((1, max_len, num_hiddens))
  8. #利用torch.pow()计算幂次方以及广播机制,计算位置编码矩阵中的数值
  9. X = torch.arange(max_len, dtype=torch.float32).reshape(
  10. -1, 1) / torch.pow(10000, torch.arange(
  11. 0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
  12. #维度方向上,从0开始,每隔两个位置取一个,也就是奇数位
  13. self.P[:, :, 0::2] = torch.sin(X)
  14. self.P[:, :, 1::2] = torch.cos(X)
  15. def forward(self, X):
  16. X = X + self.P[:, :X.shape[1], :].to(X.device)
  17. return self.dropout(X)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/909617
推荐阅读
相关标签
  

闽ICP备14008679号