赞
踩
注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.
目录
- # 导入所需的库和模块
- import os # 导入操作系统相关的库
- import json # 导入处理json数据的库
-
- import torch # 导入PyTorch库
- from PIL import Image # 导入处理图像数据的库
- from torchvision import transforms # 导入PyTorch的图像预处理库
- import matplotlib.pyplot as plt # 导入matplotlib库用于图像显示
-
- from model import vgg # 从model模块中导入vgg模型
- # 判断是否有GPU可用,并设置device变量
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- # 定义图像预处理流程
- data_transform = transforms.Compose(
- [transforms.Resize((224, 224)), # 将图像尺寸调整为224x224
- transforms.ToTensor(), # 将图像转换为Tensor格式
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 对图像进行标准化处理
这里我是用的是绝对路径,可以改成基于上级文件的路径。
- # 加载图像
- img_path = "F:/code/Python/pytorch/VGG_image_classifcation/tulip.jpg" # 定义图像路径
- assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path) # 断言图像文件存在
- img = Image.open(img_path) # 使用PIL库打开图像文件
- plt.imshow(img) # 使用matplotlib显示图像
- # 对图像进行预处理
- img = data_transform(img) # 应用预处理流程
- # 扩展图像数据的batch维度
- img = torch.unsqueeze(img, dim=0) # 将图像数据扩展为batch维度为1的张量
- # 读取类别索引字典
- json_path = './class_indices.json' # 定义json文件路径
- assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path) # 断言json文件存在
-
- with open(json_path, "r") as f: # 打开json文件
- class_indict = json.load(f) # 读取json文件内容到class_indict变量中
- # 创建模型
- model = vgg(model_name="vgg16", num_classes=4).to(device) # 创建vgg16模型,并指定输出类别数为4,然后移动到指定的设备上
- # 加载模型权重
- weights_path = "./vgg16Net.pth" # 定义模型权重文件路径
- assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path) # 断言权重文件存在
- model.load_state_dict(torch.load(weights_path, map_location=device)) # 加载模型权重
- model.eval() # 将模型设置为评估模式
- with torch.no_grad(): # 不计算梯度,节省计算资源
- # 预测类别
- output = torch.squeeze(model(img.to(device))).cpu() # 对图像进行预测,并去除batch维度,然后将结果移动到CPU上
- predict = torch.softmax(output, dim=0) # 对预测结果进行softmax计算,得到每个类别的概率
- predict_cla = torch.argmax(predict).numpy() # 找到概率最大的类别的索引
- # 打印预测结果
- print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
- predict[predict_cla].numpy()) # 格式化预测结果
- plt.title(print_res) # 设置图像标题为预测结果
- for i in range(len(predict)): # 遍历每个类别的概率
- print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- predict[i].numpy())) # 打印每个类别的名称和概率
- plt.show() # 显示图像
- # 如果当前脚本被直接运行(而不是被其他脚本导入),则执行main函数
- if __name__ == '__main__':
- main()
1.记得导入VGG模型
2.结果进行可视化处理
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。