当前位置:   article > 正文

哪些变量可放到GPU上-to(device)_什么样的变量可以放到gpu中

什么样的变量可以放到gpu中
  1. # 参考资料:https://blog.csdn.net/qq_38101208/article/details/110481390
  2. #%%
  3. import torch
  4. import torch.nn as nn
  5. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  6. class CNN(nn.Module):
  7. def __init__(self):
  8. super(CNN,self).__init__()
  9. pass
  10. #网络
  11. model = CNN().to(device)
  12. #训练时(输入训练数据,标签)
  13. x = torch.arange(4)
  14. y = torch.arange(4)
  15. x,y = x.to(device),y.to(device)
  16. #预测时(输入训练数据)
  17. #输出结果如果需要用numpy 进行处理,需要讲结果载入到CPU上
  18. out = model(x).cpu().numpy()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/715993
推荐阅读
相关标签
  

闽ICP备14008679号