赞
踩
对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:
{
"input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?",
"output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}
由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于IterableDataset的数据加载,代码实现如下:
import torch from torch.utils.data import IterableDataset class MyIterableDataset(IterableDataset): def __init__(self, data_file): self.data_file = data_file def __iter__(self): return iter(self._load_data()) def _load_data(self): with open(self.data_file, 'r') as file: for line in file: sample = process_line(line) yield sample def process_line(self, line): # Process the line to convert it to a sample ... return sample # Usage data_file = 'data.txt' dataset = MyIterableDataset(data_file) dataloader = torch.utils.data.DataLoader(dataset, batch_size=32) for batch in dataloader: # Train your model using the batch of data pass
在实际训练中还会遇到两个问题:
以上问题对策如下:
# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)
# Create a DistributedSampler
sampler = DistributedSampler(dataset)
# Create a DataLoader using the DistributedSampler
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)
for batch in dataloader:
# Train your model using the batch of data
pass
import torch import logger from torch.utils.data import IterableDataset class MyIterableDataset(IterableDataset): def __init__(self, data_file): self.data_file = data_file def __iter__(self): return iter(self._load_data()) def _load_data(self): with open(self.data_file, 'r') as file: for line in file: try: sample = process_line(line) yield sample except Exception as e: # Print the detailed error information logger.error(line) logger.error(e) pass def process_line(self, line): # Process the line to convert it to a sample ... return sample
如果使用的是普通的Dataset,则参考以下代码,在__getitem__里面加入报错逻辑:
class MyDataset(Dataset): def __init__(self, file_path): self.data = [] with open(file_path, 'r') as file: for line in file: self.data.append(line) def __len__(self): return len(self.data) def __getitem__(self, index): line = self.data[index] try: sample = self.process_line(line) return sample except Exception as e: # Print the detailed error information logger.error(line) logger.error(e) return self.__getitem__((index+1) % self.__len__()) def process_line(self, line): # Process the line to convert it to a sample ... return sample
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。