当前位置:   article > 正文

PyTorch应用实战一:实现卷积操作_卷积操作实现

卷积操作实现

实验环境

python3.6 + pytorch1.8.0

import torch
print(torch.__version__)
  • 1
  • 2
1.8.0
  • 1

0.卷积定义

卷积操作是指两个函数f和g之间的一种数学运算,它在信号处理、图像处理、机器学习等领域中广泛应用。在离散情况下,卷积操作可以表示为:

( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[m]g[n-m] (fg)[n]=m=f[m]g[nm]

其中, f f f g g g是离散函数, ∗ * 表示卷积操作, n n n是离散的变量。卷积操作可以看作是将函数 g g g沿着 n n n轴翻转,然后平移,每次和函数 f f f相乘并求和,最后得到一个新的函数。这种操作可以实现信号的滤波、特征提取等功能,是数字信号处理中非常重要的基础操作。

1.利用张量操作实现卷积

1.1 unfold函数

PyTorchunfold函数用于对张量进行展开操作。torch.unfold()可以理解为将一个高维的张量展开成一个二维矩阵的操作。即将原来的张量沿着指定的维度展开成一个二维矩阵,其中第一维对应原来张量的维度,第二维对应展开的位置。

函数原型如下:

torch.unfold(input, dimension, size, step)
  • 1

参数说明:

  • input (Tensor) – 要展开的张量
  • dimension (int) – 沿着哪个维度展开
  • size (int) – 展开窗口的大小
  • step (int) – 两个相邻窗口之间的步长

1.2 张量分片

import torch
  • 1
a = torch.arange(16).view(4, 4)
a
  • 1
  • 2
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
  • 1
  • 2
  • 3
  • 4
b = a.unfold(0, 3, 1)
b
  • 1
  • 2
tensor([[[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[ 4,  8, 12],
         [ 5,  9, 13],
         [ 6, 10, 14],
         [ 7, 11, 15]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
b.shape
  • 1
torch.Size([2, 4, 3])
  • 1
c = b.unfold(1, 3, 1)
c
  • 1
  • 2
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
c.shape
  • 1
torch.Size([2, 2, 3, 3])
  • 1

完整程序

import torch
a = torch.arange(16).view(4, 4)
b = a.unfold(0, 3, 1)
c = b.unfold(1, 3, 1)
c.shape
  • 1
  • 2
  • 3
  • 4
  • 5
torch.Size([2, 2, 3, 3])
  • 1

这段代码定义了三个变量。假设我们将其分别命名为abc,则:

  • 变量a是一个4x4的张量,其中包含了0到15的整数值,它通过torch.arange(16).view(4, 4)两个函数调用来实现。
  • 变量b是通过对变量a进行折叠操作得到的一个张量,具体来说,它是将变量a沿着第0维(即行)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量b打印出来,会得到:
tensor([[[ 0,  1,  2],
         [ 4,  5,  6],
         [ 8,  9, 10],
         [12, 13, 14]],

        [[ 1,  2,  3],
         [ 5,  6,  7],
         [ 9, 10, 11],
         [13, 14, 15]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

其中,第一个子张量的值为[[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]],第二个子张量的值为[[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]。注意,这个张量的形状为(2, 4, 3),即它包含2个子张量,每个子张量的形状为(4, 3)

  • 变量c是对变量b进行类似的操作得到的,但是它是在第1维(即列)上展开并取子张量。具体来说,它是将变量b沿着第1维(即列)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量c打印出来,会得到:
tensor([[[[ 0,  1,  2],
          [ 4,  5,  6],
          [ 8,  9, 10]],

         [[ 1,  2,  3],
          [ 5,  6,  7],
          [ 9, 10, 11]]],


        [[[ 4,  5,  6],
          [ 8,  9, 10],
          [12, 13, 14]],

         [[ 5,  6,  7],
          [ 9, 10, 11],
          [13, 14, 15]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

其中,第一个子张量的值为[[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]],第二个子张量的值为[[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]。注意,这个张量的形状为(2, 2, 3, 3),即它包含2个子张量,每个子张量的形状为(2, 3, 3)

2.实现卷积操作

2.1 编写卷积函数

完整程序

import torch
def conv2d(x, weight, bias, stride, pad):
    n, c, h, w = x.shape
    d, c, k, j = weight.shape
    
    x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
    x_pad[:, :, pad:-pad, pad:-pad] = x
    
    x_pad = x_pad.unfold(2, k, stride)
    x_pad = x_pad.unfold(3, j, stride)
    
    out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
    out = out + bias.view(1, -1, 1, 1)
    return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

该函数实现了二维卷积操作。下面对函数进行详细分析:

  1. 输入参数:
  • x: 输入张量,维度为(batch_size,in_channels,input_height,input_width)。
  • weight: 卷积核张量,维度为(out_channels,in_channels,kernel_height,kernel_width)。
  • bias: 偏置项张量,维度为(out_channels,)。
  • stride: 卷积核移动的步长,可以是一个数或是一个长度为 2 的元组,分别表示水平方向和竖直方向的步长。
  • pad: 输入张量周围要填充的零的数量。
  1. 局部填充:
  • 在进行卷积操作之前,需要在输入张量的周围按照给定的 pad 进行填充,以避免卷积核在张量边缘处超出范围的情况发生。
  • 在函数中使用 x_pad 表示经过填充后的输入张量。
  • 具体实现:将输入张量 x 在第 2 和第 3 个维度(height 和 width 维度)上分别拆分成若干个形状为(kernel_height,kernel_width)的张量,每个张量之间的跳跃长度由 stride 决定,然后在第 2 和第 3 个维度上分别进行展开。这样每个展开后的张量就可以看作一个二维卷积核作用在 x 上的局部卷积结果,这些局部结果被按照第 2 和第 3 个维度重新拼接起来,得到新的张量 x_pad。
  1. 卷积操作:
  • 在新的张量 x_pad 上使用 einsum 函数对卷积核进行卷积操作。

  • einsum 的第一个参数表示操作的规则,其中 ndhw 表示最终输出的张量的维度为(batch_size,out_channels,output_height,output_width),nchw 和 dckj 表示两个输入张量 x_pad 和 weight 的维度,其中 c k j 分别表示 input_channels、kernel_height 和 kernel_width。

  • 最终得到的输出张量形状为(batch_size,out_channels,output_height,output_width),并在每个位置上加上偏置项 bias。

  • torch: PyTorch库

  • einsum: Einstein summation notation,爱因斯坦求和约定,一种张量求和的简便表示法。

  • 'nchwkj,dckj->ndhw': 爱因斯坦求和符号,左侧的张量为x_pad,右侧的张量为weight。在左侧张量中,n, c, h, w分别表示batch size、通道数、高度、宽度。在右侧张量中,d, c, k, j分别表示输出通道数、输入通道数、卷积核高度、卷积核宽度。这个式子的意义是将x_padweight执行卷积操作,并输出结果张量,其形状为(batch_size, output_channels, height, width)

  • x_pad: 输入的张量,形状为 (batch_size, input_channels, input_height, input_width)

  • weight: 卷积核张量,形状为 (output_channels, input_channels, kernel_height, kernel_width)

  1. 返回结果:
  • 返回卷积操作后得到的输出张量。

2.2 对编写的卷积函数举例分析

# 设置测试数据
x = torch.randn(2, 3, 5, 5, requires_grad=True)
weight = torch.randn(4, 3, 3, 3, requires_grad=True)
bias = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
x, weight, bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
(tensor([[[[-0.4888,  1.0257,  0.0312, -0.9026, -0.9060],
           [ 0.2071, -0.4962, -0.1658,  1.0919,  0.3785],
           [-0.4654,  1.5442,  0.6005,  0.3594, -2.6207],
           [ 0.5830,  0.0533,  0.5719,  1.5413,  0.5949],
           [-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],
 
          [[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],
           [-0.2119,  0.3165, -2.2668, -0.8956,  1.0617],
           [-0.7809, -0.2120, -0.8592, -0.5057,  0.7954],
           [-2.8820, -0.6888,  0.4450, -0.3586, -0.9477],
           [ 0.6244,  0.4303,  1.4739,  0.2740,  1.6605]],
 
          [[-0.1501,  0.6234, -1.6086,  0.1693,  0.4932],
           [ 1.0611, -1.0938,  0.1695,  1.0193,  0.4263],
           [ 1.4681, -0.1552, -0.0667, -0.7293,  1.0816],
           [ 0.8972,  1.1683, -1.4757,  0.4421, -0.0355],
           [-2.1331,  1.4847,  0.1378, -1.6907, -0.1350]]],
 
 
         [[[-1.3853,  1.6396,  0.3436,  0.3841,  0.2355],
           [-0.2206, -0.5087, -1.6956,  1.3205,  0.7058],
           [ 0.0993,  0.3533, -0.2086,  0.2969,  0.2627],
           [ 0.3752,  0.0304,  1.2487,  1.3963, -0.0063],
           [-1.3758,  0.5088, -1.3849,  1.3050,  0.4150]],
 
          [[ 0.2824, -2.8634, -0.1016, -0.1627,  1.7081],
           [ 0.1406,  0.2220, -0.6005,  0.2997, -0.1846],
           [ 1.6700,  0.5787,  0.6561, -0.0236,  1.7743],
           [ 2.1429, -0.2838, -0.0527,  0.3504, -0.3444],
           [-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],
 
          [[-2.2152,  0.2104, -0.3302,  0.2036, -0.9443],
           [-0.6576, -0.4455,  0.5117, -2.0058, -1.3985],
           [-0.5688,  1.2338, -0.1832,  0.1760,  0.4506],
           [-0.6563,  0.4021, -1.6210,  0.5582, -0.9238],
           [-1.0506, -0.9638,  0.7453, -0.3535, -0.3536]]]], requires_grad=True),
 tensor([[[[ 0.3069,  0.2079, -0.2952],
           [ 1.7681,  1.1056, -1.0555],
           [ 1.5845,  0.8294,  0.6588]],
 
          [[ 0.2574,  0.5007,  0.2912],
           [-0.0210,  0.6593, -0.9691],
           [-0.2918,  0.5695, -1.1242]],
 
          [[ 0.7327, -0.3453,  0.7041],
           [-0.2236, -1.7762,  0.0190],
           [-1.0927, -2.9369,  0.1768]]],
 
 
         [[[-2.3830, -1.4807,  1.8573],
           [ 1.0097, -0.9640,  1.0361],
           [-0.5222, -1.0386, -0.4016]],
 
          [[ 0.5071,  1.1433, -0.1194],
           [-0.0133, -0.3878, -0.1853],
           [ 0.3456, -0.6502,  0.2221]],
 
          [[-1.7672, -0.0469, -0.5996],
           [-0.2080, -1.6209,  0.4120],
           [ 0.8404, -1.6748, -0.7170]]],
 
 
         [[[ 0.2850,  0.1691, -0.9228],
           [ 0.7234,  0.5582, -0.4327],
           [ 0.6563,  0.2941,  1.5549]],
 
          [[ 0.2642, -1.9061,  1.6212],
           [-0.5276, -0.5608,  0.3824],
           [ 0.4452, -2.5152,  0.4490]],
 
          [[-0.1276,  0.7784,  0.7998],
           [-0.3030, -0.9776,  0.9681],
           [ 1.0225,  0.8946, -0.8084]]],
 
 
         [[[-0.5087, -0.8345, -1.4763],
           [-0.4938,  1.1979, -0.1335],
           [ 0.5010,  0.2865,  0.0728]],
 
          [[-0.3177, -0.6937, -1.0327],
           [ 0.8147, -1.7101, -1.8257],
           [-0.1593, -1.3855, -0.0885]],
 
          [[-0.4687, -1.6307,  1.5791],
           [-1.3030,  0.2004, -0.7055],
           [ 0.0674, -0.8772,  0.1586]]]], requires_grad=True),
 tensor([ 1.5349, -0.5608,  0.5182,  0.3328], requires_grad=True))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
n, c, h, w = x.shape
d, c, k, j = weight.shape
  • 1
  • 2
n, c, h, w
  • 1
(2, 3, 5, 5)
  • 1
d, c, k, j
  • 1
(4, 3, 3, 3)
  • 1
# 补零
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad
  • 1
  • 2
  • 3
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]],

         [[0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
x_pad.shape
  • 1
torch.Size([2, 3, 9, 9])
  • 1
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad
  • 1
  • 2
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4888,  1.0257,  0.0312, -0.9026, -0.9060,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2071, -0.4962, -0.1658,  1.0919,  0.3785,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.4654,  1.5442,  0.6005,  0.3594, -2.6207,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.5830,  0.0533,  0.5719,  1.5413,  0.5949,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.4706, -0.1108, -0.1563, -1.7946, -0.8533,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2119,  0.3165, -2.2668, -0.8956,  1.0617,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.7809, -0.2120, -0.8592, -0.5057,  0.7954,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.8820, -0.6888,  0.4450, -0.3586, -0.9477,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.6244,  0.4303,  1.4739,  0.2740,  1.6605,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.1501,  0.6234, -1.6086,  0.1693,  0.4932,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.0611, -1.0938,  0.1695,  1.0193,  0.4263,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.4681, -0.1552, -0.0667, -0.7293,  1.0816,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.8972,  1.1683, -1.4757,  0.4421, -0.0355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.1331,  1.4847,  0.1378, -1.6907, -0.1350,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3853,  1.6396,  0.3436,  0.3841,  0.2355,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.2206, -0.5087, -1.6956,  1.3205,  0.7058,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0993,  0.3533, -0.2086,  0.2969,  0.2627,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.3752,  0.0304,  1.2487,  1.3963, -0.0063,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.3758,  0.5088, -1.3849,  1.3050,  0.4150,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.2824, -2.8634, -0.1016, -0.1627,  1.7081,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.1406,  0.2220, -0.6005,  0.2997, -0.1846,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  1.6700,  0.5787,  0.6561, -0.0236,  1.7743,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  2.1429, -0.2838, -0.0527,  0.3504, -0.3444,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -2.2152,  0.2104, -0.3302,  0.2036, -0.9443,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6576, -0.4455,  0.5117, -2.0058, -1.3985,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.5688,  1.2338, -0.1832,  0.1760,  0.4506,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.6563,  0.4021, -1.6210,  0.5582, -0.9238,
            0.0000,  0.0000],
          [ 0.0000,  0.0000, -1.0506, -0.9638,  0.7453, -0.3535, -0.3536,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000]]]], grad_fn=<CopySlices>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
# 卷积
x_pad = x_pad.unfold(2, k, stride)
x_pad.shape
  • 1
  • 2
  • 3
torch.Size([2, 3, 4, 9, 3])
  • 1
x_pad = x_pad.unfold(3, j, stride)
x_pad.shape
  • 1
  • 2
torch.Size([2, 3, 4, 4, 3, 3])
  • 1
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out.shape
  • 1
  • 2
torch.Size([2, 4, 4, 4])
  • 1
bias.view(1, -1, 1, 1).shape
  • 1
torch.Size([1, 4, 1, 1])
  • 1
# 偏置
out = out + bias.view(1, -1, 1, 1)
out
  • 1
  • 2
  • 3
tensor([[[[ 0.6573, -0.3444,  1.5693, -0.1906],
          [ 2.5483,  5.1142, -2.3528, -3.6162],
          [ 2.9913, -6.1289,  6.8200,  0.9229],
          [ 0.4849,  0.1813,  3.2616,  1.5637]],

         [[-0.1524, -1.2003, -0.3415,  0.0318],
          [-1.7830,  2.5286, -1.6660,  3.1253],
          [ 1.3314, -8.2623, -5.0055,  5.7671],
          [-1.0563,  5.2751, -0.4214,  2.8473]],

         [[ 0.0908,  2.6704,  1.0336,  0.0481],
          [ 0.2077,  2.0459,  1.8095, -0.7039],
          [ 0.9519, -4.5551,  3.7108,  0.7446],
          [ 0.6689,  3.9448,  2.3968,  0.6958]],

         [[ 0.2318, -0.3356,  2.4320,  0.0480],
          [ 0.2101, -1.7177,  6.3956, -0.4108],
          [ 8.2352, -5.8456, 12.9459, -0.8763],
          [-2.3292, -1.5263,  2.1349,  0.3653]]],


        [[[-0.0869,  1.0713,  0.1655,  2.4414],
          [-1.3623, -3.0759,  0.2430,  2.3259],
          [-0.9281,  2.1402,  7.1618,  4.1895],
          [ 0.9273,  1.1176,  0.7792,  0.9265]],

         [[ 1.6465, -1.7187, -0.7251, -0.8871],
          [-1.6260, -0.8628, -1.0122,  3.2737],
          [ 0.5831,  2.1665, -0.5353, -2.0468],
          [-2.3738, -0.1232, -0.0771, -1.8642]],

         [[ 0.2819,  6.0978,  2.9618,  0.4676],
          [ 1.3592,  6.7231,  3.8100,  3.6118],
          [ 0.9885, -5.7760,  5.4375,  0.5480],
          [-0.5778,  1.4657, -2.8315,  0.1923]],

         [[-0.1444,  3.6788,  0.3721,  0.1150],
          [-1.4057,  0.1613, -2.5436,  1.3156],
          [-6.1195,  1.8325,  3.1565,  0.8296],
          [ 1.6766,  6.9403,  1.3986,  0.8758]]]], grad_fn=<AddBackward0>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

2.3 验证编写卷积函数的正确性

import torch.nn.functional as F
x = torch.randn(2, 3, 5, 5, requires_grad=True)
w = torch.randn(4, 3, 3, 3, requires_grad=True)
b = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
torch_out = F.conv2d(x, w, b, stride, pad)
my_out = conv2d(x, w, b, stride, pad)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
torch_out == my_out
  • 1
tensor([[[[ True,  True,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True,  True,  True],
          [ True, False,  True,  True],
          [False,  True,  True,  True],
          [ True,  True,  True,  True]],

         [[ True, False, False,  True],
          [False, False,  True,  True],
          [ True, False, False,  True],
          [ True,  True, False,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True,  True,  True,  True]]],


        [[[ True,  True,  True,  True],
          [False,  True, False,  True],
          [ True,  True, False,  True],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [ True, False, False, False],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True,  True, False,  True],
          [False, False, False,  True],
          [ True, False, False, False],
          [ True, False,  True,  True]],

         [[ True, False,  True,  True],
          [ True, False, False,  True],
          [False, False, False, False],
          [ True,  True, False,  True]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
torch.allclose(torch_out, my_out, atol=1e-5)
  • 1
True
  • 1
  • torch.allclose是用于检查两个张量之间的数值是否相等的函数。

  • 在使用时,需要将第一个张量作为第一个参数传入(即torch_out),将第二个张量作为第二个参数传入(即my_out),并将允许的绝对误差(atol)作为第三个参数传入(默认值为1e-8)。

  • 函数将返回一个布尔值,表示两个张量是否具有相近的数值。如果返回True,则表示两个张量具有相近的数值,否则表示它们之间存在数值差异。

grad_out = torch.randn(*torch_out.shape)
grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
  • 1
  • 2
  • 3
torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
  • 1
True
  • 1
grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
  • 1
  • 2
torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
  • 1
True
  • 1
grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
  • 1
  • 2
torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
  • 1
True
  • 1

全是True,表明编写的卷积函数在一定范围内与PyTorch内置的Conv2d函数结果相近,说明了实现的正确性

附:系列文章

序号文章目录直达链接
1PyTorch应用实战一:实现卷积操作https://want595.blog.csdn.net/article/details/132575530
2PyTorch应用实战二:实现卷积神经网络进行图像分类https://want595.blog.csdn.net/article/details/132575702
3PyTorch应用实战三:构建神经网络https://want595.blog.csdn.net/article/details/132575758
4PyTorch应用实战四:基于PyTorch构建复杂应用https://want595.blog.csdn.net/article/details/132625270
5PyTorch应用实战五:实现二值化神经网络https://want595.blog.csdn.net/article/details/132625348
6PyTorch应用实战六:利用LSTM实现文本情感分类https://want595.blog.csdn.net/article/details/132625382
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/194227
推荐阅读
相关标签
  

闽ICP备14008679号