赞
踩
1.代码
import time import torch from torch import nn,optim import torch.nn.functional as F import torchvision device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class GlobalAvgPool2d(nn.Module): def __init__(self): super(GlobalAvgPool2d,self).__init__() def forward(self,x): return F.avg_pool2d(x,kernel_size=x.size()[2:]) class FlattenLayer(nn.Module): def __init__(self): super(FlattenLayer,self).__init__() def forward(self,x): return x.view(x.shape[0],-1) class Inception(nn.Module): def __init__(self,in_c,c1,c2,c3,c4): super(Inception,self).__init__() self.p1_1 = nn.Conv2d(in_c,c1,kernel_size=1) self.p2_1 = nn.Conv2d(in_c,c2[0],kernel_size=1) self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1) self.p3_1 = nn.Conv2d(in_c,c3[0],kernel_size=1) self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2) self.p4_1 = nn.MaxPool2d(kernel_size=3,stride=1,padding=1) self.p4_2 = nn.Conv2d(in_c,c4,kernel_size=1) def forward(self,x): p1 = F.relu(self.p1_1(x)) p2 = F.relu(self.p2_2(F.relu(self.p2_1(x)))) p3 = F.relu(self.p3_2(F.relu(self.p3_1(x)))) p4 = F.relu(self.p4_2(self.p4_1(x))) return torch.cat((p1,p2,p3,p4),dim=1) b1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3), nn.ReLU(), nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) b2 = nn.Sequential(nn.Conv2d(64,64,kernel_size=1), nn.Conv2d(64,192,kernel_size=3,padding=1), nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) b3 = nn.Sequential(Inception(192,64,(96,128),(16,32),32), Inception(256,128,(128,192),(32,96),64), nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) b4 = nn.Sequential(Inception(480,192,(96,208),(16,48),64), Inception(512,160,(112,224),(24,64),64), Inception(512,128,(128,256),(24,64),64), Inception(512,112,(144,288),(32,64),64), Inception(528,256,(160,320),(32,128),128), nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) b5 = nn.Sequential(Inception(832,256,(160,320),(32,128),128), Inception(832,384,(192,384),(48,128),128), GlobalAvgPool2d()) net = nn.Sequential(b1,b2,b3,b4,b5,FlattenLayer(),nn.Linear(1024,10)) """ X = torch.rand(1,1,96,96) for blk in net.children(): X = blk(X) print('output shape: ',X.shape) """ def evaluate_accuracy(data_iter,net,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): acc_sum,n = 0.0,0 with torch.no_grad(): for X,y in data_iter: if isinstance(net,torch.nn.Module): net.eval() acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item() net.train() else: if('is_training' in net.__code__.co_varnames): acc_sum += (net(X,is_training=False).argmax(dim=1) == y).float().sum().item() else: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum/n def load_data_fashion_mnist(batch_size,resize=None,root='~/Datasets/FashionMNIST'): trans = [] if resize: trans.append(torchvision.transforms.Resize(size=resize)) trans.append(torchvision.transforms.ToTensor()) transform = torchvision.transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,download=True,transform=transform) mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,download=True,transform=transform) train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=4) test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=4) return train_iter,test_iter def train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs): net = net.to(device) print("training on ",device) loss = torch.nn.CrossEntropyLoss() batch_count = 0 for epoch in range(num_epochs): train_l_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time() for X,y in train_iter: X = X.to(device) y = y.to(device) y_hat = net(X) l = loss(y_hat,y) optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.cpu().item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() n += y.shape[0] batch_count += 1 test_acc = evaluate_accuracy(test_iter,net) print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' %(epoch+1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start)) batch_size = 128 train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=96) lr,num_epochs = 0.001,5 optimizer = torch.optim.Adam(net.parameters(),lr=lr) train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)
2.结果
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。