当前位置:   article > 正文

基于DANN的图像分类任务迁移学习_dann图像分类

dann图像分类

注:本博客的数据和任务来自NTU-ML2020作业,Kaggle网址为Kaggle.

数据预处理

我们要进行迁移学习的对象是10000张32x32x3的有标签正常照片,共有10类,和另外100000张人类画的手绘图,28x28x1黑白照片,类别也是10类但无标签。我们希望做到,让模型从有标签的原始分布数据中学到的知识能应用于无标签的,相似但与原始分布不相同的目标分布中,并提高黑白手绘图的正确率。
为此,训练前还要对数据做预处理。首先让原始分布的图像和目标分布的图像尽可能相似,我们要做有色图转灰度图,然后做边缘检测。为了模型的输入维度相同,要把28x28转为32x32.此外还可以增加一些平移旋转来让学习更鲁棒。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt

# 在transform中使用转灰度-canny边缘提取-水平移动-小幅度旋转-转张量操作

source_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])

# 读取数据集,分为source和target两部分

source_dataset = ImageFolder('E:/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('E:/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

DANN

Domain-Adversarial Training of NNs,值域对抗学习。这种算法是我们这里将要用的迁移学习方法,它被提出的起因是让CNN能够同时用于不同分布的数据,如果模型直接接收原值域的数据分布进行训练,即使原分布和目标分布有类似的地方,在接收目标值域的数据时,也会出现相当异常的特征提取和分类结果。我们可以理解为是模型在源数据分布上出现了过拟合(并不是对数据的过拟合),在接收一些没有见到过的数据时自然会表现不佳。
在这里插入图片描述
解决这个问题最好的办法就是让模型在训练时也接收目标数据分布的数据。但是目标数据分布是无标签的,我们要用什么标准来训练模型呢?回忆CNN的架构,CNN使用卷积-池化的特征提取层来提取图片特征,后接全连接层进行预测。我们只需要让特征提取层既能提取原数据分布的特征,又能提取目标数据分布的特征,这样全连接层就能对两种值域但具有相同特征的数据进行同样的分类,从而目标数据分布的输入也很有可能被正确分类。
在这里插入图片描述
那么问题就变成了如何训练输入两个不同分布的数据,输出却是同种分布的特征提取层。回忆GAN的架构,我们让分布朝着源数据分布发展的方法是建立判别器,让判别器能分辨两种数据,而让生成器改变参数骗过判别器。这里也可以用同样的思想,我们建立能分辨原始分布和目标分布的二分类判别器,把特征提取层和二分类判别层接在一起。首先训练判别器,让判别器能分辨两类数据分布。然后训练特征提取层,逆梯度更新让特征提取层生成能骗过判别器的数据(目标输出0.5).如此训练多次直到特征提取层能把两种值域的输入变成同种分布的输出。
在这里插入图片描述
但是只是用GAN方法train特征提取层并不明智,因为我们的目标输出只有0-1的二分类,训练很有可能只是让特征提取层提取到一些没有用的特征。因此我们要一边训练正常的标签预测任务,一边训练判别器的判别任务和混淆两类输入的任务。这可能需要自己定义特殊的loss function

最后,我们就获得了能同时提取两个值域的特征的特征提取层,它后

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/532129
推荐阅读
相关标签
  

闽ICP备14008679号