赞
踩
在看trm源码的时候关于transformer position encoding的部分不能理解,记录一下以免以后要用到。(比较浅显,基本根据论文公式反推的)
- # For positional encoding
- num_timescales = self.hidden_size // 2##一半余弦,一半正弦
- max_timescale = 10000.0
- min_timescale = 1.0##max_timescale min_timescale是时间尺度的上下界
- ##以上:计算时间尺度
-
- log_timescale_increment = (
- math.log(float(max_timescale) / float(min_timescale)) /
- max(num_timescales - 1, 1))##感觉是(max_timescale-min_timescale)取对数。
- ##在对数空间中相邻时间尺度之间的增量
- ##计算时间尺度的增量
-
- inv_timescales = min_timescale * torch.exp(
- torch.arange(num_timescales, dtype=torch.float32) *
- -log_timescale_increment)
- ##计算时间尺度的值 ##inv_timescales是时间尺度的倒数
-
- self.register_buffer('inv_timescales', inv_timescales)
-
- def get_position_encoding(self, x):
- max_length = x.size()[1]
- position = torch.arange(max_length, dtype=torch.float32,
- device=x.device)
- ##position=[0.,1.,2.,……,max_length-1.]
- scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
- dim=1)
- signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
- signal = signal.view(1, max_length, self.hidden_size)
- return signal
首先看下面这张图,便于我们直观地理解pe的构成:
可以自己输出一下代码,发现torch.sin(scaled_time), torch.cos(scaled_time)的维度都是N*D/2(N是序列的最长长度,D是hidden_size),所以signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],dim=1)后维度就是N*D,也就是对每个pos的token的每一个隐藏层维度都进行一个位置编码(其实位置编码要和词嵌入向量做加法来加入位置信息,我们知道词嵌入向量的维度是B*N*D,那么位置编码1*N*D才能做加法)
然后,我不太理解为什么下面这段代码要取个指数函数。
- inv_timescales = min_timescale * torch.exp(
- torch.arange(num_timescales, dtype=torch.float32) *
- -log_timescale_increment)
其实有下面这段推导:
下面介绍一下各参数含义:
- num_timescales=256 ##也就是d/2
- max_timescale=10000.0 ##也就是公式中那个10000
- log_timescale_increment ##公式中的(2/d)*ln10000
- max(num_timescales - 1, 1) ##防止分母为0
- inv_timescales ##exp((2*i/d)*ln10000)
-
- def get_position_encoding(self, x):
- max_length = x.size()[1] ##x传入的参数是input B*N*D,所以这里取的是N
- position = torch.arange(max_length, dtype=torch.float32,
- device=x.device)
- ##position=tensor([0.,1.,2.,……,max_length-1.]),维度是N
- scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
- ##position.unsqueeze(1)后维度是N*1,inv_timescales.unsqueeze(0)后维度是1*D/2
- ##所以scaled_time维度是N*D/2
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
- dim=1)
- signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
- ##signal维度是N*D
- signal = signal.view(1, max_length, self.hidden_size)
- ##把signal拉成(1*N*D)
- return signal
另外举两个例子来理解一下torch.nn.functional.pad():
直接看图
对于一个二维的矩阵 pad中的tuple参数(a,b,c,d)代表
(a,b)表示对矩阵的倒数第一个维度做padding操作,在原tensor的左侧pad a个,右侧pad b个。
(c,d)表示对矩阵的倒数第二个维度做padding操作,在原tensor的左侧pad c个,右侧pad d个。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。