赞
踩
将数据与整个网络都集成到GPU进行运算。
- #gpu方式的mnist手写数字识别
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- from torch.autograd import Variable
- import torch.utils.data as Data
- import torchvision
- import matplotlib.pyplot as plt
- import numpy as np
-
- EPOCH = 1
- BATCH_SIZE = 50
- LR = 0.001
- DOWNLOAD_MNIST = False
-
- train_data = torchvision.datasets.MNIST(
- root='./mnist',
- train=True,
- transform=torchvision.transforms.ToTensor(),
- download=DOWNLOAD_MNIST
- )
- train_loader = Data.DataLoader(
- dataset=train_data,
- batch_size=BATCH_SIZE,
- shuffle=True,
- num_workers=2
- )
- test_data = torchvision.datasets.MNIST(
- root='./mnist',
- train=False
- )
- #change in here
- test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1)).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
- test_y = test_data.test_labels[:2000].cuda()
-
- class CNN(nn.Module):
- def __init__(self):
- super(CNN,self).__init__()
- self.conv1 = nn.Sequential(
- nn.Conv2d(1,16,5,1,2),
- nn.ReLU(),
- nn.MaxPool2d(kernel_size=2)
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(16,32,5,1,2),
- nn.ReLU(),
- nn.MaxPool2d(kernel_size=2)
- )
- self.out = nn.Linear(32 * 7 * 7,10) #10分类的问题
-
- def forward(self,x):
- x = self.conv1(x)
- x = self.conv2(x)
- x = x.view(x.size(0),-1)
- x = self.out(x)
- return x
- def main():
- cnn = CNN()
- cnn.cuda()
-
- optimizer = optim.Adam(cnn.parameters(),lr=LR)
- loss_func = nn.CrossEntropyLoss()
-
- for epoch in range(EPOCH):
- for step,(x,y) in enumerate(train_loader):
- b_x = Variable(x).cuda()
- b_y = Variable(y).cuda()
- output = cnn(b_x)
- loss = loss_func(output,b_y)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- if step % 50 == 0:
- test_output = cnn(test_x)
-
- # !!!!!!!! Change in here !!!!!!!!! #
- pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze() # move the computation in GPU
-
- accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
- print('Epoch: ', epoch, '| train loss: %.4f' % loss.item(), '| test accuracy: %.2f' % accuracy)
- if __name__ == '__main__':
- main()
- Epoch: 0 | train loss: 2.3063 | test accuracy: 0.10
- Epoch: 0 | train loss: 0.5096 | test accuracy: 0.83
- Epoch: 0 | train loss: 0.3086 | test accuracy: 0.91
- Epoch: 0 | train loss: 0.5883 | test accuracy: 0.91
- Epoch: 0 | train loss: 0.1175 | test accuracy: 0.93
- Epoch: 0 | train loss: 0.1339 | test accuracy: 0.94
- Epoch: 0 | train loss: 0.1315 | test accuracy: 0.95
- Epoch: 0 | train loss: 0.1148 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.0298 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.3159 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0756 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.2151 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.1290 | test accuracy: 0.96
- Epoch: 0 | train loss: 0.0385 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0358 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0849 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0173 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0424 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.1845 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.1270 | test accuracy: 0.97
- Epoch: 0 | train loss: 0.0263 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.0845 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.0639 | test accuracy: 0.98
- Epoch: 0 | train loss: 0.0223 | test accuracy: 0.98
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。