赞
踩
目录
- import torch
- from torch.nn.parallel import DistributedDataParallel as DDP
- from torch.utils.data import DataLoader, RandomSampler
- from accelerate import Accelerator
-
-
- # 模拟数据集
- class RandomDataset(torch.utils.data.Dataset):
- def __init__(self, size=100):
- self.data = torch.randn(size, 3)
-
- def __getitem__(self, index):
- re
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。