当前位置:   article > 正文

Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:

{
  "input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?",
  "output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}
  • 1
  • 2
  • 3
  • 4

由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于IterableDataset的数据加载,代码实现如下:

import torch
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                sample = process_line(line)
                yield sample

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

for batch in dataloader:
    # Train your model using the batch of data
    pass

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

在实际训练中还会遇到两个问题:

  1. 大模型一般需要使用多机多卡训练,需要避免多个进程中dataloader读取数据的竞争,并保证不同进程之间不会重复读取数据;
  2. 数据文件中某些行无法正确被解析,或者引用的外部资源找不到,导致process_line成员函数报错。数据集需要handle这类错误,防止因为报错中断训练。

以上问题对策如下:

  1. 在多机多卡的DDP训练中,可以使用DistributedSampler来处理多进程读数据的情形。DistributedSampler可以确保不同进程之间不会重复读取数据。具体的代码实现如下:
# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)

# Create a DistributedSampler
sampler = DistributedSampler(dataset)

# Create a DataLoader using the DistributedSampler
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)

for batch in dataloader:
    # Train your model using the batch of data
    pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  1. 可以在调用process_line的时候试图handle一个错误,如果出错就跳过这条数据,改为(试图)获取下一条数据。具体的代码实现如下:
import torch
import logger
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                try:
                    sample = process_line(line)
                    yield sample
                except Exception as e:
                    # Print the detailed error information
                    logger.error(line)
                    logger.error(e)
                    pass

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

如果使用的是普通的Dataset,则参考以下代码,在__getitem__里面加入报错逻辑:

class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as file:
            for line in file:
                self.data.append(line)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        line = self.data[index]
        try:
            sample = self.process_line(line)
            return sample
        except Exception as e:
            # Print the detailed error information
            logger.error(line)
            logger.error(e)
            return self.__getitem__((index+1) % self.__len__())

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/248670
推荐阅读
相关标签
  

闽ICP备14008679号