赞
踩
数据集路径:https://download.csdn.net/download/Ji_HON/88590044
本人常用该工程测试GPU Pytorch环境的搭建
包含了一个训练预测模型的完整流程
train.py:
1、数据集加载(自定义加载数据集的方法,并分为训练集和测试集)
- dataset =torchvision.datasets.ImageFolder(root='G:/LJH/DATASETS/flower_photos',transform=train_transform)
- train_loader =DataLoader(train_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
- valid_loader =DataLoader(valid_dataset,batch_size=4, shuffle=True,num_workers=0)#Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。
2、网络加载(加载网络并加载预训练权重)
- model = resnet50()
- model.load_state_dict(torch.load('weigths/resnet50.pth'))
3、网络训练与结果保存
- for epoch in range(1, 9):
- train(model, DEVICE, train_loader, optimizer, epoch)
- test(model, DEVICE, valid_loader)
- torch.save(model, 'weigths/ResNetFlowermodel-epoch8.pth')
其中包含了训练准确率的可视化显示
flower_predict.py:
1、预测图片的读取
img = Image.open("test.jpg")
2、模型的加载与预测
- model=torch.load('weigths/ResNetFlowermodel-epoch8.pth',map_location='cpu')
- #model.to(DEVICE)
- flowers=['雏菊','蒲公英','玫瑰','向日葵','郁金香']
- with torch.no_grad():
- output = torch.squeeze(model(img))
- print(output)
- predict = torch.softmax(output, dim=0)
- predict_cla = torch.argmax(predict).numpy()
- print(flowers[predict_cla])
- plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。