1: model = nn.DataParallel(model,device_ids=[0,1,2])model.to(device)只能指定GPU#指定某个GPUos.environ['CUDA__"pytorch 直接d">
赞
踩
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 单GPU或者CPU
model.to(device)
#如果是多GPU
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model,device_ids=[0,1,2])
model.to(device)
#指定某个GPU
os.environ['CUDA_VISIBLE_DEVICE']='1'
model.cuda()
#如果是多GPU
os.environment['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
device_ids = [0,1,2,3]
net = torch.nn.Dataparallel(net, device_ids =device_ids)
net = torch.nn.Dataparallel(net) # 默认使用所有的device_ids
net = net.cuda()
class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。