当前位置:   article > 正文

attention终极版本multhead-self-attention,代码实现(直接复制粘贴,可以直接运行)_tf.keras.layers.multiheadattention 怎么用

tf.keras.layers.multiheadattention 怎么用

multhead-self-attention终于实现了,下面附上可以直接运行的代码。

  1. # 构造mutil head attention层
  2. class MutilHeadAttention(tf.keras.layers.Layer):
  3. def __init__(self, d_model, num_heads):
  4. super(MutilHeadAttention, self).__init__()
  5. self.num_heads = num_heads
  6. self.d_model = d_model
  7. # d_model 必须可以正确分为各个头
  8. assert d_model % num_heads == 0
  9. # 分头后的维度
  10. self.depth = d_model // num_heads
  11. self.wq = tf.keras.layers.Dense(d_model)
  12. self.wk = tf.keras.layers.Dense(d_model)
  13. self.wv = tf.keras.layers.Dense(d_model)
  14. self.dense = tf.keras.layers.Dense(d_model)
  15. def split_heads(self, x, batch_size):
  16. # 分头, 将头个数的维度 放到 seq_len 前面
  17. x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
  18. return tf.transpose(x, perm=[0, 2, 1, 3])
  19. def call(self, v, k, q, mask):
  20. batch_size = tf.shape(q)[0]
  21. # 分头前的前向网络,获取q、k、v语义
  22. q = self.wq(q) # (batch_size, seq_len, d_model)
  23. k = self.wk(k)
  24. v = self.wv(v)
  25. # 分头
  26. q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
  27. k = self.split_heads(k, batch_size)
  28. v = self.split_heads(v, batch_size)
  29. # scaled_attention.shape == (batch_size, num_heads, seq_len_v, depth)
  30. # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
  31. # 通过缩放点积注意力层
  32. scaled_attention, attention_weights = scaled_dot_product_attention(
  33. q, k, v, mask)
  34. # 把多头维度后移
  35. scaled_attention = tf.transpose(scaled_attention, [0, 2, 1, 3]) # (batch_size, seq_len_v, num_heads, depth)
  36. # 合并多头
  37. concat_attention = tf.reshape(scaled_attention,
  38. (batch_size, -1, self.d_model))
  39. # 全连接重塑
  40. output = self.dense(concat_attention)
  41. return output, attention_weights

然后是这个类里面用到的一个函数

  1. def scaled_dot_product_attention(q, k, v, mask):
  2. # query key 相乘获取匹配关系
  3. matmul_qk = tf.matmul(q, k, transpose_b=True)
  4. # 使用dk进行缩放
  5. dk = tf.cast(tf.shape(k)[-1], tf.float32)
  6. scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
  7. # 掩码
  8. if mask is not None:
  9. scaled_attention_logits += (mask * -1e9)
  10. # 通过softmax获取attention权重
  11. attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
  12. # attention 乘上value
  13. output = tf.matmul(attention_weights, v) # (.., seq_len_v, depth)
  14. return output, attention_weights

然后是一个调用函数(这个是我自己写的,上面的类和函数都是tensorflow官网的,大家按照自己需求改造)

  1. def attention(y):
  2. temp_mha = MutilHeadAttention(d_model=300, num_heads=6)
  3. output, att = temp_mha(y, k=y, q=y, mask=None)
  4. print(output.shape, att.shape)
  5. y = tf.random.uniform((20, 80, 300))
  6. attention(y)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/495572
推荐阅读
相关标签
  

闽ICP备14008679号