当前位置:   article > 正文

【InternLM 实战营第二期笔记】XTuner 微调个人小助手认知

【InternLM 实战营第二期笔记】XTuner 微调个人小助手认知

XTuner运行原理

实践部分

环境安装

拉取环境

studio-conda xtuner0.1.17

激活环境

conda activate xtuner0.1.17

创建目录

  1. # 进入家目录 (~的意思是 “当前用户的home路径”)
  2. cd ~
  3. # 创建版本文件夹并进入,以跟随本教程
  4. mkdir -p /root/xtuner0117 && cd /root/xtuner0117

拉取代码

git clone -b v0.1.17  https://github.com/InternLM/xtuner

源码安装

  1. # 进入源码目录
  2. cd /root/xtuner0117/xtuner
  3. # 从源码安装 XTuner
  4. pip install -e '.[all]'

前期准备

数据集准备

  1. # 前半部分是创建一个文件夹,后半部分是进入该文件夹。
  2. mkdir -p /root/ft && cd /root/ft
  3. # 在ft这个文件夹里再创建一个存放数据的data文件夹
  4. mkdir -p /root/ft/data && cd /root/ft/data

新建一个generate_data.py文件

touch /root/ft/data/generate_data.py

脚本内容:

  1. import json
  2. # 设置用户的名字
  3. name = 'JeffDing菜鸟' #将对应的name进行修改
  4. # 设置需要重复添加的数据次数
  5. n = 10000
  6. # 初始化OpenAI格式的数据结构
  7. data = [
  8. {
  9. "messages": [
  10. {
  11. "role": "user",
  12. "content": "请做一下自我介绍"
  13. },
  14. {
  15. "role": "assistant",
  16. "content": "我是{}的小助手,内在是上海AI实验室书生·浦语的1.8B大模型哦".format(name)
  17. }
  18. ]
  19. }
  20. ]
  21. # 通过循环,将初始化的对话数据重复添加到data列表中
  22. for i in range(n):
  23. data.append(data[0])
  24. # 将data列表中的数据写入到一个名为'personal_assistant.json'的文件中
  25. with open('personal_assistant.json', 'w', encoding='utf-8') as f:
  26. # 使用json.dump方法将数据以JSON格式写入文件
  27. # ensure_ascii=False 确保中文字符正常显示
  28. # indent=4 使得文件内容格式化,便于阅读
  29. json.dump(data, f, ensure_ascii=False, indent=4)

运行generate_data.py文件

  1. # 确保先进入该文件夹
  2. cd /root/ft/data
  3. # 运行代码
  4. python /root/ft/data/generate_data.py

模型准备

  1. # 创建目标文件夹,确保它存在。
  2. # -p选项意味着如果上级目录不存在也会一并创建,且如果目标文件夹已存在则不会报错。
  3. mkdir -p /root/ft/model
  4. # 创建符号链接
  5. ln -s /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b /root/ft/model

配置文件选择

  1. # 列出所有内置配置文件
  2. # xtuner list-cfg
  3. # 假如我们想找到 internlm2-1.8b 模型里支持的配置文件
  4. xtuner list-cfg -p internlm2_1_8b

运行结果:

  1. ==========================CONFIGS===========================
  2. PATTERN: internlm2_1_8b
  3. -------------------------------
  4. internlm2_1_8b_full_alpaca_e3
  5. internlm2_1_8b_qlora_alpaca_e3
  6. =============================================================

配置文件名的解释

虽然我们用的数据集并不是 alpaca 而是我们自己通过脚本制作的小助手数据集 ,但是由于我们是通过 QLoRA 的方式对 internlm-chat-1.8b 进行微调。而最相近的配置文件应该就是 internlm2_1_8b_qlora_alpaca_e3 ,因此我们可以选择拷贝这个配置文件到当前目录:

  1. # 创建一个存放 config 文件的文件夹
  2. mkdir -p /root/ft/config
  3. # 使用 XTuner 中的 copy-cfg 功能将 config 文件复制到指定的位置
  4. xtuner copy-cfg internlm2_1_8b_qlora_alpaca_e3 /root/ft/config

配置文件修改

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from datasets import load_dataset
  4. from mmengine.dataset import DefaultSampler
  5. from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
  6. LoggerHook, ParamSchedulerHook)
  7. from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
  8. from peft import LoraConfig
  9. from torch.optim import AdamW
  10. from transformers import (AutoModelForCausalLM, AutoTokenizer,
  11. BitsAndBytesConfig)
  12. from xtuner.dataset import process_hf_dataset
  13. from xtuner.dataset.collate_fns import default_collate_fn
  14. from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory
  15. from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
  16. VarlenAttnArgsToMessageHubHook)
  17. from xtuner.engine.runner import TrainLoop
  18. from xtuner.model import SupervisedFinetune
  19. from xtuner.parallel.sequence import SequenceParallelSampler
  20. from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
  21. #######################################################################
  22. # PART 1 Settings #
  23. #######################################################################
  24. # Model
  25. pretrained_model_name_or_path = '/root/ft/model/internlm2-chat-1_8b'
  26. use_varlen_attn = False
  27. # Data
  28. alpaca_en_path = '/root/ft/data/personal_assistant.json'
  29. prompt_template = PROMPT_TEMPLATE.default
  30. max_length = 1024
  31. pack_to_max_length = True
  32. # parallel
  33. sequence_parallel_size = 1
  34. # Scheduler & Optimizer
  35. batch_size = 1 # per_device
  36. accumulative_counts = 16
  37. accumulative_counts *= sequence_parallel_size
  38. dataloader_num_workers = 0
  39. max_epochs = 2
  40. optim_type = AdamW
  41. lr = 2e-4
  42. betas = (0.9, 0.999)
  43. weight_decay = 0
  44. max_norm = 1 # grad clip
  45. warmup_ratio = 0.03
  46. # Save
  47. save_steps = 300
  48. save_total_limit = 3 # Maximum checkpoints to keep (-1 means unlimited)
  49. # Evaluate the generation performance during the training
  50. evaluation_freq = 300
  51. SYSTEM = ''
  52. evaluation_inputs = ['请你介绍一下你自己', '你是谁', '你是我的小助手吗']
  53. #######################################################################
  54. # PART 2 Model & Tokenizer #
  55. #######################################################################
  56. tokenizer = dict(
  57. type=AutoTokenizer.from_pretrained,
  58. pretrained_model_name_or_path=pretrained_model_name_or_path,
  59. trust_remote_code=True,
  60. padding_side='right')
  61. model = dict(
  62. type=SupervisedFinetune,
  63. use_varlen_attn=use_varlen_attn,
  64. llm=dict(
  65. type=AutoModelForCausalLM.from_pretrained,
  66. pretrained_model_name_or_path=pretrained_model_name_or_path,
  67. trust_remote_code=True,
  68. torch_dtype=torch.float16,
  69. quantization_config=dict(
  70. type=BitsAndBytesConfig,
  71. load_in_4bit=True,
  72. load_in_8bit=False,
  73. llm_int8_threshold=6.0,
  74. llm_int8_has_fp16_weight=False,
  75. bnb_4bit_compute_dtype=torch.float16,
  76. bnb_4bit_use_double_quant=True,
  77. bnb_4bit_quant_type='nf4')),
  78. lora=dict(
  79. type=LoraConfig,
  80. r=64,
  81. lora_alpha=16,
  82. lora_dropout=0.1,
  83. bias='none',
  84. task_type='CAUSAL_LM'))
  85. #######################################################################
  86. # PART 3 Dataset & Dataloader #
  87. #######################################################################
  88. alpaca_en = dict(
  89. type=process_hf_dataset,
  90. dataset=dict(type=load_dataset, path='json', data_files=dict(train=alpaca_en_path)),
  91. tokenizer=tokenizer,
  92. max_length=max_length,
  93. dataset_map_fn=openai_map_fn,
  94. template_map_fn=dict(
  95. type=template_map_fn_factory, template=prompt_template),
  96. remove_unused_columns=True,
  97. shuffle_before_pack=True,
  98. pack_to_max_length=pack_to_max_length,
  99. use_varlen_attn=use_varlen_attn)
  100. sampler = SequenceParallelSampler \
  101. if sequence_parallel_size > 1 else DefaultSampler
  102. train_dataloader = dict(
  103. batch_size=batch_size,
  104. num_workers=dataloader_num_workers,
  105. dataset=alpaca_en,
  106. sampler=dict(type=sampler, shuffle=True),
  107. collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
  108. #######################################################################
  109. # PART 4 Scheduler & Optimizer #
  110. #######################################################################
  111. # optimizer
  112. optim_wrapper = dict(
  113. type=AmpOptimWrapper,
  114. optimizer=dict(
  115. type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
  116. clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
  117. accumulative_counts=accumulative_counts,
  118. loss_scale='dynamic',
  119. dtype='float16')
  120. # learning policy
  121. # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
  122. param_scheduler = [
  123. dict(
  124. type=LinearLR,
  125. start_factor=1e-5,
  126. by_epoch=True,
  127. begin=0,
  128. end=warmup_ratio * max_epochs,
  129. convert_to_iter_based=True),
  130. dict(
  131. type=CosineAnnealingLR,
  132. eta_min=0.0,
  133. by_epoch=True,
  134. begin=warmup_ratio * max_epochs,
  135. end=max_epochs,
  136. convert_to_iter_based=True)
  137. ]
  138. # train, val, test setting
  139. train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
  140. #######################################################################
  141. # PART 5 Runtime #
  142. #######################################################################
  143. # Log the dialogue periodically during the training process, optional
  144. custom_hooks = [
  145. dict(type=DatasetInfoHook, tokenizer=tokenizer),
  146. dict(
  147. type=EvaluateChatHook,
  148. tokenizer=tokenizer,
  149. every_n_iters=evaluation_freq,
  150. evaluation_inputs=evaluation_inputs,
  151. system=SYSTEM,
  152. prompt_template=prompt_template)
  153. ]
  154. if use_varlen_attn:
  155. custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
  156. # configure default hooks
  157. default_hooks = dict(
  158. # record the time of every iteration.
  159. timer=dict(type=IterTimerHook),
  160. # print log every 10 iterations.
  161. logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
  162. # enable the parameter scheduler.
  163. param_scheduler=dict(type=ParamSchedulerHook),
  164. # save checkpoint per `save_steps`.
  165. checkpoint=dict(
  166. type=CheckpointHook,
  167. by_epoch=False,
  168. interval=save_steps,
  169. max_keep_ckpts=save_total_limit),
  170. # set sampler seed in distributed evrionment.
  171. sampler_seed=dict(type=DistSamplerSeedHook),
  172. )
  173. # configure environment
  174. env_cfg = dict(
  175. # whether to enable cudnn benchmark
  176. cudnn_benchmark=False,
  177. # set multi process parameters
  178. mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
  179. # set distributed parameters
  180. dist_cfg=dict(backend='nccl'),
  181. )
  182. # set visualizer
  183. visualizer = None
  184. # set log level
  185. log_level = 'INFO'
  186. # load from which checkpoint
  187. load_from = None
  188. # whether to resume training from the loaded checkpoint
  189. resume = False
  190. # Defaults to use random seed and disable `deterministic`
  191. randomness = dict(seed=None, deterministic=False)
  192. # set log processor
  193. log_processor = dict(by_epoch=False)

主要修改内容

from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory

引入openai_map_fn包

alpaca_en_path:数据集路径

dataset_map_fn:数据集map

pretrained_model_name_or_path:模型路径

模型训练

普通训练

xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train

deepspeed加速训练

  1. # 使用 deepspeed 来加速训练
  2. xtuner train /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py --work-dir /root/ft/train_deepspeed --deepspeed deepspeed_zero2

模型转换、整合、测试及部署

模型转换

  1. # 创建一个保存转换后 Huggingface 格式的文件夹
  2. mkdir -p /root/ft/huggingface
  3. # 模型转换
  4. # xtuner convert pth_to_hf ${配置文件地址} ${权重文件地址} ${转换后模型保存地址}
  5. xtuner convert pth_to_hf /root/ft/config/internlm2_1_8b_qlora_alpaca_e3_copy.py /root/ft/train_deepspeed/iter_768.pth /root/ft/huggingface

xtuner convert参数:

参数名解释
--fp32代表以fp32的精度开启,假如不输入则默认为fp16
--max-shard-size {GB}代表每个权重文件最大的大小(默认为2GB)

假如有特定的需要,我们可以在上面的转换指令后进行添加。由于本次测试的模型文件较小,并且已经验证过拟合,故没有添加。假如加上的话应该是这样的:

xtuner convert pth_to_hf /root/ft/train/internlm2_1_8b_qlora_alpaca_e3_copy.py /root/ft/train/iter_768.pth /root/ft/huggingface --fp32 --max-shard-size 2GB

模型整合

  1. # 创建一个名为 final_model 的文件夹存储整合后的模型文件
  2. mkdir -p /root/ft/final_model
  3. # 解决一下线程冲突的 Bug
  4. export MKL_SERVICE_FORCE_INTEL=1
  5. # 进行模型整合
  6. # xtuner convert merge ${NAME_OR_PATH_TO_LLM} ${NAME_OR_PATH_TO_ADAPTER} ${SAVE_PATH}
  7. xtuner convert merge /root/ft/model /root/ft/huggingface /root/ft/final_model

对话测试

  1. # 与模型进行对话
  2. xtuner chat /root/ft/final_model --prompt-template internlm2_chat

xtuner chat 参数列表:

启动参数解释
--system指定SYSTEM文本,用于在对话中插入特定的系统级信息
--system-template指定SYSTEM模板,用于自定义系统信息的模板
--bits指定LLM运行时使用的位数,决定了处理数据时的精度
--bot-name设置bot的名称,用于在对话或其他交互中识别bot
--with-plugins指定在运行时要使用的插件列表,用于扩展或增强功能
--no-streamer关闭流式传输模式,对于需要一次性处理全部数据的场景
--lagent启用lagent,用于特定的运行时环境或优化
--command-stop-word设置命令的停止词,当遇到这些词时停止解析命令
--answer-stop-word设置回答的停止词,当生成回答时遇到这些词则停止
--offload-folder指定存放模型权重的文件夹,用于加载或卸载模型权重
--max-new-tokens设置生成文本时允许的最大token数量,控制输出长度
--temperature设置生成文本的温度值,较高的值会使生成的文本更多样,较低的值会使文本更确定
--top-k设置保留用于顶k筛选的最高概率词汇标记数,影响生成文本的多样性
--top-p设置累计概率阈值,仅保留概率累加高于top-p的最小标记集,影响生成文本的连贯性
--seed设置随机种子,用于生成可重现的文本内容

除了这些参数以外其实还有一个非常重要的参数就是 --adapter ,这个参数主要的作用就是可以在转化后的 adapter 层与原模型整合之前来对该层进行测试。使用这个额外的参数对话的模型和整合后的模型几乎没有什么太多的区别,因此我们可以通过测试不同的权重文件生成的 adapter 来找到最优的 adapter 进行最终的模型整合工作。

Web demo 部署

安装依赖

pip install streamlit==1.24.0

下载项目代码

  1. # 创建存放 InternLM 文件的代码
  2. mkdir -p /root/ft/web_demo && cd /root/ft/web_demo
  3. # 拉取 InternLM 源文件
  4. git clone https://github.com/InternLM/InternLM.git
  5. # 进入该库中
  6. cd /root/ft/web_demo/InternLM

/root/ft/web_demo/InternLM/chat/web_demo.py中的内容替换为以下的代码

  1. """This script refers to the dialogue example of streamlit, the interactive
  2. generation code of chatglm2 and transformers.
  3. We mainly modified part of the code logic to adapt to the
  4. generation of our model.
  5. Please refer to these links below for more information:
  6. 1. streamlit chat example:
  7. https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
  8. 2. chatglm2:
  9. https://github.com/THUDM/ChatGLM2-6B
  10. 3. transformers:
  11. https://github.com/huggingface/transformers
  12. Please run with the command `streamlit run path/to/web_demo.py
  13. --server.address=0.0.0.0 --server.port 7860`.
  14. Using `python path/to/web_demo.py` may cause unknown problems.
  15. """
  16. # isort: skip_file
  17. import copy
  18. import warnings
  19. from dataclasses import asdict, dataclass
  20. from typing import Callable, List, Optional
  21. import streamlit as st
  22. import torch
  23. from torch import nn
  24. from transformers.generation.utils import (LogitsProcessorList,
  25. StoppingCriteriaList)
  26. from transformers.utils import logging
  27. from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
  28. logger = logging.get_logger(__name__)
  29. @dataclass
  30. class GenerationConfig:
  31. # this config is used for chat to provide more diversity
  32. max_length: int = 32768
  33. top_p: float = 0.8
  34. temperature: float = 0.8
  35. do_sample: bool = True
  36. repetition_penalty: float = 1.005
  37. @torch.inference_mode()
  38. def generate_interactive(
  39. model,
  40. tokenizer,
  41. prompt,
  42. generation_config: Optional[GenerationConfig] = None,
  43. logits_processor: Optional[LogitsProcessorList] = None,
  44. stopping_criteria: Optional[StoppingCriteriaList] = None,
  45. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
  46. List[int]]] = None,
  47. additional_eos_token_id: Optional[int] = None,
  48. **kwargs,
  49. ):
  50. inputs = tokenizer([prompt], padding=True, return_tensors='pt')
  51. input_length = len(inputs['input_ids'][0])
  52. for k, v in inputs.items():
  53. inputs[k] = v.cuda()
  54. input_ids = inputs['input_ids']
  55. _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
  56. if generation_config is None:
  57. generation_config = model.generation_config
  58. generation_config = copy.deepcopy(generation_config)
  59. model_kwargs = generation_config.update(**kwargs)
  60. bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
  61. generation_config.bos_token_id,
  62. generation_config.eos_token_id,
  63. )
  64. if isinstance(eos_token_id, int):
  65. eos_token_id = [eos_token_id]
  66. if additional_eos_token_id is not None:
  67. eos_token_id.append(additional_eos_token_id)
  68. has_default_max_length = kwargs.get(
  69. 'max_length') is None and generation_config.max_length is not None
  70. if has_default_max_length and generation_config.max_new_tokens is None:
  71. warnings.warn(
  72. f"Using 'max_length''s default ({repr(generation_config.max_length)}) \
  73. to control the generation length. "
  74. 'This behaviour is deprecated and will be removed from the \
  75. config in v5 of Transformers -- we'
  76. ' recommend using `max_new_tokens` to control the maximum \
  77. length of the generation.',
  78. UserWarning,
  79. )
  80. elif generation_config.max_new_tokens is not None:
  81. generation_config.max_length = generation_config.max_new_tokens + \
  82. input_ids_seq_length
  83. if not has_default_max_length:
  84. logger.warn( # pylint: disable=W4902
  85. f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
  86. f"and 'max_length'(={generation_config.max_length}) seem to "
  87. "have been set. 'max_new_tokens' will take precedence. "
  88. 'Please refer to the documentation for more information. '
  89. '(https://huggingface.co/docs/transformers/main/'
  90. 'en/main_classes/text_generation)',
  91. UserWarning,
  92. )
  93. if input_ids_seq_length >= generation_config.max_length:
  94. input_ids_string = 'input_ids'
  95. logger.warning(
  96. f"Input length of {input_ids_string} is {input_ids_seq_length}, "
  97. f"but 'max_length' is set to {generation_config.max_length}. "
  98. 'This can lead to unexpected behavior. You should consider'
  99. " increasing 'max_new_tokens'.")
  100. # 2. Set generation parameters if not already defined
  101. logits_processor = logits_processor if logits_processor is not None \
  102. else LogitsProcessorList()
  103. stopping_criteria = stopping_criteria if stopping_criteria is not None \
  104. else StoppingCriteriaList()
  105. logits_processor = model._get_logits_processor(
  106. generation_config=generation_config,
  107. input_ids_seq_length=input_ids_seq_length,
  108. encoder_input_ids=input_ids,
  109. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  110. logits_processor=logits_processor,
  111. )
  112. stopping_criteria = model._get_stopping_criteria(
  113. generation_config=generation_config,
  114. stopping_criteria=stopping_criteria)
  115. logits_warper = model._get_logits_warper(generation_config)
  116. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  117. scores = None
  118. while True:
  119. model_inputs = model.prepare_inputs_for_generation(
  120. input_ids, **model_kwargs)
  121. # forward pass to get next token
  122. outputs = model(
  123. **model_inputs,
  124. return_dict=True,
  125. output_attentions=False,
  126. output_hidden_states=False,
  127. )
  128. next_token_logits = outputs.logits[:, -1, :]
  129. # pre-process distribution
  130. next_token_scores = logits_processor(input_ids, next_token_logits)
  131. next_token_scores = logits_warper(input_ids, next_token_scores)
  132. # sample
  133. probs = nn.functional.softmax(next_token_scores, dim=-1)
  134. if generation_config.do_sample:
  135. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  136. else:
  137. next_tokens = torch.argmax(probs, dim=-1)
  138. # update generated ids, model inputs, and length for next step
  139. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  140. model_kwargs = model._update_model_kwargs_for_generation(
  141. outputs, model_kwargs, is_encoder_decoder=False)
  142. unfinished_sequences = unfinished_sequences.mul(
  143. (min(next_tokens != i for i in eos_token_id)).long())
  144. output_token_ids = input_ids[0].cpu().tolist()
  145. output_token_ids = output_token_ids[input_length:]
  146. for each_eos_token_id in eos_token_id:
  147. if output_token_ids[-1] == each_eos_token_id:
  148. output_token_ids = output_token_ids[:-1]
  149. response = tokenizer.decode(output_token_ids)
  150. yield response
  151. # stop when each sentence is finished
  152. # or if we exceed the maximum length
  153. if unfinished_sequences.max() == 0 or stopping_criteria(
  154. input_ids, scores):
  155. break
  156. def on_btn_click():
  157. del st.session_state.messages
  158. @st.cache_resource
  159. def load_model():
  160. model = (AutoModelForCausalLM.from_pretrained('/root/ft/final_model',
  161. trust_remote_code=True).to(
  162. torch.bfloat16).cuda())
  163. tokenizer = AutoTokenizer.from_pretrained('/root/ft/final_model',
  164. trust_remote_code=True)
  165. return model, tokenizer
  166. def prepare_generation_config():
  167. with st.sidebar:
  168. max_length = st.slider('Max Length',
  169. min_value=8,
  170. max_value=32768,
  171. value=2048)
  172. top_p = st.slider('Top P', 0.0, 1.0, 0.75, step=0.01)
  173. temperature = st.slider('Temperature', 0.0, 1.0, 0.1, step=0.01)
  174. st.button('Clear Chat History', on_click=on_btn_click)
  175. generation_config = GenerationConfig(max_length=max_length,
  176. top_p=top_p,
  177. temperature=temperature)
  178. return generation_config
  179. user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
  180. robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
  181. cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
  182. <|im_start|>assistant\n'
  183. def combine_history(prompt):
  184. messages = st.session_state.messages
  185. meta_instruction = ('')
  186. total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
  187. for message in messages:
  188. cur_content = message['content']
  189. if message['role'] == 'user':
  190. cur_prompt = user_prompt.format(user=cur_content)
  191. elif message['role'] == 'robot':
  192. cur_prompt = robot_prompt.format(robot=cur_content)
  193. else:
  194. raise RuntimeError
  195. total_prompt += cur_prompt
  196. total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
  197. return total_prompt
  198. def main():
  199. # torch.cuda.empty_cache()
  200. print('load model begin.')
  201. model, tokenizer = load_model()
  202. print('load model end.')
  203. st.title('InternLM2-Chat-1.8B')
  204. generation_config = prepare_generation_config()
  205. # Initialize chat history
  206. if 'messages' not in st.session_state:
  207. st.session_state.messages = []
  208. # Display chat messages from history on app rerun
  209. for message in st.session_state.messages:
  210. with st.chat_message(message['role'], avatar=message.get('avatar')):
  211. st.markdown(message['content'])
  212. # Accept user input
  213. if prompt := st.chat_input('What is up?'):
  214. # Display user message in chat message container
  215. with st.chat_message('user'):
  216. st.markdown(prompt)
  217. real_prompt = combine_history(prompt)
  218. # Add user message to chat history
  219. st.session_state.messages.append({
  220. 'role': 'user',
  221. 'content': prompt,
  222. })
  223. with st.chat_message('robot'):
  224. message_placeholder = st.empty()
  225. for cur_response in generate_interactive(
  226. model=model,
  227. tokenizer=tokenizer,
  228. prompt=real_prompt,
  229. additional_eos_token_id=92542,
  230. **asdict(generation_config),
  231. ):
  232. # Display robot response in chat message container
  233. message_placeholder.markdown(cur_response + '▌')
  234. message_placeholder.markdown(cur_response)
  235. # Add robot response to chat history
  236. st.session_state.messages.append({
  237. 'role': 'robot',
  238. 'content': cur_response, # pylint: disable=undefined-loop-variable
  239. })
  240. torch.cuda.empty_cache()
  241. if __name__ == '__main__':
  242. main()

运行

streamlit run /root/ft/web_demo/InternLM/chat/web_demo.py --server.address 127.0.0.1 --server.port 6006

部署到OpenXLab

模型部分

安装依赖

  1. apt install git-lfs
  2. git lfs install

下载模型代码仓

git clone https://code.openxlab.org.cn/JeffDing/xtuner_demo_1_8b.git

复制模型到代码仓

cp -r /root/ft/final_model/* /root/openxlab/xtuner_demo_1_8b/

上传模型

  1. cd xtuner_demo_1_8b
  2. git add .
  3. git commit -m "init"
  4. git push

具体的操作可以参考文档:上传模型文件 | OpenXLab浦源 - 文档中心

应用部分

需要创建一个GIT代码仓,然后将前面web_demo.py的代码修改下上传

整体代码如下:

  1. # isort: skip_file
  2. import copy
  3. import warnings
  4. import os
  5. from dataclasses import asdict, dataclass
  6. from typing import Callable, List, Optional
  7. import streamlit as st
  8. import torch
  9. from torch import nn
  10. from transformers.generation.utils import (LogitsProcessorList,
  11. StoppingCriteriaList)
  12. from transformers.utils import logging
  13. from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
  14. logger = logging.get_logger(__name__)
  15. @dataclass
  16. class GenerationConfig:
  17. # this config is used for chat to provide more diversity
  18. max_length: int = 32768
  19. top_p: float = 0.8
  20. temperature: float = 0.8
  21. do_sample: bool = True
  22. repetition_penalty: float = 1.005
  23. @torch.inference_mode()
  24. def generate_interactive(
  25. model,
  26. tokenizer,
  27. prompt,
  28. generation_config: Optional[GenerationConfig] = None,
  29. logits_processor: Optional[LogitsProcessorList] = None,
  30. stopping_criteria: Optional[StoppingCriteriaList] = None,
  31. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
  32. List[int]]] = None,
  33. additional_eos_token_id: Optional[int] = None,
  34. **kwargs,
  35. ):
  36. inputs = tokenizer([prompt], padding=True, return_tensors='pt')
  37. input_length = len(inputs['input_ids'][0])
  38. for k, v in inputs.items():
  39. inputs[k] = v.cuda()
  40. input_ids = inputs['input_ids']
  41. _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
  42. if generation_config is None:
  43. generation_config = model.generation_config
  44. generation_config = copy.deepcopy(generation_config)
  45. model_kwargs = generation_config.update(**kwargs)
  46. bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
  47. generation_config.bos_token_id,
  48. generation_config.eos_token_id,
  49. )
  50. if isinstance(eos_token_id, int):
  51. eos_token_id = [eos_token_id]
  52. if additional_eos_token_id is not None:
  53. eos_token_id.append(additional_eos_token_id)
  54. has_default_max_length = kwargs.get(
  55. 'max_length') is None and generation_config.max_length is not None
  56. if has_default_max_length and generation_config.max_new_tokens is None:
  57. warnings.warn(
  58. f"Using 'max_length''s default ({repr(generation_config.max_length)}) \
  59. to control the generation length. "
  60. 'This behaviour is deprecated and will be removed from the \
  61. config in v5 of Transformers -- we'
  62. ' recommend using `max_new_tokens` to control the maximum \
  63. length of the generation.',
  64. UserWarning,
  65. )
  66. elif generation_config.max_new_tokens is not None:
  67. generation_config.max_length = generation_config.max_new_tokens + \
  68. input_ids_seq_length
  69. if not has_default_max_length:
  70. logger.warn( # pylint: disable=W4902
  71. f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
  72. f"and 'max_length'(={generation_config.max_length}) seem to "
  73. "have been set. 'max_new_tokens' will take precedence. "
  74. 'Please refer to the documentation for more information. '
  75. '(https://huggingface.co/docs/transformers/main/'
  76. 'en/main_classes/text_generation)',
  77. UserWarning,
  78. )
  79. if input_ids_seq_length >= generation_config.max_length:
  80. input_ids_string = 'input_ids'
  81. logger.warning(
  82. f"Input length of {input_ids_string} is {input_ids_seq_length}, "
  83. f"but 'max_length' is set to {generation_config.max_length}. "
  84. 'This can lead to unexpected behavior. You should consider'
  85. " increasing 'max_new_tokens'.")
  86. # 2. Set generation parameters if not already defined
  87. logits_processor = logits_processor if logits_processor is not None \
  88. else LogitsProcessorList()
  89. stopping_criteria = stopping_criteria if stopping_criteria is not None \
  90. else StoppingCriteriaList()
  91. logits_processor = model._get_logits_processor(
  92. generation_config=generation_config,
  93. input_ids_seq_length=input_ids_seq_length,
  94. encoder_input_ids=input_ids,
  95. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  96. logits_processor=logits_processor,
  97. )
  98. stopping_criteria = model._get_stopping_criteria(
  99. generation_config=generation_config,
  100. stopping_criteria=stopping_criteria)
  101. logits_warper = model._get_logits_warper(generation_config)
  102. unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
  103. scores = None
  104. while True:
  105. model_inputs = model.prepare_inputs_for_generation(
  106. input_ids, **model_kwargs)
  107. # forward pass to get next token
  108. outputs = model(
  109. **model_inputs,
  110. return_dict=True,
  111. output_attentions=False,
  112. output_hidden_states=False,
  113. )
  114. next_token_logits = outputs.logits[:, -1, :]
  115. # pre-process distribution
  116. next_token_scores = logits_processor(input_ids, next_token_logits)
  117. next_token_scores = logits_warper(input_ids, next_token_scores)
  118. # sample
  119. probs = nn.functional.softmax(next_token_scores, dim=-1)
  120. if generation_config.do_sample:
  121. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  122. else:
  123. next_tokens = torch.argmax(probs, dim=-1)
  124. # update generated ids, model inputs, and length for next step
  125. input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
  126. model_kwargs = model._update_model_kwargs_for_generation(
  127. outputs, model_kwargs, is_encoder_decoder=False)
  128. unfinished_sequences = unfinished_sequences.mul(
  129. (min(next_tokens != i for i in eos_token_id)).long())
  130. output_token_ids = input_ids[0].cpu().tolist()
  131. output_token_ids = output_token_ids[input_length:]
  132. for each_eos_token_id in eos_token_id:
  133. if output_token_ids[-1] == each_eos_token_id:
  134. output_token_ids = output_token_ids[:-1]
  135. response = tokenizer.decode(output_token_ids)
  136. yield response
  137. # stop when each sentence is finished
  138. # or if we exceed the maximum length
  139. if unfinished_sequences.max() == 0 or stopping_criteria(
  140. input_ids, scores):
  141. break
  142. def on_btn_click():
  143. del st.session_state.messages
  144. @st.cache_resource
  145. def load_model():
  146. base_path = './xtuner_demo_1_8b'
  147. # download repo to the base_path directory using git
  148. os.system('apt install git')
  149. os.system('apt install git-lfs')
  150. os.system(f'git clone https://code.openxlab.org.cn/JeffDing/xtuner_demo_1_8b.git {base_path}')
  151. os.system(f'cd {base_path} && git lfs pull')
  152. model = (AutoModelForCausalLM.from_pretrained(base_path,
  153. trust_remote_code=True).to(
  154. torch.bfloat16).cuda())
  155. tokenizer = AutoTokenizer.from_pretrained(base_path,
  156. trust_remote_code=True)
  157. return model, tokenizer
  158. def prepare_generation_config():
  159. with st.sidebar:
  160. max_length = st.slider('Max Length',
  161. min_value=8,
  162. max_value=32768,
  163. value=2048)
  164. top_p = st.slider('Top P', 0.0, 1.0, 0.75, step=0.01)
  165. temperature = st.slider('Temperature', 0.0, 1.0, 0.1, step=0.01)
  166. st.button('Clear Chat History', on_click=on_btn_click)
  167. generation_config = GenerationConfig(max_length=max_length,
  168. top_p=top_p,
  169. temperature=temperature)
  170. return generation_config
  171. user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
  172. robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
  173. cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
  174. <|im_start|>assistant\n'
  175. def combine_history(prompt):
  176. messages = st.session_state.messages
  177. meta_instruction = ('')
  178. total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
  179. for message in messages:
  180. cur_content = message['content']
  181. if message['role'] == 'user':
  182. cur_prompt = user_prompt.format(user=cur_content)
  183. elif message['role'] == 'robot':
  184. cur_prompt = robot_prompt.format(robot=cur_content)
  185. else:
  186. raise RuntimeError
  187. total_prompt += cur_prompt
  188. total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
  189. return total_prompt
  190. def main():
  191. # torch.cuda.empty_cache()
  192. print('load model begin.')
  193. model, tokenizer = load_model()
  194. print('load model end.')
  195. st.title('InternLM2-Chat-1.8B')
  196. generation_config = prepare_generation_config()
  197. # Initialize chat history
  198. if 'messages' not in st.session_state:
  199. st.session_state.messages = []
  200. # Display chat messages from history on app rerun
  201. for message in st.session_state.messages:
  202. with st.chat_message(message['role'], avatar=message.get('avatar')):
  203. st.markdown(message['content'])
  204. # Accept user input
  205. if prompt := st.chat_input('What is up?'):
  206. # Display user message in chat message container
  207. with st.chat_message('user'):
  208. st.markdown(prompt)
  209. real_prompt = combine_history(prompt)
  210. # Add user message to chat history
  211. st.session_state.messages.append({
  212. 'role': 'user',
  213. 'content': prompt,
  214. })
  215. with st.chat_message('robot'):
  216. message_placeholder = st.empty()
  217. for cur_response in generate_interactive(
  218. model=model,
  219. tokenizer=tokenizer,
  220. prompt=real_prompt,
  221. additional_eos_token_id=92542,
  222. **asdict(generation_config),
  223. ):
  224. # Display robot response in chat message container
  225. message_placeholder.markdown(cur_response + '▌')
  226. message_placeholder.markdown(cur_response)
  227. # Add robot response to chat history
  228. st.session_state.messages.append({
  229. 'role': 'robot',
  230. 'content': cur_response, # pylint: disable=undefined-loop-variable
  231. })
  232. torch.cuda.empty_cache()
  233. if __name__ == '__main__':
  234. main()

然后在OpenXLab中创建应用即可

XTuner多模态部分

LLaVA方案简介

Haotian Liu等使用GPT-4V对图像数据生成描述,以此构建出大量<question text><image> -- <answer text>的数据对。利用这些数据对,配合文本单模态LLM,训练出一个Image Projector。

所使用的文本单模型LLM和训练出来的Image Projector,统称为LLaVA模型

实践部分

Pretrain阶段

在Pretrain阶段,我们会使用大量的图片+简单文本(caption, 即图片标题)数据对,使LLM理解图像中的普遍特征。即,对大量的图片进行粗看

Pretrain阶段训练完成后,此时的模型已经有视觉能力了!但是由于训练数据中都是图片+图片标题,所以此时的模型虽然有视觉能力,但无论用户问它什么,它都只会回答输入图片的标题。即,此时的模型只会给输入图像“写标题”

Finetune阶段

在Finetune阶段,我们会使用图片+复杂文本数据对,来对Pretrain得到的Image Projector即iter_2181.pth进行进一步的训练。

训练数据格式

  1. [
  2. {
  3. "id": "随便什么字符串",
  4. "image": "图片文件的相对位置。相对谁?相对你后面config文件里指定的image_folder参数的路径。",
  5. "conversation": [
  6. {
  7. "from": "human",
  8. "value": "<image>\n第1个问题。"
  9. },
  10. {
  11. "from": "gpt",
  12. "value": "第1个回答"
  13. },
  14. {
  15. "from": "human",
  16. "value": "第2个问题。"
  17. },
  18. {
  19. "from": "gpt",
  20. "value": "第2个回答"
  21. },
  22. # ......
  23. {
  24. "from": "human",
  25. "value": "第n个问题。"
  26. },
  27. {
  28. "from": "gpt",
  29. "value": "第n个回答"
  30. },
  31. ]
  32. },
  33. # 下面是第2组训练数据了。
  34. {
  35. "id": "随便什么字符串",
  36. "image": "图片文件的相对位置。相对谁?相对你后面config文件里指定的image_folder参数的路径。",
  37. "conversation": [
  38. {
  39. "from": "human",
  40. "value": "<image>\n第1个问题。"
  41. },
  42. # ......
  43. {
  44. "from": "gpt",
  45. "value": "第n个回答"
  46. }
  47. ]
  48. }
  49. ]

针对这张示例图片的问答对数据(repeat_data.json),生成脚本如下(重复200次)

  1. cd ~ && git clone https://github.com/InternLM/tutorial -b camp2 && conda activate xtuner0.1.17 && cd tutorial
  2. python /root/tutorial/xtuner/llava/llava_data/repeat.py \
  3. -i /root/tutorial/xtuner/llava/llava_data/unique_data.json \
  4. -o /root/tutorial/xtuner/llava/llava_data/repeated_data.json \
  5. -n 200

准备配置文件

cp /root/tutorial/xtuner/llava/llava_data/internlm2_chat_1_8b_llava_tutorial_fool_config.py /root/tutorial/xtuner/llava/llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune_copy.py

创建配置文件

  1. # 查询xtuner内置配置文件
  2. xtuner list-cfg -p llava_internlm2_chat_1_8b
  3. # 拷贝配置文件到当前目录
  4. xtuner copy-cfg \
  5. llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune \
  6. /root/tutorial/xtuner/llava

修改配置文件

修改llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune_copy.py文件中

  • pretrained_pth
  • llm_name_or_path
  • visual_encoder_name_or_path
  • data_root
  • data_path
  • image_folder

开始Finetune

  1. cd /root/tutorial/xtuner/llava/
  2. xtuner train /root/tutorial/xtuner/llava/llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune_copy.py --deepspeed deepspeed_zero

对比Finetune前后的性能差异

图片:

Finetune前

  1. # 解决小bug
  2. export MKL_SERVICE_FORCE_INTEL=1
  3. export MKL_THREADING_LAYER=GNU
  4. # pth转huggingface
  5. xtuner convert pth_to_hf \
  6. llava_internlm2_chat_1_8b_clip_vit_large_p14_336_e1_gpu8_pretrain \
  7. /root/share/new_models/xtuner/iter_2181.pth \
  8. /root/tutorial/xtuner/llava/llava_data/iter_2181_hf
  9. # 启动!
  10. xtuner chat /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b \
  11. --visual-encoder /root/share/new_models/openai/clip-vit-large-patch14-336 \
  12. --llava /root/tutorial/xtuner/llava/llava_data/iter_2181_hf \
  13. --prompt-template internlm2_chat \
  14. --image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg

运行结果:

Finetune后

  1. # 解决小bug
  2. export MKL_SERVICE_FORCE_INTEL=1
  3. export MKL_THREADING_LAYER=GNU
  4. # pth转huggingface
  5. xtuner convert pth_to_hf \
  6. /root/tutorial/xtuner/llava/llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune_copy.py \
  7. /root/tutorial/xtuner/llava/work_dirs/llava_internlm2_chat_1_8b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune_copy/iter_1200.pth \
  8. /root/tutorial/xtuner/llava/llava_data/iter_1200_hf
  9. # 启动!
  10. xtuner chat /root/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b \
  11. --visual-encoder /root/share/new_models/openai/clip-vit-large-patch14-336 \
  12. --llava /root/tutorial/xtuner/llava/llava_data/iter_1200_hf \
  13. --prompt-template internlm2_chat \
  14. --image /root/tutorial/xtuner/llava/llava_data/test_img/oph.jpg

运行结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/445772
推荐阅读
相关标签
  

闽ICP备14008679号