赞
踩
对于论文给出的模型架构,使用 PyTorch 分别实现各个部分。
引入的相关库函数:
import copy | |
import torch | |
import math | |
from torch import nn | |
from torch.nn.functional import log_softmax | |
# module: 需要深拷贝的模块 | |
# n: 拷贝的次数 | |
# return: 深拷贝后的模块列表 | |
def clones(module, n: int) -> list: | |
return [copy.deepcopy(module) for _ in range(n)] |
编码器由 N 个相同的编码层堆叠而成,每个编码层含两个子层:多头注意力层和前馈网络层。每个子层后跟着一层,用于残差连接与标准化。
对于上一层的结果:SubLayer(�)与输出上一层的变量:�做残差连接并进行标准化:LayerNorm(�+Sublayer(�))。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。