赞
踩
1.代码
import time import torch from torch import nn,optim import torch.nn.functional as F import torchvision device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def conv_block(in_channels,out_channels): blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(), nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1)) return blk class DenseBlock(nn.Module): def __init__(self,num_convs,in_channels,out_channels): super(DenseBlock,self).__init__() net = [] for i in range(num_convs): in_c = in_channels + i*out_channels net.append(conv_block(in_c,out_channels)) self.net = nn.ModuleList(net) self.out_channels = in_channels + num_convs * out_channels def forward(self,X): for blk in self.net: Y = blk(X) X = torch.cat((X,Y),dim=1) return X """ blk = DenseBlock(2,3,10) X = torch.rand(4,3,8,8) Y = blk(X) """ def transition_block(in_channels,out_channels): blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(), nn.Conv2d(in_channels,out_channels,kernel_size=1), nn.AvgPool2d(kernel_size=2,stride=2)) return blk """ blk = transition_block(23,10) print(blk(Y).shape) """ net = nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) num_channels,growth_rate = 64,32 num_convs_in_dense_blocks = [4,4,4,4] for i,num_convs in enumerate(num_convs_in_dense_blocks): DB = DenseBlock(num_convs,num_channels,growth_rate) net.add_module("DenseBlock_%d" %i,DB) num_channels = DB.out_channels if i != len(num_convs_in_dense_blocks) - 1: net.add_module("transition_block_%d" %i, transition_block(num_channels,num_channels//2)) num_channels = num_channels // 2 class GlobalAvgPool2d(nn.Module): def __init__(self): super(GlobalAvgPool2d,self).__init__() def forward(self,x): return F.avg_pool2d(x,kernel_size=x.size()[2:]) class FlattenLayer(nn.Module): def __init__(self): super(FlattenLayer,self).__init__() def forward(self,x): return x.view(x.shape[0],-1) net.add_module("BN",nn.BatchNorm2d(num_channels)) net.add_module("relu",nn.ReLU()) net.add_module("global_avg_pool",GlobalAvgPool2d()) net.add_module("fc",nn.Sequential(FlattenLayer(), nn.Linear(num_channels,10))) X = torch.rand((1,1,96,96)) for name,layer in net.named_children(): X = layer(X) print(name,'output shape:\t',X.shape) def evaluate_accuracy(data_iter,net,device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')): acc_sum,n = 0.0,0 with torch.no_grad(): for X,y in data_iter: if isinstance(net,torch.nn.Module): net.eval() acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item() net.train() else: if('is_training' in net.__code__.co_varnames): acc_sum += (net(X,is_training=False).argmax(dim=1) == y).float().sum().item() else: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum/n def load_data_fashion_mnist(batch_size,resize=None,root='~/Datasets/FashionMNIST'): trans = [] if resize: trans.append(torchvision.transforms.Resize(size=resize)) trans.append(torchvision.transforms.ToTensor()) transform = torchvision.transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,download=True,transform=transform) mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,download=True,transform=transform) train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=4) test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=4) return train_iter,test_iter def train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs): net = net.to(device) print("training on ",device) loss = torch.nn.CrossEntropyLoss() batch_count = 0 for epoch in range(num_epochs): train_l_sum,train_acc_sum,n,start = 0.0,0.0,0,time.time() for X,y in train_iter: X = X.to(device) y = y.to(device) y_hat = net(X) l = loss(y_hat,y) optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.cpu().item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item() n += y.shape[0] batch_count += 1 test_acc = evaluate_accuracy(test_iter,net) print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' %(epoch+1,train_l_sum/batch_count,train_acc_sum/n,test_acc,time.time()-start)) batch_size = 256 train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=96) lr,num_epochs = 0.001,5 optimizer = torch.optim.Adam(net.parameters(),lr=lr) train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)
2.结果
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。