当前位置:   article > 正文

细粒度识别 DCL 论文及代码学习笔记_细粒度目标识别

细粒度目标识别

论文部分

论文链接CVPR 2019 Open Access Repository

动机

在过去几年,通用目标识别在大规模注释数据集和复杂模型的帮助下取得了重大进展。然而,识别诸如鸟类和汽车等细粒度的目标类别仍然是一项具有挑战性的任务。粗略一看,细粒度的目标看起来是相似的,但它们可以通过有区别的局部细节来进行区分。
从具有区分度的目标部分学习判别特征表示在细粒度图像识别中起着关键作用。现有的细粒度识别方法可分为两类 (如下图所示):1)首先定位可区分的目标部分,然后对可区分区域进行分类。这类方法需要对目标或目标部分进行额外的边界框注释,成本高;2)通过注意力机制以无监督的方式自动定位判别区域,因此不需要额外的注释。然而,这些方法通常需要额外网络结构(如注意力机制)的辅助,从而增加了计算开销。

创新点

提出了“破坏和构建学习 (Destruction and Construction Learning,DCL)”框架用于细粒度识别。 对于 destruction,区域混淆机制 (region confusion mechanism,RCM) 迫使分类网络从判别区域中学习,对抗性损失可防止网络过拟合 RCM 引起的噪声。对于 construction,区域对齐网络通过对区域之间的语义相关性进行建模来恢复原始区域布局。此外,DCL 不需要额外的注释以及辅助网络。

方法论

DCL 的整个框架如下图所示,由四部分组成 (RCM、分类网络、对抗学习网络、区域对齐网络),但在推理阶段只需要使用分类网络即可。

(1) Destruction Learning

对于细粒度的图像识别,局部细节比全局信息更重要。在大多数情况下,不同的细粒度类别通常具有相似的全局结构,仅在某些局部细节上有所不同。因此作者提出了 RCM,以便更好地识别判别区域和学习判别特征。为了防止网络学习从破坏学习中引入的噪声模式,作者还提出了对抗学习来拒绝与细粒度分类无关的 RCM-induced 模式。

RCM:如下图所示,RCM 可以破坏局部图像区域的空间布局。给定输入图像 I,首先将图像均分成 N×N 个区域,记为 R_{i,j},其中 i 和 j 分别是水平和垂直的索引。对于 R 的第 j 行,先生成一个大小为 N 的向量 q_{j},其中第 i 个元素以下面的规则进行移动: q_{j,i}=i+r,r 服从均匀分布 U(-k,k),k 的范围是 [1,N)。列元素移动同理。移动行和列的元素后可以得到一个新的区域组合:

将原始图像 I,打乱后的图像 \phi(I) 以及标签 l 组成一个三元组 (I,\phi (i),l) 作为训练的输入。分类网络将输入图片映射为一个概率分布向量 C(I,\theta _{cls})\theta _{cls} 是网络中可训练的参数。分类网络的损失函数如下:

由于全局结构已被破坏,为了识别这些随机打乱的图像,分类网络必须找到判别区域并学习它们之间的差异。

对抗学习:由 RCM 打乱的图片不一定都有助于细粒度识别,RCM 会引入噪声,作者因此提出对抗学习。使用 one-hot 码标记每张图片,0 代表打乱后的图片,1 代表原图。作者在框架中添加一个判别器作为一个新的分支来判断图像 I 是否被破坏,公式如下:

其中,C(I,\theta _{cls}^{[1,m]}) 表示从骨干网络中第 m 层的输出中提取的特征向量,另一个参数为判别器中的可训练的参数。该网络的损失函数如下:

(2) Construction Learning

作者提出了提出了一个带有区域重构损失的区域对齐网络,它通过度量图像中不同区域的定位精度来诱导分类网络对区域间的语义关联进行建模。

公式如下:

 (3) Destruction and Construction Learning

在 DCL 框架中,分类、对抗和区域对齐损失以端到端的方式进行训练,其中网络可以利用增强的局部细节和良好建模的目标部分相关性来进行细粒度识别。具体来说,希望最小化以下目标:

破坏学习主要帮助模型从判别区域学习,而构建学习根据区域之间的语义相关性重新排列学习到的局部细节。因此,DCL 基于来自判别区域的细节特征产生了一组复杂多样的视觉表示。

实验结果

 

作者提出了 DCL 框架来进行细粒度的图片识别。其中的 destruction 部分提高了网络学习判别区域特征的能力,construction 部分构建了各部分的空间语义关联信息,模型不再需要额外的监督信息即可端到端训练。此外,模型参数较小,容易训练和应用。

代码

代码链接:GitHub - JDAI-CV/DCL: Destruction and Construction Learning for Fine-grained Image Recognition

环境配置

环境配置参考GitHub,

  1. conda create --name DCL file conda_list.txt
  2. pip install pretrainedmodels

克隆下来后,把在imagenet上预训练好的模型放到./models/pretrained目录下,预训练模型的下载链接:https://download.pytorch.org/models/resnet50-19c8e357.pth

接下来就是数据集下载,可以参考:细粒度数据集:CUB-200-2011 CUB,百度云下载_画外人易朽的博客-CSDN博客_cub数据集下载

下载好数据集后,为了和代码中的路径对齐,可以把全部图片copy到data文件夹中,参考代码如下:

  1. import os
  2. import sys
  3. import shutil
  4. dir = r'D:\Projects\Data_Augmentation\CUB_200_2011\images'
  5. for i in os.listdir(dir):
  6. path1 = os.path.join(dir, i)
  7. for j in os.listdir(path1):
  8. path2 = os.path.join(path1, j)
  9. # print(path2)
  10. # sys.exit()
  11. shutil.copy(path2, os.path.join(r'D:\Projects\DCL\datasets\CUB\data', j))

由于克隆下来的项目中只有train.txt文件而缺少ct_train.txt等文件,所以我就直接读取train.txt文件,并对源码做了修改,

首先是配置脚本的修改,主要修改路径位置,修改后的config.py脚本如下:

  1. import os
  2. import pandas as pd
  3. import torch
  4. from transforms import transforms
  5. from utils.autoaugment import ImageNetPolicy
  6. # pretrained model checkpoints
  7. pretrained_model = {'resnet50' : './models/pretrained/resnet50-19c8e357.pth',}
  8. # transforms dict
  9. def load_data_transformers(resize_reso=512, crop_reso=448, swap_num=[7, 7]):
  10. center_resize = 600
  11. Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  12. data_transforms = {
  13. 'swap': transforms.Compose([
  14. transforms.Randomswap((swap_num[0], swap_num[1])),
  15. ]),
  16. 'common_aug': transforms.Compose([
  17. transforms.Resize((resize_reso, resize_reso)),
  18. transforms.RandomRotation(degrees=15),
  19. transforms.RandomCrop((crop_reso,crop_reso)),
  20. transforms.RandomHorizontalFlip(),
  21. ]),
  22. 'train_totensor': transforms.Compose([
  23. transforms.Resize((crop_reso, crop_reso)),
  24. # ImageNetPolicy(),
  25. transforms.ToTensor(),
  26. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  27. ]),
  28. 'val_totensor': transforms.Compose([
  29. transforms.Resize((crop_reso, crop_reso)),
  30. transforms.ToTensor(),
  31. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  32. ]),
  33. 'test_totensor': transforms.Compose([
  34. transforms.Resize((resize_reso, resize_reso)),
  35. transforms.CenterCrop((crop_reso, crop_reso)),
  36. transforms.ToTensor(),
  37. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  38. ]),
  39. 'None': None,
  40. }
  41. return data_transforms
  42. class LoadConfig(object):
  43. def __init__(self, args, version):
  44. if version == 'train':
  45. get_list = ['train', 'val']
  46. elif version == 'val':
  47. get_list = ['val']
  48. elif version == 'test':
  49. get_list = ['test']
  50. else:
  51. raise Exception("train/val/test ???\n")
  52. ###############################
  53. #### add dataset info here ####
  54. ###############################
  55. # put image data in $PATH/data
  56. # put annotation txt file in $PATH/anno
  57. if args.dataset == 'product':
  58. self.dataset = args.dataset
  59. self.rawdata_root = './../FGVC_product/data'
  60. self.anno_root = './../FGVC_product/anno'
  61. self.numcls = 2019
  62. elif args.dataset == 'CUB':
  63. self.dataset = args.dataset
  64. self.rawdata_root = '/root/projects/DCL/datasets/CUB/data'
  65. self.anno_root = '/root/projects/DCL/datasets/CUB/anno'
  66. self.numcls = 200
  67. elif args.dataset == 'STCAR':
  68. self.dataset = args.dataset
  69. self.rawdata_root = './dataset/st_car/data'
  70. self.anno_root = './dataset/st_car/anno'
  71. self.numcls = 196
  72. elif args.dataset == 'AIR':
  73. self.dataset = args.dataset
  74. self.rawdata_root = './dataset/aircraft/data'
  75. self.anno_root = './dataset/aircraft/anno'
  76. self.numcls = 100
  77. else:
  78. raise Exception('dataset not defined ???')
  79. # annotation file organized as :
  80. # path/image_name cls_num\n
  81. if 'train' in get_list:
  82. self.train_anno = '/root/projects/DCL/datasets/CUB/anno/train.txt'
  83. if 'val' in get_list:
  84. self.val_anno = '/root/projects/DCL/datasets/CUB/anno/test.txt'
  85. if 'test' in get_list:
  86. self.test_anno = '/root/projects/DCL/datasets/CUB/anno/test.txt'
  87. self.swap_num = args.swap_num
  88. self.save_dir = './net_model'
  89. if not os.path.exists(self.save_dir):
  90. os.mkdir(self.save_dir)
  91. self.backbone = args.backbone
  92. self.use_dcl = True
  93. self.use_backbone = False if self.use_dcl else True
  94. self.use_Asoftmax = False
  95. self.use_focal_loss = False
  96. self.use_fpn = False
  97. self.use_hier = False
  98. self.weighted_sample = False
  99. self.cls_2 = True
  100. self.cls_2xmul = False
  101. self.log_folder = './logs'
  102. if not os.path.exists(self.log_folder):
  103. os.mkdir(self.log_folder)

接下来就对数据读取脚本进行修改,主要修改数据集的读取方式,修改后的dataset_DCL.py脚本如下:

  1. # coding=utf8
  2. from __future__ import division
  3. import os
  4. import torch
  5. import torch.utils.data as data
  6. import pandas
  7. import random
  8. import PIL.Image as Image
  9. from PIL import ImageStat
  10. import sys
  11. import pdb
  12. def random_sample(img_names, labels):
  13. anno_dict = {}
  14. img_list = []
  15. anno_list = []
  16. for img, anno in zip(img_names, labels):
  17. if not anno in anno_dict:
  18. anno_dict[anno] = [img]
  19. else:
  20. anno_dict[anno].append(img)
  21. for anno in anno_dict.keys():
  22. anno_len = len(anno_dict[anno])
  23. fetch_keys = random.sample(list(range(anno_len)), anno_len//10)
  24. img_list.extend([anno_dict[anno][x] for x in fetch_keys])
  25. anno_list.extend([anno for x in fetch_keys])
  26. return img_list, anno_list
  27. class dataset(data.Dataset):
  28. def __init__(self, Config, anno, swap_size=[7,7], common_aug=None, swap=None, totensor=None, train=False, train_val=False, test=False):
  29. self.root_path = Config.rawdata_root
  30. self.numcls = Config.numcls
  31. self.dataset = Config.dataset
  32. self.use_cls_2 = Config.cls_2
  33. self.use_cls_mul = Config.cls_2xmul
  34. # if isinstance(anno, pandas.core.frame.DataFrame):
  35. # self.paths = anno['ImageName'].tolist()
  36. # self.labels = anno['label'].tolist()
  37. # elif isinstance(anno, dict):
  38. # self.paths = anno['img_name']
  39. # self.labels = anno['label']
  40. f = open(anno)
  41. self.data_lists = f.readlines()
  42. # if train_val:
  43. # self.paths, self.labels = random_sample(self.paths, self.labels)
  44. self.common_aug = common_aug
  45. self.swap = swap
  46. self.totensor = totensor
  47. self.cfg = Config
  48. self.train = train
  49. self.swap_size = swap_size
  50. self.test = test
  51. def __len__(self):
  52. return len(self.data_lists)
  53. def __getitem__(self, item):
  54. # print(self.data_lists)
  55. # sys.exit()
  56. img_label = self.data_lists[item].strip('\n').split(' ')
  57. img_path = os.path.join(self.root_path, img_label[0])
  58. img = self.pil_loader(img_path)
  59. if self.test:
  60. img = self.totensor(img)
  61. label = int(img_label[1]) - 1
  62. return img, label, img_label[0]
  63. img_unswap = self.common_aug(img) if not self.common_aug is None else img
  64. image_unswap_list = self.crop_image(img_unswap, self.swap_size)
  65. swap_range = self.swap_size[0] * self.swap_size[1]
  66. swap_law1 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
  67. if self.train:
  68. img_swap = self.swap(img_unswap)
  69. image_swap_list = self.crop_image(img_swap, self.swap_size)
  70. unswap_stats = [sum(ImageStat.Stat(im).mean) for im in image_unswap_list]
  71. swap_stats = [sum(ImageStat.Stat(im).mean) for im in image_swap_list]
  72. swap_law2 = []
  73. for swap_im in swap_stats:
  74. distance = [abs(swap_im - unswap_im) for unswap_im in unswap_stats]
  75. index = distance.index(min(distance))
  76. swap_law2.append((index-(swap_range//2))/swap_range)
  77. img_swap = self.totensor(img_swap)
  78. # one-hot编码从0开始
  79. label = int(img_label[1]) - 1
  80. if self.use_cls_mul:
  81. label_swap = label + self.numcls
  82. if self.use_cls_2:
  83. label_swap = -1
  84. img_unswap = self.totensor(img_unswap)
  85. return img_unswap, img_swap, label, label_swap, swap_law1, swap_law2, img_label[0]
  86. else:
  87. label = int(img_label[1]) - 1
  88. swap_law2 = [(i-(swap_range//2))/swap_range for i in range(swap_range)]
  89. label_swap = label
  90. img_unswap = self.totensor(img_unswap)
  91. return img_unswap, label, label_swap, swap_law1, swap_law2, img_label[0]
  92. def pil_loader(self,imgpath):
  93. with open(imgpath, 'rb') as f:
  94. with Image.open(f) as img:
  95. return img.convert('RGB')
  96. def crop_image(self, image, cropnum):
  97. width, high = image.size
  98. crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
  99. crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
  100. im_list = []
  101. for j in range(len(crop_y) - 1):
  102. for i in range(len(crop_x) - 1):
  103. im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
  104. return im_list
  105. def get_weighted_sampler(self):
  106. img_nums = len(self.data_lists)
  107. l = []
  108. for i in range(img_nums):
  109. l.append(int(self.data_lists[i].strip('\n').split(' ')[-1]))
  110. weights = [l.count(x) for x in range(self.numcls)]
  111. return torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=img_nums)
  112. def collate_fn4train(batch):
  113. imgs = []
  114. label = []
  115. label_swap = []
  116. law_swap = []
  117. img_name = []
  118. for sample in batch:
  119. imgs.append(sample[0])
  120. imgs.append(sample[1])
  121. label.append(sample[2])
  122. label.append(sample[2])
  123. if sample[3] == -1:
  124. label_swap.append(1)
  125. label_swap.append(0)
  126. else:
  127. label_swap.append(sample[2])
  128. label_swap.append(sample[3])
  129. law_swap.append(sample[4])
  130. law_swap.append(sample[5])
  131. img_name.append(sample[-1])
  132. return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
  133. def collate_fn4val(batch):
  134. imgs = []
  135. label = []
  136. label_swap = []
  137. law_swap = []
  138. img_name = []
  139. for sample in batch:
  140. imgs.append(sample[0])
  141. label.append(sample[1])
  142. if sample[3] == -1:
  143. label_swap.append(1)
  144. else:
  145. label_swap.append(sample[2])
  146. law_swap.append(sample[3])
  147. img_name.append(sample[-1])
  148. return torch.stack(imgs, 0), label, label_swap, law_swap, img_name
  149. def collate_fn4backbone(batch):
  150. imgs = []
  151. label = []
  152. img_name = []
  153. for sample in batch:
  154. imgs.append(sample[0])
  155. if len(sample) == 7:
  156. label.append(sample[2])
  157. else:
  158. label.append(sample[1])
  159. img_name.append(sample[-1])
  160. return torch.stack(imgs, 0), label, img_name
  161. def collate_fn4test(batch):
  162. imgs = []
  163. label = []
  164. img_name = []
  165. for sample in batch:
  166. imgs.append(sample[0])
  167. label.append(sample[1])
  168. img_name.append(sample[-1])
  169. return torch.stack(imgs, 0), label, img_name

接下里便可进行训练和测试了,

python train.py --tb 16 --tnw 16 --vb 16 --vnw 16 --detail training_descibe --cls_mul --swap_num 7 7

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/311800
推荐阅读
相关标签
  

闽ICP备14008679号