当前位置:   article > 正文

服装检索-DeepFashion_deepfashion数据集中的consumer-to-shop clothes retrieval

deepfashion数据集中的consumer-to-shop clothes retrieval benchmark应该怎么处

DeepFashion: Powering Robust Clothes Recognition and Retrieval with Rich Annotations

该篇论文主要有两个任务,第一个任务是提供了一个large-scale数据集, 包含80万张图片,包含不同角度,不同场景,买家秀,买家秀图片。数据集分为4个子数据集,用于4个主要任务,分别是服装类别与属性预测、店内衣服检索、买家秀卖家秀衣服检索、关键点和目标检测。第二个任务是提出了一个深度学习网络FashionNet,该博客主要用于记录fashionnet部分的内容,具体deepfashion数据集介绍请见其它博客。

1. 论文内容部分

1.1 服装识别算法常见挑战

1.服装间由于款式、质地和裁剪方式的差异,会导致模型识别服装结果发生混淆,比如同一服装间的裁剪方式不一致导致模型认为不是同一服装,或不同服装由于款式,质地十分相近会使得模型认为是同一服装。
2.服装在展示时发生形变或部分被遮挡是不可避免的,这给模型在识别效果上制造了困难。
3.在不同场景下拍摄出来的服装图像通常表现出很严重的差异,例如:买家秀和买家秀的不同。

1.2 论文创新点

1.提出大型时装数据集DeepFashion
2.提出了FashionNet 进行DeepFashion 数据集的衣服属性的预测和分类
3.定义了多种任务的评价标准

1.3 论文观点

在这里插入图片描述

如上图,文章认为在衣服识别中,若一个数据集同时拥有类别、属性和关键点的标签,会得到两个优势,第一个是给定关键点的位置有助于服装识别精度的提高,第二个是服装的大量属性的引入导致了服装的特征空间能够得到更好的划分,从而能够促进跨域服装图像的识别和检索。

1.4 模型结构

FashionNet 建立了三个benchmark,以供后续研究作为效果评价的基线。它们分别是:服装特征分类、店内服装检索(网站中同一商品展示图的匹配),和跨领域的服务检索(买家秀卖家秀衣服检索)。具体代码可以在mmfashion中查看。

1.4.1 FashionNet

在这里插入图片描述

以下内容来自:DeepFashion服装检索及代码实现

论文提到的FashionNet方法采用的网络backbone为VGG16,在VGG16网络中,stage4之前的网络结构权值都是共享的,在这一基础上发展了三个网络分支,如上图所示。上图的橙色部分、绿色部分和蓝色部分分别代表了全局特征提取网络、局部特征提取网络和特征点回归网络。

  • 橙色部分(全局分支):基于共享的VGG16前4个stage,增加了自己的conv5卷积结构,然后接一个全连接层用于特征编码,输出特征作为最终检索特征的一部分。
  • 蓝色部分(关键点回归分支):与VGG16网络结构类似,在stage5卷积结构后面使用两个全连接层进行特征编码,根据输出的特征分别再使用两个全连接层输出服装关键点和关键点是否可见的标签。
  • 绿色部分(局部分支):借助于蓝色分支输出的服装关键点,论文作者提出了一个新的层,landmark pooling layer,与Faster RCNN中的ROI Pooling层类似,目的在于对提取特征图上的某一块区域进行统一的编码工作。这部分的特征称之为局部特征,将橙色部分和绿色部分的特征按通道组合在一起之后经过一个全连接层再编码一次,然后基于这个特征对属性、类别标签进行分类学习,同时使用triplet loss进行辅助学习。
1.4.2 Loss部分

训练过程中,首先加大蓝色模块的权值,即增大 L l a n d m a r k s L_{landmarks} Llandmarks L v i s i b i l i t y L_{visibility} Lvisibility的权重,其他权重保持相对较小,先把蓝色模块训练收敛,然后再减少权值,整体训练
1.landmark回归loss为加权的L2 loss,其中 V j V_j Vj为其权值,代表landmark的可见性,对于不可见的就不进行梯度的回传。
L l a n d m a r k s = ∑ j = 1 ∣ D ∣ ∣ ∣ V j ⋅ ( l ^ j − l j ) ∣ ∣ 2 2 L_{landmarks} =

j=1|D|||Vj(l^jlj)||22
Llandmarks=j=1DVj(l^jlj)22
2.衣服类别分类和landmark是否可见分类,采用传统的softmax crossentrop loss, 表示为 L v i s i b i l i t y L_{visibility} Lvisibility L c a t e g o r y L_{category} Lcategory
3.衣服属性分类采用加权的sigmoid crossentrop loss, X j X_j Xj代表第j个衣服, a j a_j aj代表第j个衣服的属性, W p o s W_{pos} Wpos W n e g W_{neg} Wneg代表正负样本的权值
L a t t r i b u t e s = ∑ j = 1 ∣ D ∣ ( w p o s ⋅ a j l o g p ( a j ∣ x j ) + w n e g ⋅ ( 1 − a j ) l o g ( 1 − p ( a j ∣ x j ) ) ) L_{attributes} = \sum_{j=1}^{|D|} (w_{pos} \cdot a_jlogp(a_j|x_j) + w_{neg} \cdot (1 - a_j)log(1-p(a_j|x_j))) Lattributes=j=1D(wposajlogp(ajxj)+wneg(1aj)log(1p(ajxj)))
4.衣服服装对的loss采用triplet loss。 ( x , x + , x − )
(x,x+,x)
(x,x+,x)
表示三元组,m表示margin,d表示距离函数。
L t r i p l e t = ∑ j = 1 ∣ D ∣ m a x { 0 , m + d ( x j , x j + ) − d ( x j , x j − ) } L_{triplet} = \sum_{j=1}^{|D|} max
{0,m+d(xj,xj+)d(xj,xj)}
Ltriplet=j=1Dmax{0,m+d(xj,xj+)d(xj,xj)}

1.4.3 一小点总结

对于一张图片来说,从最初的(224,224,3)经过vgg16的stage4输出一个(7,7,512)feature map, 这个feature map连接着2个分支,包括全局分支、局部分支;其中全局分支先AdaptiveAvgPool2d池化为(7,7,512),如果是用户自定义的图像输入的化,vgg输出的不一定是(7,7,512),然后将该feature map展平后通过几个全连接层得到图片的全局表示为(4096,)的维度;而对于局部分支,由于输入为feature map以及landmarks,经过仿射变换可以得到局部关键点向量,然后经过全连接层可得到局部向量表示(4096,);全局局部拼接在一起后可以得到最终的向量表示。

1.4.4 局部分支的仿射变换

基本思路就是通过仿射变换将关键点放到最中间,同时根据一定比例放大,这样对原来的tensor进行卷积就相当于局部卷积了,推荐一个讲的比较好的博客,Pytorch中的仿射变换(affine_grid)

2. 代码部分

2.1 代码思考

已知deepfashion数据集有四个子任务,分别是服装类别与属性预测、店内衣服检索、买家秀卖家秀衣服检索、关键点和目标检测。而以上的FashionNet为一个整体的框架,在观察mmfashion中的子任务的代码后发现,其实每个任务并不一定需要同时预测category、attributes、triplet、landmarks和landmark visibility,可以只用部分分支就可以进行预测,比如类别或属性预测没有必要使用triplet loss,同时我们数据集也可能没有服装对的数据。在mmfashion中可以看到,FashionNet的使用还需要对具体数据情况来实现模型。

2.2 具体代码

PS:以下内容为mmfashion中买家秀卖家秀衣服检索模型的部分代码, 主要是便于本人理解和记忆。其他部分的训练过程和数据准备可参考GETTING_STARTEDDATA_PREPARATION

2.2.1 数据集组织方式:

|---- Anno 标注文件文件夹
|-------- list_attr_cloth.txt      衣服款式中英文对照表
|-------- list_attr_items.txt     衣服属性标注,1000个属性
|-------- list_attr_type.txt      衣服属性类型中英文对照表
|-------- list_bbox_consumer2shop.txt      图片中服装的bbox框标注
|-------- list_item_consumer2shop.txt      图片中商品编号,每种商品可以用多个商品对
|-------- list_landmarks_consumer2shop.txt      图片中商品的landmarks标注
|---- Eval      数据集划分文件
|-------- list_eval_partition.txt      训练集测试集划分,商品对的形式
|---- Img      图片文件夹

2.2.2 数据准备部分代码

数据准备运行代码为: python prepare_consumer_to_shop.py

2.2.2.1 训练集测试集合切分
PREFIX = 'Consumer_to_shop/Anno'


def split_img():
    fn = open('Consumer_to_shop/Eval/list_eval_partition.txt').readlines()

    # train dataset
    train_consumer2shop = open(
        os.path.join(PREFIX, 'train_consumer2shop.txt'), 'w')
    train_imgs = []

    # test dataset
    test_consumer = open(os.path.join(PREFIX, 'consumer.txt'), 'w')
    test_shop = open(os.path.join(PREFIX, 'shop.txt'), 'w')
    consumer_imgs, shop_imgs = [], []

    for i, line in enumerate(fn[2:]):
        aline = line.strip('\n').split()
        consumer, shop, _, cate = aline[0], aline[1], aline[2], aline[3]
        if cate == 'train':
            newline = consumer + ' ' + shop + '\n'
            train_consumer2shop.write(newline)
            train_imgs.append(consumer)
            train_imgs.append(shop)
        else:
            newline = consumer + '\n'
            test_consumer.write(newline)
            newline = shop + '\n'
            test_shop.write(newline)
            consumer_imgs.append(consumer)
            shop_imgs.append(shop)

    train_consumer2shop.close()
    test_consumer.close()
    test_shop.close()
    return train_imgs, consumer_imgs, shop_imgs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

该部分获取训练集测试集的list,其中train_imgs为consumer图像和shop图像交替存储的list, consumer_imgs为测试集中consumer图像存储的list,shop_imgs为测试集合中shop图像存储的list,该函数split_img()还生成了对应与于train_imgs, consumer_imgs, shop_imgs的txt文件。

2.2.2.2 bbox数据准备
def split_bbox(train_set, consumer_set, shop_set):
    rf = open(os.path.join(PREFIX, 'list_bbox_consumer2shop.txt')).readlines()
    img2bbox = {}
    for i, line in enumerate(rf[2:]):
        aline = line.strip('\n').split()
        img = aline[0]
        bbox = aline[-4:]
        img2bbox[img] = bbox

    wf_train = open(os.path.join(PREFIX, 'list_bbox_train.txt'), 'w')
    wf_consumer = open(os.path.join(PREFIX, 'list_bbox_consumer.txt'), 'w')
    wf_shop = open(os.path.join(PREFIX, 'list_bbox_shop.txt'), 'w')
    for i, img in enumerate(train_set):
        bbox = img2bbox[img]
        newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[
            2] + ' ' + bbox[3] + '\n'
        wf_train.write(newline)

    for i, img in enumerate(consumer_set):
        bbox = img2bbox[img]
        newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[
            2] + ' ' + bbox[3] + '\n'
        wf_consumer.write(newline)

    for i, img in enumerate(shop_set):
        bbox = img2bbox[img]
        newline = img + ' ' + bbox[0] + ' ' + bbox[1] + ' ' + bbox[
            2] + ' ' + bbox[3] + '\n'
        wf_shop.write(newline)

    wf_train.close()
    wf_consumer.close()
    wf_shop.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

该函数生成了对应train_imgs, consumer_imgs, shop_imgs图片的bbox文件,文件每一行为: img   bbox[0]   bbox[1]   bbox[2]   bbox[3]

2.2.2.3 配对id准备
def split_ids(train_set, consumer_set, shop_set):
    id2label = dict()
    rf = open(os.path.join(PREFIX, 'list_item_consumer2shop.txt')).readlines()
    for i, line in enumerate(rf[1:]):
        id2label[line.strip('\n')] = i

    def write_id(cloth, wf):
        for i, line in enumerate(cloth):
            id = line.strip('\n').split('/')[3]
            label = id2label[id]
            wf.write('%s\n' % str(label))
        wf.close()

    train_id = open(os.path.join(PREFIX, 'train_id.txt'), 'w')
    consumer_id = open(os.path.join(PREFIX, 'consumer_id.txt'), 'w')
    shop_id = open(os.path.join(PREFIX, 'shop_id.txt'), 'w')
    write_id(train_set, train_id)
    write_id(consumer_set, consumer_id)
    write_id(shop_set, shop_id)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

首先对于所有商品id进行编码,同时对于train_imgs, consumer_imgs, shop_imgs中元素对于进行编码并且生成相应的文件。

2.2.2.4 landmarks数据准备
def split_lms(train_set, consumer_set, shop_set):
    rf = open(os.path.join(PREFIX,
                           'list_landmarks_consumer2shop.txt')).readlines()
    img2landmarks = {}
    for i, line in enumerate(rf[2:]):
        aline = line.strip('\n').split()
        img = aline[0]
        landmarks = aline[3:]
        img2landmarks[img] = landmarks

    wf_train = open(os.path.join(PREFIX, 'list_landmarks_train.txt'), 'w')
    wf_consumer = open(
        os.path.join(PREFIX, 'list_landmarks_consumer.txt'), 'w')
    wf_shop = open(os.path.join(PREFIX, 'list_landmarks_shop.txt'), 'w')

    def write_landmarks(img_set, wf):
        for i, img in enumerate(img_set):
            landmarks = img2landmarks[img]
            one_lms = []
            for j, lm in enumerate(landmarks):
                if j % 3 == 0:  # visibility
                    if lm == '0':  # visible
                        one_lms.append(landmarks[j + 1])
                        one_lms.append(landmarks[j + 2])
                    else:  # invisible or truncated
                        one_lms.append('000')
                        one_lms.append('000')

            while len(one_lms) < 16:  # 8 pairs
                one_lms.append('000')

            wf.write(img)
            wf.write(' ')
            for lm in one_lms:
                wf.write(lm)
                wf.write(' ')
            wf.write('\n')
        wf.close()

    write_landmarks(train_set, wf_train)
    write_landmarks(consumer_set, wf_consumer)
    write_landmarks(shop_set, wf_shop)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

处理照片对于的关键点,保存其信息,并且对于不可见的关键点坐标置为000

总结:所需要准备的数据有包括训练集的图片对,分开来的训练集的query和gallery图片,以及图片所对应的编码后的商品编码,以及bbox信息和landmark信息,以及其共同拥有的list_attr_items.txt文件,为图片属性信息。

2.2.3 模型代码

训练代码:

python tools/train_retriever.py --config configs/retriever_in_shop/roi_retriever_vgg.py

模型框架结构代码:

from .. import builder
from ..registry import RETRIEVER
from .base import BaseRetriever


@RETRIEVER.register_module
class RoIRetriever(BaseRetriever):

    def __init__(self,
                 backbone,
                 global_pool,
                 roi_pool,
                 concat,
                 embed_extractor,
                 attr_predictor=None,
                 pretrained=None):
        super(RoIRetriever, self).__init__()

        self.backbone = builder.build_backbone(backbone)
        self.global_pool = builder.build_global_pool(global_pool)

        if roi_pool is not None:
            self.roi_pool = builder.build_roi_pool(roi_pool)
        else:
            self.roi_pool = None

        self.concat = builder.build_concat(concat)
        self.embed_extractor = builder.build_embed_extractor(embed_extractor)

        if attr_predictor is not None:
            self.attr_predictor = builder.build_attr_predictor(attr_predictor)
        else:
            self.attr_predictor = None

        self.init_weights(pretrained=pretrained)

    def extract_feat(self, x, landmarks):
        x = self.backbone(x)
        global_x = self.global_pool(x)
        global_x = global_x.view(global_x.size(0), -1)

        if landmarks is not None:
            local_x = self.roi_pool(x, landmarks)
        else:
            local_x = None

        x = self.concat(global_x, local_x)
        return x

    def forward_train(self,
                      anchor,
                      id,
                      attr=None,
                      pos=None,
                      neg=None,
                      anchor_lm=None,
                      pos_lm=None,
                      neg_lm=None,
                      triplet_pos_label=None,
                      triplet_neg_label=None):

        losses = dict()

        # extract features
        anchor_feat = self.extract_feat(anchor, anchor_lm)

        if pos is not None:
            pos_feat = self.extract_feat(pos, pos_lm)
            neg_feat = self.extract_feat(neg, neg_lm)

            losses['loss_id'] = self.embed_extractor(
                anchor_feat,
                id,
                return_loss=True,
                triplet=True,
                pos=pos_feat,
                neg=neg_feat,
                triplet_pos_label=triplet_pos_label,
                triplet_neg_label=triplet_neg_label)

        else:
            losses['loss_id'] = self.embed_extractor(
                anchor_feat, id, return_loss=True)

        if self.attr_predictor is not None:
            losses['loss_attr'] = self.attr_predictor(
                anchor_feat, attr, return_loss=True)
        return losses

    def simple_test(self, x, landmarks=None):
        """Test single image"""
        x = x.unsqueeze(0)
        landmarks = landmarks.unsqueeze(0)
        feat = self.extract_feat(x, landmarks)
        embed = self.embed_extractor.forward_test(feat)[0]
        return embed

    def aug_test(self, x, landmarks=None):
        """Test batch of images"""
        feat = self.extract_feat(x, landmarks)
        embed = self.embed_extractor.forward_test(feat)
        return embed

    def init_weights(self, pretrained=None):
        super(RoIRetriever, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        self.global_pool.init_weights()

        if self.roi_pool is not None:
            self.roi_pool.init_weights()

        self.concat.init_weights()
        self.embed_extractor.init_weights()

        if self.attr_predictor is not None:
            self.attr_predictor.init_weights()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117

以上为mmfashion中买家秀买家秀检索的模型框架代码,基本与deepfashion的论文框架是一致的,框架里的每个部件的具体内容可以看相应的config文件以及对应部件的代码,整个框架原理和mmdetection的框架调用是一致的。

2.3 训练以及评估

2.3.1 数据准备

数据准备运行代码为: python prepare_consumer_to_shop.py,则会得到以下内容

data = dict(
    imgs_per_gpu=8,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        img_path=os.path.join(data_root, 'Img'),
        img_file=os.path.join(data_root, 'Anno/train_consumer2shop.txt'),
        id_file=os.path.join(data_root, 'Anno/train_id.txt'),
        label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'),
        bbox_file=os.path.join(data_root, 'Anno/list_bbox_train.txt'),
        landmark_file=os.path.join(data_root, 'Anno/list_landmarks_train.txt'),
        img_size=img_size,
        find_three=True),
    query=dict(
        type=dataset_type,
        img_path=os.path.join(data_root, 'Img'),
        img_file=os.path.join(data_root, 'Anno/consumer.txt'),
        id_file=os.path.join(data_root, 'Anno/consumer_id.txt'),
        label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'),
        bbox_file=os.path.join(data_root, 'Anno/list_bbox_consumer.txt'),
        landmark_file=os.path.join(data_root,
                                   'Anno/list_landmarks_consumer.txt'),
        img_size=img_size,
        find_three=True),
    gallery=dict(
        type=dataset_type,
        img_path=os.path.join(data_root, 'Img'),
        img_file=os.path.join(data_root, 'Anno/shop.txt'),
        id_file=os.path.join(data_root, 'Anno/shop_id.txt'),
        label_file=os.path.join(data_root, 'Anno/list_attr_items.txt'),
        bbox_file=os.path.join(data_root, 'Anno/list_bbox_shop.txt'),
        landmark_file=os.path.join(data_root, 'Anno/list_landmarks_shop.txt'),
        img_size=img_size,
        find_three=True))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

以上代码是从mmfashion的configs/retriever_consumer_to_shop/roi_retriever_vgg.py中data部分的截取的,表示模型的输入。

2.3.2 模型训练

1.修改配置文件mmfashion的configs/retriever_consumer_to_shop/roi_retriever_vgg.py,以及提前下载vgg16的权重模型。
2.运行代码训练模型: python tools/train_retriever.py --config configs/retriever_consumer_to_shop/roi_retriever_vgg.py,以下是整体模型结构

RoIRetriever(
  (backbone): Vgg(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (18): ReLU(inplace=True)
      (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (20): ReLU(inplace=True)
      (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (22): ReLU(inplace=True)
      (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (25): ReLU(inplace=True)
      (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (27): ReLU(inplace=True)
      (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (29): ReLU(inplace=True)
      (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
  )
  (global_pool): GlobalPooling(
    (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    (global_layers): Sequential(
      (0): Linear(in_features=25088, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=4096, out_features=4096, bias=True)
      (4): ReLU(inplace=True)
      (5): Dropout(p=0.5, inplace=False)
    )
  )
  (roi_pool): RoIPooling(
    (maxpool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (linear): Sequential(
      (0): Linear(in_features=4096, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
    )
  )
  (concat): Concat(
    (fc_fusion): Linear(in_features=8192, out_features=4096, bias=True)
  )
  (embed_extractor): EmbedExtractor(
    (embed_linear): Linear(in_features=4096, out_features=256, bias=True)
    (bn): BatchNorm1d(256, eps=33881, momentum=0.1, affine=True, track_running_stats=True)
    (id_linear): Linear(in_features=256, out_features=33881, bias=True)
    (loss_id): CELoss()
    (loss_triplet): TripletLoss()
  )
  (attr_predictor): AttrPredictor(
    (linear_attr): Linear(in_features=4096, out_features=303, bias=True)
    (loss_attr): BCEWithLogitsLoss()
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
2.3.4 模型预测以及模型评估

测试集输出代码:

python tools/test_retriever.py \
    --config configs/retriever_consumer_to_shop/roi_retriever_vgg.py \
    --checkpoint checkpoint/Retrieve_consumer_to_shop/vgg/latest.pth
  • 1
  • 2
  • 3

测试输出以及模型评估步骤如下:
1.通过模型我们可以分别得到query和gallery的向量表示,即query_embeds,gallery_embeds。
2.准备query中index对应的item_id(商品)字典,以及item_id对应index组成的list,gallery也是一样的操作。
3.遍历query_embeds中的每一个元素,对于每一个query都计算其与gallery集合中的元素的余弦距离,然后进行排序,得到topk的一个召回率,由此可以得到query的评价指标。

以下是一个query的召回代码片段

    def single_query(self, query_id, query_feat, gallery_embeds, query_idx):
        query_dist = []
        for j, feat in enumerate(gallery_embeds):
            cosine_dist = cosine(
                feat.reshape(1, -1), query_feat.reshape(1, -1))
            query_dist.append(cosine_dist)
        query_dist = np.array(query_dist)

        order = np.argsort(query_dist)
        single_recall = dict()

        print(self.query_id2idx[query_id])
        for k in self.topks:
            retrieved_idxes = order[:k]
            tp = 0
            relevant_num = len(self.gallery_id2idx[query_id])
            for idx in retrieved_idxes:
                retrieved_id = self.gallery_dict[idx]
                if query_id == retrieved_id:
                    tp += 1

            single_recall[k] = float(tp) / relevant_num
        return single_recall
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

在这里插入图片描述

参考

  1. 时装分类+检索之DeepFashion
  2. DeepFashion服装检索及代码实现
  3. open-mmlab/mmfashion
  4. 使用deepfashion实现自己的第一个分类网络
  5. Pytorch中的仿射变换(affine_grid)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/877980
推荐阅读
相关标签
  

闽ICP备14008679号