当前位置:   article > 正文

VGG网络在CIFAR_10和GID数据集上的Pytorch实现_gaofen image dataset

gaofen image dataset

一、VGG简介

在这里插入图片描述

\quad \quad VGGNet由牛津大学的视觉几何组(Visual Geometry Group)提出,它的主要贡献是使用非常小的**( 3 × 3 3×3 3×3)卷积滤波器架构对网络深度的增加进行了全面评估,这表明通过将深度推到16-19加权层可以实现对现有技术配置的显著改进**。这些发现让VGG团队在ILSVRC-2014的**定位任务(localisation)**取得第一、**分类任务(classification)**取得第二(第一名是GoogLeNet)。并且作者还表明其提出的ConvNet对于其他数据集泛化的很好,在其它数据集上也取得了最好的结果。

\quad \quad 关于VGG网络详解见我另一篇博文:https://blog.csdn.net/Bobodareng/article/details/117599525

二、VGG-16在CIFAR_10数据集上的实现

2.1 CIFAR_10数据集简介

\quad \quad CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。该数据集共有60000张32*32彩色图像,一共包含 10 个类别的 RGB 彩色图 片,每类6000张:

飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )

卡车( truck )

\quad \quad 图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。 CIFAR-10 的图片样例如下图所示:

在这里插入图片描述

\quad \quad 上图列举了CIFAR_10中的10种类别,每一类随机展示了10张图片。

数据集下载
​ 官方下载地址:(很慢)
(共有三个版本:python,matlab,binary version 适用于C语言),我在训练时直接利用代码从Pytorch.datasets()中下载。
http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

​ http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

2.2 加载数据集(Dataset)

\quad \quad 我们可以手动下载CIFAR_10数据集,也可以用Pytorch的torchvision.datasets模块加载一些经典的数据集,比如:Imagenet, CIFAR10, MNIST都可以通过torchvision来获取,并且torchvision还提供了transforms类可以用来预处理数据。

   import torchvision
    ...
    #利用 torchvision.datasets来下载CIFAR_10数据集到根目录的data文件夹中
    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=128,
                                               shuffle=True, num_workers=0)
     # 10000张验证图片
    第一次使用时要将download设置为True才会自动去下载数据集
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=128,
                                             shuffle=True, num_workers=0)
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

下载后结果如下图:

在这里插入图片描述

(1)数据集划分
\quad \quad 一般来讲数据集可分为训练集(train)验证集(val)测试集(test),训练集用于训练,验证集用于验证训练期间的模型精度,测试集用于测试最终模型的表现。通常利用验证集可用来设计一些交叉验证方法,在数据量较少的情况下能够提高模型的鲁棒性,在我的这次任务中只将数据集分为trainval,只是为了观察模型的训练过程,训练完成后测试单张图片。

(2)数据预处理。
\quad \quad 常用数据预处理方法可概述为2类,数据标准化(Normalize)处理和数据增广(Augmentation)。最常用的数据标准化处理就是数据的归一化,原数据可能数据量很大,维数很多,计算机处理起来时间复杂度很高,预处理可以降低数据维度。同时,把数据都规范到(0-1),这样使得它们对模型的影响具有同样的尺度。
\quad \quad 我通过torchvision.transforms中的各类数据处理函数对图像数据进行预处理:
\quad \quad 关于torchvision.transforms的各种算法实现,见我的另一篇博文:
\quad \quad https://blog.csdn.net/Bobodareng/article/details/117597673

import torchvision.transforms as transforms
#数据预处理
transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4), #上下左右填充4个像素后随机裁剪,由于CIFAR_10数据图片均为32×32,故将随机裁剪后的尺度也设为32
        transforms.RandomHorizontalFlip(),#随机水平翻转处理
        transforms.ToTensor(),  #转换为张量Tensor
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])#归一化处理
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

(3)随机可视化数据集

​ 我们通过imshow()函数对加载的验证集(val)数据进行随机可视化,输出随机的8张图片和对应标签。

 import matplotlib.pyplot as plt
 import numpy as np
 from torchvision import datasets, transforms,utils
 
    def main():
    ...
    classes = ('airplane', 'automobile', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    # 随机显示val图片
    def imshow(img):
        img = img / 2 + 0.5  # unnormalize,逆标准化还原原始图像
        npimg = img.numpy()  #将图像转化为numpy.andgrry格式
        plt.imshow(np.transpose(npimg, (1, 2, 0)))   
        plt.show()
    #输出图像对应的标签
    print(' '.join('%5s' % classes[val_label[j].item()] for j in range(8)))
    imshow(utils.make_grid(val_image))
    #...
    
    if __name__ == '__main__':
       main()
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

输出结果如下图所示:
在这里插入图片描述

2.3 模型搭建(Model)

\quad \quad VGG16包含了16个隐藏层(13个卷积层和3个全连接层),在VGG-16的基础上搭建模型,特征提取层沿用原网络结构(13个卷积层),由于我使用的CIFAR_10数据集的图片是3x32x32的图片,所以这里面有一些通道是和3x224x224图片是不一样,尺寸小了7x7倍,需要进行调整,比如在后面全连接层,我的是256的输出通道,最后是10个类,因为VGG网络参数量在全连接层最多,这样可以减少参数,而且参考过的几篇文章说即便去掉几个全连接层性能也不会受太大影像。

\quad \quad 模型完整代码(model.py):

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import torch.nn as nn
import torch

class VGG(nn.Module):
    def __init__(self, features, num_classes=10, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*1*1, 256),           #针对CIFAR_10 的input 32*32进行了修改
            nn.ReLU(True),                     #激活函数ReLU
            nn.Dropout(p=0.5),                 #随机失活,神经元失活率为50%
            nn.Linear(256, 256),               #为了减少参数,输出通道设为256
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(256, num_classes)        #num_classes=10
        )
        if init_weights:
            self._initialize_weights()
        
    def forward(self, x):
        # N x 3 x 32 x 32
        x = self.features(x)
        # N x 512 x 1 x 1
        x = torch.flatten(x, start_dim=1)      #将卷积层输出结果展开为一维向量
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


    def make_features(cfg: list): #自定义特征提取块,通过for循环搭建vgg16的13个卷积层(原始结构不做改动)
       layers = []
       in_channels = 3            # 原始图片输入channels=3
       for v in cfg:              # 利用传入参数列表cfg搭建
           if v == "M":           # M代表最大池化层,卷积核2*2,步长为2
               layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
           else:                  #卷积层,卷积核3*3,上下左右均填充一个为单位像素
               conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
               layers += [conv2d, nn.ReLU(inplace=True)]
               in_channels = v
      return nn.Sequential(*layers)


    cfgs = {                              #设置参数字典方便变换模型,可以实验vgg19和其他vgg结构网络
        'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
   }


    def vgg(model_name="vgg16", **kwargs):# 断言输入的模型名称在设置好的字典cfgs中,否则弹出警告
        assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
        cfg = cfgs[model_name]            # 取模型对应的参数列表,作为自定义特侦提取块make_feature的参数

        model = VGG(make_features(cfg), **kwargs)
        return model
        
    if __name__ == '__main__':
        net = vgg("vgg16")   #打印出修改后的vgg16网络结构
        print(net)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

\quad \quad 模型搭建完成后保存为model.py文件,点击运行打印出网络结构:
在这里插入图片描述

VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=256, out_features=10, bias=True)
)
)

Process finished with exit code 0

\quad \quad 说明模型是正确的,接下来准备进行训练。

2.4 训练(Train)

\quad \quad 训练设置 train_loader 的batch_size=200,val_loader的batch_size=50,每个epoch有50000/200=250个train_iteration,10000/50=200个val_iteration,总共训练50个epoch。对数据进行预处理,选择GPU进行运算,实例化之前搭建的vgg模型, 选用交叉熵损失函数(CrossEntropyLoss),优化器选用Adam,lr=0.0002。

\quad \quad Adam优化器主要包含以下几个显著的优点::

\quad \quad 1. 实现简单,计算高效,对内存需求少

\quad \quad 2. 参数的更新不受梯度的伸缩变换影响

\quad \quad 3. 超参数具有很好的解释性,且通常无需调整或仅需很少的微调

\quad \quad 4. 更新的步长能够被限制在大致的范围内(初始学习率)

\quad \quad 5. 能自然地实现步长退火过程(自动调整学习率)

\quad \quad 6. 很适合应用于大规模的数据及参数的场景

\quad \quad 7. 适用于不稳定目标函数

\quad \quad 8. 适用于梯度稀疏或梯度存在很大噪声的问题

\quad \quad 综合Adam在很多情况下算作默认工作性能比较优秀的优化器。详解见:简单认识Adam优化器 - 简书 (jianshu.com)

\quad \quad 完整的训练代码 (cifar_train.py):

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from model import vgg                       
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms,utils
import matplotlib.pyplot as plt
import numpy as np
import hiddenlayer as hl                                                   
import json

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 设置运行设备,如果GPU可用使用GPU,否则使用CPU
    print("using {} device.".format(device))                               # 输出训练的设备名称

    # 数据预处理
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
                                                                              
    # 加载数据集,50000张训练图片                                                                          
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,   # 第一次使用时要将download设置为True才会自动去下载数据集
                                             download=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=200,
                                               shuffle=True, num_workers=0)

    cifar10_list = train_set.class_to_idx
    cla_dict = dict((val, key) for key, val in cifar10_list.items())
    json_str = json.dumps(cla_dict, indent=9)                            # 将字典写入 json file
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    '''   
     {"0": "airplane","1": "automobile", "2": "bird", "3": "cat","4": "deer",
     "5": "dog","6": "frog","7": "horse","8": "ship","9": "truck"}
    '''

    # 10000张验证图片
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=50,
                                             shuffle=True, num_workers=0)
    val_num = len(val_set)
                                             
    # 注释段用来随机可视化数据集                                         
    # val_data_iter = iter(val_loader)
    # val_image, val_label = val_data_iter.next()

    # classes = ('airplane', 'automobile', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # 用来随机可视化数据集的函数
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % classes[val_label[j].item()] for j in range(8)))
    # imshow(utils.make_grid(val_image))
    
    model_name = "vgg16"                                                # Net实例化
    net = vgg(model_name=model_name, num_classes=10, init_weights=True) #没有使用预训练模型参数,初始化权重从头开始训练
    net.to(device)
    loss_function = nn.CrossEntropyLoss()                               # 选用交叉熵损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.0002)                 # 选用Adam优化器
    history1 = hl.History()                                             # 用history记录训练过程指标
    canvas1 = hl.Canvas()                                               # 用Canvas绘制曲线图
    
    epochs=50                                                           # 由于第一次训练较深的神经网络,epoch采用50次
    train_steps=len(train_loader)
    save_path = './vgg16.pth' 
    best_acc = 0.0
    #train   
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0                                              # 每个epoch的loss归零,重新累加
        train_bar = tqdm(train_loader)                                  # Tqdm 是一个快速可扩展的Python进度条,显示每个epoch训练进度
        for step, data in enumerate(train_bar):
            images, labels = data                                       # 获得输入信息[inputs,lables]
            optimizer.zero_grad()                                       # 梯度清零,避免梯度累加
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))            # 计算每个iteration的loss
            loss.backward()                                             # loss后向传播
            optimizer.step()                                            # 优化参数

            running_loss += loss.item()                                 # 每个epoch的loss累加,用于输出本次epoch平均loss
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,  # 进度条信息epoch[n/50]
                                                                     epochs,
                                                                     loss )
                                                                        
        net.eval()
        acc = 0.0                                                      
        with torch.no_grad():                                           # 计算精度过程不求梯度
            val_bar = tqdm(val_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1                 # 计算验证精度 correct_number
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num                                    # 预测正确率(correct_number/val_num(验证集样本数))
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))    # 打印该次epoch的train_loss和val_acc
                                                                        # train_loss=running_loss / train_steps(itreation数)
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)                     # 保存训练好的模型到根目录
        history1.log((epoch, step), train_loss=running_loss / train_steps, val_acc=best_acc)

    print('Finished Training')
    with canvas1:
        canvas1.draw_plot(history1['train_loss'])                       # canvas绘图
        canvas1.draw_plot(history1['val_acc'])
                                              

if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
'''
            if step % 50 == 49:    # print every 500 mini-batches
                with torch.no_grad():
                    outputs = net(val_image.to(device))  # [batch, 10]
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label.to(device)).sum().item() / val_label.size(0)
              
                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 50, accuracy))
                    running_loss = 0.0
 '''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

\quad \quad 开始训练,50个epoch大概持续了不到 1h,电脑也是跑的发烫,输出训练过程如下:

C:\Users\wxyz1\anaconda3\envs\pytorch\python.exe “F:/Deep learning (compute vision)/Code master/deep-learning-for-image-processing-master/pytorch_classification/vggnet/cifar_train.py”
using cuda:0 device.
train epoch[1/50] loss:1.857: 100%|██████████| 250/250 [00:56<00:00, 4.44it/s]
100%|██████████| 200/200 [00:06<00:00, 30.05it/s]
[epoch 1] train_loss: 2.058 val_accuracy: 0.319
train epoch[2/50] loss:1.555: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.22it/s]
[epoch 2] train_loss: 1.684 val_accuracy: 0.414
train epoch[3/50] loss:1.456: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.30it/s]
[epoch 3] train_loss: 1.489 val_accuracy: 0.494
train epoch[4/50] loss:1.106: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.08it/s]
[epoch 4] train_loss: 1.296 val_accuracy: 0.599
train epoch[5/50] loss:1.179: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.22it/s]
[epoch 5] train_loss: 1.133 val_accuracy: 0.626
train epoch[6/50] loss:0.854: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 30.09it/s]
[epoch 6] train_loss: 0.985 val_accuracy: 0.662
train epoch[7/50] loss:0.843: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 30.11it/s]
[epoch 7] train_loss: 0.888 val_accuracy: 0.711
train epoch[8/50] loss:0.841: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.07it/s]
[epoch 8] train_loss: 0.802 val_accuracy: 0.740
train epoch[9/50] loss:0.776: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.05it/s]
[epoch 9] train_loss: 0.739 val_accuracy: 0.748
train epoch[10/50] loss:0.805: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.22it/s]
[epoch 10] train_loss: 0.689 val_accuracy: 0.738
train epoch[11/50] loss:0.521: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.17it/s]
[epoch 11] train_loss: 0.639 val_accuracy: 0.778
train epoch[12/50] loss:0.687: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.18it/s]
[epoch 12] train_loss: 0.601 val_accuracy: 0.790
train epoch[13/50] loss:0.521: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 29.89it/s]
[epoch 13] train_loss: 0.561 val_accuracy: 0.787
train epoch[14/50] loss:0.620: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.15it/s]
[epoch 14] train_loss: 0.533 val_accuracy: 0.788
train epoch[15/50] loss:0.590: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.14it/s]
[epoch 15] train_loss: 0.495 val_accuracy: 0.812
train epoch[16/50] loss:0.431: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.15it/s]
[epoch 16] train_loss: 0.467 val_accuracy: 0.813
train epoch[17/50] loss:0.463: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.23it/s]
[epoch 17] train_loss: 0.450 val_accuracy: 0.819
train epoch[18/50] loss:0.349: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.23it/s]
[epoch 18] train_loss: 0.421 val_accuracy: 0.827
train epoch[19/50] loss:0.346: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.20it/s]
[epoch 19] train_loss: 0.402 val_accuracy: 0.827
train epoch[20/50] loss:0.492: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.25it/s]
[epoch 20] train_loss: 0.383 val_accuracy: 0.825
train epoch[21/50] loss:0.536: 100%|██████████| 250/250 [00:55<00:00, 4.49it/s]
100%|██████████| 200/200 [00:06<00:00, 30.03it/s]
[epoch 21] train_loss: 0.355 val_accuracy: 0.833
train epoch[22/50] loss:0.439: 100%|██████████| 250/250 [00:55<00:00, 4.50it/s]
100%|██████████| 200/200 [00:06<00:00, 30.17it/s]
[epoch 22] train_loss: 0.350 val_accuracy: 0.835
train epoch[23/50] loss:0.301: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 30.07it/s]
[epoch 23] train_loss: 0.332 val_accuracy: 0.840
train epoch[24/50] loss:0.407: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 29.96it/s]
[epoch 24] train_loss: 0.320 val_accuracy: 0.836
train epoch[25/50] loss:0.344: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 29.81it/s]
[epoch 25] train_loss: 0.314 val_accuracy: 0.834
train epoch[26/50] loss:0.264: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 29.92it/s]
[epoch 26] train_loss: 0.284 val_accuracy: 0.836
train epoch[27/50] loss:0.346: 100%|██████████| 250/250 [00:55<00:00, 4.49it/s]
100%|██████████| 200/200 [00:06<00:00, 30.00it/s]
[epoch 27] train_loss: 0.270 val_accuracy: 0.843
train epoch[28/50] loss:0.320: 100%|██████████| 250/250 [00:55<00:00, 4.51it/s]
100%|██████████| 200/200 [00:06<00:00, 30.14it/s]
[epoch 28] train_loss: 0.263 val_accuracy: 0.852
train epoch[29/50] loss:0.272: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.20it/s]
[epoch 29] train_loss: 0.245 val_accuracy: 0.852
train epoch[30/50] loss:0.201: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.19it/s]
[epoch 30] train_loss: 0.238 val_accuracy: 0.853
train epoch[31/50] loss:0.191: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.19it/s]
[epoch 31] train_loss: 0.232 val_accuracy: 0.854
train epoch[32/50] loss:0.212: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.25it/s]
[epoch 32] train_loss: 0.225 val_accuracy: 0.843
train epoch[33/50] loss:0.324: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.16it/s]
[epoch 33] train_loss: 0.214 val_accuracy: 0.848
train epoch[34/50] loss:0.224: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.16it/s]
[epoch 34] train_loss: 0.204 val_accuracy: 0.846
train epoch[35/50] loss:0.251: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.18it/s]
[epoch 35] train_loss: 0.193 val_accuracy: 0.845
train epoch[36/50] loss:0.258: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.18it/s]
[epoch 36] train_loss: 0.195 val_accuracy: 0.852
train epoch[37/50] loss:0.120: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.21it/s]
[epoch 37] train_loss: 0.176 val_accuracy: 0.851
train epoch[38/50] loss:0.103: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.25it/s]
[epoch 38] train_loss: 0.181 val_accuracy: 0.854
train epoch[39/50] loss:0.237: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.14it/s]
[epoch 39] train_loss: 0.172 val_accuracy: 0.849
train epoch[40/50] loss:0.249: 100%|██████████| 250/250 [00:55<00:00, 4.53it/s]
100%|██████████| 200/200 [00:06<00:00, 30.09it/s]
[epoch 40] train_loss: 0.163 val_accuracy: 0.852
train epoch[41/50] loss:0.167: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.09it/s]
[epoch 41] train_loss: 0.159 val_accuracy: 0.857
train epoch[42/50] loss:0.237: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.25it/s]
[epoch 42] train_loss: 0.149 val_accuracy: 0.852
train epoch[43/50] loss:0.137: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.17it/s]
[epoch 43] train_loss: 0.152 val_accuracy: 0.856
train epoch[44/50] loss:0.157: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 29.91it/s]
[epoch 44] train_loss: 0.145 val_accuracy: 0.860
train epoch[45/50] loss:0.209: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.16it/s]
[epoch 45] train_loss: 0.135 val_accuracy: 0.854
train epoch[46/50] loss:0.080: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.19it/s]
[epoch 46] train_loss: 0.138 val_accuracy: 0.859
train epoch[47/50] loss:0.153: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.13it/s]
[epoch 47] train_loss: 0.129 val_accuracy: 0.855
train epoch[48/50] loss:0.143: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.15it/s]
[epoch 48] train_loss: 0.127 val_accuracy: 0.861
train epoch[49/50] loss:0.139: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.20it/s]
[epoch 49] train_loss: 0.120 val_accuracy: 0.860
train epoch[50/50] loss:0.065: 100%|██████████| 250/250 [00:55<00:00, 4.52it/s]
100%|██████████| 200/200 [00:06<00:00, 30.25it/s]
[epoch 50] train_loss: 0.125 val_accuracy: 0.856
Finished Training

Process finished with exit code 0

\quad \quad 得出train_loss和val_acc曲线图如下:

\quad \quad 可见train_loss逐渐收敛到0.1左右(估计再有20个epoch可以收敛到0.0x),验证精度达到86%左右,图中 x x x轴单位为 epoch:iteration。
\quad \quad 事实上上图是我多次尝试最终的训练结果,刚开始由于尝试对每个epoch的每隔50个iteration 的train_loss 都进行可视化输出,由于单位换算不准确得到下面的错误train_loss图(batch_size均为128,50个epoch, 390个iteration/epoch,每50个iteration打印输出信息),可以看出由于换算错误一开始的loss值就很小(异常),但是val_acc公式不受影响,显示了50个epoch下每50个iteration的val_acc波动情况,最高达到89%,train_loss的曲线形态也依然说明了网络学习的趋势。

\quad \quad 之后又尝试调整,将batch_size改为300,150 iteration/epoch,train_loss曲线值依然没有反映真实值(换算仍不正确),于是想到以epoch为单位对tran_loss进行输出(最终实验表明是正确的,train_loss显示为正常值)省去很多麻烦,由于iteration间隔扩大,val_acc曲线变得平滑,基本趋势没有改变。

\quad \quad 之后想验证batch_size=500,epoch=90对于val_cc收敛精度的影响,跑了2个小时,得出的曲线如下,精度提升不明显,抖动比较明显,大概收敛在87%左右,说明调大batch_size和epoch对val_cc精读提升不大。

\quad \quad 最终选择 train_loader 的batch_size=200,val_acc的batch_size=50,以epoch为单位输出train_loss,总共训练50个epoch,得到顶图曲线作为最终结果。

2.5 预测(Predict)

\quad \quad 在根目录新建文件夹sample存放测试图片,任何尺度均可。将预测的结果图片保存在predict文件夹中。预测程序保存为predict.py:

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 对输入的测试图片进行一个预处理,裁剪为32×32
    data_transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 加载测试图片
    img_path = "./sample/01.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # 读取lable字典文件
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)
    
    # 实例化模型
    model = vgg(model_name="vgg16", num_classes=10).to(device)
    # 加载训练好的模型文件
    weights_path = "./vgg16.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # 预测图片所属类别
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    # 显示并保存预测结果                                             
    plt.title(print_res)
    print(print_res)
    plt.savefig('./predict/01.jpg', bbox_inches=None)
    plt.show()


if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

\quad \quad 预测结果如下:

\quad \quad 通过预测可以看出,模型学习的效果还是不错的,识别正确类的prob都很高,尤其是ship、frog、deer和bird类,但是也有错分,比如horse,说明模型在识别对象和背景对比差异不明显的object时还有待改进。

在这里插入图片描述

\quad \quad 在CIFAR_10验证集上预测结果的混淆矩阵如上图所示,the model accuracy is 0.8546,下表为各类别的预测准确度(Precision )、召回率(Recall)和特异度(Specificity)。

Classes\MericsPrecisionRecallSpecificity
airplane0.8890.8690.988
automobile0.960.9030.996
bird0.880.740.989
cat0.7510.6530.976
deer0.8070.8650.977
dog0.780.8210.974
frog0.8590.9140.983
horse0.8420.930.981
ship0.9240.9180.992
truck0.860.9330.983

三、VGG-16在GID数据集上的实现

3.1 Gaofen Image Dataset(GID)数据集简介

\quad \quad Gaofen Image Dataset(GID)是一个用于土地利用和土地覆盖(LULC)分类的大型数据集。它包含来自中国60多个不同城市的150幅高质量高分二号(GF-2)图像,这些图像覆盖的地理区域超过了5万km²。GID图像具有较高的类内多样性和较低的类间可分离性。GF-2是高清晰度地球观测系统(HDEOS)的第二颗卫星。GF-2卫星包括了空间分辨率为1 m的全色图像和4 m的多光谱图像,图像大小为6908×7300像素。多光谱提供了蓝色、绿色、红色和近红外波段的图像。自2014年启动以来,GF-2已被用于土地调查、环境监测、作物估算、建设规划等重要应用。
在这里插入图片描述

\quad \quad 本次任务选用的是GID中遥感场景分类训练数据集—SecenClass Training Set,其中包含了15个场景类别,每个类别有2000张56×56的影像,总共30K遥感场景影像,训练完全够用。

类别包括:

  • industrial land(工业用地)
    在这里插入图片描述
  • shrub land(灌木地 )
    在这里插入图片描述
  • natural grassland(自然草地)
    在这里插入图片描述
  • artificial grassland(人工草地)
    在这里插入图片描述
  • river(河流)
    在这里插入图片描述
  • lake(湖泊)
    在这里插入图片描述
  • pond(池塘)
    在这里插入图片描述
  • urban residential(城市住宅 )
    在这里插入图片描述
  • rural residential(农村住宅 )
    在这里插入图片描述
  • traffic land(交通用地)
    在这里插入图片描述
  • paddy field(稻田)
    在这里插入图片描述
  • irrigated land(灌溉用地)
    在这里插入图片描述
  • dry cropland(旱地 )
    在这里插入图片描述
  • garden plot(园地 )
    在这里插入图片描述
  • arbor woodland(林地 )
    在这里插入图片描述
    数据集下载地址http://captain.whu.edu.cn/GID/

    相关参考文献:
    【Tong X Y, Xia G S, Lu Q, et al. Learning Transferable Deep Models for Land-Use Classification with High-Resolution Remote Sensing Images[J]. arXiv preprint arXiv:1807.05713, 2018.】

(1)数据集划分
\quad \quad 利用数据集分割程序split_data.py将SecenClass Training Set划分为train(27000张)和val(3000张),观察模型的训练过程,训练完成后测试单张图片。划分好的data文件如下图:

在这里插入图片描述

完整的split_data.py:

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向data_set文件夹,data文件夹是其子文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "data")
    origin_flower_path = os.path.join(data_root, "data_set")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68

(2)数据预处理。
\quad \quad 和前面一样,通过torchvision.transforms中的各类数据处理函数对图像数据进行预处理:

 ...
 data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(56),   #随机裁剪的尺寸依然为56
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((56, 56)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }
 ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

(3)随机可视化数据集

\quad \quad 我们通过imshow()函数对加载的验证集(val)数据进行随机可视化,输出随机的8张图片和对应标签如下:

在这里插入图片描述

3.1 模型搭建(Model)

\quad \quad 和之前在CIFAR_10数据集上应用的模型相同,不做改变,只是稍作调整,将num_classes变为15。搭建完保存为model.py

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import torch.nn as nn
import torch

class VGG(nn.Module):
    def __init__(self, features, num_classes=15, init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*1*1, 256),           #针对GID 的input 56*56进行了修改
            nn.ReLU(True),                     #激活函数ReLU
            nn.Dropout(p=0.5),                 #随机失活,神经元失活率为50%
            nn.Linear(256, 256),               #为了减少参数,输出通道设为256
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(256, num_classes)        #num_classes=15
        )
        if init_weights:
            self._initialize_weights()
        
    def forward(self, x):
        # N x 3 x 56 x 56
        x = self.features(x)
        # N x 512 x 1 x 1
        x = torch.flatten(x, start_dim=1)      #将卷积层输出结果展开为一维向量
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


    def make_features(cfg: list): #自定义特征提取块,通过for循环搭建vgg16的13个卷积层(原始结构不做改动)
       layers = []
       in_channels = 3            # 原始图片输入channels=3
       for v in cfg:              # 利用传入参数列表cfg搭建
           if v == "M":           # M代表最大池化层,卷积核2*2,步长为2
               layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
           else:                  #卷积层,卷积核3*3,上下左右均填充一个为单位像素
               conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
               layers += [conv2d, nn.ReLU(inplace=True)]
               in_channels = v
      return nn.Sequential(*layers)


    cfgs = {                      #设置参数字典方便变换模型,可以实验vgg19和其他vgg结构网络
        'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
        'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
   }


    def vgg(model_name="vgg16", **kwargs):# 断言输入的模型名称在设置好的字典cfgs中,否则弹出警告
        assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
        cfg = cfgs[model_name]            # 取模型对应的参数列表,作为自定义特侦提取块make_feature的参数

        model = VGG(make_features(cfg), **kwargs)
        return model
    if __name__ == '__main__':
        net = vgg("vgg16")   #打印出修改后的vgg16网络结构
        print(net)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

\quad \quad 打印出网络结构:

VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=512, out_features=256, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=256, out_features=256, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=256, out_features=15, bias=True)
)
)

Process finished with exit code 0

3.2 训练(Train)

\quad \quad 由于刚开始尝试在遥感数据集上训练VGG(较深的网络),之前训练LeNet和AlexNet(均在CIFAR_10数据集上)收敛速度还可以,鉴于自己电脑的算力(GTX 1050Ti 4G),分三个阶段进行训练,每个阶段训练30个epoch,总共训练90个epochbatch_size设为200,每个阶段训练完毕分别保存模型为: vgg16Net.pth、 vgg16Net2.pth和vgg16Net3pth,后一个阶段加载上一个阶段的预训练模型继续训练直到Train_loss趋近于收敛。

\quad \quad 完整训练代码train.py:

import os
import json
import torch
import torch.nn as nn
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import hiddenlayer as hl
from model import vgg


def main():
    #指定训练设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")      
    print("using {} device.".format(device))
    #数据预处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(56),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((56, 56)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    }
    #加载数据
    data_root = os.path.abspath(os.path.join(os.getcwd(), "./.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "data")  # data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    scene_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in scene_list.items())
    # 将标签分类写入字典文件 json file
    json_str = json.dumps(cla_dict, indent=14)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    '''
    {              
              "0": "arbor woodland",
              "1": "artificial grassland",
              "2": "dry cropland",
              "3": "garden plot",
              "4": "industrial land",
              "5": "irrigated land",
              "6": "lake",
              "7": "natural grassland",
              "8": "paddy field",
              "9": "pond",
              "10": "river",
              "11": "rural residential",
              "12": "shrub land",
              "13": "traffic land",
              "14": "urban residential"
     }
   '''
    batch_size =200
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers,单处理器一般设置为0
    print('Using {} dataloader workers every process'.format(nw))
    #加载训练集
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    #加载验证集
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=True, #shufle=Ture,打乱数据集,防止模型训练在某一类
                                                  num_workers=nw)                      #上过度训练
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # classes = ('arbor woodland', 'artificial grassland','dry cropland', 'garden plot','industrial land', 'irrigated land',
    #            'lake','natural grassland','paddy field','pond','river', 'rural residential','shrub land','traffic land',
    #            'urban residential')
    # val_data_iter = iter(validate_loader)
    # val_image, val_label = val_data_iter.next()
    #
    # 随机显示val图片
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize,逆标准化还原原始图像
    #     npimg = img.numpy()  # 将图像转化为numpy.andgrry格式
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # # 输出图像对应的标签
    # print(' '.join('%5s' % classes[val_label[j].item()] for j in range(8)))
    # imshow(utils.make_grid(val_image))

    model_name = "vgg16"
    net = vgg(model_name=model_name, num_classes=15, init_weights=True)
    net.load_state_dict(torch.load("./vgg16Net2.pth"))                               #加载第二阶段训练模型
    net.to(device)
    history1 = hl.History()
    canvas1 = hl.Canvas()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)                              #学习率初始为0.0002

    epochs = 30                                                                      #每次epoch=30,训练3次
    best_acc = 0.0
    save_path = './{}Net3.pth'.format(model_name)
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss )

        # validate
        net.eval()
        acc = 0.0                                                                 # 计算 val_acc/ epoch
        with torch.no_grad():                                                          
            val_bar = tqdm(validate_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        history1.log((epoch, step), train_loss=running_loss / train_steps, val_acc=best_acc)

    print('Finished Training')
    with canvas1:
        canvas1.draw_plot(history1['train_loss'])
        canvas1.draw_plot(history1['val_acc'])

if __name__ == '__main__':
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155

\quad \quad 训练过程如下所示,刚开始每个epoch大约花费5—9 min,随着训练加快,每个epoch 花费的时间稳定在1.25 min左右。由最后30个epoch的训练过程可以看出train_loss的收敛速度逐渐放缓,但并没有完全收敛,验证精度val_acc也一直在保持上升状态,说明网络依然在学习,目测还需要至少40个epoch才能完全收敛。训练全程LR都为0.0002没有变动,网上看过的几篇调参博文说后期LR增大或许可以加速收敛速度,有待尝试。

\quad \quad 参考博文连接:train loss相关问题 - 吱吱了了 - 博客园 (cnblogs.com)

C:\Users\wxyz1\anaconda3\envs\pytorch\python.exe “F:/Deep learning (compute vision)/Code master/deep-learning-for-image-processing-master/Remote sensing scene classification/train.py”
using cuda:0 device.
Using 8 dataloader workers every process
using 27000 images for training, 3000 images for validation.
train epoch[1/30] loss:0.428: 100%|██████████| 135/135 [03:48<00:00, 1.69s/it]
100%|██████████| 15/15 [00:46<00:00, 3.07s/it]
[epoch 1] train_loss: 0.424 val_accuracy: 0.845
train epoch[2/30] loss:0.365: 100%|██████████| 135/135 [04:48<00:00, 2.13s/it]
100%|██████████| 15/15 [00:45<00:00, 3.05s/it]
[epoch 2] train_loss: 0.413 val_accuracy: 0.858
train epoch[3/30] loss:0.386: 100%|██████████| 135/135 [03:27<00:00, 1.54s/it]
100%|██████████| 15/15 [00:59<00:00, 3.98s/it]
[epoch 3] train_loss: 0.411 val_accuracy: 0.867
train epoch[4/30] loss:0.297: 100%|██████████| 135/135 [04:41<00:00, 2.08s/it]
100%|██████████| 15/15 [00:44<00:00, 2.98s/it]
[epoch 4] train_loss: 0.392 val_accuracy: 0.863
train epoch[5/30] loss:0.358: 100%|██████████| 135/135 [02:15<00:00, 1.00s/it]
100%|██████████| 15/15 [00:17<00:00, 1.14s/it]
[epoch 5] train_loss: 0.396 val_accuracy: 0.858
train epoch[6/30] loss:0.462: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 6] train_loss: 0.379 val_accuracy: 0.861
train epoch[7/30] loss:0.299: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 7] train_loss: 0.394 val_accuracy: 0.872
train epoch[8/30] loss:0.414: 100%|██████████| 135/135 [01:17<00:00, 1.75it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 8] train_loss: 0.375 val_accuracy: 0.860
train epoch[9/30] loss:0.305: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 9]train_loss: 0.365 val_accuracy: 0.866
train epoch[10/30] loss:0.416: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 10] train_loss: 0.367 val_accuracy: 0.862
train epoch[11/30] loss:0.306: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 11] train_loss: 0.346 val_accuracy: 0.867
train epoch[12/30] loss:0.301: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 12] train_loss: 0.348 val_accuracy: 0.868
train epoch[13/30] loss:0.270: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 13] train_loss: 0.351 val_accuracy: 0.870
train epoch[14/30] loss:0.269: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 14] train_loss: 0.339 val_accuracy: 0.878
train epoch[15/30] loss:0.295: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 15] train_loss: 0.352 val_accuracy: 0.866
train epoch[16/30] loss:0.232: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 16] train_loss: 0.347 val_accuracy: 0.876
train epoch[17/30] loss:0.322: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 17] train_loss: 0.330 val_accuracy: 0.870
train epoch[18/30] loss:0.426: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 18] train_loss: 0.338 val_accuracy: 0.870
train epoch[19/30] loss:0.292: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 19] train_loss: 0.325 val_accuracy: 0.872
train epoch[20/30] loss:0.272: 100%|██████████| 135/135 [01:16<00:00, 1.77it/s]
100%|██████████| 15/15 [00:16<00:00, 1.09s/it]
[epoch 20] train_loss: 0.326 val_accuracy: 0.860
train epoch[21/30] loss:0.277: 100%|██████████| 135/135 [01:16<00:00, 1.76it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 21] train_loss: 0.328 val_accuracy: 0.875
train epoch[22/30] loss:0.232: 100%|██████████| 135/135 [01:17<00:00, 1.75it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 22] train_loss: 0.319 val_accuracy: 0.880
train epoch[23/30] loss:0.338: 100%|██████████| 135/135 [01:16<00:00, 1.76it/s]
100%|██████████| 15/15 [00:16<00:00, 1.12s/it]
[epoch 23] train_loss: 0.313 val_accuracy: 0.876
train epoch[24/30] loss:0.306: 100%|██████████| 135/135 [01:17<00:00, 1.75it/s]
100%|██████████| 15/15 [00:18<00:00, 1.24s/it]
[epoch 24] train_loss: 0.322 val_accuracy: 0.875
train epoch[25/30] loss:0.315: 100%|██████████| 135/135 [01:16<00:00, 1.77it/s]
100%|██████████| 15/15 [00:16<00:00, 1.11s/it]
[epoch 25] train_loss: 0.313 val_accuracy: 0.872
train epoch[26/30] loss:0.291: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 26] train_loss: 0.301 val_accuracy: 0.886
train epoch[27/30] loss:0.356: 100%|██████████| 135/135 [01:16<00:00, 1.77it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 27] train_loss: 0.312 val_accuracy: 0.869
train epoch[28/30] loss:0.295: 100%|██████████| 135/135 [01:16<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 28] train_loss: 0.300 val_accuracy: 0.886
train epoch[29/30] loss:0.224: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 29] train_loss: 0.299 val_accuracy: 0.884
train epoch[30/30] loss:0.348: 100%|██████████| 135/135 [01:15<00:00, 1.78it/s]
100%|██████████| 15/15 [00:16<00:00, 1.08s/it]
[epoch 30] train_loss: 0.293 val_accuracy: 0.884
Finished Training

Process finished with exit code 0

\quad \quad HiddenLayer是一个小型库。它覆盖基础元素,但你可能需要为自己的用例进行扩展。如果要跟踪您的训练过程,您需要使用两个类:History 存储指标,Canvas 进行绘制。 利用HiddenLayer模块绘制train_loss和val_acc曲线图如下:

\quad \quad 关于 HiddenLayer的详细用法参见:hiddenlayer/pytorch_train.ipynb at master · waleedka/hiddenlayer · GitHub

3.3 预测(Predict)

\quad \quad 虽然90个epoch后train_loss仍然没有完全收敛,但是val_acc已经可以达到89%,效果还不错。在根目录新建文件夹sample存放测试图片,任何尺度均可。将预测的结果图片保存在predict文件夹中。预测程序保存为predict.py:

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import vgg


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 对输入的测试图片进行一个预处理,裁剪为56×56
    data_transform = transforms.Compose(
        [transforms.Resize((56, 56)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 加载图片
    img_path = "./sample/01.tif"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # 增加batch维
    img = torch.unsqueeze(img, dim=0)

    # 读取之前写的json标签文件
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)
    
    # 实例化模型
    model = vgg(model_name="vgg16", num_classes=15).to(device)
    # 加载训练好的模型文件
    weights_path = "./vgg16Net3.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    model.eval()
    with torch.no_grad():
        # 预测图片分类
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.savefig('./predict/01.jpg', bbox_inches=None)  #保存预测图片结果
    plt.show()


if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61

\quad \quad \quad \quad \quad \quad \quad \quad 错误 (urban residential) \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 错误(lake)

\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确 \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 错误(urban residential)

思考:

\quad \quad 可以看出,模型对于差异比较大的异质场景分类效果还可以,如:交通用地(traffic land),但是对于类间差异不明显的场景分类效果还有待提高,例如:城市住宅(urban residential)和农村住宅(rural residential)。究其原因:

(1)是数据集划分太过细致,许多类间差异不明显的对象容易混淆,例如将水体分为:河流(river)、湖泊(lake)、池塘(pond);将草地划分为人工草地(artificial grassland)和自然草地(nature grassland)。由于图片尺寸为56×56,场景分割不太宏观,部分图片缺失有效的上下文信息从而引起模型的错分。一般来说,对于遥感影像的宏观分类效果比较好。

(2)模型有待改进,可以添加一些对于光谱信息识别较好的tricks来提升模型的效果。通过扩展其他的遥感数据进行训练来提高模型的泛化性。

\quad \quad 之后将train_loader 的batch_size改为300,val_loader的batch_size改为30,训练100个epoch,观察模型的train_loss和val_acc都最终收敛,训练过程如下:

using cuda:0 device.
Using 8 dataloader workers every process
using 27000 images for training, 3000 images for validation.
train epoch[1/100] loss:2.436: 100%|██████████| 90/90 [07:41<00:00, 5.12s/it]
100%|██████████| 100/100 [01:11<00:00, 1.40it/s]
[epoch 1] train_loss: 2.542 val_accuracy: 0.188
train epoch[2/100] loss:2.149: 100%|██████████| 90/90 [04:56<00:00, 3.30s/it]
100%|██████████| 100/100 [00:44<00:00, 2.27it/s]
[epoch 2] train_loss: 2.233 val_accuracy: 0.186
train epoch[3/100] loss:2.039: 100%|██████████| 90/90 [01:26<00:00, 1.04it/s]
100%|██████████| 100/100 [00:17<00:00, 5.84it/s]
[epoch 3] train_loss: 2.112 val_accuracy: 0.260
train epoch[4/100] loss:1.914: 100%|██████████| 90/90 [01:14<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 4] train_loss: 1.985 val_accuracy: 0.327
train epoch[5/100] loss:1.698: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.95it/s]
[epoch 5] train_loss: 1.808 val_accuracy: 0.382
train epoch[6/100] loss:1.754: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.93it/s]
[epoch 6] train_loss: 1.710 val_accuracy: 0.442
train epoch[7/100] loss:1.584: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.94it/s]
[epoch 7] train_loss: 1.626 val_accuracy: 0.458
train epoch[8/100] loss:1.435: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 8] train_loss: 1.498 val_accuracy: 0.494
train epoch[9/100] loss:1.417: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 9] train_loss: 1.451 val_accuracy: 0.500
train epoch[10/100] loss:1.282: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.00it/s]
[epoch 10] train_loss: 1.361 val_accuracy: 0.541
train epoch[11/100] loss:1.268: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 11] train_loss: 1.319 val_accuracy: 0.533
train epoch[12/100] loss:1.179: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.96it/s]
[epoch 12] train_loss: 1.252 val_accuracy: 0.579
train epoch[13/100] loss:1.116: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 13] train_loss: 1.194 val_accuracy: 0.600
train epoch[14/100] loss:1.234: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.97it/s]
[epoch 14] train_loss: 1.158 val_accuracy: 0.598
train epoch[15/100] loss:0.975: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.01it/s]
[epoch 15] train_loss: 1.109 val_accuracy: 0.629
train epoch[16/100] loss:0.972: 100%|██████████| 90/90 [01:13<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 6.00it/s]
[epoch 16] train_loss: 1.063 val_accuracy: 0.645
train epoch[17/100] loss:1.038: 100%|██████████| 90/90 [01:14<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 17] train_loss: 1.025 val_accuracy: 0.679
train epoch[18/100] loss:0.939: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.01it/s]
[epoch 18] train_loss: 0.986 val_accuracy: 0.681
train epoch[19/100] loss:0.939: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.98it/s]
[epoch 19] train_loss: 0.953 val_accuracy: 0.676
train epoch[20/100] loss:0.798: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 20] train_loss: 0.914 val_accuracy: 0.688
train epoch[21/100] loss:0.876: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 21] train_loss: 0.899 val_accuracy: 0.709
train epoch[22/100] loss:0.916: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 22] train_loss: 0.872 val_accuracy: 0.717
train epoch[23/100] loss:0.826: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 23] train_loss: 0.861 val_accuracy: 0.712
train epoch[24/100] loss:0.841: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 24] train_loss: 0.823 val_accuracy: 0.744
train epoch[25/100] loss:0.860: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 25] train_loss: 0.795 val_accuracy: 0.743
train epoch[26/100] loss:0.820: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 26] train_loss: 0.781 val_accuracy: 0.751
train epoch[27/100] loss:0.828: 100%|██████████| 90/90 [01:15<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 27] train_loss: 0.753 val_accuracy: 0.761
train epoch[28/100] loss:0.667: 100%|██████████| 90/90 [01:15<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.89it/s]
[epoch 28] train_loss: 0.738 val_accuracy: 0.759
train epoch[29/100] loss:0.685: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.00it/s]
[epoch 29] train_loss: 0.732 val_accuracy: 0.757
train epoch[30/100] loss:0.709: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 30] train_loss: 0.717 val_accuracy: 0.759
train epoch[31/100] loss:0.844: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.01it/s]
[epoch 31] train_loss: 0.704 val_accuracy: 0.771
train epoch[32/100] loss:0.746: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.98it/s]
[epoch 32] train_loss: 0.677 val_accuracy: 0.795
train epoch[33/100] loss:0.607: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.97it/s]
[epoch 33] train_loss: 0.675 val_accuracy: 0.804
train epoch[34/100] loss:0.700: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.97it/s]
[epoch 34] train_loss: 0.670 val_accuracy: 0.797
train epoch[35/100] loss:0.567: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 5.99it/s]
[epoch 35] train_loss: 0.632 val_accuracy: 0.801
train epoch[36/100] loss:0.683: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 36] train_loss: 0.608 val_accuracy: 0.807
train epoch[37/100] loss:0.502: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 37] train_loss: 0.615 val_accuracy: 0.819
train epoch[38/100] loss:0.579: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.00it/s]
[epoch 38] train_loss: 0.598 val_accuracy: 0.812
train epoch[39/100] loss:0.598: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 39] train_loss: 0.581 val_accuracy: 0.823
train epoch[40/100] loss:0.517: 100%|██████████| 90/90 [01:16<00:00, 1.18it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 40] train_loss: 0.574 val_accuracy: 0.808
train epoch[41/100] loss:0.514: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 41] train_loss: 0.578 val_accuracy: 0.821
train epoch[42/100] loss:0.514: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.91it/s]
[epoch 42] train_loss: 0.567 val_accuracy: 0.832
train epoch[43/100] loss:0.506: 100%|██████████| 90/90 [01:19<00:00, 1.13it/s]
100%|██████████| 100/100 [00:18<00:00, 5.50it/s]
[epoch 43] train_loss: 0.544 val_accuracy: 0.831
train epoch[44/100] loss:0.516: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.09it/s]
[epoch 44] train_loss: 0.535 val_accuracy: 0.820
train epoch[45/100] loss:0.521: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.00it/s]
[epoch 45] train_loss: 0.533 val_accuracy: 0.823
train epoch[46/100] loss:0.492: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 46] train_loss: 0.516 val_accuracy: 0.848
train epoch[47/100] loss:0.520: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 47] train_loss: 0.501 val_accuracy: 0.839
train epoch[48/100] loss:0.438: 100%|██████████| 90/90 [01:13<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 48] train_loss: 0.503 val_accuracy: 0.841
train epoch[49/100] loss:0.521: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 49] train_loss: 0.487 val_accuracy: 0.836
train epoch[50/100] loss:0.483: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 50] train_loss: 0.501 val_accuracy: 0.850
train epoch[51/100] loss:0.548: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 51] train_loss: 0.474 val_accuracy: 0.841
train epoch[52/100] loss:0.403: 100%|██████████| 90/90 [01:13<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 52] train_loss: 0.470 val_accuracy: 0.847
train epoch[53/100] loss:0.337: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.96it/s]
[epoch 53] train_loss: 0.452 val_accuracy: 0.850
train epoch[54/100] loss:0.463: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.09it/s]
[epoch 54] train_loss: 0.454 val_accuracy: 0.838
train epoch[55/100] loss:0.461: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 55] train_loss: 0.435 val_accuracy: 0.859
train epoch[56/100] loss:0.416: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 56] train_loss: 0.443 val_accuracy: 0.858
train epoch[57/100] loss:0.497: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 57] train_loss: 0.433 val_accuracy: 0.858
train epoch[58/100] loss:0.480: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 58] train_loss: 0.424 val_accuracy: 0.853
train epoch[59/100] loss:0.420: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 59] train_loss: 0.436 val_accuracy: 0.861
train epoch[60/100] loss:0.392: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 60] train_loss: 0.416 val_accuracy: 0.861
train epoch[61/100] loss:0.379: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 61] train_loss: 0.411 val_accuracy: 0.840
train epoch[62/100] loss:0.400: 100%|██████████| 90/90 [01:13<00:00, 1.22it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 62] train_loss: 0.396 val_accuracy: 0.856
train epoch[63/100] loss:0.390: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 63] train_loss: 0.406 val_accuracy: 0.864
train epoch[64/100] loss:0.347: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 64] train_loss: 0.393 val_accuracy: 0.866
train epoch[65/100] loss:0.275: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 65] train_loss: 0.376 val_accuracy: 0.870
train epoch[66/100] loss:0.411: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 66] train_loss: 0.381 val_accuracy: 0.852
train epoch[67/100] loss:0.287: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 67] train_loss: 0.385 val_accuracy: 0.868
train epoch[68/100] loss:0.350: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 68] train_loss: 0.381 val_accuracy: 0.861
train epoch[69/100] loss:0.254: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.02it/s]
[epoch 69] train_loss: 0.370 val_accuracy: 0.877
train epoch[70/100] loss:0.319: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 70] train_loss: 0.361 val_accuracy: 0.871
train epoch[71/100] loss:0.406: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 71] train_loss: 0.373 val_accuracy: 0.870
train epoch[72/100] loss:0.280: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.01it/s]
[epoch 72] train_loss: 0.335 val_accuracy: 0.869
train epoch[73/100] loss:0.360: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 73] train_loss: 0.349 val_accuracy: 0.877
train epoch[74/100] loss:0.289: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.01it/s]
[epoch 74] train_loss: 0.361 val_accuracy: 0.876
train epoch[75/100] loss:0.365: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 75] train_loss: 0.343 val_accuracy: 0.869
train epoch[76/100] loss:0.246: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 76] train_loss: 0.349 val_accuracy: 0.869
train epoch[77/100] loss:0.255: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 77] train_loss: 0.328 val_accuracy: 0.881
train epoch[78/100] loss:0.339: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.09it/s]
[epoch 78] train_loss: 0.346 val_accuracy: 0.871
train epoch[79/100] loss:0.331: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 79] train_loss: 0.329 val_accuracy: 0.875
train epoch[80/100] loss:0.344: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 80] train_loss: 0.319 val_accuracy: 0.877
train epoch[81/100] loss:0.327: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 81] train_loss: 0.324 val_accuracy: 0.874
train epoch[82/100] loss:0.326: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 82] train_loss: 0.314 val_accuracy: 0.882
train epoch[83/100] loss:0.320: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 83] train_loss: 0.324 val_accuracy: 0.863
train epoch[84/100] loss:0.251: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 84] train_loss: 0.320 val_accuracy: 0.882
train epoch[85/100] loss:0.266: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 85] train_loss: 0.302 val_accuracy: 0.869
train epoch[86/100] loss:0.249: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 86] train_loss: 0.297 val_accuracy: 0.874
train epoch[87/100] loss:0.304: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 87] train_loss: 0.300 val_accuracy: 0.893
train epoch[88/100] loss:0.187: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 89] train_loss: 0.305 val_accuracy: 0.875
train epoch[89/100] loss:0.336: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.05it/s]
[epoch 89] train_loss: 0.297 val_accuracy: 0.881
train epoch[90/100] loss:0.295: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 90] train_loss: 0.293 val_accuracy: 0.881
train epoch[91/100] loss:0.307: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.07it/s]
[epoch 91] train_loss: 0.305 val_accuracy: 0.884
train epoch[92/100] loss:0.210: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.04it/s]
[epoch 92] train_loss: 0.292 val_accuracy: 0.882
train epoch[93/100] loss:0.207: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 5.95it/s]
[epoch 93] train_loss: 0.287 val_accuracy: 0.889
train epoch[94/100] loss:0.190: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.10it/s]
[epoch 94] train_loss: 0.271 val_accuracy: 0.893
train epoch[95/100] loss:0.304: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.08it/s]
[epoch 95] train_loss: 0.281 val_accuracy: 0.893
train epoch[96/100] loss:0.308: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 96] train_loss: 0.260 val_accuracy: 0.890
train epoch[97/100] loss:0.297: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 97] train_loss: 0.272 val_accuracy: 0.875
train epoch[98/100] loss:0.333: 100%|██████████| 90/90 [01:14<00:00, 1.20it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 98] train_loss: 0.282 val_accuracy: 0.879
train epoch[99/100] loss:0.298: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.03it/s]
[epoch 99] train_loss: 0.272 val_accuracy: 0.886
train epoch[100/100] loss:0.278: 100%|██████████| 90/90 [01:14<00:00, 1.21it/s]
100%|██████████| 100/100 [00:16<00:00, 6.06it/s]
[epoch 100] train_loss: 0.268 val_accuracy: 0.879
Finished Training

Process finished with exit code 0

\quad \quad 得到loss曲线图如下:

在这里插入图片描述

\quad \quad train_loss并没有像预想的一样降到0.0x,val_acc也没有得到提高,只是更加收敛。下面对scene image进行类别预测:

\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 错误(urban residential) \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确


\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确 \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确


\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确 \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确


\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确 \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确


\quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 正确 \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad 错误(dry cropland)

\quad \quad 通过预测发现准确率提升了,之前错分的image也正确分类,如:lake、urban residential,并且正确分类的prob都挺高,说明epoch提升可以提高test accurancy。

在这里插入图片描述

\quad \quad 在GID验证集上预测结果的混淆矩阵如上图所示,the model accuracy is 0.8933,下表为各类别的预测准确度(Precision )、召回率(Recall)和特异度(Specificity)。

Classes\MericsPrecisionRecallSpecificity
arbor woodland0.9050.9550.993
artificial grassland0.950.960.996
dry cropland0.9220.890.995
garden plot0.9070.930.993
industrial land0.8220.8750.986
irrigated land0.8590.730.991
lake0.860.890.99
natural grassland0.9570.990.997
paddy field0.9250.980.994
pond0.9020.870.993
river0.9050.860.994
rural residential0.8810.7750.992
shrub land0.920.9250.994
traffic land0.9360.9450.995
urban residential0.7570.8250.981
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/运维做开发/article/detail/957573
推荐阅读
相关标签
  

闽ICP备14008679号