当前位置:   article > 正文

Diffusion的unet中用到的AttentionBlock详解_class attentionblock(nn.module): def __init__(self

class attentionblock(nn.module): def __init__(self, in_channels, norm="gn",


Diffusion的unet中用到的AttentionBlock详解

class AttentionBlock(nn.Module):
    __doc__ = r"""Applies QKV self-attention with a residual connection.
    
    Input:
        x: tensor of shape (N, in_channels, H, W)
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
    Output:
        tensor of shape (N, in_channels, H, W)
    Args:
        in_channels (int): number of input channels
    """
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        # 为啥这里的QKV并不是一样的???而是把通道数翻了3倍
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention = torch.softmax(dot_products, dim=-1)
        out = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)

torch.split

torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)

torch.split详解

torch中的permute的用法

作用:permute可以对tensor进行转置

import torch
import torch.nn as nn

x = torch.randn(1, 2, 3, 4)
print(x.size())    # torch.Size([1, 2, 3, 4])   
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

torch.transpose()

因为torch.transpose 一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)
  • 1

view()

view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型

import torch
import torch.nn as nn
import numpy as np

y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())

print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],
         [4, 5, 6]]])
tensor([[[1, 4]],

        [[2, 5]],

        [[3, 6]]])
tensor([[[1, 2],
         [3, 4],
         [5, 6]]])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

permute参考
permute详解参考

torch.bmm

作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)

torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D

矩阵相乘

softmax(x, dim=-1)

import torch 
a = torch.randn(2,3)
print(a)
  • 1
  • 2
  • 3
tensor([[-8.2976e-01,  5.8105e-04,  1.2218e+00],
        [ 1.9745e-01,  1.2727e+00,  5.9587e-01]])
  • 1
  • 2
b = torch.softmax(a, dim=-1)
print(b)
  • 1
  • 2
tensor([[0.0903, 0.2072, 0.7025],
        [0.1845, 0.5407, 0.2748]])
  • 1
  • 2

softmax(x, dim=-1)

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/238500
推荐阅读
相关标签
  

闽ICP备14008679号