赞
踩
SENet[1]是ImageNet 2017年的冠军模型,自SeNet提出后,ImageNet挑战赛就停止举办了。SENet同之前的ResNet一样,引入了一些技巧,可以在很大程度上降低模型的参数,并且提升模型的运算速度。
SENet全称Squeeze-and-Excitation Networks,中文名可以翻译为挤压和激励网络。SENet在ImageNet 2017取得了第一名的成绩,Top-5 error rate降低到了2.251%,官方的模型和代码在github仓库中可以找到[2]。
SENet提出的动机是将通道之间的关系结合起来,于是引出了一个Squeeze-and-excitation(SE)块[1],它的目的就是通过显式建模网络卷积特征的信道之间的相互依赖性来提高网络表征的质量。SE块的机制也可以说是通过学习全局信息来选择性地强调有用的特征和抑制不太有用的特征,SENet块如fig1所示。
SE模块可以看作是一个计算单元,用 F t r F_{tr} Ftr 表示,可以将输入 X ∈ R H ′ × W ′ × C ′ X \in \R^{H' \times W' \times C'} X∈RH′×W′×C′ 映射为特征图 U ∈ R H × W × C U \in \R^{H \times W \times C} U∈RH×W×C。以下的符号中, F t r F_{tr} Ftr 表示卷积操作, V = [ V 1 , V 2 , … , V C ] \bold{V}=[V_1, V_2, \dots, V_C] V=[V1,V2,…,VC] 来表示学习到的一组滤波器核,其中 V c V_c Vc 表示的是第 c c c 个滤波器的参数,所以输出可以表示为 U = [ U 1 , U 2 , … , U C ] \bold{U}=[U_1, U_2, \dots, U_C] U=[U1,U2,…,UC],其中:
U c = V c ∗ X = ∑ s = 1 C ′ V c s ∗ X s U_c=V_c * \bold{X}=\sum_{s=1}^{C'}V_c^s * X^s Uc=Vc∗X=s=1∑C′Vcs∗Xs
公式1中 ∗ * ∗ 表示的是卷积操作, V c = [ V c 1 , V c 2 , … , V c C ′ ] , X = [ X 1 , X 2 , … , X C ′ ] V_c=[V_c^1, V_c^2, \dots, V_c^{C'}], \quad \bold{X}=[X^1, X^2, \dots, X^{C'}] Vc=[Vc1,Vc2,…,VcC′],X=[X1,X2,…,XC′] 以及 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W, V c s V_c^s Vcs 表示的是 X \bold{X} X 对应单个 V c V_c Vc 通道的 2D 空间核。
对于以上公式有以下的说明:
为了解决通道依赖的问题,需要考虑将输出特征中每个通道对应的信号。每一个训练的滤波器都有一个局部感受野,因此每个神经元的转换输出都不能很好地利用这个区域之外的上下文信息。
为了解决这个问题,SeNet 将全局空间信息压缩到通道描述符中,这是通过使用全局平均池化(global average pooling)来生成通道统计数据来实现的。形式上,统计量 Z ∈ R C Z \in \R^C Z∈RC 是通过收缩 U U U 的空间维度 H × W H \times W H×W 来生成的,从而 Z Z Z 的第 c c c 个元素通过以下方式计算:
z c = F s q ( u c ) = 1 H × W ∑ i = 1 H ∑ j = 1 W u c ( i , j ) z_c = F_{sq}(u_c)=\frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W u_c (i, j) zc=Fsq(uc)=H×W1i=1∑Hj=1∑Wuc(i,j)
为了利用在 Squeeze 操作中聚集到的信息,接下来进行第二个操作,目的是为了完全捕获通道依赖信息。为了实现这一目标,该功能必须满足两个标准:
为了满足这些标准,这里选择了带有 sigmoid 激活函数的简单门控机制:
s = F e x ( z , W ) = σ ( g ( z , W ) ) = σ ( W 2 δ ( W 1 z ) ) s = F_{ex}(z, W) = \sigma (g(z, \bold{W}))=\sigma(\bold{W}_2 \delta(\bold{W}_1 z)) s=Fex(z,W)=σ(g(z,W))=σ(W2δ(W1z))
其中 δ \delta δ 表示的是 ReLU 函数, W 1 ∈ R C r × C , W 2 ∈ R C × C r \bold{W}_1 \in \R^{\frac{C}{r} \times C} ,\quad \bold{W}_2 \in \R^{C \times \frac{C}{r}} W1∈RrC×C,W2∈RC×rC 。为了降低模型复杂度以及提升泛化能力,这里用到了两个全连接层的bottleneck结构,其中第一个全连接层起到降维的作用,降维系数为r是个超参数,然后采用ReLU激活,最后的全连接层恢复原始的维度,最后将学习到的各个通道的激活值(sigmoid激活,值为0~1)乘上 U U U 上的原始特征:
x ~ c = F s c a l e ( u c , s c ) = s c ⋅ u c \tilde{x}_c = F_{scale}(u_c, s_c) = s_c \cdot u_c x~c=Fscale(uc,sc)=sc⋅uc
其中 X ~ = [ X ~ 1 , X ~ 2 , … , X ~ C ] \widetilde{\bold{X}}=[\widetilde{X}_1, \widetilde{X}_2, \dots, \widetilde{X}_C] X =[X 1,X 2,…,X C], F s c a l e ( u c , s c ) F_{scale}(u_c, s_c) Fscale(uc,sc) 表示的是标量 S c S_c Sc 和特征图 u c ∈ R H × W u_c \in \R^{H \times W} uc∈RH×W 的乘法
其实整个操作可以看做学习到了各个通道的权重参数,从而使得模型对各个通道的特征更加有辨别能力,这应该也算一种attention机制[3]
SE模块十分灵活,可以直接应用到现用的网络架构中。例如GoogLeNet和ResNet等,如图2和图3所示
同样地,SE模块还可以应用在其他的网络结构,这里给出论文中的原表格,SE-ResNet-50和SE-ResNetXt-50的具体结构,见表格1
增加了SE模块后,模型的参数以及计算量都会相应的增加,这些增加的参数仅仅由门控门控机制的两个全连接层产生,因此只占网络容量的一小部分。具体的计算公式如公式5:
2 r ∑ s = 1 s N s ⋅ C s 2 \frac{2}{r}\sum_{s=1}^s N_s \cdot C_s^2 r2s=1∑sNs⋅Cs2
其中 r r r 表示的是降维系数, S S S 表示的是级数(the number of stages),一个级数指的是对公共空间维度的特征图进行操作的块的集合, C s C_s Cs 表示的输出通道的维度, N s N_s Ns 表示的级数 S S S 重复块的数量。
当 r = 16 r=16 r=16 时, SE-ResNet-50 只增加了约10%的参数量,但是计算量却增加不到1%
SE模块可以很容易地引入到其他网络中,为了验证SE模块的效果,在主流的流行网络中引入了SE模块,对比其在ImageNet上的效果,如表2所示:
可以看到所有的网络在加入SE模块后分类准确度均有一定的提升,为了实际地体会SE模块,之后就是尝试仿真实现,更加深入的了解其网络架构和效果
以下代码参考的是github代码[4]
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
import time
device = ('cuda' if torch.cuda.is_available() else 'cpu')
device
'cpu'
# 超参数
EPOCHS = 40
BATCH_SIZE = 128
LEARNING_RATE = 1e-1
WEIGHT_DECAY = 1e-4
使用torchvision.dataset
获取数据
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
Files already downloaded and verified
Files already downloaded and verified
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
# Squeeze and Excitation Block Module class SEBlock(nn.Module): def __init__(self, channels, reduction=16): super(SEBlock, self).__init__() self.fc = nn.Sequential( nn.Conv2d(channels, channels // reduction, 1, bias=False), nn.ReLU(), nn.Conv2d(channels // reduction, channels * 2, 1, bias=False), ) def forward(self, x): w = F.adaptive_avg_pool2d(x, 1) # Squeeze w = self.fc(x) w, b = w.split(w.data.size(1) // 2, dim=1) # Excitation w = torch.sigmoid(w) return x * w + b # Scale and add bias
# Residual Block with SEBlock class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.conv_lower = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.BatchNorm2d(channels), nn.ReLU() ) self.conv_upper = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.BatchNorm2d(channels) ) self.se_block = SEBlock(channels) def forward(self, x): path = self.conv_lower(x) path = self.conv_upper(path) path = self.se_block(path) path = x + path return F.relu(path)
# Network Module class Network(nn.Module): def __init__(self, in_channel, filters, blocks, num_classes): super(Network, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(in_channel, filters, 3, padding=1, bias=False), nn.BatchNorm2d(filters), nn.ReLU() ) self.res_blocks = nn.Sequential(*[ResBlock(filters) for _ in range(blocks - 1)]) self.out_conv = nn.Sequential( nn.Conv2d(filters, 128, 1, padding=0, bias=False), nn.BatchNorm2d(128), nn.ReLU() ) self.fc = nn.Linear(128, num_classes) def forward(self, x): x = self.conv_block(x) x = self.res_blocks(x) x = self.out_conv(x) x = F.adaptive_avg_pool2d(x, 1) x = x.view(x.data.size(0), -1) x = self.fc(x) return F.log_softmax(x, dim=1)
net = Network(3, 128, 10, 10).to(device)
ACE = nn.CrossEntropyLoss().to(device)
opt = optim.SGD(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, momentum=.9, nesterov=True)
for epoch in range(1, EPOCHS + 1): print('[Epoch %d]' % epoch) train_loss = 0 train_correct, train_total = 0, 0 start_point = time.time() for inputs, labels in train_loader: inputs, labels = Variable(inputs).to(device),Variable(labels).to(device) opt.zero_grad() preds = net(inputs) loss = ACE(preds, labels) loss.backward() opt.step() train_loss += loss.item() train_correct += (preds.argmax(dim=1) == labels).sum().item() train_total += len(preds) print('train-acc : %.4f%% train-loss : %.5f' % (100 * train_correct / train_total, train_loss / len(train_loader))) print('elapsed time: %ds' % (time.time() - start_point)) test_loss = 0 test_correct, test_total = 0, 0 for inputs, labels in test_loader: with torch.no_grad(): inputs, labels = Variable(inputs).to(device), Variable(labels).to(device) preds = net(inputs) test_loss += ACE(preds, labels).item() test_correct += (preds.argmax(dim=1) == labels).sum().item() test_total += len(preds) print('test-acc : %.4f%% test-loss : %.5f' % (100 * test_correct / test_total, test_loss / len(test_loader))) torch.save(net.state_dict(), './data/checkpoint/checkpoint-%04d.bin' % epoch)
[Epoch 1] train-acc : 62.9240% train-loss : 1.02725 elapsed time: 167s test-acc : 59.9800% test-loss : 1.13711 [Epoch 2] train-acc : 69.3160% train-loss : 0.85710 elapsed time: 170s test-acc : 67.6300% test-loss : 0.92139 [Epoch 3] train-acc : 73.9000% train-loss : 0.74356 elapsed time: 171s test-acc : 70.7700% test-loss : 0.84002 [Epoch 4] train-acc : 77.2340% train-loss : 0.65098 elapsed time: 171s test-acc : 74.3400% test-loss : 0.75001 [Epoch 5] train-acc : 79.7560% train-loss : 0.58424 elapsed time: 171s test-acc : 74.8000% test-loss : 0.71813 [Epoch 6] train-acc : 81.8820% train-loss : 0.52713 elapsed time: 171s test-acc : 77.7400% test-loss : 0.66449 [Epoch 7] train-acc : 83.0260% train-loss : 0.49098 elapsed time: 171s test-acc : 79.3000% test-loss : 0.60599 [Epoch 8] train-acc : 84.2880% train-loss : 0.45633 elapsed time: 171s test-acc : 78.0500% test-loss : 0.64819 [Epoch 9] train-acc : 85.2660% train-loss : 0.43147 elapsed time: 171s test-acc : 80.7400% test-loss : 0.57734 [Epoch 10] train-acc : 86.2080% train-loss : 0.39924 elapsed time: 171s test-acc : 81.9000% test-loss : 0.53836 [Epoch 11] train-acc : 86.9320% train-loss : 0.38040 elapsed time: 171s test-acc : 82.7100% test-loss : 0.51160 [Epoch 12] train-acc : 87.4740% train-loss : 0.36286 elapsed time: 170s test-acc : 81.8500% test-loss : 0.54868 [Epoch 13] train-acc : 88.1580% train-loss : 0.34673 elapsed time: 171s test-acc : 83.0700% test-loss : 0.49779 [Epoch 14] train-acc : 88.9260% train-loss : 0.31996 elapsed time: 171s test-acc : 83.8900% test-loss : 0.48193 [Epoch 15] train-acc : 89.1380% train-loss : 0.31583 elapsed time: 171s test-acc : 83.9900% test-loss : 0.49245 [Epoch 16] train-acc : 89.5460% train-loss : 0.30087 elapsed time: 170s test-acc : 84.0100% test-loss : 0.49648 [Epoch 17] train-acc : 90.0420% train-loss : 0.29067 elapsed time: 171s test-acc : 85.2700% test-loss : 0.44473 [Epoch 18] train-acc : 90.3720% train-loss : 0.28137 elapsed time: 171s test-acc : 83.8900% test-loss : 0.49883 [Epoch 19] train-acc : 90.6020% train-loss : 0.26961 elapsed time: 171s test-acc : 84.4700% test-loss : 0.47203 [Epoch 20] train-acc : 91.1460% train-loss : 0.25927 elapsed time: 170s test-acc : 84.4200% test-loss : 0.49412 [Epoch 21] train-acc : 91.1540% train-loss : 0.25661 elapsed time: 170s test-acc : 85.3500% test-loss : 0.43626 [Epoch 22] train-acc : 91.3620% train-loss : 0.24741 elapsed time: 171s test-acc : 86.2200% test-loss : 0.41310 [Epoch 23] train-acc : 91.9760% train-loss : 0.23271 elapsed time: 171s test-acc : 86.5600% test-loss : 0.40795 [Epoch 24] train-acc : 92.0000% train-loss : 0.23080 elapsed time: 171s test-acc : 84.8000% test-loss : 0.46834 [Epoch 25] train-acc : 92.1460% train-loss : 0.22744 elapsed time: 171s test-acc : 85.4300% test-loss : 0.44402 [Epoch 26] train-acc : 92.2120% train-loss : 0.22320 elapsed time: 170s test-acc : 86.3300% test-loss : 0.41405 [Epoch 27] train-acc : 92.3740% train-loss : 0.21625 elapsed time: 170s test-acc : 87.3800% test-loss : 0.38440 [Epoch 28] train-acc : 92.6960% train-loss : 0.21098 elapsed time: 171s test-acc : 84.9300% test-loss : 0.46326 [Epoch 29] train-acc : 92.8700% train-loss : 0.20541 elapsed time: 171s test-acc : 86.5900% test-loss : 0.41840 [Epoch 30] train-acc : 93.0700% train-loss : 0.20067 elapsed time: 170s test-acc : 86.8400% test-loss : 0.42302 [Epoch 31] train-acc : 93.2300% train-loss : 0.19319 elapsed time: 171s test-acc : 87.1700% test-loss : 0.39542 [Epoch 32] train-acc : 93.2280% train-loss : 0.19576 elapsed time: 171s test-acc : 86.6500% test-loss : 0.43697 [Epoch 33] train-acc : 93.5900% train-loss : 0.18686 elapsed time: 170s test-acc : 86.8300% test-loss : 0.40863 [Epoch 34] train-acc : 93.5820% train-loss : 0.18315 elapsed time: 170s test-acc : 86.8200% test-loss : 0.42321 [Epoch 35] train-acc : 93.6140% train-loss : 0.18232 elapsed time: 170s test-acc : 86.1700% test-loss : 0.43491 [Epoch 36] train-acc : 93.9620% train-loss : 0.17560 elapsed time: 170s test-acc : 86.9100% test-loss : 0.41068 [Epoch 37] train-acc : 93.9920% train-loss : 0.17193 elapsed time: 170s test-acc : 87.0600% test-loss : 0.41822 [Epoch 38] train-acc : 93.8620% train-loss : 0.17253 elapsed time: 170s test-acc : 88.0500% test-loss : 0.38560 [Epoch 39] train-acc : 94.2040% train-loss : 0.16850 elapsed time: 170s test-acc : 86.7000% test-loss : 0.42949 [Epoch 40] train-acc : 94.2940% train-loss : 0.16422 elapsed time: 170s test-acc : 87.2100% test-loss : 0.39914
net.load_state_dict(torch.load('data\\checkpoint\\checkpoint-0040.bin', map_location=torch.device('cpu')))
net.eval()
Network( (conv_block): Sequential( (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (res_blocks): Sequential( (0): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (1): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (2): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (3): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (4): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (5): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (6): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (7): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) (8): ResBlock( (conv_lower): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (conv_upper): Sequential( (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (se_block): SEBlock( (fc): Sequential( (0): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): ReLU() (2): Conv2d(8, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) ) ) (out_conv): Sequential( (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) (fc): Linear(in_features=128, out_features=10, bias=True) )
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline
sns.set()
for images, labels in test_loader:
pred = torch.argmax(net(images), axis=1)
print('confusion_matrix: \n', confusion_matrix(pred, labels))
print('accuracy_score:', accuracy_score(pred, labels))
print('precision_score:', precision_score(pred, labels, average='micro'))
print('f1-score:', f1_score(pred, labels, average='micro'))
break
confusion_matrix:
[[11 0 0 0 0 0 0 0 1 0]
[ 0 9 0 0 0 0 0 0 0 0]
[ 0 0 10 0 1 2 0 0 0 0]
[ 0 0 1 11 0 0 0 1 0 0]
[ 0 0 0 1 9 0 0 0 0 0]
[ 0 0 0 1 0 7 1 0 0 0]
[ 0 0 0 0 0 0 18 0 0 0]
[ 1 0 0 2 0 0 0 12 0 0]
[ 1 0 0 0 0 0 0 0 16 0]
[ 0 1 0 0 0 0 0 0 0 11]]
accuracy_score: 0.890625
precision_score: 0.890625
f1-score: 0.890625
pred
tensor([3, 8, 8, 8, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 0, 6, 7, 0, 4, 9,
2, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 3, 4, 9, 9, 5, 4, 6, 5, 6, 0, 9, 4, 9,
7, 6, 9, 8, 7, 3, 8, 8, 7, 3, 2, 5, 7, 5, 6, 3, 6, 2, 1, 2, 7, 7, 2, 6,
8, 8, 0, 2, 9, 3, 7, 8, 8, 1, 1, 7, 2, 2, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
8, 3, 1, 2, 8, 0, 8, 3])
conf_mat = confusion_matrix(labels, pred)
conf_mat
array([[11, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[ 0, 9, 0, 0, 0, 0, 0, 0, 0, 1],
[ 0, 0, 10, 1, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 11, 1, 1, 0, 2, 0, 0],
[ 0, 0, 1, 0, 9, 0, 0, 0, 0, 0],
[ 0, 0, 2, 0, 0, 7, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 1, 18, 0, 0, 0],
[ 0, 0, 0, 1, 0, 0, 0, 12, 0, 0],
[ 1, 0, 0, 0, 0, 0, 0, 0, 16, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 11]], dtype=int64)
df = pd.DataFrame(conf_mat, index=test_dataset.classes, columns=test_dataset.classes)
df
airplane | automobile | bird | cat | deer | dog | frog | horse | ship | truck | |
---|---|---|---|---|---|---|---|---|---|---|
airplane | 11 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 |
automobile | 0 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
bird | 0 | 0 | 10 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
cat | 0 | 0 | 0 | 11 | 1 | 1 | 0 | 2 | 0 | 0 |
deer | 0 | 0 | 1 | 0 | 9 | 0 | 0 | 0 | 0 | 0 |
dog | 0 | 0 | 2 | 0 | 0 | 7 | 0 | 0 | 0 | 0 |
frog | 0 | 0 | 0 | 0 | 0 | 1 | 18 | 0 | 0 | 0 |
horse | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 12 | 0 | 0 |
ship | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 16 | 0 |
truck | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 11 |
# 绘制混淆矩阵图
plt.figure(figsize=(12, 12))
plt.rcParams['font.sans-serif']=['SimHei']
sns.heatmap(df, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix")
plt.ylabel("True Class")
plt.xlabel("Predicted Class")
plt.show()
文章代码地址:madao33/computer-vision-learning
个人博客:madao33 blog
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。