赞
踩
- # 参考资料:https://blog.csdn.net/qq_38101208/article/details/110481390
- #%%
- import torch
- import torch.nn as nn
-
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
-
- class CNN(nn.Module):
- def __init__(self):
- super(CNN,self).__init__()
- pass
-
- #网络
- model = CNN().to(device)
-
- #训练时(输入训练数据,标签)
- x = torch.arange(4)
- y = torch.arange(4)
-
- x,y = x.to(device),y.to(device)
-
- #预测时(输入训练数据)
- #输出结果如果需要用numpy 进行处理,需要讲结果载入到CPU上
- out = model(x).cpu().numpy()
-
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。