当前位置:   article > 正文

CogVLM训练源码解读--数据处理_from utils.utils llama2_text_processor

from utils.utils llama2_text_processor


前言

本文是CogVLM是一个多模态大型模型,它能够处理文本、图像和其他类型的数据。在数据处理方面,CogVLM可以接收多种类型的输入数据,包括文本、图像、音频等。然而,很少有人对代码数据处理进行解读或者基本找不到。基于此,本文将结合源码给出CogVLM大模型数据处理内容,主要包含图像数据处理、文本tokenizer构建、文本加工与修改自己文本方法代码修改。总之,我将结合代码一步一步带领读者实现大模型数据处理源码内容。


一、数据主函数源码解读

CogVLM的数据处理包含2部分,一部分是图像数据处理,另一部分是文本数据处理。其源码位于finetune_cogvlm_demo.py文件,如下:

from utils.utils import llama2_tokenizer
    tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
    image_processor = get_image_processor(args.eva_args["image_size"][0])  # 获得图像加工函数,并附image_size参数
    text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数

    model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval)
    #        训练函数      参数         模型                    训练机制                                    处理数据函数-->给该函数添加部分参数<---参数为函数                               dataloader整合数据                 评估函数               
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

从上面可看出获得图像处理函数调用是通过get_image_processor,而获得文本处理函数是调用llama2_text_processor函数,然后通过training_main函数把图像处理与文本处理函数分别作为image_processor与text_processor参数传递。我将说明这2个函数如何处理成参数方法。

1、图像函数源码调用解读

图像处理实际是调用blip2_image_processor_func_with_inputs函数(后面我会详细解读),实现图像加工,最终作为image_processor函数参数。源码如下:

def get_image_processor(image_size):
    return partial(blip2_image_processor_func_with_inputs, BlipImageEvalProcessor(image_size))
  • 1
  • 2

2、文本函数源码调用解读

文本处理实际是通过huggingface的方式调用tokenizer等方法,在通过llama2_text_processor类对文本处理,我后面会解读,这里介绍通过这样方式作为text_processor函数参数。源码如下:

    from utils.utils import llama2_tokenizer
    tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
    text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数
  • 1
  • 2
  • 3

3、tokenizer生成函数

对文本的tokenizer处理,CogVLM调用huggingface的llama函数,使用from utils.utils import llama2_tokenizer调用,在源码中使用函数如下:

tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 
  • 1

而对于语言函数

from transformers import LlamaTokenizer
def llama2_tokenizer(tokenizer_path, signal_type="base"):
    tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = 32000
    tokenizer.boi = "[IMG]"
    tokenizer.eoi = "[/IMG]"
    assert signal_type in ["base", "chat", "vqa", "chat_old"]
    tokenizer.signal_type = signal_type
    return tokenizer
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

4、llama2_text_processor文本处理函数解读

该函数是文本处理相关内容,我将其代码罗列如下:

class llama2_text_processor:
    def __init__(self, tokenizer, max_target_length=2048, image_length=257, model=None):
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length
        self.image_length = image_length

    def __call__(self, caption, prompt=""):
        if '<EOI>' not in prompt:
            prompt = self.replace_tags_with_empty(prompt)
            # caption = self.replace_tags_with_empty(caption)
            history = []
            prompt = self.history_to_prompt(prompt, history)

        input_ids = [self.tokenizer.bos_token_id]

        prompt_splits = prompt.split('<EOI>')
        caption_splits = caption.split('<EOI>')
        if len(prompt_splits) > 0:
            input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False))
        for tokens in prompt_splits[1:]:
            tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
            input_ids.extend(tokens_with_img)
        context_length = len(input_ids) + (len(prompt_splits)-1) * (self.image_length + 1)
        if context_length > self.max_target_length - 10:
            return None
        if len(caption_splits) > 0:
            input_ids.extend(self.tokenizer.encode(caption_splits[0], add_special_tokens=False))
        for tokens in caption_splits[1:]:
            tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
            input_ids.extend(tokens_with_img)

        if len(input_ids) > self.max_target_length - self.image_length - 5:
            input_ids = input_ids[:self.max_target_length - self.image_length - 5]

        input_ids += [self.tokenizer.eos_token_id]

        while -100 in input_ids:
            img_idx = input_ids.index(-100)
            input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx+1:]

        image_position = []
        while -1 in input_ids:
            img_idx = input_ids.index(-1)
            input_ids[img_idx] = 0
            image_position.append(img_idx)

        image_embed_mask = [0] * len(input_ids)
        vision_expert_mask = [0] * len(input_ids)
        image_rope_mask = [0] * len(input_ids)
        for idx in image_position:
            image_embed_mask[idx-self.image_length-1: idx+1] = [1] * (self.image_length + 2)
            vision_expert_mask[idx-self.image_length-1: idx] = [1] * (self.image_length + 1)
            image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length
        attention_mask = [1] * len(input_ids)
        labels = [-100] * context_length + input_ids[context_length:]

        pad_len = self.max_target_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        attention_mask = attention_mask + [1] * pad_len
        vision_expert_mask = vision_expert_mask + [0] * pad_len
        image_embed_mask = image_embed_mask + [0] * pad_len
        image_rope_mask = image_rope_mask + [0] * pad_len
        np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0))
        labels = labels + [-100] * pad_len

        for idx in image_position:
            labels[idx-self.image_length-1: idx+1] = [-100] * (self.image_length + 2)

        position_ids = []
        pid = -1
        for i in range(len(input_ids)):
            if image_rope_mask[i] == 0 or (i > 0 and image_rope_mask[i] != image_rope_mask[i - 1]):
                pid += 1
            position_ids.append(pid)

        input_ids = torch.tensor(input_ids).unsqueeze(0)
        labels = torch.tensor(labels).unsqueeze(0)
        attention_mask = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0)
        image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0)
        vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0)
        image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0)
        position_ids = torch.tensor(position_ids).unsqueeze(0)
        context_length = torch.tensor(context_length).unsqueeze(0).long()
        return {'input_ids': input_ids, 'labels': labels, 'position_ids': position_ids, 'attention_mask': attention_mask, 'image_embed_mask': image_embed_mask,
                'context_length': context_length, 'image_position': image_position, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask
                }

    def history_to_prompt(self, query, history):
        return _history_to_prompt[self.tokenizer.signal_type](self, query, history)

    def replace_tags_with_empty(self, text):
        return re.sub('<pad>|<s>|</s>|<EOI>', '', text)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

二、create_dataset_function函数源码代码解读

在使用model = training_main函数create_dataset_function=partial(create_dataset_function, image_processor, text_processor),借助python自带partial函数实现数据类处理,其源码如下:


from utils.utils import ItemDataset
def create_dataset_function(image_processor, text_processor, path, args):
    dataset = ItemDataset(image_processor, text_processor, args, path)
    return dataset
  • 1
  • 2
  • 3
  • 4
  • 5

我们可发现该函数调用了ItemDataset类,该类恰好是数据加工的迭代器类,类似taorch的dataset等功能,实际也是继承torch的dataset类,进一步封装加工数据。其中image_processor, text_processor是上面提到图像加工与文本加工方法代码。

三、sat库之make_loaders函数源码解读

在上面,我们已给出training_main函数,该函数集成数据处理、模型训练等方法,其中hooks保存不同方法,而数据处理代码为make_loaders。

1、make_loaders函数调用说明

make_loaders为模型数据加载方法,又是sat库集成方法,其源码如下:

# Data stuff. 数据处理方法适用
    train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn) # 通过该函数调用
  • 1
  • 2

很明显,make_loaders参数为需要参数、处理数据类(集成dataset)、collate_fn函数(类似dataloader处理batch数据方式)。

源码位置:sat.data_utils.configure_data.py文件

2、make_loaders函数源码解读

我将maker_loaders源码分为三个部分,第一个部分是使用partial函数将传递数据类create_dataset_function(实际是ItemDataset类)赋参数并重命名函数为make_dataset;第二部分是否传训练、验证、测试路径,使用ItemDataset类处理数据,类似torch的dataset结构;第三部分也是调用sat库的make_data_loader包装对应数据,类似torch的dataloader结构,且调用自己传递的collate_fn方法。

第一部分源码解读

将我们对数据处理create_dataset_function类(实际是ItemDataset类)使用sat库的make_dataset_full函数包装,源码如下:

    make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)
  • 1

第二部分源码解读

根据传递路径参数,使用上面make_dataset方法处理数据,类似torch的dataset,源码如下:

# make datasets splits and tokenizer
    train = None
    valid = None
    test = None

    if args.train_data is not None:
        train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
        if should_split(split):
            train, valid, test = train

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

第三部分源码解读

将我们处理的datset数据进行dataloader包装,类似torch的dataloader,且调用自己写的collate_fn方法,源码如下:

# wrap datasets with data loader
    if train is not None and args.batch_size > 0:
        train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
        args.do_test = True
    else:
        args.do_test = False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

make_loaders源码展示

另外,该函数也对相应参数做处理,如world_size等内容,所有源码如下:

def make_loaders(args, create_dataset_function, collate_fn=None):
    """makes training/val/test
    Args:
        args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
        args.split: str. format: "8,1,1". how to split train_data.
        args.dataset_type: use to create the right datasets. 
    """
    make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)

    world_size = torch.distributed.get_world_size(   group=mpu.get_data_parallel_group())
    batch_size = args.batch_size * world_size
    eval_batch_size = batch_size
    if args.eval_batch_size is not None:
        eval_batch_size = args.eval_batch_size * world_size
    
    split = get_split(args)

    data_set_args = {
        'path': args.train_data,
        'split': split,
    }

    eval_set_args = copy.copy(data_set_args)
    eval_set_args['split'] = [1.]
    
    # make datasets splits and tokenizer
    train = None
    valid = None
    test = None

    if args.train_data is not None:
        train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
        if should_split(split):
            train, valid, test = train

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)

    # wrap datasets with data loader
    if train is not None and args.batch_size > 0:
        train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
        args.do_test = True
    else:
        args.do_test = False

    return train, valid, test

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

四、sat库之make_dataset_full函数源码解读

1、参数配置

我是在vscode运行,我对训练数据参数配置作为列子,可配置三种方式,假设CogVLM-SFT-311K文件夹有2个文件,分别为llava_instruction_multi_conversations_formate与llava_instruction_single_conversation_formate文件,可单独给子文件路径、也可给子文件上一个文件路径、也可给多个文件列表路径,具体如下:

# 第一种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate",
# 第二种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K",
# 第三种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_single_conversation_formate" ,           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2、数据格式说明

单个数据文件夹内容如下:
在这里插入图片描述

json文件内容如下:
在这里插入图片描述

3、make_dataset_full数据读取源码解读

make_loaders函数中的make_dataset类调用是被make_dataset_full函数包装,主要处理一些逻辑,使其进入数据类为统一格式,其源码如下:

ds = []
for p in path:
    d = create_dataset_function(p, args)
    ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)
  • 1
  • 2
  • 3
  • 4
  • 5

以上,可看到path类似给定参数,然后使用os.walk遍历所有.jpg格式数据,并给成绝对路径,且将所有数据cat为一个列表。

4、make_dataset_full源码展示

def make_dataset_full(path, split, args, create_dataset_function, 
        dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
    """function to create datasets+tokenizers for common options"""
    print_all('make dataset ' + str(path), level='DEBUG')
    assert isinstance(path, list)

    if args.iterable_dataset: # cannot indexed
        # the random mapping is flexible and efficient, but sometimes we have pratical issue
        # For instance, someone just gives you a iterable dataset, e.g. webdataset
        from .webds import ConfiguredResampledShards, DataPipeline
        valid_types = (ConfiguredResampledShards, DataPipeline)
        
        assert split[0] == 1, 'Iterable dataset cannot auto split.'
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            assert isinstance(d, valid_types)
            ds.append(d)
        # ds = ChainDataset(ds) # please merge them in a url if chain
        if batch_from_same_dataset:
            assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
        ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
        return ds

    if split is None:
        split = [1.] 
    if not should_split(split):
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            ds.append(d)
        ds = ConcatDataset(ds, weights=dataset_weights)
        if random_mapping:
            if args.epochs is not None: # not auto-scale, but use a given number of epoches.
                ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
            else:
                world_size = torch.distributed.get_world_size(
                    group=mpu.get_data_parallel_group())
                if is_train_data:
                # only train-dataset will set this to True,
                # so we enlarge it to make sure that the data is sufficient.
                    scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                else:
                    scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                ds = RandomMappingDataset(ds, scale=scale)
        return ds   [-1, 9, C3, [512]],
    else:
        # must first split datasets, then reweight/concat, finally random-mapping.
        # this order avoids overlapping.
        train_ds, valid_ds, test_ds = [], [], []
        for p in path:
            d = create_dataset_function(p, args)
            if should_split(split):
                dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
                train_ds.append(dtrain)
                valid_ds.append(dvalid)
                test_ds.append(dtest)
        train_ds = ConcatDataset(train_ds, weights=dataset_weights)
        valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
        test_ds = ConcatDataset(test_ds, weights=dataset_weights)
        if random_mapping:
            world_size = torch.distributed.get_world_size(
                group=mpu.get_data_parallel_group())
            scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
            train_ds = RandomMappingDataset(train_ds, scale=scale)
            scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
            valid_ds = RandomMappingDataset(valid_ds, scale=scale)
            test_ds = RandomMappingDataset(test_ds)
        return train_ds, valid_ds, test_ds
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

五、sat库之make_data_loader函数源码解读

在三说过类似torch的dataloader方法,在这里我大致说下,这里loader主要对dataset包装,使用顺序采用,配置类似的world_size与相应环境等内容,这也是库本身包装好的,可直接使用,其源码如下:

def make_data_loader(dataset, batch_size, args, split, collate_fn=None):

    world_size = torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
    rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
    distributed = world_size > 1

    # if IterableDataset, assume everything is properly configured. (pre-sharded) 
    if isinstance(dataset, IterableDataset):
        if split in ['val', 'test'] and args.strict_eval:
            raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
        args.val_last_shape = [1] * world_size # just fake it, not actually used
        args.val_drop_number = 0
        args.test_last_shape = [1] * world_size
        args.test_drop_number = 0
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size//world_size,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
            prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
            )

    sampler = torch.utils.data.SequentialSampler(dataset)  # 顺序采样

    drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.

    # the GPUs in the same model parallel group receive the same data
    if distributed: # TODO reformat this, but it is not urgent
        gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
        batch_sampler = DistributedBatchSampler(sampler,
                                                batch_size,
                                                drop_last,
                                                rank,
                                                world_size,
                                                gradient_accumulation_steps=gradient_accumulation_steps)
    else:
        batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                      batch_size,
                                                      drop_last)
    last_len = len(dataset) % batch_size
    batch_per_worker = batch_size // world_size
    last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
    if last_len != 0:
        if last_len % batch_per_worker != 0:
            last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
        drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
        # other processes get nothing, but append 1 for running. will drop later according to drop_number.
        for j in range(drop_number): 
            last_shape.append(1)
    else:
        drop_number = 0
    if split=='val':
        args.val_last_shape = last_shape
        args.val_drop_number = drop_number
    elif split=='test':
        args.test_last_shape = last_shape
        args.test_drop_number = drop_number
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_sampler=batch_sampler,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn,
                                              prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
                                              )
    return data_loader


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

六、ItemDataset数据类源码解读

在这里,终于进入最核心部分,数据加工

class ItemDataset(Dataset):
    def __init__(self, image_processor, text_processor, args, data_dirs, cross_image_processor=None, **kwargs):
        super().__init__()
        self.data = self.load_data(data_dirs)  # 获得.jpg图片绝对路径的列表,保存到self.data变量中
        self.image_processor, self.text_processor, self.cross_image_processor = image_processor, text_processor, cross_image_processor # 传递的数据加工函数

    def process_img(self, img):
        img_dict = {'vision': self.image_processor(img)}
        if self.cross_image_processor:
            img_dict.update({'cross': self.cross_image_processor(img)})
        return img_dict
    
    def process_text(self, answer, prompt):
        return self.text_processor(answer, prompt)
    
    def load_data(self, data_dir):
        all_files = find_all_files(data_dir, suffix=".jpg")
        print_rank0(f"find {len(all_files)} samples in all...")
        return all_files
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]  # 获得图片的绝对路径
        # img
        try:
            img = Image.open(data).convert('RGB')  # 载入图片
        except Exception as e:
            print_rank0(e, level=logging.WARNING)
            return {}
        img_dict = self.process_img(img)  # 图像加工
        # text
        label = data.split('/')[-1].split('.')[0]
        uni_key = label
        text_dict = self.process_text(label, "CAPTCHA:")
        if text_dict is None:
            print_rank0(f"Process text failed. Please check the max_target_length & max_source_length.\n The data is {data}", level=logging.WARNING)
            return {}
        # other attr
        ret = {**img_dict, **text_dict, "question_id": uni_key}
        return ret

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

1、图像数据处理

1、process_img(self, img)

图像处理最终为一个img_dict的字典,包含2部分处理,其源码如下:

    def process_img(self, img):
        img_dict = {'vision': self.image_processor(img)}
        if self.cross_image_processor:
            img_dict.update({'cross': self.cross_image_processor(img)})
        return img_dict
  • 1
  • 2
  • 3
  • 4
  • 5

2、self.image_processor(img)

image_processor函数传递为blip2_image_processor_func_with_inputs函数

def blip2_image_processor_func_with_inputs(image_processor, image):
    return {'image': image_processor(image).unsqueeze(0), 'input_ids': torch.zeros(1, 1, dtype=torch.long), 'position_ids': None, 'attention_mask': torch.ones(1, 1, dtype=torch.long)}

  • 1
  • 2
  • 3

2、文本数据处理

1、原始图像验证码文本数据处理

原始验证码数据是一个图片,其名字为验证码数字命名,在源码data表示图像绝对路径(/home/*/0a4Ovs8789.jpg)。因此,使用以下方式label = data.split(‘/’)[-1].split(‘.’)[0]即可获得文本(0a4Ovs8789),随后使用self.process_text函数即可实现验证码文本数据。

label = data.split('/')[-1].split('.')[0]  
uni_key = label
text_dict = self.process_text(label, "CAPTCHA:")

  • 1
  • 2
  • 3
  • 4

验证码数据图如下:
在这里插入图片描述

2、使用自己文本数据代码修改

假如一个图片对应一个数据json文件,其内容如下:

{
  "conversations": [
    {
      "role": "assistant",
      "content": "虽然无法从照片中确定他们的确切目的地或目的地,但这两名滑板运动员很可能正在使用公共交通工具前往滑板公园、休闲场所或其他可以练习滑板的地方。他们也可以在空闲时间携带滑板进行娱乐活动,往返于学校、工作或其他日常活动。或者,他们可能只是带着滑板在不同地点之间旅行,作为首选的交通方式。"
    }
  ]
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

代码内容是获取json文件相应内容:

        json_path = data.replace('images','labels_zh')[:-4]+'.json'
        label=self.read_json(json_path)
        label = label['captions'][0]['content']  # 获取描述内容
        uni_key = label
        text_dict = self.process_text(label, "CAPTCHA:")
  • 1
  • 2
  • 3
  • 4
  • 5

读取json辅助代码如下:

    def read_json(self,json_root):
        import json
        with open(json_root, encoding='utf-8') as f:
            json_info = json.load(f)
        return json_info
  • 1
  • 2
  • 3
  • 4
  • 5

最后文本将使用以下方式加工文本,我后期有时间在具体解读。

text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数

  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/354816
推荐阅读
相关标签
  

闽ICP备14008679号

        
cppcmd=keepalive&