当前位置:   article > 正文

AlexNet小结_for layer in net

for layer in net

提示:本文是参考李沐老师和另一个B站up主代码以及讲解对自己所学东西的整理,具体资料连接会在文章中给出。且全部实验代码是在kaggle平台上验证过滴。


前言

建议先看我推荐的资料

李沐老师参考资料地址:link.
B站up主霹雳吧啦Wz:link.
注意:本文主要是对AlexNet网络的梳理,且主要是对代码的梳理,是Pytorch版本。视以后情况,可能会增加tensorflow版本代码。看懂改代码需要一定MLP、CNN和Pytorch基础知识,B站有相关up主讲解比较详细,在此我推荐几个up主吧,大家自行决定决定要不要看吧。
李沐老师主页:link.
B站up主刘二大人:link.
B站up主二次元的Datawhale:link.
其中二次元的Datawhale是一个开源组织,这个开源组织还有其他资料也比较好,pandas教程,西瓜书教程(偏理论教学),其中南瓜书就是由这个开源组织编写的。我觉得可能对刚入门的小伙伴比较友好一些。
还有请大家知晓一下啦,本博客基本是对自己所学知识整理,方便以后自己复习(主要是代码整理)。而且自己也还是学生,初学深度学习(但是不是人工智能方向相关专业学生哦,只是需要用到深度学习作为一个工具使用),有很多表述可能有不当和错误,希望大家可以指出来哦!谢谢大家。


一、网络架构

示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。
这是李沐老师动手深度学习书上的图,自己比较懒,就不动手画了。这个就是AlexNet基本网络架构。

二、搭建AlexNet网络

2.1 版本一

参考地址:代码参考地址视屏参考地址

2.1.1 模型构建

代码如下(示例):

	net = nn.Sequential(
    # 这里,我们使用一个11*11的更大窗口来捕捉对象。
    # 同时,步幅为4,以减少输出的高度和宽度。
    # 另外,输出通道的数目远大于LeNet
    nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 减小卷积窗口,使用填充为2来使得输入与输出的高和宽一致,且增大输出通道数
    nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 使用三个连续的卷积层和较小的卷积窗口。
    # 除了最后的卷积层,输出通道的数量进一步增加。
    # 在前两个卷积层之后,汇聚层不用于减少输入的高度和宽度
    nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),
    nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.Flatten(),
    # 这里,全连接层的输出数量是LeNet中的好几倍。使用dropout层来减轻过拟合
    nn.Linear(6400, 4096), nn.ReLU(),
    # 注意dropout层一般用于全连接层,且用于激活函数后。
    nn.Dropout(p=0.5),
    nn.Linear(4096, 4096), nn.ReLU(),
    nn.Dropout(p=0.5),
    # 最后是输出层。由于这里使用Fashion-MNIST,所以用类别数为10,而非论文中的1000
    nn.Linear(4096, 10))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

每一层输出维度:

X = torch.randn(1, 1, 224, 224)
for layer in net:
    X=layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
  • 1
  • 2
  • 3
  • 4

输出结果:
在这里插入图片描述

2.1.2 数据加载以及模型训练

注意
下述代码中有相当一部分函数我没有列出来,具体函数见该链接LeNet小结,代码都是通用的,为了节约时间,我就没有一一列出了,缺啥就在该博客中找吧。
加载数据:

'''使用的Fashion_MNIST数据集,具体情况也可以见我在上面提供的博客连接'''
batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)
  • 1
  • 2
  • 3

训练模型:

lr, num_epochs = 0.01, 10
device = "cuda:0" if torch.cuda.is_available() else "cpu"
train_ch6(net, train_iter, test_iter, num_epochs, lr, device)
  • 1
  • 2
  • 3

输出结果:
在这里插入图片描述
我们看到训练数据集精度和验证集精度一样,实际上有可能是欠拟合情况发生,一般在真实情况中,我们都是会使模型先过拟合,然后使用权重衰退(weight_decay)、暂退法(Dropout)等来使缓解模型过拟合的情况。

2.2 版本二

参考地址:代码参考地址视屏参考地址
代码参考网址是一个github地址,需要大家合理翻墙。

2.2.1 模型定义

这部分是定义模型的代码。

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        # 是否使用自定义初始化
        if init_weights:
            self._initialize_weights()
            
	# 前向传播
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
        
	# 自定义模型参数初始化函数
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

查看模型模型输出形状。

'''注意通道数是3'''
X = torch.randn((1,3,224,224))
net = AlexNet(num_classes=5, init_weights=True)
for layer in net.modules():
	'''注意,大家可以先用这个for循环只打印 layer.__class__.__name__看看结果,就明白这样写的目的。这里应该使用了二叉树之类的数据结构'''
    if not layer.__class__.__name__ in ["AlexNet", "Sequential"]:
        X = layer(X)
        print(layer.__class__.__name__,'output shape: \t',X.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

输出结果:
在这里插入图片描述

2.2.2 数据下载与处理

接下来这部分是关于数据下载与解压部分的程序。原作者是提供了网址需要下载,现在我直接写了一个简单程序完成这部分操作,程序如下。

import requests
import zipfile
import os
import tarfile

# tgz版本
def download_extract_tgz(url=None):
    # 检查url是否为空
    assert url is not None, "url is None"
    # 压缩包的绝对路径,解压缩后文件夹的绝对路径
    tgzPath, detgzPath = './data.tgz', './'
    # 判断当前路径下是否有文件 data.tgz,有我们就直接退出了。
    if not os.path.exists(tgzPath):
        print(f'正在从{url}下载资源...')
        # 获取内容
        response = requests.get(url)
        with open(tgzPath, 'wb') as file:
            # 通过二进制写文件的方式保存获取的内容
            file.write(response.content)
            # 无需等待缓冲区满,直接写入文件中
            file.flush()
        # 打开压缩文件
        tar = tarfile.open(tgzPath)
        # 解压
        tar.extractall(detgzPath)
        # 关闭文件
        tar.close()
        print("下载并解压完成!")
        # 删除原始压缩包,这句话可以屏蔽掉,否则检测是否if语句判断没啥用
        #os.remove(tgzPath)

# zip版本,这个没有尝试调用,应该是没问题的
def download_extract_zip(url=None):
    # 检查url是否为空
    assert url is not None, "url is None"
    # 压缩包的绝对路径,解压缩后文件夹的绝对路径
    zipPath, dezipPath = './data.zip', './'
    # 判断当前路径下是否有文件 data.zip,有我们就直接退出了。
    if not os.path.exists(zipPath):
        print(f'正在从{url}下载资源...')
        # 获取内容
        response = requests.get(url)
        with open(zipPath, 'wb') as file:
            # 通过二进制写文件的方式保存获取的内容
            file.write(response.content)
            # 无需等待缓冲区满,直接写入文件中
            file.flush()
        # 创建压缩包对象
        f = zipfile.ZipFile(zipPath)
        # 解压
        f.extractall(dezipPath)
        # 关闭文件
        f.close()
        print("下载并解压完成!")
        # 删除原始压缩包, 这句话可以屏蔽掉,否则检测是否if语句判断没啥用
        #os.remove(zipPath)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

写了两个版本,一个是zip版本,一个是tgz版本。
调用上述函数:

# 压缩包网址
url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
download_extract_tgz(url)
  • 1
  • 2
  • 3

因为压缩包是以tgz结尾,因此我们调用download_extract_tgz函数。输出结果如下
在这里插入图片描述
不出意外的话,大家在kaggle上面会有我标红的两个文件夹,其余文件夹是我运行后面的程序得来的,所以目前大家是没有的,大家点击展开flower_photos文件夹会看到5个文件夹,分别存储了5中不同类型的花的照片。
接下来这部分代码,以9:1的比例分出训练集和验证集。这个具体可以看我提供的视屏连接地址,我基本没改,就是改了一下路径,并删除了一些多余的代码。

import os
from shutil import copy, rmtree
import random

# 创建文件夹
def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)
    
def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 得到当前目录
    cwd = os.getcwd()
    data_root = cwd
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path {} does not exist.".format(origin_flower_path)
    
    # 花的种类
    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla) # 对应每一类花的文件夹路径
        images = os.listdir(cla_path) 
        num = len(images) # 该种类的数量
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path) # 复制文件
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            # 这里其实没有必要每轮都打印,可以固定多少次打印一次,防止数据量过大,影响程序效率
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()
    print("processing done!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

调用上述代码:

# 进行训练集和验证集划分
main()
  • 1
  • 2

我们看下效果是啥
在这里插入图片描述
我们会在kaggle上看到又多了valtrain两个文件夹。接下来我们进行下一步操作。
这部分程序如下,这部分程序主要对我们之前分好类的数据集进行再次处理,使其可以喂入模型进行训练。关于json模块的使用,可参考该博客link

# 关于json的使用:https://blog.csdn.net/weixin_38842821/article/details/108359551
from torchvision import transforms, datasets, utils
import json
'''
    RandomResizedCrop(224):将图片(H, W)扩张或者压缩成(224, 224)
    RandomHorizontalFlip():以0.5的概率水平翻转图片
'''
data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

def data_process():
    # train_data处理
    train_dataset = datasets.ImageFolder(root="./train", transform=data_transform["train"])
    validate_dataset = datasets.ImageFolder(root = "./val", transform=data_transform["val"])
    return train_dataset, validate_dataset

# 提取所有文件夹中的图片数据
train_data, val_data = data_process()
# 训练集数量、验证集数量
train_num, val_num= len(train_data), len(val_data)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
class_2_idx = train_data.class_to_idx
# {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
idx_2_class = dict((value, key) for key, value in class_2_idx.items())
# 将idx_2_class转换成json形式的字符串,indent表示缩进设置吧
json_str = json.dumps(idx_2_class, indent=4)
'''
    原作者之所以这么写,是因为使用pycharm分文件写的。
    idx_2_class是在train.py中定义的,为了在predict.py也能使用
    idx_2_class但不进行变量传递话,就写成一个文件来达到共享的目的
'''
with open('idx_2_class.json', 'w') as json_file:
        json_file.write(json_str)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

这部分函数运行之后,我们可以看到在kaggle上多了一个idx_2_class.json文件。
接下来我们来对图片进行显示看看。代码如下:

def load_data_iter(dataset, batch_size, is_train=True):
    # 使用nw个线程(还是进程来着)来加快数据读取
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))
    data_iter =  torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=is_train,num_workers=nw)
    return data_iter

# 加载数据
batch_size = 32
train_data_iter = load_data_iter(train_data, batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

画图部分代码如下:

import matplotlib.pyplot as plt
import numpy as np

# 得到对应文本标签
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=3.0):
    # num_rows:行,num_cols:列,scale:设置图片大小
    figsize = (num_cols * scale, num_rows * scale) # 相当于画布大小
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(np.transpose(img.numpy(), (1,2,0)))
        else:
            # PIL图片
            ax.imshow(img)
        # 不显示X轴
        ax.axes.get_xaxis().set_visible(False)
        # 不显示Y轴
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            # 给每一个图片设置标题
            ax.set_title(titles[i])
    return axes

# 取出第一个batch
imgs = next(iter(train_data_iter))
imgs[0] = imgs[0] / 2 + 0.5  # unnormalize
show_images(imgs[0][:8], 2, 4, titles=get_fashion_mnist_labels(imgs[1][:8]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

输出结果:
在这里插入图片描述

2.2.3 模型训练

在定义训练函数之前,先定义一个有意思的函数:

# Python设置文本文字颜色
class bcolors:
    HEADER = '\033[95m'       # pink
    OKBLUE = '\033[94m'       # blue
    OKGREEN = '\033[92m'      # green
    WARNING = '\033[93m'      # yellow
    FAIL = '\033[91m'         # red
    ENDC = '\033[0m'          # black
    BOLD = '\033[1m'          # black+bold
    UNDERLINE = '\033[4m'     # black+underline


print(bcolors.HEADER + "提示:此时文字颜色为pink")
print(bcolors.OKBLUE + "提示:此时文字颜色为blue")
print(bcolors.OKGREEN + "提示:此时文字颜色为green")
print(bcolors.WARNING + "提示:此时文字颜色为yellow")
print(bcolors.FAIL + "提示:此时文字颜色为red")
print(bcolors.ENDC + "提示:此时文字颜色为black")
print(bcolors.UNDERLINE + "提示:此时文字颜色为black+underline")
print(bcolors.BOLD + "提示:此时文字颜色为black+bold")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

运行之后,输出如下:
在这里插入图片描述
接下来定义我们的训练函数,代码如下:

from tqdm import tqdm
import torch.optim as optim
import sys

def train(net, train_data_iter, val_data_iter, epoch_nums, lr, device, save_path):
    # 放在GPU上计算
    net.to(device)
    # 使用交叉熵计算损失
    loss_function = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = optim.Adam(net.parameters(), lr=lr)
    best_acc = 0.0
    # 训练数据集总batch数和验证数据集总batch数
    train_num = len(train_data_iter)
    for epoch in range(epoch_nums):
        # 训练模式
        net.train()
        # 将每一个epoch的每个batch的loss加起来,然后算每个epoch的平均loss
        loss_sum = 0.0;
        train_bar = tqdm(train_data_iter, file=sys.stdout)
        for X, y in train_bar:
            # 这里要注意将数据集放在device上训练,否则会报错
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            # 梯度清零
            optimizer.zero_grad()
            loss = loss_function(y_hat, y)
            loss.backward() # 反向传播
            optimizer.step() # 参数更新
            loss_sum += loss.item()
            train_bar.desc = f"train epoch[{epoch + 1}/{epoch_nums}] loss:{loss:.3f}"
        # 评估模式
        net.eval()
        val_acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_data_iter, file=sys.stdout)
            for val_X, val_y in val_bar:
                val_X = val_X.to(device)
                val_y = val_y.to(device)
                outputs = net(val_X)
                predict_y = torch.argmax(outputs, dim=1)
                val_acc += torch.eq(predict_y, val_y).sum().item() # 预测正确数
            print(bcolors.HEADER +'[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, loss_sum / train_num, val_acc/val_num) + bcolors.ENDC)
            # 存储在验证集上表现最好的参数
            if val_acc > best_acc:
                best_acc = val_acc 
                torch.save(net.state_dict(), save_path)
    print('Finished Training')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

开始训练:

batch_size = 128
train_data_iter = load_data_iter(train_data, batch_size)
val_data_iter = load_data_iter(val_data, batch_size, is_train=False)
net = AlexNet(num_classes=5, init_weights=True)
epoch_nums = 20
lr = 0.0002
device = "cuda:0" if torch.cuda.is_available() else "cpu"
save_path = './AlexNet.pth' # 保存模型参数文件名
train(net, train_data_iter, val_data_iter, epoch_nums, lr, device, save_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

输出如下,由于结果较长,我只截图了后面几个epoch:
在这里插入图片描述
这部分就是关于模型训练的代码。

2.2.4 模型预测

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
	
	'''我们进行预测的图片必须进行预处理来保证能喂入模型'''
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
	
	'''大家如果要进行预测的话要将图片路径改下'''
    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0) # 增加batch这一个维度

    # read class_indict
    json_path = './idx_2_class.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = AlexNet(num_classes=5).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

以上这部分就是预测代码。

总结

这部分就是关于AlexNet网络的小结。主要是对代码进行了一定程度的优化。有些写的不是特别详细,所以有疑问的话,欢迎提问。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/351078
推荐阅读
相关标签
  

闽ICP备14008679号