当前位置:   article > 正文

cross attention输入不同维度的矩阵_cross attention代码

cross attention代码

一.问题背景

在学习使用cross attention的时候我查阅了很多资料,发现里面说的都是cross attention的输入需要是相同维度的矩阵,但是我所需要的是可以处理不同维度数据的cross attention。
cross attention

二.cross attention的代码

看了关于cross attention的一些介绍和代码,发现大多都是这样

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(CrossAttention, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.query = nn.Linear(in_dim, out_dim, bias=False)
        self.key = nn.Linear(in_dim, out_dim, bias=False)
        self.value = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x, y):
        batch_size = x.shape[0]
        num_queries = x.shape[1]
        num_keys = y.shape[1]
        x = self.query(x)
        y = self.key(y)
        # 计算注意力分数
        attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 计算加权和
        V = self.value(y)
        output = torch.bmm(attn_weights, V)
        
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

这里的x和y所输入的维度需要一致,那么从代码上看好像不太好分析如何进行改变,我们先看看cross attention的公式:

Cross-Attention ( Q , K , V ) = softmax ( ( W Q S 2 ) ( W K S 1 ) T ) W V S 1 \text{Cross-Attention}(Q,K,V) = \text{softmax}\left((W_{Q}S2)(W_{K}S1)^T\right)W_{V}S1 Cross-Attention(Q,K,V)=softmax((WQS2)WKS1T)WVS1

其中, Q Q Q为查询向量, K K K为编码器的键向量, V V V为编码器的值向量, d k d_k dk为编码器键向量的维度。

所以,当输入的维度不同的时候,我们可以对从W入手进行维度的变换和适配。

那W从哪儿来呢?

注意看代码中,QKV均过了一个线性层,所以,我们将线性层的输出改为我们所需要的输出,就可以完成不同维度的输入了。

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_dim, out_dim,in_q_dim,hid_q_dim):
        super(CrossAttention, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.in_q_dim = in_q_dim #新增
        self.hid_q_dim = hid_q_dim #新增
        # 定义查询、键、值三个线性变换
        self.query = nn.Linear(in_q_dim, hid_q_dim, bias=False) #变化
        self.key = nn.Linear(in_dim, out_dim, bias=False)
        self.value = nn.Linear(in_dim, out_dim, bias=False)
        
    def forward(self, x, y):
        # 对输入进行维度变换,为了方便后面计算注意力分数
        batch_size = x.shape[0]   # batch size
        num_queries = x.shape[1]  # 查询矩阵中的元素个数
        num_keys = y.shape[1]     # 键值矩阵中的元素个数
        x = self.query(x)  # 查询矩阵
        y = self.key(y)    # 键值矩阵
        # 计算注意力分数
        attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)  # 计算注意力分数,注意力分数矩阵的大小为 batch_size x num_queries x num_keys x num_keys
        attn_weights = F.softmax(attn_scores, dim=-1)  # 对注意力分数进行 softmax 归一化
        # 计算加权和
        V = self.value(y)  # 通过值变换得到值矩阵 V
        output = torch.bmm(attn_weights, V)  # 计算加权和,output 的大小为 batch_size x num_queries x num_keys x out_dim
       
        return output

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

例如:输入的两个矩阵分别为x=[batch, 1024, 512] y=[batch, 1024, 1024],其中X作为被用作查询,那么下面是一个实例:

# 定义输入矩阵 x 和 y,其大小分别为 1024, 512 和 1024, 1024
x = torch.randn(1, 1024, 512)
y = torch.randn(1, 1024, 1024)

# 创建 CrossAttention 模型,并对输入进行前向传播
cross_attn = CrossAttention(in_dim=1024, out_dim=1024,in_q_dim=512,hid_q_dim=1024)
output = cross_attn(x=x, y=y)

# 输出新的矩阵大小
print(output.shape) # (1, 1024,1024,1024)
print(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/357302
推荐阅读
相关标签
  

闽ICP备14008679号