赞
踩
对于序列数据,目前存在三种主流的建模方式:卷积操作、循环操作和自注意力。其中,卷积和循环操作都具有局部性,即只作用目标元素的若干邻居上,而自注意力则是一种全局操作。具有局部性的操作,可以天然地注意到了元素间的相对位置;而注意力机制则是位置不敏感的·,即使调换序列中两个元素的位置对编码后的结果也不会产生影响。
因此,有必要将元素对应的位置信息添加到表示中,或者在计算注意力得分时考虑两个元素之间的相对位置。这些方法统称为位置编码,可以分为绝对位置编码和相对位置编码。
最为经典的位置编码莫过于 BERT[1] 模型所使用的,即直接将位置的表示加到token的表示上,而每个位置的表示则为一个可学习的向量。这种编码方式,据我所知最早是由ConvS2S[2]提出,被BERT、GPT2[3]、ERNIE[4]、ALBERT[5]、electra[6] 等模型所采用。
以上的位置编码被称为learnable绝对位置编码,存在着两个问题:(1) 位置编码本身通过大量数据才能学习到;(2) 位置向量之间的相对关系没有被利用到,如位置1和位置2之间的相似性应比位置1和位置9之间的相似性高。当然这些问题都可以通过大规模语料上的预训练来缓解。与learnable绝对位置编码相对的则是fixed绝对位置编码,以三角式位置编码[7]为代表。
绝对位置编码是将位置编码直接嵌入到序列的表示中,而相对位置编码则是指在计算注意力分数的时候,直接考虑两个token之间的相对位置,即
绝对位置编码具有实施简单、计算速度快的优点。而其缺点也是明显的,因为真正重要的往往不是绝对位置,而是token之间的相对位置。在下面三个句子中,东西的含义和东西与鱼的相对位置有关,而与东西本身的绝对位置无关。
有个东西在吃鱼
小明放眼望去,看到有个东西在吃鱼
有条鱼在吃东西
虽然三角式位置编码,作为一种绝对位置编码,包含了一定相对位置信息,但这种相对位置信息仅仅包含在位置编码内部。当添加位置编码的表示在计算自注意力的时候,表示中的相对位置信息是否仍然保留就是个未知数了。
此外,对于线性attention而言,相对位置编码无法直接得到应用。因此,沿着三角式位置编码的思路,进一步发展绝对位置编码是有必要的。
绝对位置编码
可学习的绝对位置编码
- class LearnableAbsolutePositionEmbedding(nn.Module):
- def __init__(self, max_position_embeddings, hidden_size):
- super().__init__()
- self.is_absolute = True
- self.embeddings = nn.Embedding(max_position_embeddings, hidden_size)
- self.register_buffer('position_ids', torch.arange(max_position_embeddings))
-
- def forward(self, x):
- """
- return (b l d) / (b h l d)
- """
- position_ids = self.position_ids[:x.size(-2)]
-
- if x.dim() == 3:
- return x + self.embeddings(position_ids)[None, :, :]
-
- elif x.dim() == 4:
- h = x.size(1)
- x = rearrange(x, 'b h l d -> b l (h d)')
- x = x + self.embeddings(position_ids)[None, :, :]
- x = rearrange(x, 'b l (h d) -> b h l d', h=h)
- return x
三角式绝对位置编码
- class FixedAbsolutePositionEmbedding(nn.Module):
- def __init__(self, max_position_embeddings, hidden_size, position_embedding_type):
- super().__init__()
-
- self.position_embedding_type = position_embedding_type
- self.is_absolute = True
-
- inv_freq = 1. / (10000 ** (torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size))
- position = torch.arange(max_position_embeddings, dtype=torch.float)
- sinusoid_inp = torch.einsum('i,j -> ij', position, inv_freq)
- embeddings = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
- self.register_buffer('embeddings', embeddings)
-
- def forward_fixed(self, x):
- """
- return (b l d)
- """
- return x + self.embeddings[None, :x.size(1), :]
-
- def forward_rope(self, x):
- """
- return (b l d)
- """
- embeddings = self.embeddings[None, :x.size(1), :] # b l d
- embeddings = rearrange(embeddings, 'b l (j d) -> b l j d', j=2)
- sin, cos = embeddings.unbind(dim=-2) # b l d//2
- sin, cos = map(lambda t: repeat(t, '... d -> ... (d 2)'), (sin, cos)) # b l d
- return x * cos + self.rotate_every_two(x) * sin
-
- @staticmethod
- def rotate_every_two(x):
- x = rearrange(x, '... (d j) -> ... d j', j=2)
- x1, x2 = x.unbind(dim=-1)
- x = torch.stack((-x2, x1), dim=-1)
- return rearrange(x, '... d j -> ... (d j)')
-
- def _forward(self, x):
- if self.position_embedding_type == 'fixed':
- return self.forward_fixed(x)
-
- elif self.position_embedding_type == 'rope':
- return self.forward_rope(x)
-
- def forward(self, x):
- if x.dim() == 3:
- return self._forward(x)
-
- elif x.dim() == 4:
- h = x.size(1)
- x = rearrange(x, 'b h l d -> (b h) l d')
- x = self._forward(x)
- x = rearrange(x, '(b h) l d -> b h l d', h=h)
- return x
相对位置编码
- class RelativePositionEmbedding(nn.Module):
- def __init__(self,
- relative_attention_num_buckets, num_attention_heads,
- hidden_size, position_embedding_type):
-
- super().__init__()
-
- self.relative_attention_num_buckets = relative_attention_num_buckets
- self.position_embedding_type = position_embedding_type
- self.num_attention_heads = num_attention_heads
- self.is_absolute = False
-
- if position_embedding_type == 'bias':
- self.embeddings = nn.Embedding(relative_attention_num_buckets, num_attention_heads)
-
- elif position_embedding_type == 'contextual(1)':
- self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
- self.to_r = nn.Linear(hidden_size, hidden_size, bias=False)
-
- elif position_embedding_type == 'contextual(2)':
- self.embeddings = nn.Embedding(relative_attention_num_buckets, hidden_size)
-
- def compute_bias(self, q, k, to_q=None, to_k=None):
- """
- q, k: [b h l d]
- return [b h l l]
- """
- h = self.num_attention_heads
- query_position = torch.arange(q.size(2), dtype=torch.long, device=self.embeddings.weight.device)[:, None]
- key_position = torch.arange(k.size(2), dtype=torch.long, device=self.embeddings.weight.device)[None, :]
-
- relative_position = query_position - key_position
- relative_position_bucket = self._relative_position_bucket(
- relative_position,
- num_buckets=self.relative_attention_num_buckets
- )
-
- if self.position_embedding_type == 'bias':
- bias = self.embeddings(relative_position_bucket)
- bias = rearrange(bias, 'm n h -> 1 h m n')
-
- elif self.position_embedding_type == 'contextual(1)':
- r = self.embeddings(relative_position_bucket)
- r = self.to_r(r)
- r = rearrange(r, 'm n (h d) -> h m n d', h=h)
-
- bias = torch.einsum('b h m d, h m n d -> b h m n', q, r)
-
- elif self.position_embedding_type == 'contextual(2)':
- r = self.embeddings(relative_position_bucket)
-
- kr = to_k(r)
- qr = to_q(r)
-
- kr = rearrange(kr, 'm n (h d) -> h m n d', h=h)
- qr = rearrange(qr, 'm n (h d) -> h m n d', h=h)
-
- bias1 = torch.einsum('b h m d, h m n d -> b h m n', q, kr)
- bias2 = torch.einsum('b h n d, h m n d -> b h m n', k, qr)
-
- bias = bias1 + bias2
-
- return bias
-
- @staticmethod
- def _relative_position_bucket(relative_position, num_buckets, max_distance=128):
- """
- relative_position: [m n]
- """
-
- num_buckets //= 2
- relative_buckets = (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
-
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
-
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
-
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
-
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
Embedding层
- class Embedddings(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
- # self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.dropout = StableDropout(config.hidden_dropout_prob)
- self.dense = nn.Linear(config.embedding_size, config.hidden_size)
-
- if config.position_embedding_type == 'learnable':
- self.position_embeddings = LearnableAbsolutePositionEmbedding(
- max_position_embeddings=config.max_position_embeddings,
- hidden_size=config.hidden_size
- )
-
- elif config.position_embedding_type in ('fixed', 'rope'):
- self.position_embeddings = FixedAbsolutePositionEmbedding(
- max_position_embeddings=config.max_position_embeddings,
- hidden_size=config.hidden_size,
- position_embedding_type=config.position_embedding_type
- )
-
- def forward(self, input_ids):
- embeds = self.word_embeddings(input_ids)
- embeds = self.dropout(embeds)
- embeds = self.dense(embeds)
-
- if hasattr(self, 'position_embeddings'):
- embeds = self.position_embeddings(embeds)
-
- return embeds
注意力
- class Attention(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- self.n_heads = config.num_attention_heads
- dim_heads = config.hidden_size // config.num_attention_heads
-
- self.to_q = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
- self.to_k = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
- self.to_v = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
-
- self.to_out = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
- self.dropout = StableDropout(config.hidden_dropout_prob)
-
- if config.encoder_layer == 'transformer':
- self.attn_fn = TransformerAttention(config)
-
- elif config.encoder_layer == 'performer':
- self.attn_fn = PerformerAttention(config)
-
- else:
- raise NotImplementedError
-
- def forward(self, x, mask, pos_emb):
- h = self.n_heads
-
- q = self.to_q(x)
- k = self.to_k(x)
- v = self.to_v(x)
-
- q, k, v = map(lambda t: rearrange(t, 'b l (h d) -> b h l d', h=h), (q, k, v))
-
- context = self.attn_fn(q, k, v, mask, pos_emb, to_q=self.to_q, to_k=self.to_k)
- out = self.to_out(context)
- out = self.dropout(out)
-
- return out
自注意力
- class TransformerAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- attention_head_size = config.hidden_size // config.num_attention_heads
- self.scale = attention_head_size ** -0.5
- # self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.dropout = StableDropout(config.attention_probs_dropout_prob)
-
- def forward(self, q, k, v, mask, pos_emb, to_q, to_k):
- """
- q, k, v: [b h l d]
- mask: [b l]
- """
- if pos_emb is not None and pos_emb.is_absolute is True:
- q = pos_emb(q)
- k = pos_emb(k)
-
- dots = torch.einsum('b h m d, b h n d -> b h m n', q, k)
-
- if pos_emb is not None and pos_emb.is_absolute is False:
- bias = pos_emb.compute_bias(q, k, to_q, to_k)
- dots = dots + bias
-
- # assert mask is not None
- # if mask is not None:
- mask = mask[:, None, None, :] & mask[:, None, :, None]
- # dots = dots.masked_fill(~mask, -10000.)
- # probs = dots.softmax(dim=-1)
- probs = XSoftmax.apply(dots, mask, -1)
-
- probs = self.dropout(probs)
-
- context = torch.einsum('b h m n, b h n d -> b h m d', probs, v)
- context = rearrange(context, 'b h m d -> b m (h d)')
-
- return context
线性注意力
- class PerformerAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- attention_head_size = config.hidden_size // config.num_attention_heads
- self.attn = FastAttention(dim_heads=attention_head_size, causal=False)
-
- def forward(self, q, k, v, mask, pos_emb, **kwargs):
- """
- q, k, v: [b h l d]
- mask: [b l]
- """
- if pos_emb is not None:
- assert pos_emb.is_absolute is True
- q = pos_emb(q)
- k = pos_emb(k)
-
- mask = mask[:, None, :, None]
- v = v.masked_fill(~mask, 0.)
-
- context = self.attn(q, k, v)
- context = rearrange(context, 'b h l d -> b l (h d)')
- return context
Transformer Encoder
- class Encoder(nn.Module):
- def __init__(self, config):
- super().__init__()
-
- dim_heads = config.hidden_size // config.num_attention_heads
-
- if config.position_embedding_type == 'layerwise_learnable':
- self.position_embeddings = LearnableAbsolutePositionEmbedding(
- max_position_embeddings=config.max_position_embeddings,
- hidden_size=config.hidden_size
- )
-
- elif config.position_embedding_type in ('layerwise_fixed', 'layerwise_rope'):
- self.position_embeddings = FixedAbsolutePositionEmbedding(
- max_position_embeddings=config.max_position_embeddings,
- hidden_size=dim_heads,
- position_embedding_type=config.position_embedding_type.split('_')[-1],
- )
-
- elif config.position_embedding_type in ('layerwise_bias', 'layerwise_contextual(1)', 'layerwise_contextual(2)'):
- self.position_embeddings = RelativePositionEmbedding(
- config.relative_attention_num_buckets,
- config.num_attention_heads,
- config.hidden_size,
- position_embedding_type=config.position_embedding_type.split('_')[-1]
- )
-
- else:
- self.position_embeddings = None
-
- self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
-
- def forward(self, x, mask):
- for layer_module in self.layer:
- x = layer_module(x, mask, self.position_embeddings)
-
- return x
完整模型
- class TDLMModel(TDLMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embeddings = Embedddings(config)
- self.encoder = Encoder(config)
-
- self.init_weights()
-
- def forward(self, input_ids, attention_mask):
- """
- input_ids: [b, l]
- attention_mask: [b, l]
- """
- attention_mask = attention_mask.bool()
-
- x = self.embeddings(input_ids)
- x = self.encoder(x, attention_mask)
-
- return x
实验及结果
实验设置
参数量 由于GPU资源的限制,实验中所使用的Transformer模型(29M)在参数量上,要比BERT-base(110M)小很多。
训练语料 模型在英语维基百科语料(13G文本)上训练,batch_size通过梯度累计的方式设置为2048,一共训练了20K步,这相当于全部语料的1/3左右。
评价指标 选择bpd(Bits Per Dimension)作为语言模型的评价指标,bpd=loss/ln(2)。
- TDLMConfig { [0/1855]
- "attention_probs_dropout_prob": 0.1,
- "embedding_size": 128,
- "encoder_layer": "transformer",
- "glu": false,
- "hidden_act": "gelu",
- "hidden_dropout_prob": 0.1,
- "hidden_size": 512,
- "initializer_range": 0.02,
- "layer_norm_eps": 1e-12,
- "max_position_embeddings": 512,
- "num_attention_heads": 8,
- "num_hidden_layers": 8,
- "pad_token_id": 0,
- "position_embedding_type": "layerwise_rope",
- "pre_norm": true,
- "relative_attention_num_buckets": 32,
- "transformers_version": "4.6.1",
- "vocab_size": 30522
- }
首先,进行了简单的超参数调节,最后发现初始学习率设为2e-4比较好(使用了linear_schedule_with_warmup)。
在embedding层应用绝对位置编码,如下图,可以发现RoPE优于三角式位置编码和可学习的位置编码,bqd最低为3.05。
将RoPE与相对位置编码进行比较时,可以发现contextual模式的相对位置编码还是优于RoPE的。但是相对于相对位置编码,RoPE仍然有以下优势
RoPE本身不会给模型引入额外的参数
RoPE是直接作用于q,k上,因而无需修改注意力的计算过程。进一步说,RoPE可以直接且方便的作用在Transformer变体上,如Performer[14]、Reformer[15]等
Pre-Norm v.s. Post-Norm
此外,这里对比了transformer模型中的pre-norm和post-norm。如下图所示,对于语言模型而言,post-norm还是更好。也许pre-norm只适合CV领域上的任务吧。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。