当前位置:   article > 正文

无监督图像分类《SCAN:Learning to Classify Images without》代码分析笔记(1):simclr_scan: learning to classify images without labels

scan: learning to classify images without labels

前言

  1. SCAN 分为多个步骤,本文分析第一步 simclr.py 代码。
  2. 根据论文描述,第一步为前置任务(pretext task),用于训练特征提取网络。
  3. 核心思想是对同一张图像 P P P 变换两次得到 P 1 P_1 P1 P 2 P_2 P2,通过特征提取网络输出对应特征 T 1 T_1 T1 T 2 T_2 T2,最小化 T 1 T_1 T1 T 2 T_2 T2 特征距离(比和其他图像的特征距离近)。
  4. 代码最后阶段用 faiss 库生成 topk 用于后续步骤,因此需在 Linux 系统上运行。
simclr.py
# 输出路径
--config_env configs/env.yml
# 网络配置文件
--config_exp configs/pretext/simclr_cifar10.yml
  • 1
  • 2
  • 3
  • 4
  • 5

0. 配置信息

# utils/config.py
p = create_config(args.config_env, args.config_exp)

p = 
{'setup': 'simclr', 
'backbone': 'resnet18', 
'model_kwargs': {'head': 'mlp', 'features_dim': 128}, 
'train_db_name': 'cifar-10', 
'val_db_name': 'cifar-10', 
'num_classes': 10, 
'criterion': 'simclr', 
'criterion_kwargs': {'temperature': 0.1}, 
'epochs': 500, 
'optimizer': 'sgd', 
'optimizer_kwargs': {'nesterov': False, 'weight_decay': 0.0001, 'momentum': 0.9, 'lr': 0.4}, 
'scheduler': 'cosine', 
'scheduler_kwargs': {'lr_decay_rate': 0.1}, 
'batch_size': 128, 
'num_workers': 8, 
'augmentation_strategy': 'simclr', 
'augmentation_kwargs': {'random_resized_crop': {'size': 32, 'scale': [0.2, 1.0]}, 
						'color_jitter_random_apply': {'p': 0.8}, 
						'color_jitter': {'brightness': 0.4, 'contrast': 0.4, 'saturation': 0.4, 'hue': 0.1}, 
						'random_grayscale': {'p': 0.2}, 
						'normalize': {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.201]}}, 
'transformation_kwargs': {'crop_size': 32, 'normalize': {'mean': [0.4914, 0.4822, 0.4465], 'std': [0.2023, 0.1994, 0.201]}}, 
'pretext_dir': '/path/where/to/store/results/cifar-10\\pretext', 
'pretext_checkpoint': '/path/where/to/store/results/cifar-10\\pretext\\checkpoint.pth.tar', 
'pretext_model': '/path/where/to/store/results/cifar-10\\pretext\\model.pth.tar', 
'topk_neighbors_train_path': '/path/where/to/store/results/cifar-10\\pretext\\topk-train-neighbors.npy', 
'topk_neighbors_val_path': '/path/where/to/store/results/cifar-10\\pretext\\topk-val-neighbors.npy'}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

1. Model

在这里插入图片描述
其中 normalize 为 L2_norm

model = get_model(p)  # utils/common_config.py 44

# 在 get_model(p) 中分两步构建网络
from models.resnet_cifar import resnet18
backbone = resnet18()
from models.models import ContrastiveModel
model = ContrastiveModel(backbone, **p['model_kwargs']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2. Dataset

CIFAR-10简介
数量:60000
图片尺寸:32*32
图片格式:RGB
类别数量:10
训练集:50000
测试集:10000

train_transforms = get_train_transformations(p)  # utils/common_config.py 207
val_transforms = get_val_transformations(p)  # utils/common_config.py 247

# utils/common_config.py 120
train_dataset = get_train_dataset(p, train_transforms, to_augmented_dataset=True, split='train+unlabeled') # Split is for stl-10
# utils/common_config.py 160
val_dataset = get_val_dataset(p, val_transforms)

train_dataloader = get_train_dataloader(p, train_dataset)  # utils/common_config.py 195
val_dataloader = get_val_dataloader(p, val_dataset)  # utils/common_config.py 201
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

训练数据中包含 image_transform 和 augmentation_transform 两个相同的随机变换方式;假设原始图像为p,p分别通过 image_transform 和 augmentation_transform 进行变换得到 p1、p2,网络输入p1、p2后得到两个特征,网络通过缩小两个特征值的差异进行学习。
图像变换方式

vars(train_dataloader)
{'dataset': <data.custom_dataset.AugmentedDataset object at 0x000001A883E57CC8>, 
'num_workers': 8, 
'pin_memory': True, 
'timeout': 0, 
'worker_init_fn': None, 
'_DataLoader__multiprocessing_context': None, 
'_dataset_kind': 0, 
'batch_size': 128, 
'drop_last': True, 
'sampler': <torch.utils.data.sampler.RandomSampler object at 0x000001A883CB1808>, 
'batch_sampler': <torch.utils.data.sampler.BatchSampler object at 0x000001A885466C48>, 
'collate_fn': <function collate_custom at 0x000001A8827C4168>, 
'_DataLoader__initialized': True}
----------------------------------------------------------------
vars(train_dataloader.dataset)
{
'dataset': <data.cifar.CIFAR10 object at 0x000001A8854503C8>, 
'image_transform': Compose(
    RandomResizedCrop(
    	size=(32, 32), 
    	scale=(0.2, 1.0), 
    	ratio=(0.75, 1.3333), 
    	interpolation=PIL.Image.BILINEAR)
   	RandomHorizontalFlip(p=0.5)
    RandomApply(
    	p=0.8
    	ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=[-0.1, 0.1])
		)
    RandomGrayscale(p=0.2)
    ToTensor()
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
	), 
'augmentation_transform': Compose(
    RandomResizedCrop(
    	size=(32, 32), 
    	scale=(0.2, 1.0), 
    	ratio=(0.75, 1.3333), 
    	interpolation=PIL.Image.BILINEAR)
    RandomHorizontalFlip(p=0.5)
    RandomApply(
    	p=0.8
    	ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=[-0.1, 0.1])
		)
    RandomGrayscale(p=0.2)
    ToTensor()
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
	)
}
----------------------------------------------------------------
vars(train_dataloader.dataset.dataset)
{'root': '/path/to/cifar-10/', 
'transform': None, 
'train': True, 
'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 
'data': array([50000, 32, 32, 3]),
'targets':[50000],
'class_to_idx': {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
vars(val_dataloader)
{'dataset': <data.cifar.CIFAR10 object at 0x00000207077BEA88>, 
'num_workers': 8, 
'pin_memory': True, 
'timeout': 0, 
'worker_init_fn': None, 
'_DataLoader__multiprocessing_context': None, 
'_dataset_kind': 0, 
'batch_size': 128, 
'drop_last': False, 
'sampler': <torch.utils.data.sampler.SequentialSampler object at 0x00000207077C5508>, 
'batch_sampler': <torch.utils.data.sampler.BatchSampler object at 0x00000207077C50C8>, 
'collate_fn': <function collate_custom at 0x000002070278ADC8>, 
'_DataLoader__initialized': True}
----------------------------------------------------------------
vars(val_dataloader.dataset)
{'root': '/path/to/cifar-10/', 
'transform': Compose(
    CenterCrop(size=(32, 32))
    ToTensor()
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.201])
), 
'train': False, 
'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 
'data': array([10000, 32, 32, 3]),
'targets': [10000]
'class_to_idx': {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

3. Memory Bank

base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval
base_dataloader = get_val_dataloader(p, base_dataset)
# 50000, 128, 10, 0.1
memory_bank_base = MemoryBank(len(base_dataset), 
                             p['model_kwargs']['features_dim'],
                             p['num_classes'], p['criterion_kwargs']['temperature'])
memory_bank_base.cuda()
memory_bank_val = MemoryBank(len(val_dataset),
                             p['model_kwargs']['features_dim'],
                             p['num_classes'], p['criterion_kwargs']['temperature'])
memory_bank_val.cuda()
----------------------------------------------------------------
vars(memory_bank_base)
{'n': 50000, 
'dim': 128, 
'features': [50000,128] tensor,
'targets': [50000] tensor,
'ptr': 0, 
'device': 'cuda:0', 
'K': 100, 
'temperature': 0.1, 
'C': 10}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

4. Criterion

这部分主要负责 loss 计算,可以暂时跳过,看到 6.2 训练 部分后返回看 loss。

(1)loss 代码梳理
mask: [ b s , b s ] [bs, bs] [bs,bs] 单位矩阵
contrast_features: [ b s ∗ 2 , 128 ] [bs*2, 128] [bs2,128] 两部分特征拼接
anchor: [ b s , 128 ] [bs, 128] [bs,128] 第一部分特征
dot_product: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] torch.matmul(anchor, contrast_features.T) / 0.1
logits: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] dot_product 每行减去该行最大值,实际上就是减去左半部分主对角线的值
mask: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] 变为左右两个单位矩阵
logits_mask: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] 除了左半部分主对角线为0,其余全为1
mask: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] 变为左半部分0矩阵,右半部分单位矩阵
exp_logits: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] logits 取exp并将左半部分主对角线置0
log_prob: [ b s , b s ∗ 2 ] [bs, bs*2] [bs,bs2] logits 每行减去 exp_logits 每行的和取log
loss: log_prob 右半部分主对角线均值

(2)loss 分析
loss 的核心计算为 -log_softmax
loss 的下降可以通过缩小 P P P P T P^T PT 的特征距离,以及扩大 P P P 与除 P T P^T PT 以外图像的特征距离
疑问:loss 是否会使 batch 中同类别图像特征距离扩大?或者只是在整体上 P T P^T PT 的特征距离比其它的更近?

(3)点积最大值为左半部分主对角线证明
设 F 1 , F 2 在 n o r m a l i z e 前 为 [ x 1 x 2 ⋯ x n ] , [ y 1 y 2 ⋯ y n ] 设F_1, F_2在normalize前为[x1x2xn], [y1y2yn] F1,F2normalize[x1x2xn],[y1y2yn]
F 1 ⋅ F 2 = x 1 y 1 + x 2 y 2 + ⋯ + x n y n x 1 2 + x 2 2 + ⋯ + x n 2 y 1 2 + y 2 2 + ⋯ + y n 2 F_1 \cdot F_2=\frac{x_1y_1+x_2y_2+\cdots+x_ny_n}{\sqrt{x_1^2+x_2^2+\cdots+x_n^2}\sqrt{y_1^2+y_2^2+\cdots+y_n^2}} F1F2=x12+x22++xn2 y12+y22++yn2 x1y1+x2y2++xnyn
分 母 2 = x 1 2 y 1 2 x 1 2 y 2 2 ⋯ x 1 2 y n 2 x 2 2 y 1 2 x 2 2 y 2 2 ⋯ x 2 2 y n 2 ⋮ ⋮ ⋱ ⋮ x n 2 y 1 2 x n 2 y 2 2 ⋯ x n 2 y n 2 分母^2=x21y21x21y22x21y2nx22y21x22y22x22y2nx2ny21x2ny22x2ny2n 2=x12y12x22y12xn2y12x12y22x22y22xn2y22x12yn2x22yn2xn2yn2
分 子 2 = x 1 2 y 1 2 x 1 y 1 x 2 y 2 ⋯ x 1 y 1 x n y n x 2 y 2 x 1 y 1 x 2 2 y 2 2 ⋯ x 2 y 2 x n y n ⋮ ⋮ ⋱ ⋮ x n y n x 1 y 1 x n y n x 2 y 2 ⋯ x n 2 y n 2 分子^2=x21y21x1y1x2y2x1y1xnynx2y2x1y1x22y22x2y2xnynxnynx1y1xnynx2y2x2ny2n 2=x12y12x2y2x1y1xnynx1y1x1y1x2y2x22y22xnynx2y2x1y1xnynx2y2xnynxn2yn2
∵ 分 母 2 − 分 子 2 沿 主 对 角 线 看 为 完 全 平 方 公 式 ∴ 分 母 2 − 分 子 2 ≥ 0 ∴ 仅 当 F 1 = F 2 时 , F 1 F 2 最 大 = 1 22沿线220F1=F2F1F2=1 22沿线220F1=F2F1F2=1

criterion = get_criterion(p)
criterion = criterion.cuda()

# utils/common_config.py 14
def get_criterion(p):
    if p['criterion'] == 'simclr':
        from losses.losses import SimCLRLoss
        criterion = SimCLRLoss(**p['criterion_kwargs'])

class SimCLRLoss(nn.Module):
    # Based on the implementation of SupContrast
    def __init__(self, temperature):
        super(SimCLRLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, features):
        """
        input:
            - features: hidden feature representation of shape [b, 2, dim]
        output:
            - loss: loss computed according to SimCLR 
        """
		b, n, dim = features.size()	# [128,2,128]
		assert(n == 2)
		mask = torch.eye(b, dtype=torch.float32).cuda()
		# torch.unbind() 删除指定维度后返回一个元组,在这里为 ([128,128],[128,128])
		# torch.cat() 按指定维度拼接,在这里为 [256,128]
		contrast_features = torch.cat(torch.unbind(features, dim=1), dim=0)	
		anchor = features[:, 0]	# anchor.size()=[128,128]
		# Dot product
		dot_product = torch.matmul(anchor, contrast_features.T) / self.temperature	# dot_product.size()=[128,256]
		# Log-sum trick for numerical stability
		logits_max, _ = torch.max(dot_product, dim=1, keepdim=True)	# logits_max.size()=[128,1]
		logits = dot_product - logits_max.detach()	# 相乘后每行减去该行的最大值

		# repeat(重复次数, 维度)
		mask = mask.repeat(1, 2)	# mask.size()=[128,256]
		# logits_mask 左半部分为1、0互换的单位矩阵右半部分为 ones 矩阵
		logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(b).view(-1, 1).cuda(), 0)
		mask = mask * logits_mask	# 将 mask 的左半部分变成了0矩阵,右半部分依然是单位矩阵

		# Log-softmax
		exp_logits = torch.exp(logits) * logits_mask
		log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        
		# Mean log-likelihood for positive
		# 实际上就是提取 log_prob 右半部分单位矩阵(左上至右下对角线)的均值
		loss = - ((mask * log_prob).sum(1) / mask.sum(1)).mean()

		return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

5. Optimizer

optimizer = get_optimizer(p, model)

optimizer = 
SGD (
Parameter Group 0
    dampening: 0
    lr: 0.4
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0001
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

6. Train

for epoch in range(start_epoch, p['epochs']):
	# Adjust lr
    lr = adjust_learning_rate(p, optimizer, epoch)
    # Train
    simclr_train(train_dataloader, model, criterion, optimizer, epoch)
    # Fill memory bank
    fill_memory_bank(base_dataloader, model, memory_bank_base)
    # Evaluate (To monitor progress - Not for validation)
    top1 = contrastive_evaluate(val_dataloader, model, memory_bank_base)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

6.1 调整学习率

# Adjust lr
lr = adjust_learning_rate(p, optimizer, epoch)

# utils/common_config.py 280
def adjust_learning_rate(p, optimizer, epoch):
    lr = p['optimizer_kwargs']['lr']  # 0.4
    
    if p['scheduler'] == 'cosine':
        eta_min = lr * (p['scheduler_kwargs']['lr_decay_rate'] ** 3)  # 0.4 * (0.1 ** 3)
        lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / p['epochs'])) / 2

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

6.2 训练

一个 batch 的图像为 [ b s , 3 , 32 , 32 ] [bs, 3, 32, 32] [bs,3,32,32]
但网络的实际输入为 [ b s ∗ 2 , 3 , 32 , 32 ] [bs*2, 3, 32, 32] [bs2,3,32,32]
因此网络的输出为 [ b s ∗ 2 , 128 ] [bs*2, 128] [bs2,128],并 resize 为 [ b s , 2 , 128 ] [bs, 2, 128] [bs,2,128]
loss 计算看 4. Criterion

simclr_train(train_dataloader, model, criterion, optimizer, epoch)

# utils/train_utils.py
def simclr_train(train_loader, model, criterion, optimizer, epoch):
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(len(train_loader),
        [losses],
        prefix="Epoch: [{}]".format(epoch))

    model.train()
    
	for i, batch in enumerate(train_loader):
		images = batch['image']
	    images_augmented = batch['image_augmented']
		b, c, h, w = images.size()	# images.size() = [128,3,32,32]
		
		input_ = torch.cat([images.unsqueeze(1), images_augmented.unsqueeze(1)], dim=1)
		# 增加一个维度然后cat, input_.size() = [128,2,3,32,32]
		input_ = input_.view(-1, c, h, w)	# input_.size() = [256,3,32,32]
		input_ = input_.cuda(non_blocking=True)
		targets = batch['target'].cuda(non_blocking=True)
		output = model(input_).view(b, 2, -1)	# output.size() = [128,2,128]
		loss = criterion(output)
        losses.update(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 25 == 0:
            progress.display(i)
            
batch
{'image':[128,3,32,32],
'target':[128],
'meta':{'im_size':[2,32], 'index':[128], 'class_name':[128]},
'image_augmented':[128,3,32,32]}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

6.3 Fill memory bank

得到网络对训练集(按照 val 变换)的输出特征以及标签

fill_memory_bank(base_dataloader, model, memory_bank_base)
  • 1

6.4 Evaluate

验证集图像特征 F v a l F_{\mathrm{val}} Fval 与所有训练集图像特征 F t r a i n F_{\mathrm{train}} Ftrain 做点积,取出最大的100个,根据训练集标签类别索引做累加,取数值最高的索引作为 P v a l P_{\mathrm{val}} Pval 的类别,最后与 P v a l P_{\mathrm{val}} Pval 的真实标签对比计算准确度。

top1 = contrastive_evaluate(val_dataloader, model, memory_bank_base)

# utils/evaluate_utils.py
@torch.no_grad()
def contrastive_evaluate(val_loader, model, memory_bank):
    top1 = AverageMeter('Acc@1', ':6.2f')
    model.eval()

    for batch in val_loader:
        images = batch['image'].cuda(non_blocking=True)
        target = batch['target'].cuda(non_blocking=True)

        output = model(images)
        output = memory_bank.weighted_knn(output) 

        acc1 = 100*torch.mean(torch.eq(output, target).float())
        top1.update(acc1.item(), images.size(0))

    return top1.avg

class MemoryBank(object):
    def weighted_knn(self, predictions):
        # perform weighted knn
        retrieval_one_hot = torch.zeros(self.K, self.C).to(self.device)	# [100,10]
        batchSize = predictions.shape[0]
        correlation = torch.matmul(predictions, self.features.t())	# [128,128] [50000,128].T
        yd, yi = correlation.topk(self.K, dim=1, largest=True, sorted=True)	# [128,100]点积最大的前100个
        candidates = self.targets.view(1,-1).expand(batchSize, -1)	# [128,50000]
        retrieval = torch.gather(candidates, 1, yi)					# [128,100] torch.gather(索引矩阵, 索引维度, 索引)
        retrieval_one_hot.resize_(batchSize * self.K, self.C).zero_()	# [12800,10]
        retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)			# [12800,10] (dim, index, value)
        yd_transform = yd.clone().div_(self.temperature).exp_()
        probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , self.C), 
                          yd_transform.view(batchSize, -1, 1)), 1)  # [128, 100, 10]*[128, 100, 1] 求和后 [128, 10]
        _, class_preds = probs.sort(1, True)
        class_pred = class_preds[:, 0]
		# 和训练集做点积,挑100个最大的统计标签,最多的为验证图像的类别
        return class_pred
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

7. 存储模型和 topk

# Save final model
torch.save(model.state_dict(), p['pretext_model'])

# Mine the topk nearest neighbors at the very end (Train) 
# These will be served as input to the SCAN loss.
print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue'))
fill_memory_bank(base_dataloader, model, memory_bank_base)
topk = 20
print('Mine the nearest neighbors (Top-%d)' %(topk)) 
indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
np.save(p['topk_neighbors_train_path'], indices)   

# Mine the topk nearest neighbors at the very end (Val)
# These will be used for validation.
print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue'))
fill_memory_bank(val_dataloader, model, memory_bank_val)
topk = 5
print('Mine the nearest neighbors (Top-%d)' %(topk)) 
indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
np.save(p['topk_neighbors_val_path'], indices)   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
class MemoryBank(object):
    def mine_nearest_neighbors(self, topk, calculate_accuracy=True):
        # mine the topk nearest neighbors for every sample
        import faiss
        features = self.features.cpu().numpy()
        n, dim = features.shape[0], features.shape[1]
        index = faiss.IndexFlatIP(dim)  # 点乘,归一化的向量点乘即cosine相似度(越大越好)
        index = faiss.index_cpu_to_all_gpus(index)
        index.add(features)  # 添加训练时的样本
        # indices 为相似向量的索引
        distances, indices = index.search(features, topk+1) # Sample itself is included
        
        # evaluate 
        if calculate_accuracy:
            targets = self.targets.cpu().numpy()
            neighbor_targets = np.take(targets, indices[:,1:], axis=0) # Exclude sample itself for eval
            anchor_targets = np.repeat(targets.reshape(-1,1), topk, axis=1)
            accuracy = np.mean(neighbor_targets == anchor_targets)
            return indices, accuracy
        
        else:
            return indices
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/116524
推荐阅读
相关标签
  

闽ICP备14008679号