赞
踩
python3.6 + pytorch1.8.0
import torch
print(torch.__version__)
1.8.0
卷积操作是指两个函数f和g之间的一种数学运算,它在信号处理、图像处理、机器学习等领域中广泛应用。在离散情况下,卷积操作可以表示为:
( f ∗ g ) [ n ] = ∑ m = − ∞ ∞ f [ m ] g [ n − m ] (f * g)[n] = \sum_{m=-\infty}^{\infty}f[m]g[n-m] (f∗g)[n]=m=−∞∑∞f[m]g[n−m]
其中, f f f和 g g g是离散函数, ∗ * ∗表示卷积操作, n n n是离散的变量。卷积操作可以看作是将函数 g g g沿着 n n n轴翻转,然后平移,每次和函数 f f f相乘并求和,最后得到一个新的函数。这种操作可以实现信号的滤波、特征提取等功能,是数字信号处理中非常重要的基础操作。
PyTorch的unfold
函数用于对张量进行展开操作。torch.unfold()
可以理解为将一个高维的张量展开成一个二维矩阵的操作。即将原来的张量沿着指定的维度展开成一个二维矩阵,其中第一维对应原来张量的维度,第二维对应展开的位置。
函数原型如下:
torch.unfold(input, dimension, size, step)
参数说明:
import torch
a = torch.arange(16).view(4, 4)
a
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
b = a.unfold(0, 3, 1)
b
tensor([[[ 0, 4, 8],
[ 1, 5, 9],
[ 2, 6, 10],
[ 3, 7, 11]],
[[ 4, 8, 12],
[ 5, 9, 13],
[ 6, 10, 14],
[ 7, 11, 15]]])
b.shape
torch.Size([2, 4, 3])
c = b.unfold(1, 3, 1)
c
tensor([[[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]],
[[ 1, 2, 3],
[ 5, 6, 7],
[ 9, 10, 11]]],
[[[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]],
[[ 5, 6, 7],
[ 9, 10, 11],
[13, 14, 15]]]])
c.shape
torch.Size([2, 2, 3, 3])
完整程序
import torch
a = torch.arange(16).view(4, 4)
b = a.unfold(0, 3, 1)
c = b.unfold(1, 3, 1)
c.shape
torch.Size([2, 2, 3, 3])
这段代码定义了三个变量。假设我们将其分别命名为a
,b
和c
,则:
a
是一个4x4的张量,其中包含了0到15的整数值,它通过torch.arange(16)
和.view(4, 4)
两个函数调用来实现。b
是通过对变量a
进行折叠操作得到的一个张量,具体来说,它是将变量a
沿着第0维(即行)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量b
打印出来,会得到:tensor([[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]],
[[ 1, 2, 3],
[ 5, 6, 7],
[ 9, 10, 11],
[13, 14, 15]]])
其中,第一个子张量的值为[[0, 1, 2], [4, 5, 6], [8, 9, 10], [12, 13, 14]]
,第二个子张量的值为[[1, 2, 3], [5, 6, 7], [9, 10, 11], [13, 14, 15]]
。注意,这个张量的形状为(2, 4, 3)
,即它包含2个子张量,每个子张量的形状为(4, 3)
。
c
是对变量b
进行类似的操作得到的,但是它是在第1维(即列)上展开并取子张量。具体来说,它是将变量b
沿着第1维(即列)展开,并取窗口大小为3,步长为1的子张量所得到的结果。因此,如果我们将张量c
打印出来,会得到:tensor([[[[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]],
[[ 1, 2, 3],
[ 5, 6, 7],
[ 9, 10, 11]]],
[[[ 4, 5, 6],
[ 8, 9, 10],
[12, 13, 14]],
[[ 5, 6, 7],
[ 9, 10, 11],
[13, 14, 15]]]])
其中,第一个子张量的值为[[[0, 1, 2], [4, 5, 6], [8, 9, 10]], [[1, 2, 3], [5, 6, 7], [9, 10, 11]]]
,第二个子张量的值为[[[4, 5, 6], [8, 9, 10], [12, 13, 14]], [[5, 6, 7], [9, 10, 11], [13, 14, 15]]]
。注意,这个张量的形状为(2, 2, 3, 3)
,即它包含2个子张量,每个子张量的形状为(2, 3, 3)
。
完整程序
import torch
def conv2d(x, weight, bias, stride, pad):
n, c, h, w = x.shape
d, c, k, j = weight.shape
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad = x_pad.unfold(2, k, stride)
x_pad = x_pad.unfold(3, j, stride)
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out = out + bias.view(1, -1, 1, 1)
return out
该函数实现了二维卷积操作。下面对函数进行详细分析:
在新的张量 x_pad 上使用 einsum 函数对卷积核进行卷积操作。
einsum 的第一个参数表示操作的规则,其中 ndhw 表示最终输出的张量的维度为(batch_size,out_channels,output_height,output_width),nchw 和 dckj 表示两个输入张量 x_pad 和 weight 的维度,其中 c k j 分别表示 input_channels、kernel_height 和 kernel_width。
最终得到的输出张量形状为(batch_size,out_channels,output_height,output_width),并在每个位置上加上偏置项 bias。
torch
: PyTorch库
einsum
: Einstein summation notation,爱因斯坦求和约定,一种张量求和的简便表示法。
'nchwkj,dckj->ndhw'
: 爱因斯坦求和符号,左侧的张量为x_pad
,右侧的张量为weight
。在左侧张量中,n
, c
, h
, w
分别表示batch size、通道数、高度、宽度。在右侧张量中,d
, c
, k
, j
分别表示输出通道数、输入通道数、卷积核高度、卷积核宽度。这个式子的意义是将x_pad
和weight
执行卷积操作,并输出结果张量,其形状为(batch_size, output_channels, height, width)
。
x_pad
: 输入的张量,形状为 (batch_size, input_channels, input_height, input_width)
。
weight
: 卷积核张量,形状为 (output_channels, input_channels, kernel_height, kernel_width)
。
# 设置测试数据
x = torch.randn(2, 3, 5, 5, requires_grad=True)
weight = torch.randn(4, 3, 3, 3, requires_grad=True)
bias = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
x, weight, bias
(tensor([[[[-0.4888, 1.0257, 0.0312, -0.9026, -0.9060],
[ 0.2071, -0.4962, -0.1658, 1.0919, 0.3785],
[-0.4654, 1.5442, 0.6005, 0.3594, -2.6207],
[ 0.5830, 0.0533, 0.5719, 1.5413, 0.5949],
[-0.9152, -0.2114, -0.4888, -0.0065, -0.9767]],
[[ 0.4706, -0.1108, -0.1563, -1.7946, -0.8533],
[-0.2119, 0.3165, -2.2668, -0.8956, 1.0617],
[-0.7809, -0.2120, -0.8592, -0.5057, 0.7954],
[-2.8820, -0.6888, 0.4450, -0.3586, -0.9477],
[ 0.6244, 0.4303, 1.4739, 0.2740, 1.6605]],
[[-0.1501, 0.6234, -1.6086, 0.1693, 0.4932],
[ 1.0611, -1.0938, 0.1695, 1.0193, 0.4263],
[ 1.4681, -0.1552, -0.0667, -0.7293, 1.0816],
[ 0.8972, 1.1683, -1.4757, 0.4421, -0.0355],
[-2.1331, 1.4847, 0.1378, -1.6907, -0.1350]]],
[[[-1.3853, 1.6396, 0.3436, 0.3841, 0.2355],
[-0.2206, -0.5087, -1.6956, 1.3205, 0.7058],
[ 0.0993, 0.3533, -0.2086, 0.2969, 0.2627],
[ 0.3752, 0.0304, 1.2487, 1.3963, -0.0063],
[-1.3758, 0.5088, -1.3849, 1.3050, 0.4150]],
[[ 0.2824, -2.8634, -0.1016, -0.1627, 1.7081],
[ 0.1406, 0.2220, -0.6005, 0.2997, -0.1846],
[ 1.6700, 0.5787, 0.6561, -0.0236, 1.7743],
[ 2.1429, -0.2838, -0.0527, 0.3504, -0.3444],
[-0.9409, -0.4734, -0.4060, -0.5088, -1.8518]],
[[-2.2152, 0.2104, -0.3302, 0.2036, -0.9443],
[-0.6576, -0.4455, 0.5117, -2.0058, -1.3985],
[-0.5688, 1.2338, -0.1832, 0.1760, 0.4506],
[-0.6563, 0.4021, -1.6210, 0.5582, -0.9238],
[-1.0506, -0.9638, 0.7453, -0.3535, -0.3536]]]], requires_grad=True),
tensor([[[[ 0.3069, 0.2079, -0.2952],
[ 1.7681, 1.1056, -1.0555],
[ 1.5845, 0.8294, 0.6588]],
[[ 0.2574, 0.5007, 0.2912],
[-0.0210, 0.6593, -0.9691],
[-0.2918, 0.5695, -1.1242]],
[[ 0.7327, -0.3453, 0.7041],
[-0.2236, -1.7762, 0.0190],
[-1.0927, -2.9369, 0.1768]]],
[[[-2.3830, -1.4807, 1.8573],
[ 1.0097, -0.9640, 1.0361],
[-0.5222, -1.0386, -0.4016]],
[[ 0.5071, 1.1433, -0.1194],
[-0.0133, -0.3878, -0.1853],
[ 0.3456, -0.6502, 0.2221]],
[[-1.7672, -0.0469, -0.5996],
[-0.2080, -1.6209, 0.4120],
[ 0.8404, -1.6748, -0.7170]]],
[[[ 0.2850, 0.1691, -0.9228],
[ 0.7234, 0.5582, -0.4327],
[ 0.6563, 0.2941, 1.5549]],
[[ 0.2642, -1.9061, 1.6212],
[-0.5276, -0.5608, 0.3824],
[ 0.4452, -2.5152, 0.4490]],
[[-0.1276, 0.7784, 0.7998],
[-0.3030, -0.9776, 0.9681],
[ 1.0225, 0.8946, -0.8084]]],
[[[-0.5087, -0.8345, -1.4763],
[-0.4938, 1.1979, -0.1335],
[ 0.5010, 0.2865, 0.0728]],
[[-0.3177, -0.6937, -1.0327],
[ 0.8147, -1.7101, -1.8257],
[-0.1593, -1.3855, -0.0885]],
[[-0.4687, -1.6307, 1.5791],
[-1.3030, 0.2004, -0.7055],
[ 0.0674, -0.8772, 0.1586]]]], requires_grad=True),
tensor([ 1.5349, -0.5608, 0.5182, 0.3328], requires_grad=True))
n, c, h, w = x.shape
d, c, k, j = weight.shape
n, c, h, w
(2, 3, 5, 5)
d, c, k, j
(4, 3, 3, 3)
# 补零
x_pad = torch.zeros(n, c, h+2*pad, w+2*pad).to(x.device)
x_pad
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
x_pad.shape
torch.Size([2, 3, 9, 9])
x_pad[:, :, pad:-pad, pad:-pad] = x
x_pad
tensor([[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.4888, 1.0257, 0.0312, -0.9026, -0.9060,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.2071, -0.4962, -0.1658, 1.0919, 0.3785,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.4654, 1.5442, 0.6005, 0.3594, -2.6207,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.5830, 0.0533, 0.5719, 1.5413, 0.5949,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.9152, -0.2114, -0.4888, -0.0065, -0.9767,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.4706, -0.1108, -0.1563, -1.7946, -0.8533,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.2119, 0.3165, -2.2668, -0.8956, 1.0617,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.7809, -0.2120, -0.8592, -0.5057, 0.7954,
0.0000, 0.0000],
[ 0.0000, 0.0000, -2.8820, -0.6888, 0.4450, -0.3586, -0.9477,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.6244, 0.4303, 1.4739, 0.2740, 1.6605,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.1501, 0.6234, -1.6086, 0.1693, 0.4932,
0.0000, 0.0000],
[ 0.0000, 0.0000, 1.0611, -1.0938, 0.1695, 1.0193, 0.4263,
0.0000, 0.0000],
[ 0.0000, 0.0000, 1.4681, -0.1552, -0.0667, -0.7293, 1.0816,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.8972, 1.1683, -1.4757, 0.4421, -0.0355,
0.0000, 0.0000],
[ 0.0000, 0.0000, -2.1331, 1.4847, 0.1378, -1.6907, -0.1350,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]]],
[[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, -1.3853, 1.6396, 0.3436, 0.3841, 0.2355,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.2206, -0.5087, -1.6956, 1.3205, 0.7058,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0993, 0.3533, -0.2086, 0.2969, 0.2627,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.3752, 0.0304, 1.2487, 1.3963, -0.0063,
0.0000, 0.0000],
[ 0.0000, 0.0000, -1.3758, 0.5088, -1.3849, 1.3050, 0.4150,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.2824, -2.8634, -0.1016, -0.1627, 1.7081,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.1406, 0.2220, -0.6005, 0.2997, -0.1846,
0.0000, 0.0000],
[ 0.0000, 0.0000, 1.6700, 0.5787, 0.6561, -0.0236, 1.7743,
0.0000, 0.0000],
[ 0.0000, 0.0000, 2.1429, -0.2838, -0.0527, 0.3504, -0.3444,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.9409, -0.4734, -0.4060, -0.5088, -1.8518,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, -2.2152, 0.2104, -0.3302, 0.2036, -0.9443,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.6576, -0.4455, 0.5117, -2.0058, -1.3985,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.5688, 1.2338, -0.1832, 0.1760, 0.4506,
0.0000, 0.0000],
[ 0.0000, 0.0000, -0.6563, 0.4021, -1.6210, 0.5582, -0.9238,
0.0000, 0.0000],
[ 0.0000, 0.0000, -1.0506, -0.9638, 0.7453, -0.3535, -0.3536,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]]]], grad_fn=<CopySlices>)
# 卷积
x_pad = x_pad.unfold(2, k, stride)
x_pad.shape
torch.Size([2, 3, 4, 9, 3])
x_pad = x_pad.unfold(3, j, stride)
x_pad.shape
torch.Size([2, 3, 4, 4, 3, 3])
out = torch.einsum('nchwkj,dckj->ndhw', x_pad, weight)
out.shape
torch.Size([2, 4, 4, 4])
bias.view(1, -1, 1, 1).shape
torch.Size([1, 4, 1, 1])
# 偏置
out = out + bias.view(1, -1, 1, 1)
out
tensor([[[[ 0.6573, -0.3444, 1.5693, -0.1906],
[ 2.5483, 5.1142, -2.3528, -3.6162],
[ 2.9913, -6.1289, 6.8200, 0.9229],
[ 0.4849, 0.1813, 3.2616, 1.5637]],
[[-0.1524, -1.2003, -0.3415, 0.0318],
[-1.7830, 2.5286, -1.6660, 3.1253],
[ 1.3314, -8.2623, -5.0055, 5.7671],
[-1.0563, 5.2751, -0.4214, 2.8473]],
[[ 0.0908, 2.6704, 1.0336, 0.0481],
[ 0.2077, 2.0459, 1.8095, -0.7039],
[ 0.9519, -4.5551, 3.7108, 0.7446],
[ 0.6689, 3.9448, 2.3968, 0.6958]],
[[ 0.2318, -0.3356, 2.4320, 0.0480],
[ 0.2101, -1.7177, 6.3956, -0.4108],
[ 8.2352, -5.8456, 12.9459, -0.8763],
[-2.3292, -1.5263, 2.1349, 0.3653]]],
[[[-0.0869, 1.0713, 0.1655, 2.4414],
[-1.3623, -3.0759, 0.2430, 2.3259],
[-0.9281, 2.1402, 7.1618, 4.1895],
[ 0.9273, 1.1176, 0.7792, 0.9265]],
[[ 1.6465, -1.7187, -0.7251, -0.8871],
[-1.6260, -0.8628, -1.0122, 3.2737],
[ 0.5831, 2.1665, -0.5353, -2.0468],
[-2.3738, -0.1232, -0.0771, -1.8642]],
[[ 0.2819, 6.0978, 2.9618, 0.4676],
[ 1.3592, 6.7231, 3.8100, 3.6118],
[ 0.9885, -5.7760, 5.4375, 0.5480],
[-0.5778, 1.4657, -2.8315, 0.1923]],
[[-0.1444, 3.6788, 0.3721, 0.1150],
[-1.4057, 0.1613, -2.5436, 1.3156],
[-6.1195, 1.8325, 3.1565, 0.8296],
[ 1.6766, 6.9403, 1.3986, 0.8758]]]], grad_fn=<AddBackward0>)
import torch.nn.functional as F
x = torch.randn(2, 3, 5, 5, requires_grad=True)
w = torch.randn(4, 3, 3, 3, requires_grad=True)
b = torch.randn(4, requires_grad = True)
stride = 2
pad = 2
torch_out = F.conv2d(x, w, b, stride, pad)
my_out = conv2d(x, w, b, stride, pad)
torch_out == my_out
tensor([[[[ True, True, True, True],
[ True, False, False, True],
[ True, True, False, True],
[ True, True, False, True]],
[[ True, True, True, True],
[ True, False, True, True],
[False, True, True, True],
[ True, True, True, True]],
[[ True, False, False, True],
[False, False, True, True],
[ True, False, False, True],
[ True, True, False, True]],
[[ True, True, False, True],
[False, False, False, True],
[ True, False, False, False],
[ True, True, True, True]]],
[[[ True, True, True, True],
[False, True, False, True],
[ True, True, False, True],
[ True, False, True, True]],
[[ True, True, False, True],
[ True, False, False, False],
[ True, False, False, False],
[ True, False, True, True]],
[[ True, True, False, True],
[False, False, False, True],
[ True, False, False, False],
[ True, False, True, True]],
[[ True, False, True, True],
[ True, False, False, True],
[False, False, False, False],
[ True, True, False, True]]]])
torch.allclose(torch_out, my_out, atol=1e-5)
True
torch.allclose是用于检查两个张量之间的数值是否相等的函数。
在使用时,需要将第一个张量作为第一个参数传入(即torch_out),将第二个张量作为第二个参数传入(即my_out),并将允许的绝对误差(atol)作为第三个参数传入(默认值为1e-8)。
函数将返回一个布尔值,表示两个张量是否具有相近的数值。如果返回True,则表示两个张量具有相近的数值,否则表示它们之间存在数值差异。
grad_out = torch.randn(*torch_out.shape)
grad_x = torch.autograd.grad(torch_out, x, grad_out, retain_graph=True)
my_grad_x = torch.autograd.grad(my_out, x, grad_out, retain_graph=True)
torch.allclose(grad_x[0], my_grad_x[0], atol=1e-5)
True
grad_w = torch.autograd.grad(torch_out, w, grad_out, retain_graph=True)
my_grad_w = torch.autograd.grad(my_out, w, grad_out, retain_graph=True)
torch.allclose(grad_w[0], my_grad_w[0], atol=1e-5)
True
grad_b = torch.autograd.grad(torch_out, b, grad_out, retain_graph=True)
my_grad_b = torch.autograd.grad(my_out, b, grad_out, retain_graph=True)
torch.allclose(grad_b[0], my_grad_b[0], atol=1e-5)
True
全是True,表明编写的卷积函数在一定范围内与PyTorch内置的Conv2d函数结果相近,说明了实现的正确性
序号 | 文章目录 | 直达链接 |
---|---|---|
1 | PyTorch应用实战一:实现卷积操作 | https://want595.blog.csdn.net/article/details/132575530 |
2 | PyTorch应用实战二:实现卷积神经网络进行图像分类 | https://want595.blog.csdn.net/article/details/132575702 |
3 | PyTorch应用实战三:构建神经网络 | https://want595.blog.csdn.net/article/details/132575758 |
4 | PyTorch应用实战四:基于PyTorch构建复杂应用 | https://want595.blog.csdn.net/article/details/132625270 |
5 | PyTorch应用实战五:实现二值化神经网络 | https://want595.blog.csdn.net/article/details/132625348 |
6 | PyTorch应用实战六:利用LSTM实现文本情感分类 | https://want595.blog.csdn.net/article/details/132625382 |
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。