当前位置:   article > 正文

Transformer模型_class tokenembedding(nn.embedding): """ token embe

class tokenembedding(nn.embedding): """ token embedding using torch.nn they

MHA的基础:SDPA

上面的MHA是Multi Head Attention的缩写,即多头注意力机制,SDPA是Scale Dot Product Attention的缩写,即缩放点积注意力

SDPA干了什么

涉及一些关于MHA的部分在这里先列清楚:

首先MHA接收到的张量形状是(batch_size, seq_len, embedding_dim),MHA会对这个张量进行分头行动,进行一步split的操作,之后的张量形状是(batch_size, head, seq_len, embedding_dim // head),我们这里将embedding_dim // head简化成splited_dim,这里分头行动的是embedding_dim,将上述形状简写成(batch_size, head, seq_len, splited_dim),这个就是SDPA的输入形状.

接下来我们看到SDAP的公式

score=Q \cdot K^{T} / sqrt(D)

ATT = score \cdot V

其中的D取自qkv矩阵里的splited_dim

其中的形状变化如下图所示

SDPA的含义

我们着重看score的得出,因为score才是这个ATT存在的意义

 可以看到,score的计算方式是每个单词对于其余所有单词(包括这个单词自己)的点积.

SDPA里的mask操作

后续要进行的mask操作就是对score的操作,比如在Decoder的时候score第一行的单词1能看到的理应只有它自己,或者什么都看不到,第二行的单词2理应只能看到它自己及其以前的单词,或者只有他以前单词.即计算分数的时候我们要手动遮挡住一些不应该被看到的单词.这种遮挡方式被称为上三角mask.后话:训练的时候采用并行训练时要对Decoder的self-att进行mask操作,使得Decoder的输出就是完整的翻译完的句子,后面会详细地描述这一过程.

遮挡后的score矩阵如下图所示,一般采用右边的.

SDPA的代码实现

  1. class ScaleDotProductAttention(nn.Module):
  2. def __init__(
  3. self,
  4. device,
  5. ):
  6. super().__init__()
  7. self.softmax = nn.Softmax(dim=-1).to(device)
  8. def forward(self, q, k, v, mask=None, e=1e-12):
  9. """
  10. :param q:
  11. :param k:
  12. :param v:
  13. :param mask: (batch_size, n_head, seq_len, seq_len)
  14. :param e:
  15. :return:
  16. """
  17. batch_size, head, length, d_tensor = k.size()
  18. k_t = k.transpose(2, 3)
  19. score = (q @ k_t) / math.sqrt(d_tensor)
  20. if mask is not None:
  21. score = score.masked_fill(mask == 0, -10000)
  22. score = self.softmax(score)
  23. v = score @ v
  24. return v, score

MHA

MHA很好理解就是对输入进行分头后再交给SDPA处理,然后再把分出来的头合并就行了

代码如下所示

Encoder部分

  1. import torch
  2. from torch import nn
  3. import math
  4. import dataset
  5. import time
  6. class TransformerEmbedding(nn.Module):
  7. """
  8. token embedding + positional encoding (sinusoid)
  9. positional encoding can give positional information to network
  10. """
  11. def __init__(
  12. self,
  13. vocabulary_size,
  14. embedding_dim,
  15. seq_len,
  16. dropout_prob,
  17. device
  18. ):
  19. """
  20. class for word embedding that included positional information
  21. :param vocabulary_size: size of vocabulary
  22. :param embedding_dim: dimensions of model
  23. """
  24. super().__init__()
  25. # 这里的TokenEmbedding使用的就是nn.Embedding,所以在这里我把它改回nn.Embedding
  26. # self.tok_emb = TokenEmbedding(vocabulary_size, embedding_dim)
  27. self.tok_emb = nn.Embedding(vocabulary_size, embedding_dim).to(device)
  28. self.pos_emb = PositionalEncoding(embedding_dim, seq_len, device).to(device)
  29. self.drop_out = nn.Dropout(p=dropout_prob).to(device)
  30. def forward(self, x):
  31. """
  32. :param x: (batch_size, seq_len)
  33. :return: (batch_size, seq_len, embedding_dim)
  34. """
  35. tok_emb = self.tok_emb(x)
  36. # tok_emb: (batch_size, seq_len, embedding_dim)
  37. pos_emb = self.pos_emb(x)
  38. # pos_emb: (seq_len, embedding_dim)
  39. # temp = pos_emb + tok_emb
  40. # print(temp[:5, :3, :3])
  41. # 运用了广播机制竟然给他加上去了神奇
  42. return self.drop_out(tok_emb + pos_emb)
  43. class PositionalEncoding(nn.Module):
  44. """
  45. compute sinusoid encoding.
  46. """
  47. def __init__(
  48. self,
  49. embedding_dim,
  50. seq_len,
  51. device
  52. ):
  53. """
  54. constructor of sinusoid encoding class
  55. :param embedding_dim: dimension of model
  56. :param seq_len: max sequence length
  57. :param device: hardware device setting
  58. """
  59. super().__init__()
  60. # same size with input matrix (for adding with input matrix)
  61. self.encoding = torch.zeros(seq_len, embedding_dim, device=device)
  62. self.encoding.requires_grad = False # we don't need to compute gradient
  63. # encoding: (seq_len, embedding_dim)
  64. pos = torch.arange(0, seq_len, device=device)
  65. # pos: (seq_len,)
  66. # print(pos.shape)
  67. pos = pos.float().unsqueeze(dim=1)
  68. # pos: (seq_len, 1)
  69. # print(pos.shape)
  70. # 1D => 2D unsqueeze to represent word's position
  71. _2i = torch.arange(0, embedding_dim, step=2, device=device).float()
  72. # 'i' means index of embedding_dim (e.g. embedding size = 50, 'i' = [0,50])
  73. # "step=2" means 'i' multiplied with two (same with 2 * i)
  74. self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / embedding_dim)))
  75. self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / embedding_dim)))
  76. # self.encoding: (seq_len, embedding_dim)
  77. # print(self.encoding)
  78. # compute positional encoding to consider positional information of words
  79. def forward(self, x):
  80. # self.encoding
  81. # [seq_len = 512, embedding_dim = 512]
  82. batch_size, seq_len = x.size()
  83. # [batch_size = 128, seq_len = 30]
  84. return self.encoding[:seq_len, :]
  85. class TokenEmbedding(nn.Embedding):
  86. """
  87. Token Embedding using torch.nn
  88. they will be dense representation of word using weighted matrix
  89. """
  90. def __init__(self, vocabulary_size, embedding_dim):
  91. """
  92. class for token embedding that included positional information
  93. :param vocabulary_size: size of vocabulary
  94. :param embedding_dim: dimensions of model
  95. """
  96. super().__init__(vocabulary_size, embedding_dim, padding_idx=1)
  97. class MultiHeadAttention(nn.Module):
  98. def __init__(
  99. self,
  100. embedding_dim,
  101. n_head,
  102. device,
  103. ):
  104. super().__init__()
  105. self.n_head = n_head
  106. self.attention = ScaleDotProductAttention(device).to(device)
  107. self.w_q = nn.Linear(embedding_dim, embedding_dim).to(device)
  108. self.w_k = nn.Linear(embedding_dim, embedding_dim).to(device)
  109. self.w_v = nn.Linear(embedding_dim, embedding_dim).to(device)
  110. self.w_concat = nn.Linear(embedding_dim, embedding_dim).to(device)
  111. def forward(self, q, k, v, mask=None):
  112. """
  113. :param q: (batch_size, seq_len, embedding_dim)
  114. :param k: (batch_size, seq_len, embedding_dim)
  115. :param v: (batch_size, seq_len, embedding_dim)
  116. :param mask:
  117. :return:
  118. """
  119. # 1. dot product with weight matrices
  120. # 线性变换
  121. q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) # [N, seq_len, embedding_dim]
  122. # 2. split tensor by number of heads
  123. q, k, v = self.split(q), self.split(k), self.split(v) # [N, head, seq_len, embedding_dim]
  124. # q, k, v: (batch_size, n_head, seq_len, embedding_dim // n_head)
  125. # print(q.shape, k.shape, v.shape)
  126. # 3. do scale dot product to compute similarity
  127. out, attention = self.attention(q, k, v, mask=mask) # out:[N, head, seq_len, embedding_dim]
  128. # 4. concat and pass to linear layer
  129. out = self.concat(out) # [N, seq_len, embedding_dim]
  130. # out: (batch_size, seq_len, embedding_dim)
  131. # print(out.shape)
  132. out = self.w_concat(out)
  133. # 5. visualize attention map
  134. # TODO : we should implement visualization
  135. return out
  136. def split(self, tensor):
  137. """
  138. split tensor by number of head
  139. :param tensor: [batch_size, length, embedding_dim]
  140. :return: [batch_size, head, length, d_tensor]
  141. """
  142. batch_size, length, embedding_dim = tensor.size()
  143. d_tensor = embedding_dim // self.n_head
  144. tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
  145. # it is similar with group convolution (split by number of heads)
  146. return tensor
  147. def concat(self, tensor):
  148. """
  149. inverse function of self.split(tensor : torch.Tensor)
  150. :param tensor: [batch_size, head, length, d_tensor]
  151. :return: [batch_size, length, embedding_dim]
  152. """
  153. batch_size, head, length, d_tensor = tensor.size()
  154. embedding_dim = head * d_tensor
  155. tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, embedding_dim)
  156. return tensor
  157. class ScaleDotProductAttention(nn.Module):
  158. """
  159. compute scale dot product attention
  160. Query : given sentence that we focused on (decoder)
  161. Key : every sentence to check relationship with Qeury(encoder)
  162. Value : every sentence same with Key (encoder)
  163. """
  164. def __init__(
  165. self,
  166. device,
  167. ):
  168. super().__init__()
  169. self.softmax = nn.Softmax(dim=-1).to(device)
  170. def forward(self, q, k, v, mask=None, e=1e-12):
  171. """
  172. :param q:
  173. :param k:
  174. :param v:
  175. :param mask: (batch_size, n_head, seq_len, seq_len)
  176. :param e:
  177. :return:
  178. """
  179. # input is 4 dimension tensor
  180. # [batch_size, head, length, d_tensor]
  181. batch_size, head, length, d_tensor = k.size()
  182. # 1. dot product Query with Key^T to compute similarity
  183. k_t = k.transpose(2, 3) # transpose
  184. score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
  185. # print(score.shape)
  186. # 2. apply masking (opt)
  187. if mask is not None:
  188. score = score.masked_fill(mask == 0, -10000)
  189. # TODO: 搞明白这个mask怎么工作
  190. # 3. pass them softmax to make [0, 1] range
  191. score = self.softmax(score)
  192. # 4. multiply with Value
  193. v = score @ v
  194. return v, score
  195. class LayerNorm(nn.Module):
  196. def __init__(self, embedding_dim, eps=1e-12):
  197. """
  198. 使用该层的时候记得.cuda()
  199. :param embedding_dim:
  200. :param eps:
  201. """
  202. super().__init__()
  203. self.gamma = nn.Parameter(torch.ones(embedding_dim))
  204. self.beta = nn.Parameter(torch.zeros(embedding_dim))
  205. self.eps = eps
  206. def forward(self, x):
  207. mean = x.mean(-1, keepdim=True)
  208. var = x.var(-1, unbiased=False, keepdim=True)
  209. # '-1' means last dimension.
  210. out = (x - mean) / torch.sqrt(var + self.eps)
  211. out = self.gamma * out + self.beta
  212. return out
  213. class PositionwiseFeedForward(nn.Module):
  214. def __init__(
  215. self,
  216. embedding_dim,
  217. hidden,
  218. dropout_prob,
  219. device
  220. ):
  221. super().__init__()
  222. self.linear1 = nn.Linear(embedding_dim, hidden).to(device)
  223. self.linear2 = nn.Linear(hidden, embedding_dim).to(device)
  224. self.relu = nn.ReLU().to(device)
  225. self.dropout = nn.Dropout(p=dropout_prob).to(device)
  226. def forward(self, x):
  227. """
  228. :param x: (batch_size, seq_len, embedding_dim)
  229. :return: (batch_size, seq_len, embedding_dim)
  230. """
  231. x = self.linear1(x)
  232. x = self.relu(x)
  233. x = self.dropout(x)
  234. x = self.linear2(x)
  235. return x
  236. class EncoderLayer(nn.Module):
  237. def __init__(
  238. self,
  239. embedding_dim,
  240. ffn_hidden,
  241. n_head,
  242. dropout_prob,
  243. device,
  244. ):
  245. super().__init__()
  246. self.attention = MultiHeadAttention(embedding_dim=embedding_dim, n_head=n_head, device=device).to(device)
  247. self.norm1 = LayerNorm(embedding_dim=embedding_dim).to(device)
  248. self.dropout1 = nn.Dropout(p=dropout_prob).to(device)
  249. self.ffn = PositionwiseFeedForward(embedding_dim=embedding_dim, hidden=ffn_hidden, dropout_prob=dropout_prob, device=device).to(device)
  250. self.norm2 = LayerNorm(embedding_dim=embedding_dim).to(device)
  251. self.dropout2 = nn.Dropout(p=dropout_prob).to(device)
  252. def forward(self, x, s_mask):
  253. """
  254. :param x: (batch_size, seq_len, embedding_dim)
  255. :param s_mask:
  256. :return: (batch_size, seq_len, embedding_dim)
  257. """
  258. # 1. compute self attention
  259. _x = x
  260. x = self.attention(q=x, k=x, v=x, mask=s_mask)
  261. # 2. add and norm
  262. x = self.dropout1(x)
  263. x = self.norm1(x + _x)
  264. # 3. positionwise feed forward network
  265. _x = x
  266. x = self.ffn(x)
  267. # 4. add and norm
  268. x = self.dropout2(x)
  269. x = self.norm2(x + _x)
  270. return x
  271. class Encoder(nn.Module):
  272. def __init__(
  273. self,
  274. encoder_vocabulary_size,
  275. seq_len,
  276. embedding_dim,
  277. ffn_hidden,
  278. n_head,
  279. n_layers,
  280. dropout_prob,
  281. device
  282. ):
  283. super().__init__()
  284. self.embedding = TransformerEmbedding(
  285. embedding_dim=embedding_dim,
  286. seq_len=seq_len,
  287. vocabulary_size=encoder_vocabulary_size,
  288. dropout_prob=dropout_prob,
  289. device=device
  290. ).to(device)
  291. self.layers = nn.ModuleList(
  292. [
  293. EncoderLayer(
  294. embedding_dim=embedding_dim,
  295. ffn_hidden=ffn_hidden,
  296. n_head=n_head,
  297. dropout_prob=dropout_prob,
  298. device=device
  299. )
  300. for _ in range(n_layers)
  301. ]
  302. ).to(device)
  303. def forward(self, x, s_mask):
  304. """
  305. :param x: (batch_size, seq_len)
  306. :param s_mask: ?
  307. :return: ?
  308. """
  309. x = self.embedding(x)
  310. # x: (batch_size, seq_len, embedding_dim)
  311. for layer in self.layers:
  312. x = layer(x, s_mask)
  313. return x
  314. if __name__ == '__main__':
  315. BATCH_SIZE = 128
  316. SEQ_LEN = 16
  317. VOCABULARY_SIZE = 2500
  318. EMBEDDING_DIM = 32
  319. N_HEAD = 8
  320. N_LAYERS = 6
  321. FFN_HIDDEN = 64
  322. DEVICE = torch.device('cuda')
  323. DROPOUT_P = 0.2
  324. E = Encoder(
  325. encoder_vocabulary_size=VOCABULARY_SIZE,
  326. seq_len=SEQ_LEN,
  327. embedding_dim=EMBEDDING_DIM,
  328. ffn_hidden=FFN_HIDDEN,
  329. n_head=N_HEAD,
  330. n_layers=N_LAYERS,
  331. dropout_prob=DROPOUT_P,
  332. device=DEVICE
  333. ).cuda()
  334. input_tensor = torch.zeros(
  335. (BATCH_SIZE, SEQ_LEN),
  336. device=DEVICE,
  337. ).int()
  338. a = E(input_tensor, None)
  339. print(a.shape)

Decoder部分后面会更新

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/591656
推荐阅读
相关标签
  

闽ICP备14008679号