赞
踩
该篇论文主要有两个任务,第一个任务是提供了一个large-scale数据集, 包含80万张图片,包含不同角度,不同场景,买家秀,买家秀图片。数据集分为4个子数据集,用于4个主要任务,分别是服装类别与属性预测、店内衣服检索、买家秀卖家秀衣服检索、关键点和目标检测。第二个任务是提出了一个深度学习网络FashionNet,该博客主要用于记录fashionnet部分的内容,具体deepfashion数据集介绍请见其它博客。
1.服装间由于款式、质地和裁剪方式的差异,会导致模型识别服装结果发生混淆,比如同一服装间的裁剪方式不一致导致模型认为不是同一服装,或不同服装由于款式,质地十分相近会使得模型认为是同一服装。
2.服装在展示时发生形变或部分被遮挡是不可避免的,这给模型在识别效果上制造了困难。
3.在不同场景下拍摄出来的服装图像通常表现出很严重的差异,例如:买家秀和买家秀的不同。
1.提出大型时装数据集DeepFashion
2.提出了FashionNet 进行DeepFashion 数据集的衣服属性的预测和分类
3.定义了多种任务的评价标准
如上图,文章认为在衣服识别中,若一个数据集同时拥有类别、属性和关键点的标签,会得到两个优势,第一个是给定关键点的位置有助于服装识别精度的提高,第二个是服装的大量属性的引入导致了服装的特征空间能够得到更好的划分,从而能够促进跨域服装图像的识别和检索。
FashionNet 建立了三个benchmark,以供后续研究作为效果评价的基线。它们分别是:服装特征分类、店内服装检索(网站中同一商品展示图的匹配),和跨领域的服务检索(买家秀卖家秀衣服检索)。具体代码可以在mmfashion中查看。
以下内容来自:DeepFashion服装检索及代码实现
论文提到的FashionNet方法采用的网络backbone为VGG16,在VGG16网络中,stage4之前的网络结构权值都是共享的,在这一基础上发展了三个网络分支,如上图所示。上图的橙色部分、绿色部分和蓝色部分分别代表了全局特征提取网络、局部特征提取网络和特征点回归网络。
- 橙色部分(全局分支):基于共享的VGG16前4个stage,增加了自己的conv5卷积结构,然后接一个全连接层用于特征编码,输出特征作为最终检索特征的一部分。
- 蓝色部分(关键点回归分支):与VGG16网络结构类似,在stage5卷积结构后面使用两个全连接层进行特征编码,根据输出的特征分别再使用两个全连接层输出服装关键点和关键点是否可见的标签。
- 绿色部分(局部分支):借助于蓝色分支输出的服装关键点,论文作者提出了一个新的层,landmark pooling layer,与Faster RCNN中的ROI Pooling层类似,目的在于对提取特征图上的某一块区域进行统一的编码工作。这部分的特征称之为局部特征,将橙色部分和绿色部分的特征按通道组合在一起之后经过一个全连接层再编码一次,然后基于这个特征对属性、类别标签进行分类学习,同时使用triplet loss进行辅助学习。
训练过程中,首先加大蓝色模块的权值,即增大 L l a n d m a r k s L_{landmarks} Llandmarks和 L v i s i b i l i t y L_{visibility} Lvisibility的权重,其他权重保持相对较小,先把蓝色模块训练收敛,然后再减少权值,整体训练
1.landmark回归loss为加权的L2 loss,其中 V j V_j Vj为其权值,代表landmark的可见性,对于不可见的就不进行梯度的回传。
L l a n d m a r k s = ∑ j = 1 ∣ D ∣ ∣ ∣ V j ⋅ ( l ^ j − l j ) ∣ ∣ 2 2 L_{landmarks} =Llandmarks=∑j=1∣D∣∣∣Vj⋅(l^j−lj)∣∣22∑|D|j=1||Vj⋅(l^j−lj)||22
2.衣服类别分类和landmark是否可见分类,采用传统的softmax crossentrop loss, 表示为 L v i s i b i l i t y L_{visibility} Lvisibility 和 L c a t e g o r y L_{category} Lcategory
3.衣服属性分类采用加权的sigmoid crossentrop loss, X j X_j Xj代表第j个衣服, a j a_j aj代表第j个衣服的属性, W p o s W_{pos} Wpos和 W n e g W_{neg} Wneg代表正负样本的权值
L a t t r i b u t e s = ∑ j = 1 ∣ D ∣ ( w p o s ⋅ a j l o g p ( a j ∣ x j ) + w n e g ⋅ ( 1 − a j ) l o g ( 1 − p ( a j ∣ x j ) ) ) L_{attributes} = \sum_{j=1}^{|D|} (w_{pos} \cdot a_jlogp(a_j|x_j) + w_{neg} \cdot (1 - a_j)log(1-p(a_j|x_j))) Lattributes=j=1∑∣D∣(wpos⋅ajlogp(aj∣xj)+wneg⋅(1−aj)log(1−p(aj∣xj)))
4.衣服服装对的loss采用triplet loss。 ( x , x + , x − )(x,x+,x−) 表示三元组,m表示margin,d表示距离函数。(x,x+,x−)
L t r i p l e t = ∑ j = 1 ∣ D ∣ m a x { 0 , m + d ( x j , x j + ) − d ( x j , x j − ) } L_{triplet} = \sum_{j=1}^{|D|} maxLtriplet=j=1∑∣D∣max{0,m+d(xj,xj+)−d(xj,xj−)}{0,m+d(xj,x+j)−d(xj,x−j)}
对于一张图片来说,从最初的(224,224,3)经过vgg16的stage4输出一个(7,7,512)feature map, 这个feature map连接着2个分支,包括全局分支、局部分支;其中全局分支先AdaptiveAvgPool2d池化为(7,7,512),如果是用户自定义的图像输入的化,vgg输出的不一定是(7,7,512),然后将该feature map展平后通过几个全连接层得到图片的全局表示为(4096,)的维度;而对于局部分支,由于输入为feature map以及landmarks,经过仿射变换可以得到局部关键点向量,然后经过全连接层可得到局部向量表示(4096,);全局局部拼接在一起后可以得到最终的向量表示。
基本思路就是通过仿射变换将关键点放到最中间,同时根据一定比例放大,这样对原来的tensor进行卷积就相当于局部卷积了,推荐一个讲的比较好的博客,Pytorch中的仿射变换(affine_grid)
已知deepfashion数据集有四个子任务,分别是服装类别与属性预测、店内衣服检索、买家秀卖家秀衣服检索、关键点和目标检测。而以上的FashionNet为一个整体的框架,在观察mmfashion中的子任务的代码后发现,其实每个任务并不一定需要同时预测category、attributes、triplet、landmarks和landmark visibility,可以只用部分分支就可以进行预测,比如类别或属性预测没有必要使用triplet loss,同时我们数据集也可能没有服装对的数据。在mmfashion中可以看到,FashionNet的使用还需要对具体数据情况来实现模型。
PS:以下内容为mmfashion中买家秀卖家秀衣服检索模型的部分代码, 主要是便于本人理解和记忆。其他部分的训练过程和数据准备可参考GETTING_STARTED和DATA_PREPARATION
|---- Anno 标注文件文件夹
|-------- list_attr_cloth.txt 衣服款式中英文对照表
|-------- list_attr_items.txt 衣服属性标注,1000个属性
|-------- list_attr_type.txt 衣服属性类型中英文对照表
|-------- list_bbox_consumer2shop.txt 图片中服装的bbox框标注
|-------- list_item_consumer2shop.txt 图片中商品编号,每种商品可以用多个商品对
|-------- list_landmarks_consumer2shop.txt 图片中商品的landmarks标注
|---- Eval 数据集划分文件
|-------- list_eval_partition.txt 训练集测试集划分,商品对的形式
|---- Img 图片文件夹
数据准备运行代码为: python prepare_consumer_to_shop.py
PREFIX = 'Consumer_to_shop/Anno' def split_img(): fn = open('Consumer_to_shop/Eval/list_eval_partition.txt').readlines() # train dataset train_consumer2shop = open( os.path.join(PREFIX, 'train_consumer2shop.txt'), 'w') train_imgs = [] # test dataset test_consumer = open(os.path.join(PREFIX, 'consumer.txt'), 'w') test_shop = open(os.path.join(PREFIX, 'shop.txt'), 'w') consumer_imgs, shop_imgs = [], [] for i, line in enumerate(fn[2:]): aline = line.strip('\n').split() consumer, shop, _, cate = aline[0], aline[1], aline[2], aline[3] if cate == 'train': newline = consumer + ' ' + shop + '\n' train_consumer2shop.write(newline) train_imgs.append(consumer) train_imgs.append(shop) else: newline = consumer + '\n' test_consumer.write(newline) newline = shop + '\n' test_shop.write(newline) consumer_imgs.append(consumer) shop_imgs.append(shop) train_consumer2shop.close() test_consumer.close() test_shop.close() return train_imgs, consumer_imgs, shop_imgs
该部分获取训练集测试集的list,其中train_imgs为consumer图像和shop图像交替存储的list, consumer_imgs为测试集中consumer图像存储的list,shop_imgs为测试集合中shop图像存储的list,该函数split_img()还生成了对应与于train_imgs, consumer_imgs, shop_imgs的txt文件。
def split_bbox(train_set, consumer_set, shop_set): rf = open(os.path.join(PREFIX, 'list_bbox_consumer2shop.txt')).readlines() img2bbox = {} for i, line in enumerate(rf[2:]): aline = line.strip('\n').split() img = aline[0] bbox = aline[-4:] img2bbox[img] = bbox wf_train = open(os.path.join(PREFIX, 'list_bbox_train.txt'), 'w') wf_consumer = open(os.path.join(PREFIX, 'list_bbox_consumer.txt'), 'w') wf_shop = open(os.path.join(PREFIX, 'list_bbox_shop.txt'), 'w') for i, img in enumerate(train_set): bbox = img2bbox[img] newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[ 2] + ' ' + bbox[3] + '\n' wf_train.write(newline) for i, img in enumerate(consumer_set): bbox = img2bbox[img] newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[ 2] + ' ' + bbox[3] + '\n' wf_consumer.write(newline) for i, img in enumerate(shop_set): bbox = img2bbox[img] newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[ 2] + ' ' + bbox[3] + '\n' wf_shop.write(newline) wf_train.close() wf_consumer.close() wf_shop.close()
该函数生成了对应train_imgs, consumer_imgs, shop_imgs图片的bbox文件,文件每一行为: img bbox[0] bbox[1] bbox[2] bbox[3]
def split_ids(train_set, consumer_set, shop_set): id2label = dict() rf = open(os.path.join(PREFIX, 'list_item_consumer2shop.txt')).readlines() for i, line in enumerate(rf[1:]): id2label[line.strip('\n')] = i def write_id(cloth, wf): for i, line in enumerate(cloth): id = line.strip('\n').split('/')[3] label = id2label[id] wf.write('%s\n' % str(label)) wf.close() train_id = open(os.path.join(PREFIX, 'train_id.txt'), 'w') consumer_id = open(os.path.join(PREFIX, 'consumer_id.txt'), 'w') shop_id = open(os.path.join(PREFIX, 'shop_id.txt'), 'w') write_id(train_set, train_id) write_id(consumer_set, consumer_id) write_id(shop_set, shop_id)
首先对于所有商品id进行编码,同时对于train_imgs, consumer_imgs, shop_imgs中元素对于进行编码并且生成相应的文件。
def split_lms(train_set, consumer_set, shop_set): rf = open(os.path.join(PREFIX, 'list_landmarks_consumer2shop.txt')).readlines() img2landmarks = {} for i, line in enumerate(rf[2:]): aline = line.strip('\n').split() img = aline[0] landmarks = aline[3:] img2landmarks[img] = landmarks wf_train = open(os.path.join(PREFIX, 'list_landmarks_train.txt'), 'w') wf_consumer = open( os.path.join(PREFIX, 'list_landmarks_consumer.txt'), 'w') wf_shop = open(os.path.join(PREFIX, 'list_landmarks_shop.txt'), 'w') def write_landmarks(img_set, wf): for i, img in enumerate(img_set): landmarks = img2landmarks[img] one_lms = [] for j, lm in enumerate(landmarks): if j % 3 == 0: # visibility if lm == '0': # visible one_lms.append(landmarks[j + 1]) one_lms.append(landmarks[j + 2]) else: # invisible or truncated one_lms.append('000') one_lms.append('000') while len(one_lms) < 16: # 8 pairs one_lms.append('000') wf.write(img) wf.write(' ') for lm in one_lms: wf.write(lm) wf.write(' ') wf.write('\n') wf.close() write_landmarks(train_set, wf_train) write_landmarks(consumer_set, wf_consumer) write_landmarks(shop_set, wf_shop)
处理照片对于的关键点,保存其信息,并且对于不可见的关键点坐标置为000
总结:所需要准备的数据有包括训练集的图片对,分开来的训练集的query和gallery图片,以及图片所对应的编码后的商品编码,以及bbox信息和landmark信息,以及其共同拥有的list_attr_items.txt文件,为图片属性信息。
训练代码:
python tools/train_retriever.py --config configs/retriever_in_shop/roi_retriever_vgg.py
模型框架结构代码:
from .. import builder from ..registry import RETRIEVER from .base import BaseRetriever @RETRIEVER.register_module class RoIRetriever(BaseRetriever): def __init__(self, backbone, global_pool, roi_pool, concat, embed_extractor, attr_predictor=None, pretrained=None): super(RoIRetriever, self).__init__() self.backbone = builder.build_backbone(backbone) self.global_pool = builder.build_global_pool(global_pool) if roi_pool is not None: self.roi_pool = builder.build_roi_pool(roi_pool) else: self.roi_pool = None self.concat = builder.build_concat(concat) self.embed_extractor = builder.build_embed_extractor(embed_extractor) if attr_predictor is not None: self.attr_predictor = builder.build_attr_predictor(attr_predictor) else: self.attr_predictor = None self.init_weights(pretrained=pretrained) def extract_feat(self, x, landmarks): x = self.backbone(x) global_x = self.global_pool(x) global_x = global_x.view(global_x.size(0), -1) if landmarks is not None: local_x = self.roi_pool(x, landmarks) else: local_x = None x = self.concat(global_x, local_x) return x def forward_train(self, anchor, id, attr=None, pos=None, neg=None, anchor_lm=None, pos_lm=None, neg_lm=None, triplet_pos_label=None, triplet_neg_label=None): losses = dict() # extract features anchor_feat = self.extract_feat(anchor, anchor_lm) if pos is not None: pos_feat = self.extract_feat(pos, pos_lm) neg_feat = self.extract_feat(neg, neg_lm) losses['loss_id'] = self.embed_extractor( anchor_feat, id, return_loss=True, triplet=True, pos=pos_feat, neg=neg_feat, triplet_pos_label=triplet_pos_label, triplet_neg_label=triplet_neg_label) else: losses['loss_id'] = self.embed_extractor( anchor_feat, id, return_loss=True) if self.attr_predictor is not None: losses['loss_attr'] = self.attr_predictor( anchor_feat, attr, return_loss=True) return losses def simple_test(self, x, landmarks=None): """Test single image""" x = x.unsqueeze(0) landmarks = landmarks.unsqueeze(0) feat = self.extract_feat(x, landmarks) embed = self.embed_extractor.forward_test(feat)[0] return embed def aug_test(self, x, landmarks=None): """Test batch of images""" feat = self.extract_feat(x, landmarks) embed = self.embed_extractor.forward_test(feat) return embed def init_weights(self, pretrained=None): super(RoIRetriever, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) self.global_pool.init_weights() if self.roi_pool is not None: self.roi_pool.init_weights() self.concat.init_weights() self.embed_extractor.init_weights() if self.attr_predictor is not None: self.attr_predictor.init_weights()
以上为mmfashion中买家秀买家秀检索的模型框架代码,基本与deepfashion的论文框架是一致的,框架里的每个部件的具体内容可以看相应的config文件以及对应部件的代码,整个框架原理和mmdetection的框架调用是一致的。
数据准备运行代码为: python prepare_consumer_to_shop.py,则会得到以下内容
data = dict( imgs_per_gpu=8, workers_per_gpu=2, train=dict( type=dataset_type, img_path=os.path.join(data_root, 'Img'), img_file=os.path.join(data_root, 'Anno/train_consumer2shop.txt'), id_file=os.path.join(data_root, 'Anno/train_id.txt'), label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'), bbox_file=os.path.join(data_root, 'Anno/list_bbox_train.txt'), landmark_file=os.path.join(data_root, 'Anno/list_landmarks_train.txt'), img_size=img_size, find_three=True), query=dict( type=dataset_type, img_path=os.path.join(data_root, 'Img'), img_file=os.path.join(data_root, 'Anno/consumer.txt'), id_file=os.path.join(data_root, 'Anno/consumer_id.txt'), label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'), bbox_file=os.path.join(data_root, 'Anno/list_bbox_consumer.txt'), landmark_file=os.path.join(data_root, 'Anno/list_landmarks_consumer.txt'), img_size=img_size, find_three=True), gallery=dict( type=dataset_type, img_path=os.path.join(data_root, 'Img'), img_file=os.path.join(data_root, 'Anno/shop.txt'), id_file=os.path.join(data_root, 'Anno/shop_id.txt'), label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'), bbox_file=os.path.join(data_root, 'Anno/list_bbox_shop.txt'), landmark_file=os.path.join(data_root, 'Anno/list_landmarks_shop.txt'), img_size=img_size, find_three=True))
以上代码是从mmfashion的configs/retriever_consumer_to_shop/roi_retriever_vgg.py中data部分的截取的,表示模型的输入。
1.修改配置文件mmfashion的configs/retriever_consumer_to_shop/roi_retriever_vgg.py,以及提前下载vgg16的权重模型。
2.运行代码训练模型: python tools/train_retriever.py --config configs/retriever_consumer_to_shop/roi_retriever_vgg.py,以下是整体模型结构
RoIRetriever( (backbone): Vgg( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU(inplace=True) (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): ReLU(inplace=True) (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (6): ReLU(inplace=True) (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (8): ReLU(inplace=True) (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU(inplace=True) (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU(inplace=True) (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (15): ReLU(inplace=True) (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (18): ReLU(inplace=True) (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (20): ReLU(inplace=True) (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (22): ReLU(inplace=True) (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (25): ReLU(inplace=True) (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (27): ReLU(inplace=True) (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (29): ReLU(inplace=True) (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) ) (global_pool): GlobalPooling( (avgpool): AdaptiveAvgPool2d(output_size=(7, 7)) (global_layers): Sequential( (0): Linear(in_features=25088, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU(inplace=True) (5): Dropout(p=0.5, inplace=False) ) ) (roi_pool): RoIPooling( (maxpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False) (linear): Sequential( (0): Linear(in_features=4096, out_features=4096, bias=True) (1): ReLU(inplace=True) (2): Dropout(p=0.5, inplace=False) ) ) (concat): Concat( (fc_fusion): Linear(in_features=8192, out_features=4096, bias=True) ) (embed_extractor): EmbedExtractor( (embed_linear): Linear(in_features=4096, out_features=256, bias=True) (bn): BatchNorm1d(256, eps=33881, momentum=0.1, affine=True, track_running_stats=True) (id_linear): Linear(in_features=256, out_features=33881, bias=True) (loss_id): CELoss() (loss_triplet): TripletLoss() ) (attr_predictor): AttrPredictor( (linear_attr): Linear(in_features=4096, out_features=303, bias=True) (loss_attr): BCEWithLogitsLoss() ) )
测试集输出代码:
python tools/test_retriever.py \
--config configs/retriever_consumer_to_shop/roi_retriever_vgg.py \
--checkpoint checkpoint/Retrieve_consumer_to_shop/vgg/latest.pth
测试输出以及模型评估步骤如下:
1.通过模型我们可以分别得到query和gallery的向量表示,即query_embeds,gallery_embeds。
2.准备query中index对应的item_id(商品)字典,以及item_id对应index组成的list,gallery也是一样的操作。
3.遍历query_embeds中的每一个元素,对于每一个query都计算其与gallery集合中的元素的余弦距离,然后进行排序,得到topk的一个召回率,由此可以得到query的评价指标。
以下是一个query的召回代码片段
def single_query(self, query_id, query_feat, gallery_embeds, query_idx): query_dist = [] for j, feat in enumerate(gallery_embeds): cosine_dist = cosine( feat.reshape(1, -1), query_feat.reshape(1, -1)) query_dist.append(cosine_dist) query_dist = np.array(query_dist) order = np.argsort(query_dist) single_recall = dict() print(self.query_id2idx[query_id]) for k in self.topks: retrieved_idxes = order[:k] tp = 0 relevant_num = len(self.gallery_id2idx[query_id]) for idx in retrieved_idxes: retrieved_id = self.gallery_dict[idx] if query_id == retrieved_id: tp += 1 single_recall[k] = float(tp) / relevant_num return single_recall
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。