赞
踩
- import torch
- import torch.nn as nn
-
- class AttentionModule(nn.Module):
- def __init__(self, input_channels):
- super(AttentionModule, self).__init__()
- self.conv = nn.Conv2d(input_channels, 1, kernel_size=1) # 用于学习注意力权重的卷积层
- self.sigmoid = nn.Sigmoid() # Sigmoid函数用于将注意力权重限制在0到1之间
-
- def forward(self, x):
- # 将输入x的尺寸从[batch_size, input_channels, height, width]变为[batch_size, input_channels, height, width]
- # x = x.unsqueeze(2) # 在第3维上增加一个维度
- attention_weights = self.conv(x) # 应用卷积层获取注意力权重,尺寸为[batch_size, 1, height, width]
- attention_weights = self.sigmoid(attention_weights) # 使用Sigmoid函数将注意力权重限制在0到1之间
- attended_features = x * attention_weights # 将注意力权重应用到输入特征上
- return attended_features # 去除增加的维度
-
- # 创建一个尺寸为[16, 32, 1, 5]的随机输入张量
- input_tensor = torch.randn(16, 32, 8, 3)
-
- # 创建注意力模块实例
- attention_module = AttentionModule(input_channels=32)
-
- # 应用注意力模块到输入张量上
- output_tensor = attention_module(input_tensor)
-
- # 打印输出张量的尺寸
- print(output_tensor.size())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。