当前位置:   article > 正文

python自注意力模块_利用多头自注意力的模型,怎么打印训练中的attention权重

利用多头自注意力的模型,怎么打印训练中的attention权重
  1. import torch
  2. import torch.nn as nn
  3. class AttentionModule(nn.Module):
  4. def __init__(self, input_channels):
  5. super(AttentionModule, self).__init__()
  6. self.conv = nn.Conv2d(input_channels, 1, kernel_size=1) # 用于学习注意力权重的卷积层
  7. self.sigmoid = nn.Sigmoid() # Sigmoid函数用于将注意力权重限制在0到1之间
  8. def forward(self, x):
  9. # 将输入x的尺寸从[batch_size, input_channels, height, width]变为[batch_size, input_channels, height, width]
  10. # x = x.unsqueeze(2) # 在第3维上增加一个维度
  11. attention_weights = self.conv(x) # 应用卷积层获取注意力权重,尺寸为[batch_size, 1, height, width]
  12. attention_weights = self.sigmoid(attention_weights) # 使用Sigmoid函数将注意力权重限制在0到1之间
  13. attended_features = x * attention_weights # 将注意力权重应用到输入特征上
  14. return attended_features # 去除增加的维度
  15. # 创建一个尺寸为[16, 32, 1, 5]的随机输入张量
  16. input_tensor = torch.randn(16, 32, 8, 3)
  17. # 创建注意力模块实例
  18. attention_module = AttentionModule(input_channels=32)
  19. # 应用注意力模块到输入张量上
  20. output_tensor = attention_module(input_tensor)
  21. # 打印输出张量的尺寸
  22. print(output_tensor.size())

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/379853
推荐阅读
相关标签
  

闽ICP备14008679号