当前位置:   article > 正文

Transformer模型中应用的各类位置编码_旋转位置编码

旋转位置编码

六种位置编码的代码实现及性能实验 

1、位置编码的意义

对于序列数据,目前存在三种主流的建模方式:卷积操作、循环操作和自注意力。其中,卷积和循环操作都具有局部性,即只作用目标元素的若干邻居上,而自注意力则是一种全局操作。具有局部性的操作,可以天然地注意到了元素间的相对位置;而注意力机制则是位置不敏感的·,即使调换序列中两个元素的位置对编码后的结果也不会产生影响。

因此,有必要将元素对应的位置信息添加到表示中,或者在计算注意力得分时考虑两个元素之间的相对位置。这些方法统称为位置编码,可以分为绝对位置编码和相对位置编码。

2、绝对位置编码

最为经典的位置编码莫过于 BERT[1] 模型所使用的,即直接将位置的表示加到token的表示上,而每个位置的表示则为一个可学习的向量。这种编码方式,据我所知最早是由ConvS2S[2]提出,被BERT、GPT2[3]、ERNIE[4]、ALBERT[5]、electra[6] 等模型所采用。

以上的位置编码被称为learnable绝对位置编码,存在着两个问题:(1) 位置编码本身通过大量数据才能学习到;(2) 位置向量之间的相对关系没有被利用到,如位置1和位置2之间的相似性应比位置1和位置9之间的相似性高。当然这些问题都可以通过大规模语料上的预训练来缓解。与learnable绝对位置编码相对的则是fixed绝对位置编码,以三角式位置编码[7]为代表。

 

 3、相对位置编码

绝对位置编码是将位置编码直接嵌入到序列的表示中,而相对位置编码则是指在计算注意力分数的时候,直接考虑两个token之间的相对位置,即

 4、绝对位置 v.s. 相对位置

绝对位置编码具有实施简单、计算速度快的优点。而其缺点也是明显的,因为真正重要的往往不是绝对位置,而是token之间的相对位置。在下面三个句子中,东西的含义和东西与鱼的相对位置有关,而与东西本身的绝对位置无关。

有个东西在吃鱼
小明放眼望去,看到有个东西在吃鱼
有条鱼在吃东西

虽然三角式位置编码,作为一种绝对位置编码,包含了一定相对位置信息,但这种相对位置信息仅仅包含在位置编码内部。当添加位置编码的表示在计算自注意力的时候,表示中的相对位置信息是否仍然保留就是个未知数了。

此外,对于线性attention而言,相对位置编码无法直接得到应用。因此,沿着三角式位置编码的思路,进一步发展绝对位置编码是有必要的。

5、旋转式位置编码 RoPE

 

 

 

 5、 代码及实验部分

绝对位置编码

可学习的绝对位置编码

  1. class LearnableAbsolutePositionEmbedding(nn.Module):
  2. def __init__(self, max_position_embeddings, hidden_size):
  3. super().__init__()
  4. self.is_absolute = True
  5. self.embeddings = nn.Embedding(max_position_embeddings, hidden_size)
  6. self.register_buffer('position_ids', torch.arange(max_position_embeddings))
  7. def forward(self, x):
  8. """
  9. return (b l d) / (b h l d)
  10. """
  11. position_ids = self.position_ids[:x.size(-2)]
  12. if x.dim() == 3:
  13. return x + self.embeddings(position_ids)[None, :, :]
  14. elif x.dim() == 4:
  15. h = x.size(1)
  16. x = rearrange(x, 'b h l d -> b l (h d)')
  17. x = x + self.embeddings(position_ids)[None, :, :]
  18. x = rearrange(x, 'b l (h d) -> b h l d', h=h)
  19. return x

三角式绝对位置编码

  1. class FixedAbsolutePositionEmbedding(nn.Module):
  2. def __init__(self, max_position_embeddings, hidden_size, position_embedding_type):
  3. super().__init__()
  4. self.position_embedding_type = position_embedding_type
  5. self.is_absolute = True
  6. inv_freq = 1. / (10000 ** (torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size))
  7. position = torch.arange(max_position_embeddings, dtype=torch.float)
  8. sinusoid_inp = torch.einsum('i,j -> ij', position, inv_freq)
  9. embeddings = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
  10. self.register_buffer('embeddings', embeddings)
  11. def forward_fixed(self, x):
  12. """
  13. return (b l d)
  14. """
  15. return x + self.embeddings[None, :x.size(1), :]
  16. def forward_rope(self, x):
  17. """
  18. return (b l d)
  19. """
  20. embeddings = self.embeddings[None, :x.size(1), :] # b l d
  21. embeddings = rearrange(embeddings, 'b l (j d) -> b l j d', j=2)
  22. sin, cos = embeddings.unbind(dim=-2) # b l d//2
  23. sin, cos = map(lambda t: repeat(t, '... d -> ... (d 2)'), (sin, cos)) # b l d
  24. return x * cos + self.rotate_every_two(x) * sin
  25. @staticmethod
  26. def rotate_every_two(x):
  27. x = rearrange(x, '... (d j) -> ... d j', j=2)
  28. x1, x2 = x.unbind(dim=-1)
  29. x = torch.stack((-x2, x1), dim=-1)
  30. return rearrange(x, '... d j -> ... (d j)')
  31. def _forward(self, x):
  32. if self.position_embedding_type == 'fixed':
  33. return self.forward_fixed(x)
  34. elif self.position_embedding_type == 'rope':
  35. return self.forward_rope(x)
  36. def forward(self, x):
  37. if x.dim() == 3:
  38. return self._forward(x)
  39. elif x.dim() == 4:
  40. h = x.size(1)
  41. x = rearrange(x, 'b h l d -> (b h) l d')
  42. x = self._forward(x)
  43. x = rearrange(x, '(b h) l d -> b h l d', h=h)
  44. return x

相对位置编码

  1. class RelativePositionEmbedding(nn.Module):
  2. def __init__(self,
  3. relative_attention_num_buckets, num_attention_heads,
  4. hidden_size, position_embedding_type):
  5. super().__init__()
  6. self.relative_attention_num_buckets = relative_attention_num_buckets
  7. self.position_embedding_type = position_embedding_type
  8. self.num_attention_heads = num_attention_heads
  9. self.is_absolute = False
  10. if position_embedding_type == 'bias':
  11. self.embeddings = nn.Embedding(relative_attention_num_buckets, num_attention_heads)
  12. elif position_embedding_type == 'contextual(1)':
  13. self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
  14. self.to_r = nn.Linear(hidden_size, hidden_size, bias=False)
  15. elif position_embedding_type == 'contextual(2)':
  16. self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
  17. def compute_bias(self, q, k, to_q=None, to_k=None):
  18. """
  19. q, k: [b h l d]
  20. return [b h l l]
  21. """
  22. h = self.num_attention_heads
  23. query_position = torch.arange(q.size(2), dtype=torch.long, device=self.embeddings.weight.device)[:, None]
  24. key_position = torch.arange(k.size(2), dtype=torch.long, device=self.embeddings.weight.device)[None, :]
  25. relative_position = query_position - key_position
  26. relative_position_bucket = self._relative_position_bucket(
  27. relative_position,
  28. num_buckets=self.relative_attention_num_buckets
  29. )
  30. if self.position_embedding_type == 'bias':
  31. bias = self.embeddings(relative_position_bucket)
  32. bias = rearrange(bias, 'm n h -> 1 h m n')
  33. elif self.position_embedding_type == 'contextual(1)':
  34. r = self.embeddings(relative_position_bucket)
  35. r = self.to_r(r)
  36. r = rearrange(r, 'm n (h d) -> h m n d', h=h)
  37. bias = torch.einsum('b h m d, h m n d -> b h m n', q, r)
  38. elif self.position_embedding_type == 'contextual(2)':
  39. r = self.embeddings(relative_position_bucket)
  40. kr = to_k(r)
  41. qr = to_q(r)
  42. kr = rearrange(kr, 'm n (h d) -> h m n d', h=h)
  43. qr = rearrange(qr, 'm n (h d) -> h m n d', h=h)
  44. bias1 = torch.einsum('b h m d, h m n d -> b h m n', q, kr)
  45. bias2 = torch.einsum('b h n d, h m n d -> b h m n', k, qr)
  46. bias = bias1 + bias2
  47. return bias
  48. @staticmethod
  49. def _relative_position_bucket(relative_position, num_buckets, max_distance=128):
  50. """
  51. relative_position: [m n]
  52. """
  53. num_buckets //= 2
  54. relative_buckets = (relative_position > 0).to(torch.long) * num_buckets
  55. relative_position = torch.abs(relative_position)
  56. max_exact = num_buckets // 2
  57. is_small = relative_position < max_exact
  58. relative_position_if_large = max_exact + (
  59. torch.log(relative_position.float() / max_exact)
  60. / math.log(max_distance / max_exact)
  61. * (num_buckets - max_exact)
  62. ).to(torch.long)
  63. relative_position_if_large = torch.min(
  64. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  65. )
  66. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  67. return relative_buckets

Embedding层

  1. class Embedddings(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  5. # self.dropout = nn.Dropout(config.hidden_dropout_prob)
  6. self.dropout = StableDropout(config.hidden_dropout_prob)
  7. self.dense = nn.Linear(config.embedding_size, config.hidden_size)
  8. if config.position_embedding_type == 'learnable':
  9. self.position_embeddings = LearnableAbsolutePositionEmbedding(
  10. max_position_embeddings=config.max_position_embeddings,
  11. hidden_size=config.hidden_size
  12. )
  13. elif config.position_embedding_type in ('fixed', 'rope'):
  14. self.position_embeddings = FixedAbsolutePositionEmbedding(
  15. max_position_embeddings=config.max_position_embeddings,
  16. hidden_size=config.hidden_size,
  17. position_embedding_type=config.position_embedding_type
  18. )
  19. def forward(self, input_ids):
  20. embeds = self.word_embeddings(input_ids)
  21. embeds = self.dropout(embeds)
  22. embeds = self.dense(embeds)
  23. if hasattr(self, 'position_embeddings'):
  24. embeds = self.position_embeddings(embeds)
  25. return embeds

注意力

  1. class Attention(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. self.n_heads = config.num_attention_heads
  5. dim_heads = config.hidden_size // config.num_attention_heads
  6. self.to_q = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  7. self.to_k = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  8. self.to_v = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  9. self.to_out = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  10. self.dropout = StableDropout(config.hidden_dropout_prob)
  11. if config.encoder_layer == 'transformer':
  12. self.attn_fn = TransformerAttention(config)
  13. elif config.encoder_layer == 'performer':
  14. self.attn_fn = PerformerAttention(config)
  15. else:
  16. raise NotImplementedError
  17. def forward(self, x, mask, pos_emb):
  18. h = self.n_heads
  19. q = self.to_q(x)
  20. k = self.to_k(x)
  21. v = self.to_v(x)
  22. q, k, v = map(lambda t: rearrange(t, 'b l (h d) -> b h l d', h=h), (q, k, v))
  23. context = self.attn_fn(q, k, v, mask, pos_emb, to_q=self.to_q, to_k=self.to_k)
  24. out = self.to_out(context)
  25. out = self.dropout(out)
  26. return out

自注意力

  1. class TransformerAttention(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. attention_head_size = config.hidden_size // config.num_attention_heads
  5. self.scale = attention_head_size ** -0.5
  6. # self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  7. self.dropout = StableDropout(config.attention_probs_dropout_prob)
  8. def forward(self, q, k, v, mask, pos_emb, to_q, to_k):
  9. """
  10. q, k, v: [b h l d]
  11. mask: [b l]
  12. """
  13. if pos_emb is not None and pos_emb.is_absolute is True:
  14. q = pos_emb(q)
  15. k = pos_emb(k)
  16. dots = torch.einsum('b h m d, b h n d -> b h m n', q, k)
  17. if pos_emb is not None and pos_emb.is_absolute is False:
  18. bias = pos_emb.compute_bias(q, k, to_q, to_k)
  19. dots = dots + bias
  20. # assert mask is not None
  21. # if mask is not None:
  22. mask = mask[:, None, None, :] & mask[:, None, :, None]
  23. # dots = dots.masked_fill(~mask, -10000.)
  24. # probs = dots.softmax(dim=-1)
  25. probs = XSoftmax.apply(dots, mask, -1)
  26. probs = self.dropout(probs)
  27. context = torch.einsum('b h m n, b h n d -> b h m d', probs, v)
  28. context = rearrange(context, 'b h m d -> b m (h d)')
  29. return context

线性注意力

  1. class PerformerAttention(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. attention_head_size = config.hidden_size // config.num_attention_heads
  5. self.attn = FastAttention(dim_heads=attention_head_size, causal=False)
  6. def forward(self, q, k, v, mask, pos_emb, **kwargs):
  7. """
  8. q, k, v: [b h l d]
  9. mask: [b l]
  10. """
  11. if pos_emb is not None:
  12. assert pos_emb.is_absolute is True
  13. q = pos_emb(q)
  14. k = pos_emb(k)
  15. mask = mask[:, None, :, None]
  16. v = v.masked_fill(~mask, 0.)
  17. context = self.attn(q, k, v)
  18. context = rearrange(context, 'b h l d -> b l (h d)')
  19. return context

Transformer Encoder

  1. class Encoder(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. dim_heads = config.hidden_size // config.num_attention_heads
  5. if config.position_embedding_type == 'layerwise_learnable':
  6. self.position_embeddings = LearnableAbsolutePositionEmbedding(
  7. max_position_embeddings=config.max_position_embeddings,
  8. hidden_size=config.hidden_size
  9. )
  10. elif config.position_embedding_type in ('layerwise_fixed', 'layerwise_rope'):
  11. self.position_embeddings = FixedAbsolutePositionEmbedding(
  12. max_position_embeddings=config.max_position_embeddings,
  13. hidden_size=dim_heads,
  14. position_embedding_type=config.position_embedding_type.split('_')[-1],
  15. )
  16. elif config.position_embedding_type in ('layerwise_bias', 'layerwise_contextual(1)', 'layerwise_contextual(2)'):
  17. self.position_embeddings = RelativePositionEmbedding(
  18. config.relative_attention_num_buckets,
  19. config.num_attention_heads,
  20. config.hidden_size,
  21. position_embedding_type=config.position_embedding_type.split('_')[-1]
  22. )
  23. else:
  24. self.position_embeddings = None
  25. self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
  26. def forward(self, x, mask):
  27. for layer_module in self.layer:
  28. x = layer_module(x, mask, self.position_embeddings)
  29. return x

完整模型

  1. class TDLMModel(TDLMPreTrainedModel):
  2. def __init__(self, config):
  3. super().__init__(config)
  4. self.embeddings = Embedddings(config)
  5. self.encoder = Encoder(config)
  6. self.init_weights()
  7. def forward(self, input_ids, attention_mask):
  8. """
  9. input_ids: [b, l]
  10. attention_mask: [b, l]
  11. """
  12. attention_mask = attention_mask.bool()
  13. x = self.embeddings(input_ids)
  14. x = self.encoder(x, attention_mask)
  15. return x

实验及结果

实验设置

  • 参数量 由于GPU资源的限制,实验中所使用的Transformer模型(29M)在参数量上,要比BERT-base(110M)小很多。

  • 训练语料 模型在英语维基百科语料(13G文本)上训练,batch_size通过梯度累计的方式设置为2048,一共训练了20K步,这相当于全部语料的1/3左右。

  • 评价指标 选择bpd(Bits Per Dimension)作为语言模型的评价指标,bpd=loss/ln(2)。

  1. TDLMConfig { [0/1855]
  2. "attention_probs_dropout_prob": 0.1,
  3. "embedding_size": 128,
  4. "encoder_layer": "transformer",
  5. "glu": false,
  6. "hidden_act": "gelu",
  7. "hidden_dropout_prob": 0.1,
  8. "hidden_size": 512,
  9. "initializer_range": 0.02,
  10. "layer_norm_eps": 1e-12,
  11. "max_position_embeddings": 512,
  12. "num_attention_heads": 8,
  13. "num_hidden_layers": 8,
  14. "pad_token_id": 0,
  15. "position_embedding_type": "layerwise_rope",
  16. "pre_norm": true,
  17. "relative_attention_num_buckets": 32,
  18. "transformers_version": "4.6.1",
  19. "vocab_size": 30522
  20. }

首先,进行了简单的超参数调节,最后发现初始学习率设为2e-4比较好(使用了linear_schedule_with_warmup)。

在embedding层应用绝对位置编码,如下图,可以发现RoPE优于三角式位置编码和可学习的位置编码,bqd最低为3.05。

将RoPE与相对位置编码进行比较时,可以发现contextual模式的相对位置编码还是优于RoPE的。但是相对于相对位置编码,RoPE仍然有以下优势

  • RoPE本身不会给模型引入额外的参数

  • RoPE是直接作用于q,k上,因而无需修改注意力的计算过程。进一步说,RoPE可以直接且方便的作用在Transformer变体上,如Performer[14]、Reformer[15]等

 

Pre-Norm v.s. Post-Norm

此外,这里对比了transformer模型中的pre-norm和post-norm。如下图所示,对于语言模型而言,post-norm还是更好。也许pre-norm只适合CV领域上的任务吧。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/366112
推荐阅读
相关标签
  

闽ICP备14008679号