赞
踩
在Transformer中,位置编码是为了引入位置信息,而位置编码的形式通常是一个正弦函数和一个余弦函数的组合,公式如下:
其中,PE(pos,i)表示位置编码矩阵中第 pos 个位置,第 i 个维度的值;dmodel表示模型嵌入向量的维度;i表示位置编码矩阵中第 i 个维度的值。这种位置编码方式可以引入位置信息,使得Transformer模型可以处理序列数据。
假设序列长度为4,位置编码维度为6,则位置编码矩阵如下:
其中三角函数括号中的部分可以由*号拆分成两部分,第一部分可以理解为x,第二部分可以理解为周期(普通的三角函数sin(2ΠX)的周期T为2Π,X为因变量)。
按列分析:如dim0这一列周期T为
X为0~3的一个周期为定值的三角函数;
按行分析:
如pos0这一行中,周期每两个元素变化一次,X为递增数列;所以按行看每个pos的位置编码是一个变周期(T)的三角函数;
代码如下(示例):
1、实现上表中的矩阵:
import torch def creat_pe_absolute_sincos_embedding(n_pos_vec, dim): assert dim % 2 == 0, "wrong dim" position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) omega = torch.arange(dim//2, dtype=torch.float) omega /= dim/2. omega = 1./(10000**omega) sita = n_pos_vec[:,None] @ omega[None,:] emb_sin = torch.sin(sita) emb_cos = torch.cos(sita) position_embedding[:,0::2] = emb_sin position_embedding[:,1::2] = emb_cos return position_embedding
2、初始化序列长度和位置编码的维度,并计算位置编码矩阵:
n_pos = 512
dim = 768
n_pos_vec = torch.arange(n_pos, dtype=torch.float)
pe = creat_pe_absolute_sincos_embedding(n_pos_vec, dim)
print(pe)
tensor([[ 0.0000e+00, 1.0000e+00, 0.0000e+00, ..., 1.0000e+00,
0.0000e+00, 1.0000e+00],
[ 8.4147e-01, 5.4030e-01, 8.2843e-01, ..., 1.0000e+00,
1.0243e-04, 1.0000e+00],
[ 9.0930e-01, -4.1615e-01, 9.2799e-01, ..., 1.0000e+00,
2.0486e-04, 1.0000e+00],
...,
[ 6.1950e-02, 9.9808e-01, 5.3552e-01, ..., 9.9857e-01,
5.2112e-02, 9.9864e-01],
[ 8.7333e-01, 4.8714e-01, 9.9957e-01, ..., 9.9857e-01,
5.2214e-02, 9.9864e-01],
[ 8.8177e-01, -4.7168e-01, 5.8417e-01, ..., 9.9856e-01,
5.2317e-02, 9.9863e-01]])
3、按行对位置编码矩阵进行可视化:
# 不同pos
import matplotlib.pyplot as plt
x = [i for i in range(dim)]
for index, item in enumerate(pe):
if index % 50 != 1:
continue
y = item.tolist()
plt.plot(x, y, label=f"数据 {index}")
plt.show()
以50为间隔打印,由于序列长度为512,所以可以打印出11个pos位置的曲线,下图为pos0,pos250,pos500处的位置编码曲线:
4、按列对位置编码矩阵进行可视化:
# 不同dim
x = [i for i in range(n_pos)]
for index, item in enumerate(pe.transpose(0, 1)):
if index % 50 != 1:
continue
y = item.tolist()
plt.plot(x, y, label=f"数据 {index}")
plt.show()
以50为间隔打印,由于序列长度为768,所以可以打印出16个pos位置的曲线,下图为dim0,dim350,dim750处的位置编码曲线:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。