当前位置:   article > 正文

Saliency maps

saliency maps 博客

Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps

问题

这篇文章和ZFnet相似,旨在研究网络可视化的问题,根据分裂网络最后的向量来反推出最原始的图像,如果假设输入(input)是I, 而输入图像对应的标签是c, 而分类器的得分是Sc(I)(也就是第c个分量),那么我们希望找到一个I使得Sc(I)足够大,说明这个输入很有可能是这个类的:

argmaxISc(I),

不过,论文实际上是研究下面的问题:
argmaxISc(I)λI22.

其实就是加了一个正则化项,我想这应该是处于实际角度出发的,因为在处理图像的时候往往有一个Normlize的过程,所以如果I太过“巨大”那肯定是不合适的——起码它都不能称为一个图像.

细节

变量

需要注意的是,上面的问题是关于I,也就是图像来说的,如果有k个类,那么理论上应该有k张对应的图像(同一个λ).

然后论文的结果是这样的:

在这里插入图片描述

我的结果是这样的(CIFAR10):
在这里插入图片描述
相差甚远, 是λ=0.1不合适?

Sc(I)

需要一提的是,这个Sc(I)不是sigmoid后的值,而是之前的分数,作者是这么解释的,因为sigmoid:

Pc=Sccexp(Sc),

我们的目的是提高Sc,而如果是Pc, 那么我们可以通过降低别的Sc来间接提高Pc,而非提高Sc, 有点道理吧,试了一下,在原来的参数条件下几乎不学习了...

扩展

作者提到这个方案可以用于定位, 首先要说明的是,通过这种方法,我们可以“定位”(虽然可能是臆想)敏感地带.

输入一张图片,计算

Sc(I)I,

结果是一个“矩阵”(张量?), 其中的元素的绝对值大小可以衡量对类别判断的重要,即越大越是敏感地带.

在这里插入图片描述
那个简单例子,感觉没能和好的说服我. 如果网络就是一个线性判别器,那么照此思路,其敏感程度就是权重,直观上这样似乎如此,但是感觉就像是抛开了数据本身...但的确是有道理的. 还有一个问题是,对于一张图片,如果它被误判了, 那么是选择其本身的标签,还是网络所判断的那个c呢?
在我的实验中,二者似乎没有太大的差别.

回到定位的话题,计算出梯度的矩阵后,如果有C个通道,C个通道的每个元素的绝对值的最大作为那个位置的敏感程度,如此,如果图片是(C,H,W), 那么最后会得到一个(1,H,W)的矩阵,其中的元素则反应了敏感程度.

但是,其中的敏感程度指示反应了物体所在的大概位置,作者说还要通过一种颜色的连续来更为细致地框定范围,那种技术我不知道,就简单地做个实验:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

细看,我觉得还是有那么点感觉的.

代码

I的时候,不知道怎么利用已有的梯度方法,就自己写了一个. 网络的测试成功率为60%,因为是一个比较简单的网络,大的网络实在难以下手.

  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import torch.nn as nn
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. class Net(nn.Module):
  8. def __init__(self, num):
  9. super(Net, self).__init__()
  10. self.conv = nn.Sequential(
  11. nn.Conv2d(3, 16, 4, 2), #3x32x32 --> 8x15x15
  12. nn.ReLU(),
  13. nn.MaxPool2d(2, 2), # 15 --> 7
  14. nn.Conv2d(16, 64, 3, 1, 1), #16x7x7 --> 64x7x7
  15. nn.ReLU(),
  16. nn.MaxPool2d(2, 1) #7-->6
  17. )
  18. self.dense = nn.Sequential(
  19. nn.Linear(64 * 6 * 6, 256),
  20. nn.ReLU(),
  21. nn.Linear(256, num)
  22. )
  23. def forward(self, x):
  24. x = self.conv(x)
  25. x = x.view(x.size(0), -1)
  26. out = self.dense(x)
  27. return out
  28. class SGD:
  29. def __init__(self, lr=1e-3, momentum=0.9):
  30. self.v = 0
  31. self.lr = lr
  32. self.momentum = momentum
  33. def step(self, x, grad):
  34. self.v = self.momentum * self.v + grad
  35. return x + self.lr * self.v
  36. class Train:
  37. def __init__(self, trainset, num=10, lr=1e-4, momentum=0.9,loss_function=nn.CrossEntropyLoss()):
  38. self.net = Net(num)
  39. self.trainset = trainset
  40. self.criterion = loss_function
  41. self.opti = torch.optim.SGD(self.net.parameters(), lr=lr, momentum=momentum)
  42. def trainnet(self, iterations, path):
  43. running_loss = 0.0
  44. for epoch in range(iterations):
  45. for i, data in enumerate(self.trainset):
  46. imgs, labels = data
  47. output = self.net(imgs)
  48. loss = self.criterion(output, labels)
  49. self.opti.zero_grad()
  50. loss.backward()
  51. self.opti.step()
  52. running_loss += loss
  53. if i % 10 == 9:
  54. print("[epoch: {} loss: {:.7f}]".format(
  55. epoch,
  56. running_loss / 10
  57. ))
  58. running_loss = 0.0
  59. torch.save(self.net.state_dict(), path)
  60. def loading(self, path):
  61. self.net.load_state_dict(torch.load(path))
  62. self.net.eval()
  63. def visual(self, iterations=100, digit=0, gamma=0.1, lr=1e-3, momentum=0.9):
  64. def criterion(out, x, digit, gamma=0.1):
  65. return out[0][digit] - gamma * torch.norm(x, 2) ** 2
  66. opti = SGD(lr, momentum)
  67. x = torch.zeros((1, 3, 32, 32), requires_grad=True, dtype=torch.float)
  68. for i in range(iterations):
  69. output = self.net(x)
  70. loss = criterion(output, x, digit, gamma)
  71. print(loss.item())
  72. loss.backward()
  73. x = torch.tensor(opti.step(x, x.grad), requires_grad=True)
  74. img = x[0].detach()
  75. img = img / 2 + 0.5
  76. img = img / torch.max(img.abs())
  77. img = np.transpose(img, (1, 2, 0))
  78. print(img[0])
  79. plt.imshow(img)
  80. plt.title(classes[digit])
  81. plt.show()
  82. return x
  83. def local(self, img, label):
  84. cimg = img.view(1, 3, 32, 32).detach()
  85. cimg.requires_grad = True
  86. output = self.net(cimg)
  87. print(output)
  88. print(label)
  89. s = output[0][label]
  90. s.backward()
  91. with torch.no_grad():
  92. grad = cimg.grad.data[0]
  93. graph = torch.max(torch.abs(grad), 0)[0]
  94. saliency = graph.detach().numpy()
  95. print(np.max(saliency))
  96. img = img.detach().numpy()
  97. img = img / 2 + 0.5
  98. img = np.transpose(img, (1, 2, 0))
  99. fig, ax = plt.subplots(1, 2)
  100. ax[0].set_title(classes[label])
  101. ax[0].imshow(img)
  102. ax[1].imshow(saliency, cmap=plt.cm.hot)
  103. plt.show()
  104. def testing(self, testloader):
  105. correct = 0
  106. total = 0
  107. with torch.no_grad():
  108. for data in testloader:
  109. images, labels = data
  110. outputs = self.net(images)
  111. _, predicted = torch.max(outputs.data, 1)
  112. total += labels.size(0)
  113. correct += (predicted == labels).sum().item()
  114. print('Accuracy of the network on the 10000 test images: %d %%' % (
  115. 100 * correct / total))
  116. root = "C:/Users/pkavs/1jupiterdata/data"
  117. #准备训练集
  118. trainset = torchvision.datasets.CIFAR10(root=root, train=True,
  119. download=False,
  120. transform=transforms.Compose(
  121. [transforms.ToTensor(),
  122. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  123. ))
  124. train_loader = torch.utils.data.DataLoader(trainset, batch_size=64,
  125. shuffle=True, num_workers=0)
  126. testset = torchvision.datasets.CIFAR10(root=root, train=False,
  127. download=False,
  128. transform=transforms.Compose(
  129. [transforms.ToTensor(),
  130. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  131. ))
  132. testloader = torch.utils.data.DataLoader(testset, batch_size=64,
  133. shuffle=False, num_workers=0)
  134. classes = ('plane', 'car', 'bird', 'cat',
  135. 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  136. path = root + "/visual1.pt"
  137. test = Train(train_loader, lr=1e-4)
  138. test.loading(path)
  139. #test.testing(testloader) 60%
  140. data = next(iter(train_loader))
  141. imgs, labels = data
  142. img = imgs[0]
  143. label = labels[0]
  144. test.local(img, label)
  145. #test.visual(1000, digit=3)

转载于:https://www.cnblogs.com/MTandHJ/p/11355180.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/1010579
推荐阅读
相关标签
  

闽ICP备14008679号