当前位置:   article > 正文

pytorch中如何将CPU上运行的数据模型转到GPU上运行(mnist举例)_cpu训练的模型能在gpu上跑吗

cpu训练的模型能在gpu上跑吗

要在GPU上运行数据需要把一些相关的参数和模型转到GPU上

需要转换的有:model,数据,criterion,loss函数; 其中optimizer不需要转换
以下需要转换的部分均用**###标记**
首先定义

// An highlighted block
1 device = t.device('cuda:0')
2 model = model.to(device)
3 criterion = criterion.to(device)
  • 1
  • 2
  • 3
  • 4

训练部分

 1 def train(epoch):
 2     running_loss = 0.0
 3     for batch_idx, data in enumerate(train_loader, 0):
 4         inputs, target = data
 5         #cuda inputs and target
 6         inputs = inputs.to(device) ###
 7         target = target.to(device) ###
 8         optimizer.zero_grad() 
 9 
10         # forward + backward +update
11         outputs = model(inputs)
12         outputs = outputs.to(device) ###
13         loss = criterion(outputs, target)
14         loss.backward()
15         optimizer.step()
16 
17         running_loss += loss.item()
18         if batch_idx % 300 == 299:
19             print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
20             running_loss = 0.0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

测试部分

复制代码
 1 def test():
 2     correct = 0
 3     total = 0
 4     with t.no_grad():  # ensuring grad can not updating
 5         for data in test_loader:
 6             images, label = data
 7             #cuda images, label
 8             images = images.to(device) ###
 9             label = label.to(device) ###
10             outputs = model(images)
11             outputs = outputs.to(device) ###
12             _, predicted = t.max(outputs.data, dim=1) 
13             total += label.size(0)
14      
15             correct += (predicted == label).sum().item()
16     print('Accurary on test set: %d %%' % (100 * correct / total))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/851423
推荐阅读
相关标签
  

闽ICP备14008679号