赞
踩
multhead-self-attention终于实现了,下面附上可以直接运行的代码。
- # 构造mutil head attention层
- class MutilHeadAttention(tf.keras.layers.Layer):
- def __init__(self, d_model, num_heads):
- super(MutilHeadAttention, self).__init__()
- self.num_heads = num_heads
- self.d_model = d_model
-
- # d_model 必须可以正确分为各个头
- assert d_model % num_heads == 0
- # 分头后的维度
- self.depth = d_model // num_heads
-
- self.wq = tf.keras.layers.Dense(d_model)
- self.wk = tf.keras.layers.Dense(d_model)
- self.wv = tf.keras.layers.Dense(d_model)
-
- self.dense = tf.keras.layers.Dense(d_model)
-
- def split_heads(self, x, batch_size):
- # 分头, 将头个数的维度 放到 seq_len 前面
- x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
- return tf.transpose(x, perm=[0, 2, 1, 3])
-
- def call(self, v, k, q, mask):
- batch_size = tf.shape(q)[0]
-
- # 分头前的前向网络,获取q、k、v语义
- q = self.wq(q) # (batch_size, seq_len, d_model)
- k = self.wk(k)
- v = self.wv(v)
-
- # 分头
- q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
- k = self.split_heads(k, batch_size)
- v = self.split_heads(v, batch_size)
- # scaled_attention.shape == (batch_size, num_heads, seq_len_v, depth)
- # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
-
- # 通过缩放点积注意力层
- scaled_attention, attention_weights = scaled_dot_product_attention(
- q, k, v, mask)
- # 把多头维度后移
- scaled_attention = tf.transpose(scaled_attention, [0, 2, 1, 3]) # (batch_size, seq_len_v, num_heads, depth)
-
- # 合并多头
- concat_attention = tf.reshape(scaled_attention,
- (batch_size, -1, self.d_model))
-
- # 全连接重塑
- output = self.dense(concat_attention)
- return output, attention_weights
-
-
-
-

然后是这个类里面用到的一个函数
- def scaled_dot_product_attention(q, k, v, mask):
- # query key 相乘获取匹配关系
- matmul_qk = tf.matmul(q, k, transpose_b=True)
-
- # 使用dk进行缩放
- dk = tf.cast(tf.shape(k)[-1], tf.float32)
- scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
-
- # 掩码
- if mask is not None:
- scaled_attention_logits += (mask * -1e9)
-
- # 通过softmax获取attention权重
- attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
-
- # attention 乘上value
- output = tf.matmul(attention_weights, v) # (.., seq_len_v, depth)
-
- return output, attention_weights
-

然后是一个调用函数(这个是我自己写的,上面的类和函数都是tensorflow官网的,大家按照自己需求改造)
- def attention(y):
- temp_mha = MutilHeadAttention(d_model=300, num_heads=6)
- output, att = temp_mha(y, k=y, q=y, mask=None)
- print(output.shape, att.shape)
- y = tf.random.uniform((20, 80, 300))
- attention(y)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。