赞
踩
最近学习Ray框架进行分布式模型训练,Ray框架下的pytorch模型与普通的pytorch框架还是有一定区别,记录一下留做笔记。
这里没有用官网文档给的数据集,在上一篇写了如何加载自己的pytorch数据集,在定义训练模型时,在TorchTrainer中不指定数据集参数,在训练模型中直接加载自己的数据集,就可以实现训练自己的数据集
import torch import torch.nn as nn import torch.nn.functional as F from ray import train import torch.optim as optim from torch.utils.data import DataLoader from picread_data import generate_map, MyDatasets class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) # flatten all dimensions except batch x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def train_loop_per_worker(config): generate_map(config["put_in"], 2) # 数据加载 trainloader = DataLoader(MyDatasets('D:/tmp/photo', 'trainmap.txt'), batch_size=config["batch_size"], shuffle=True) testloader = DataLoader(MyDatasets('D:/tmp/photo', 'testmap.txt'), batch_size=config["batch_size"], shuffle=True) model = train.torch.prepare_model(Net()) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) for epoch in range(2): running_loss = 0.0 for i, data in enumerate(trainloader): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") running_loss = 0.0 train.save_checkpoint(model=model.module.state_dict()) from ray.ml.train.integrations.torch.torch_trainer import TorchTrainer trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, train_loop_config={"batch_size": 2, "put_in": "D:/tmp/photo"}, # datasets={"train": train_dataset}, scaling_config={"num_workers": 2} ) result = trainer.fit() latest_checkpoint = result.checkpoint
scaling_config={"num_workers": 2}这里可以指定训练资源
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。