当前位置:   article > 正文

深度学习笔记--Transformer中position encoding的源码理解与实现_positionalencoding代码

positionalencoding代码

1--源码

  1. import torch
  2. import math
  3. import numpy as np
  4. import torch.nn as nn
  5. class Pos_Embed(nn.Module):
  6. def __init__(self, channels, num_frames, num_joints):
  7. super().__init__()
  8. # 根据帧序和节点序生成位置向量
  9. pos_list = []
  10. for tk in range(num_frames):
  11. for st in range(num_joints):
  12. pos_list.append(st)
  13. position = torch.from_numpy(np.array(pos_list)).unsqueeze(1).float() # num_frames*num_joints, 1
  14. pe = torch.zeros(num_frames * num_joints, channels) # T*N, C
  15. div_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))
  16. pe[:, 0::2] = torch.sin(position * div_term) # 偶数列 # 偶数C维度sin
  17. pe[:, 1::2] = torch.cos(position * div_term) # 奇数列 # 奇数C维度cos
  18. pe = pe.view(num_frames, num_joints, channels).permute(2, 0, 1).unsqueeze(0) # T N C -> C T N -> 1 C T N
  19. self.register_buffer('pe', pe)
  20. def forward(self, x): # nctv # BCTN
  21. x = self.pe[:, :, :x.size(2)]
  22. return x
  23. if __name__ == "__main__":
  24. B = 2
  25. C = 4
  26. T = 120
  27. N = 25
  28. x = torch.rand((B, C, T, N))
  29. Pos_embed_1 = Pos_Embed(C, T, N)
  30. PE = Pos_embed_1(x)
  31. # print(PE.shape) # 1 C T N
  32. x = x + PE
  33. print("All Done !")

2--源码分析与理解

原理理解:Positional Encoding(位置编码)

推荐视频:位置编码

代码解释:

①代码 div_term = torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)):

令:channels = C, torch.arange(0, channels, 2).float() = k(则k = 0, 2, ..., C-2);

-(math.log(10000.0) / channels)  =loge1000C

则:torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels)=kloge10000C

torch.exp(torch.arange(0, channels, 2).float() * -(math.log(10000.0) / channels))=ekloge10000C=eloge10000kC=10000kC;

②代码:pe[:, 0::2] = torch.sin(position * div_term)  pe[:, 1::2] = torch.cos(position * div_term):

令:position = p,则position * div_term=p10000kC=p10000kc;

k等价为2ipe[:, 0::2]pe[:, 1::2]分别取维度C的偶数列和奇数列,就可以得到上图绿框所示的公式。

3--参考

参考1

参考2

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/371451?site
推荐阅读
相关标签
  

闽ICP备14008679号