需要转换的有:model,数据,criterion,loss函数; 其中optimizer不需要转换
// An highlighted block
1 device = t.device('cuda:0')
2 model = model.to(device)
3 criterion = criterion.to(device)
1 def train(epoch): 2 running_loss = 0.0 3 for batch_idx, data in enumerate(train_loader, 0): 4 inputs, target = data 5 #cuda inputs and target 6 inputs = inputs.to(device) ### 7 target = target.to(device) ### 8 optimizer.zero_grad() 9 10 # forward + backward +update 11 outputs = model(inputs) 12 outputs = outputs.to(device) ### 13 loss = criterion(outputs, target) 14 loss.backward() 15 optimizer.step() 16 17 running_loss += loss.item() 18 if batch_idx % 300 == 299: 19 print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) 20 running_loss = 0.0
复制代码 1 def test(): 2 correct = 0 3 total = 0 4 with t.no_grad(): # ensuring grad can not updating 5 for data in test_loader: 6 images, label = data 7 #cuda images, label 8 images = images.to(device) ### 9 label = label.to(device) ### 10 outputs = model(images) 11 outputs = outputs.to(device) ### 12 _, predicted = t.max(outputs.data, dim=1) 13 total += label.size(0) 14 15 correct += (predicted == label).sum().item() 16 print('Accurary on test set: %d %%' % (100 * correct / total))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。