赞
踩
class AttentionBlock(nn.Module): __doc__ = r"""Applies QKV self-attention with a residual connection. Input: x: tensor of shape (N, in_channels, H, W) norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn" num_groups (int): number of groups used in group normalization. Default: 32 Output: tensor of shape (N, in_channels, H, W) Args: in_channels (int): number of input channels """ def __init__(self, in_channels, norm="gn", num_groups=32): super().__init__() self.in_channels = in_channels self.norm = get_norm(norm, in_channels, num_groups) # 为啥这里的QKV并不是一样的???而是把通道数翻了3倍 self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1) self.to_out = nn.Conv2d(in_channels, in_channels, 1) def forward(self, x): b, c, h, w = x.shape q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1) q = q.permute(0, 2, 3, 1).view(b, h * w, c) k = k.view(b, c, h * w) v = v.permute(0, 2, 3, 1).view(b, h * w, c) dot_products = torch.bmm(q, k) * (c ** (-0.5)) assert dot_products.shape == (b, h * w, h * w) attention = torch.softmax(dot_products, dim=-1) out = torch.bmm(attention, v) assert out.shape == (b, h * w, c) out = out.view(b, h, w, c).permute(0, 3, 1, 2) return self.to_out(out) + x
x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)
torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''
eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)
作用:permute可以对tensor进行转置
import torch
import torch.nn as nn
x = torch.randn(1, 2, 3, 4)
print(x.size()) # torch.Size([1, 2, 3, 4])
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])
因为torch.transpose
一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:
x.transpose(0,2).transpose(1,3)
view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型
import torch
import torch.nn as nn
import numpy as np
y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())
print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2))
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],
[4, 5, 6]]])
tensor([[[1, 4]],
[[2, 5]],
[[3, 6]]])
tensor([[[1, 2],
[3, 4],
[5, 6]]])
作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)
torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D
import torch
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01, 5.8105e-04, 1.2218e+00],
[ 1.9745e-01, 1.2727e+00, 5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],
[0.1845, 0.5407, 0.2748]])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。