当前位置:   article > 正文

图像分类识别入门训练模型(PyTorch)_pytorch训练图像识别模型

pytorch训练图像识别模型

ResNet_Flower

数据集路径:https://download.csdn.net/download/Ji_HON/88590044

本人常用该工程测试GPU Pytorch环境的搭建

包含了一个训练预测模型的完整流程

train.py:

1、数据集加载(自定义加载数据集的方法,并分为训练集和测试集)

  1. dataset =torchvision.datasets.ImageFolder(root='G:/LJH/DATASETS/flower_photos',transform=train_transform)
  2. train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
  3. valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

2、网络加载(加载网络并加载预训练权重)

  1. model = resnet50()
  2. model.load_state_dict(torch.load('weigths/resnet50.pth'))

3、网络训练与结果保存

  1. for epoch in range(1, 9):
  2. train(model, DEVICE, train_loader, optimizer, epoch)
  3. test(model, DEVICE, valid_loader)
  4. torch.save(model, 'weigths/ResNetFlowermodel-epoch8.pth')

其中包含了训练准确率的可视化显示

flower_predict.py:

1、预测图片的读取

img = Image.open("test.jpg")

2、模型的加载与预测

  1. model=torch.load('weigths/ResNetFlowermodel-epoch8.pth',map_location='cpu')
  2. #model.to(DEVICE)
  3. flowers=['雏菊','蒲公英','玫瑰','向日葵','郁金香']
  4. with torch.no_grad():
  5. output = torch.squeeze(model(img))
  6. print(output)
  7. predict = torch.softmax(output, dim=0)
  8. predict_cla = torch.argmax(predict).numpy()
  9. print(flowers[predict_cla])
  10. plt.show()

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号