当前位置:   article > 正文

使用 torchtune 微调 Llama3_torchtune grammar

torchtune grammar

Llama3 + torchtune


本文翻译改编自:https://pytorch.org/torchtune/stable/tutorials/llama3.html



1、Llama3-8B

Llama3-8B 是 Meta AI 发布的新模型,它在一系列不同基准测试中改进了 Llama2 系列模型的性能。 Llama2-7B 和 Llama3-8B 型号之间有一些主要变化:

  • Llama3-8B 使用分组查询注意力,而不是 Llama2-7B 的标准多头注意力
  • Llama3-8B 具有更大的词汇量(128,256,而不是 Llama2 模型的 32,000)
  • Llama3-8B 使用与 Llama2 模型不同的分词器(tiktoken而不是sentencepiece
  • Llama3-8B 在其 MLP 层中使用比 Llama2-7B 更大的中间维度
  • Llama3-8B 使用更高的基值来计算其 rotary positional embeddings 中的 theta。

2、访问 Llama3-8B

首先,我们从 Hugging Face 下载模型。您需要按照官方元页面上的说明进行操作才能访问该模型。接下来,请确保从这里获取您的 Hugging Face 令牌。

tune download meta-llama/Meta-Llama-3-8B \
    --output-dir <checkpoint_dir> \
    --hf-token <ACCESS TOKEN>
  • 1
  • 2
  • 3

3、在 torchtune 中微调 Llama3-8B

torchtune 提供LoRAQLoRA和完整的微调recipe,用于在一个或多个 GPU 上微调 Llama3-8B。

有关 torchtune 中 LoRA 的更多信息,请参阅我们的LoRA 教程。有关 torchtune 中 QLoRA 的更多信息,请参阅我们的QLoRA 教程

让我们看看如何使用 torchtune 在单个设备上对 Llama3-8B 和 LoRA 进行微调。

在此示例中,出于说明目的,我们将在通用指令数据集上微调一个时期。单设备 LoRA 微调的基本命令是

tune run lora_finetune_single_device --config llama3/8B_lora_single_device
  • 1

注:要查看recipe 及其相应配置的完整列表,只需运行 tune ls 命令即可。

我们还可以根据需要 添加命令行覆盖,例如

tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
    checkpointer.checkpoint_dir=<checkpoint_dir> \
    tokenizer.path=<checkpoint_dir>/tokenizer.model \
    checkpointer.output_dir=<checkpoint_dir>
  • 1
  • 2
  • 3
  • 4

这将从上面 tune download 命令中使用的 <checkpoint_dir> 加载 Llama3-8B 检查点和标记器,然后按照原始格式 将最终检查点保存在同一目录中。

有关 torchtune 支持的检查点格式的更多详细信息,请参阅我们的检查点深入研究

注:要查看此(和其他)配置的完整可配置参数集,我们可以使用 tune cp 来复制(和修改)默认配置。

tune cp 也可以与recipe脚本一起使用,以防您想要进行更多无法通过直接修改现有可配置参数来实现的自定义更改。

有关更多 tune cp 的信息,请参阅修改配置部分 。

训练完成后,模型检查点将被保存并记录其位置。对于 LoRA 微调,最终检查点将包含合并的权重,并且将单独保存(小得多的)LoRA 权重的副本。

在我们的实验中,我们观察到内存使用峰值为 18.5 GB。默认配置可以在具有 24 GB VRAM 的消费级 GPU 上进行训练。

如果您有多个可用的 GPU,则可以运行recipe的分布式版本。
torchtune 利用PyTorch Distributed 的 FSDP API 对模型、优化器状态和梯度进行分片。
这将使您能够增加批量大小,从而加快整体训练速度。
例如,在两台设备上:

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_lora
  • 1

最后,如果我们想使用更少的内存,我们可以通过以下方式利用 TorchTune 的 QLoRA recipe:

tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
  • 1

由于我们的默认配置启用完整的 bfloat16 训练,因此上述所有命令都可以在具有至少 24 GB VRAM 的设备上运行,事实上,QLoRA recipe的峰值分配内存应低于 10 GB。
您还可以尝试 LoRA 和 QLoRA 的不同配置,甚至进行全面微调。试试看!


4、使用 EleutherAI 的 Eval Harness 评估 微调的 Llama3-8B 模型

现在我们已经对 Llama3-8B 进行了微调,下一步是什么?
让我们采用上一节中的 LoRA 微调模型,看看我们可以评估其在我们关心的任务上的性能的几种不同方法。

首先,torchtune 与 EleutherAI’s evaluation harness 集成 ,用于常见基准任务的模型评估。

注:确保您首先通过安装了评估工具:pip install "lm_eval==0.4.*"

在本教程中,我们将使用harness 中的任务 truthfulqa_mc2
此任务衡量模型在回答问题时的真实倾向,并衡量模型在一个或多个真实响应和一个或多个错误响应后的问题上的零样本准确性。
首先,让我们复制配置,以便将 YAML 文件指向我们微调的检查点文件。

tune cp eleuther_evaluation ./custom_eval_config.yaml
  • 1

接下来,我们进行修改 custom_eval_config.yaml 以包含微调的 checkpoints。

model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer

  # directory with the checkpoint files
  # this should match the output_dir specified during
  # fine-tuning
  checkpoint_dir: <checkpoint_dir>

  # checkpoint files for the fine-tuned model. These will be logged
  # at the end of your fine-tune
  checkpoint_files: [
    consolidated.00.pth
  ]

  output_dir: <checkpoint_dir>
  model_type: LLAMA3

# Make sure to update the tokenizer path to the right
# checkpoint directory as well
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: <checkpoint_dir>/tokenizer.model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

最后,我们可以使用修改后的配置来运行评估。

tune run eleuther_eval --config ./custom_eval_config.yaml
  • 1

亲自尝试一下,看看您的模型的准确性如何!


5、使用微调后的 Llama3-8B 模型 生成文本

接下来,让我们看看评估模型的另一种方法:生成文本!
torchtune 也提供了生成的方法

与我们所做的类似,让我们复制并修改默认的生成配置。

tune cp generation ./custom_generation_config.yaml
  • 1

现在我们修改custom_generation_config.yaml为指向我们的检查点和标记器。

model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer

  # directory with the checkpoint files
  # this should match the output_dir specified during
  # fine-tuning
  checkpoint_dir: <checkpoint_dir>

  # checkpoint files for the fine-tuned model. These will be logged
  # at the end of your fine-tune
  checkpoint_files: [
    consolidated.00.pth
  ]

  output_dir: <checkpoint_dir>
  model_type: LLAMA3

# Make sure to update the tokenizer path to the right
# checkpoint directory as well
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: <checkpoint_dir>/tokenizer.model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

使用经过 LoRA 微调的模型 运行生成,我们会看到以下输出:

tune run generate --config ./custom_generation_config.yaml \
prompt="Hello, my name is"

[generate.py:122] Hello, my name is Sarah and I am a busy working mum of two young children, living in the North East of England.
...
[generate.py:135] Time for inference: 10.88 sec total, 18.94 tokens/sec
[generate.py:138] Bandwidth achieved: 346.09 GB/s
[generate.py:139] Memory used: 18.31 GB
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

6、通过量化更快地生成

我们可以看到该模型只用了不到 11 秒,每秒生成近 19 个令牌。
我们可以通过量化我们的模型来加快速度。
在这里,我们将使用 torchao提供的 4 位仅权重量化。

如果您已经了解了到目前为止,那么您现在已经知道了该怎么做。
让我们复制量化配置并将其指向我们微调的模型。

tune cp quantization ./custom_quantization_config.yaml
  • 1

并更新custom_quantization_config.yaml以下内容:

# Model arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer

  # directory with the checkpoint files
  # this should match the output_dir specified during
  # fine-tuning
  checkpoint_dir: <checkpoint_dir>

  # checkpoint files for the fine-tuned model. These will be logged
  # at the end of your fine-tune
  checkpoint_files: [
    consolidated.00.pth
  ]

  output_dir: <checkpoint_dir>
  model_type: LLAMA3
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

为了量化模型,我们现在可以运行:

tune run quantize --config ./custom_quantization_config.yaml

[quantize.py:90] Time for quantization: 2.93 sec
[quantize.py:91] Memory used: 23.13 GB
[quantize.py:104] Model checkpoint of size 4.92 GB saved to /tmp/Llama-3-8B-hf/consolidated-4w.pt
  • 1
  • 2
  • 3
  • 4
  • 5

我们可以看到该模型现在低于 5 GB,或者每个 8B 参数仅超过 4 位。

注:与微调检查点不同,量化recipe输出单个检查点文件。
这是因为我们的量化 API 目前不支持任何跨格式的转换。
因此,您将无法在 torchtune 之外使用这些量化模型。
但是您应该能够将它们与 torchtune 中的生成和评估recipe一起使用。
这些结果将有助于告知您应该在您最喜欢的推理引擎中使用哪种量化方法。

让我们采用量化模型并再次运行同一代。首先,我们将改变 custom_generation_config.yaml

checkpointer:
  # we need to use the custom TorchTune checkpointer
  # instead of the HF checkpointer for loading
  # quantized models
  _component_: torchtune.utils.FullModelTorchTuneCheckpointer

  # directory with the checkpoint files
  # this should match the output_dir specified during
  # fine-tuning
  checkpoint_dir: <checkpoint_dir>

  # checkpoint files point to the quantized model
  checkpoint_files: [
    consolidated-4w.pt,
  ]

  output_dir: <checkpoint_dir>
  model_type: LLAMA3

# we also need to update the quantizer to what was used during
# quantization
quantizer:
  _component_: torchtune.utils.quantization.Int4WeightOnlyQuantizer
  groupsize: 256
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

让我们重新运行吧!

tune run generate --config ./custom_generation_config.yaml \
prompt="Hello, my name is"

[generate.py:122] Hello, my name is Jake.
I am a multi-disciplined artist with a passion for creating, drawing and painting.
...
Time for inference: 1.62 sec total, 57.95 tokens/sec
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

通过量化模型并运行,torch.compile我们获得了超过 3 倍的加速!

这只是您可以使用 torchtune 和更广泛的生态系统对 Llama3-8B 进行的操作的开始。我们期待看到您构建的内容!


2024-04-23(二)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/653938
推荐阅读
相关标签
  

闽ICP备14008679号