当前位置:   article > 正文

transfomer中正余弦位置编码的源码实现_正余弦编码器软件

正余弦编码器软件

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

简介

Transformer模型抛弃了RNN、CNN作为序列学习的基本模型。循环神经网络本身就是一种顺序结构,天生就包含了词在序列中的位置信息。当抛弃循环神经网络结构,完全采用Attention取而代之,这些词序信息就会丢失,模型就没有办法知道每个词在句子中的相对和绝对的位置信息。因此,有必要把词序信号加到词向量上帮助模型学习这些信息,位置编码(Positional Encoding)就是用来解决这种问题的方法。
关于位置编码更多介绍参考bev感知专栏博客

源码实现:

import torch
import matplotlib.pyplot as plt


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)


def posemb_sincos_1d(len, dim, temperature: int = 10000, dtype=torch.float32):
    x = torch.arange(len)
    assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb"
    omega = torch.arange(dim // 2) / (dim // 2 - 1)
    omega = 1.0 / (temperature ** omega)

    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos()), dim=1)  # 这里不用担心,不交叉无所谓,
    return pe.type(dtype)


if __name__ == '__main__':
    pos = posemb_sincos_1d(200, 256)
    # pos = posemb_sincos_2d(20,20,256)

    # 创建一个热力图
    plt.imshow(pos, cmap='hot', interpolation='nearest')
    # 添加颜色条
    plt.colorbar()
    # 显示图形
    plt.show()
    pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

可视化结果如下:
在这里插入图片描述
如需获取全套代码请参考

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/371417
推荐阅读
相关标签
  

闽ICP备14008679号