当前位置:   article > 正文

Transformer位置编码代码讲解_transformer时间戳编码

transformer时间戳编码

在看trm源码的时候关于transformer position encoding的部分不能理解,记录一下以免以后要用到。(比较浅显,基本根据论文公式反推的)

  1. # For positional encoding
  2. num_timescales = self.hidden_size // 2##一半余弦,一半正弦
  3. max_timescale = 10000.0
  4. min_timescale = 1.0##max_timescale min_timescale是时间尺度的上下界
  5. ##以上:计算时间尺度
  6. log_timescale_increment = (
  7. math.log(float(max_timescale) / float(min_timescale)) /
  8. max(num_timescales - 1, 1))##感觉是(max_timescale-min_timescale)取对数。
  9. ##在对数空间中相邻时间尺度之间的增量
  10. ##计算时间尺度的增量
  11. inv_timescales = min_timescale * torch.exp(
  12. torch.arange(num_timescales, dtype=torch.float32) *
  13. -log_timescale_increment)
  14. ##计算时间尺度的值 ##inv_timescales是时间尺度的倒数
  15. self.register_buffer('inv_timescales', inv_timescales)
  16. def get_position_encoding(self, x):
  17. max_length = x.size()[1]
  18. position = torch.arange(max_length, dtype=torch.float32,
  19. device=x.device)
  20. ##position=[0.,1.,2.,……,max_length-1.]
  21. scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
  22. signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
  23. dim=1)
  24. signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
  25. signal = signal.view(1, max_length, self.hidden_size)
  26. return signal

首先看下面这张图,便于我们直观地理解pe的构成:

可以自己输出一下代码,发现torch.sin(scaled_time), torch.cos(scaled_time)的维度都是N*D/2(N是序列的最长长度,D是hidden_size),所以signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],dim=1)后维度就是N*D,也就是对每个pos的token的每一个隐藏层维度都进行一个位置编码(其实位置编码要和词嵌入向量做加法来加入位置信息,我们知道词嵌入向量的维度是B*N*D,那么位置编码1*N*D才能做加法)

然后,我不太理解为什么下面这段代码要取个指数函数。

  1. inv_timescales = min_timescale * torch.exp(
  2. torch.arange(num_timescales, dtype=torch.float32) *
  3. -log_timescale_increment)

其实有下面这段推导:

下面介绍一下各参数含义:

  1. num_timescales=256 ##也就是d/2
  2. max_timescale=10000.0 ##也就是公式中那个10000
  3. log_timescale_increment ##公式中的(2/d)*ln10000
  4. max(num_timescales - 1, 1) ##防止分母为0
  5. inv_timescales ##exp((2*i/d)*ln10000)
  6. def get_position_encoding(self, x):
  7. max_length = x.size()[1] ##x传入的参数是input B*N*D,所以这里取的是N
  8. position = torch.arange(max_length, dtype=torch.float32,
  9. device=x.device)
  10. ##position=tensor([0.,1.,2.,……,max_length-1.]),维度是N
  11. scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
  12. ##position.unsqueeze(1)后维度是N*1,inv_timescales.unsqueeze(0)后维度是1*D/2
  13. ##所以scaled_time维度是N*D/2
  14. signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
  15. dim=1)
  16. signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
  17. ##signal维度是N*D
  18. signal = signal.view(1, max_length, self.hidden_size)
  19. ##把signal拉成(1*N*D)
  20. return signal

另外举两个例子来理解一下torch.nn.functional.pad():

直接看图

对于一个二维的矩阵 pad中的tuple参数(a,b,c,d)代表

(a,b)表示对矩阵的倒数第一个维度做padding操作,在原tensor的左侧pad a个,右侧pad b个。

(c,d)表示对矩阵的倒数第二个维度做padding操作,在原tensor的左侧pad c个,右侧pad d个。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/371445?site
推荐阅读
相关标签
  

闽ICP备14008679号