赞
踩
resnet50=torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
in_features=resnet50.fc.in_features
# 将原resnet50网络中的最后一个全连接层改成10分类的输出
resnet50.fc=nn.Linear(in_features,10)
resnet50=resnet50.to(device)
因为resnet50网络需要输入224x224x3大小的图片
因此对网络接收的输入也要做相应的调整
tf=torchvision.transforms.Compose([
torchvision.transforms.Resize(size=(224,224)),
torchvision.transforms.Grayscale(num_output_channels=3),
torchvision.transforms.ToTensor(),
# torchvision.transforms.Normalize((0.1307,),(0.3081,))
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 固定卷积层的参数
optim=torch.optim.Adam(resnet50.fc.parameters(),lr=0.001)
完整代码:
import torch import torchvision from torch import nn from torch.utils.data import DataLoader tf=torchvision.transforms.Compose([ torchvision.transforms.Resize(size=(224,224)), torchvision.transforms.Grayscale(num_output_channels=3), torchvision.transforms.ToTensor(), # torchvision.transforms.Normalize((0.1307,),(0.3081,)) torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) transforms = torchvision.transforms.Compose([ # torchvision.transforms.Normalize((0.1307,),(0.3081,) torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5), torchvision.transforms.ToTensor()]) # 导入数据集 train_data=torchvision.datasets.MNIST(root='./dataset', train=True,transform=tf,download=True) test_data=torchvision.datasets.MNIST(root='./dataset', train=False,transform=tf,download=True) test_size=len(test_data) device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size=128 trainloader=DataLoader(train_data,batch_size=batch_size) testlooader=DataLoader(test_data,batch_size=batch_size) # 定义LeNet网络 class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.model=nn.Sequential( # MNIST数据集大小为28x28,要先做padding=2的填充才满足32x32的输入大小 nn.Conv2d(1,6,5,1,2), nn.ReLU(), nn.MaxPool2d(2,2), nn.Conv2d(6,16,5), nn.ReLU(), nn.MaxPool2d(2,2), nn.Flatten(), nn.Linear(16*5*5,120), nn.ReLU(), nn.Linear(120,84), nn.ReLU(), nn.Linear(84,10) ) def forward(self, x): x=self.model(x) return x resnet50=torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT) vgg16=torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT) in_features=resnet50.fc.in_features # 将原resnet50网络中的最后一个全连接层改成10分类的输出 resnet50.fc=nn.Linear(in_features,10) resnet50=resnet50.to(device) # in_features=vgg16.classifier[6].in_features # vgg16.classifier[6]=nn.Linear(in_features,10) # vgg16=vgg16.to(device) print(resnet50) # print(in_features) epochs=30 model=LeNet().to(device) loss_fn=nn.CrossEntropyLoss().to(device) # 固定卷积层的参数 optim=torch.optim.Adam(resnet50.fc.parameters(),lr=0.001) for epoch in range(epochs): resnet50.train() for data in trainloader: images,labels=data images,labels=images.to(device),labels.to(device) output=resnet50(images) loss=loss_fn(output,labels) optim.zero_grad() loss.backward() optim.step() resnet50.eval() with torch.no_grad(): accuracy=0 for data in testlooader: images,labels=data images,labels=images.to(device),labels.to(device) output=resnet50(images) accuracy+=((output.argmax(1)==labels).sum()) print("第{}轮中,测试集上的准确率为:{}".format(epoch+1,accuracy/test_size))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。