赞
踩
Ray Train中包含4个部分
Training function: 包含训练模型逻辑的函数
Worker: 用来跑训练的
Trainer: 协调以上三个部分
这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。
- import os
- import tempfile
-
- import torch
- from torch.nn import CrossEntropyLoss
- from torch.optim import Adam
- from torch.utils.data import DataLoader
- from torchvision.models import resnet18
- from torchvision.datasets import FashionMNIST
- from torchvision.transforms import ToTensor, Normalize, Compose
-
- import ray.train.torch
-
- def train_func(config):
- # Model, Loss, Optimizer
- model = resnet18(num_classes=10)
- model.conv1 = torch.nn.Conv2d(
- 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
- )
- # model.to("cuda") # This is done by `prepare_model`
- # [1] Prepare model.
- model = ray.train.torch.prepare_model(model)
- criterion = CrossEntropyLoss()
- optimizer = Adam(model.parameters(), lr=0.001)
-
- # Data
- transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
- data_dir = os.path.join(tempfile.gettempdir(), "data")
- train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
- train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
- # [2] Prepare dataloader.
- train_loader = ray.train.torch.prepare_data_loader(train_loader)
-
- # Training
- for epoch in range(10):
- for images, labels in train_loader:
- # This is done by `prepare_data_loader`!
- # images, labels = images.to("cuda"), labels.to("cuda")
- outputs = model(images)
- loss = criterion(outputs, labels)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- # [3] Report metrics and checkpoint.
- metrics = {"loss": loss.item(), "epoch": epoch}
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- torch.save(
- model.module.state_dict(),
- os.path.join(temp_checkpoint_dir, "model.pt")
- )
- ray.train.report(
- metrics,
- checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
- )
- if ray.train.get_context().get_world_rank() == 0:
- print(metrics)
-
- # [4] Configure scaling and resource requirements.
- scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
-
- # [5] Launch distributed training job.
- trainer = ray.train.torch.TorchTrainer(
- train_func,
- scaling_config=scaling_config,
- # [5a] If running in a multi-node cluster, this is where you
- # should configure the run's persistent storage that is accessible
- # across all worker nodes.
- # run_config=ray.train.RunConfig(storage_path="s3://..."),
- )
- result = trainer.fit()
- # [6] Load the trained model.
- with result.checkpoint.as_directory() as checkpoint_dir:
- model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
- model = resnet18(num_classes=10)
- model.conv1 = torch.nn.Conv2d(
- 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
- )
- model.load_state_dict(model_state_dict)
ray.train.torch.prepare_model()
model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") + DistributedDataParallel(model, device_ids=[device_id])
将model移动到合适的device上,同时实现分布式
ray.train.torch.prepare_data_loader()
- +import ray.train
- +from ray.train import Checkpoint
-
- def train_func(config):
-
- ...
- torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
- + metrics = {"loss": loss.item()} # Training/validation metrics.
- + checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
- + ray.train.report(metrics=metrics, checkpoint=checkpoint)
-
- ...
data_loader = ray.train.torch.prepare_data_loader(data_loader)
将batches移动到合适的device上,同时实现分布式sampler
- from ray.train import ScalingConfig
- scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
多节点分布式训练时必须指定,本地路径会有问题。
- from ray.train import RunConfig
-
- # Local path (/some/local/path/unique_run_name)
- run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")
-
- # Shared cloud storage URI (s3://bucket/unique_run_name)
- run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")
-
- # Shared NFS path (/mnt/nfs/unique_run_name)
- run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")
- from ray.train.torch import TorchTrainer
-
- trainer = TorchTrainer(
- train_func, scaling_config=scaling_config, run_config=run_config
- )
- result = trainer.fit()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。