当前位置:   article > 正文

分布式执行引擎ray入门--(3)Ray Train

分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

  1. import os
  2. import tempfile
  3. import torch
  4. from torch.nn import CrossEntropyLoss
  5. from torch.optim import Adam
  6. from torch.utils.data import DataLoader
  7. from torchvision.models import resnet18
  8. from torchvision.datasets import FashionMNIST
  9. from torchvision.transforms import ToTensor, Normalize, Compose
  10. import ray.train.torch
  11. def train_func(config):
  12. # Model, Loss, Optimizer
  13. model = resnet18(num_classes=10)
  14. model.conv1 = torch.nn.Conv2d(
  15. 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
  16. )
  17. # model.to("cuda") # This is done by `prepare_model`
  18. # [1] Prepare model.
  19. model = ray.train.torch.prepare_model(model)
  20. criterion = CrossEntropyLoss()
  21. optimizer = Adam(model.parameters(), lr=0.001)
  22. # Data
  23. transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
  24. data_dir = os.path.join(tempfile.gettempdir(), "data")
  25. train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
  26. train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
  27. # [2] Prepare dataloader.
  28. train_loader = ray.train.torch.prepare_data_loader(train_loader)
  29. # Training
  30. for epoch in range(10):
  31. for images, labels in train_loader:
  32. # This is done by `prepare_data_loader`!
  33. # images, labels = images.to("cuda"), labels.to("cuda")
  34. outputs = model(images)
  35. loss = criterion(outputs, labels)
  36. optimizer.zero_grad()
  37. loss.backward()
  38. optimizer.step()
  39. # [3] Report metrics and checkpoint.
  40. metrics = {"loss": loss.item(), "epoch": epoch}
  41. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  42. torch.save(
  43. model.module.state_dict(),
  44. os.path.join(temp_checkpoint_dir, "model.pt")
  45. )
  46. ray.train.report(
  47. metrics,
  48. checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
  49. )
  50. if ray.train.get_context().get_world_rank() == 0:
  51. print(metrics)
  52. # [4] Configure scaling and resource requirements.
  53. scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)
  54. # [5] Launch distributed training job.
  55. trainer = ray.train.torch.TorchTrainer(
  56. train_func,
  57. scaling_config=scaling_config,
  58. # [5a] If running in a multi-node cluster, this is where you
  59. # should configure the run's persistent storage that is accessible
  60. # across all worker nodes.
  61. # run_config=ray.train.RunConfig(storage_path="s3://..."),
  62. )
  63. result = trainer.fit()
  64. # [6] Load the trained model.
  65. with result.checkpoint.as_directory() as checkpoint_dir:
  66. model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
  67. model = resnet18(num_classes=10)
  68. model.conv1 = torch.nn.Conv2d(
  69. 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
  70. )
  71. model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

  1. +import ray.train
  2. +from ray.train import Checkpoint
  3. def train_func(config):
  4. ...
  5. torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
  6. + metrics = {"loss": loss.item()} # Training/validation metrics.
  7. + checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
  8. + ray.train.report(metrics=metrics, checkpoint=checkpoint)
  9. ...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

  1. from ray.train import ScalingConfig
  2. scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

  1. from ray.train import RunConfig
  2. # Local path (/some/local/path/unique_run_name)
  3. run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")
  4. # Shared cloud storage URI (s3://bucket/unique_run_name)
  5. run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")
  6. # Shared NFS path (/mnt/nfs/unique_run_name)
  7. run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

  1. from ray.train.torch import TorchTrainer
  2. trainer = TorchTrainer(
  3. train_func, scaling_config=scaling_config, run_config=run_config
  4. )
  5. result = trainer.fit()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/233678
推荐阅读
相关标签
  

闽ICP备14008679号