当前位置:   article > 正文

[NLP]使用Alpaca-Lora基于llama模型进行微调教程_you are using the legacy behaviour of the

you are using the legacy behaviour of the

Stanford Alpaca 是在 LLaMA 整个模型上微调,即对预训练模型中的所有参数都进行微调(full fine-tuning)。但该方法对于硬件成本要求仍然偏高且训练低效。

[NLP]理解大型语言模型高效微调(PEFT)

因此, Alpaca-Lora 则是利用 Lora 技术,在冻结原模型 LLaMA 参数的情况下,通过往模型中加入额外的网络层,并只训练这些新增的网络层参数。由于这些新增参数数量较少,这样不仅微调的成本显著下降,还能获得和全模型微调(full fine-tuning)类似的效果。

LoRA 的原理其实并不复杂,它的核心思想是在原始预训练语言模型旁边增加一个旁路,做一个降维再升维的操作,来模拟所谓的 intrinsic rank(预训练模型在各类下游任务上泛化的过程其实就是在优化各类任务的公共低维本征(low-dimensional intrinsic)子空间中非常少量的几个自由参数)。训练的时候固定预训练语言模型的参数,只训练降维矩阵 A 与升维矩阵 B。而模型的输入输出维度不变,输出时将 BA 与预训练语言模型的参数叠加。用随机高斯分布初始化 A,用 0 矩阵初始化 B。这样能保证训练开始时,新增的通路BA=0从,而对模型结果没有影响。

在推理时,将左右两部分的结果加到一起即可,h=Wx+BAx=(W+BA)x,所以,只要将训练完成的矩阵乘积BA跟原本的权重矩阵W加到一起作为新权重参数替换原始预训练语言模型的W即可,不会增加额外的计算资源。

LoRA 的最大优势是速度更快,使用的内存更少;因此,可以在消费级硬件上运行。

准备数据集

fine-tune 的目标通常有两种:

  • 像 Alpaca 一样,收集 input/output 生成 prompt 用于训练,让模型完成特定任务
  • 语言填充,收集文本用于训练,让模型补全 prompt。

以第一种目标为例,假设我们的目标是让模型讲中文,那么,我们可以通过其他 LLM (如 text-davinci-003)把一个现有数据集(如 Alpaca)翻译为中文来做 fine-tune。实际上这个想法已经在开源社区已经有人实现了。

为了达成这个目标,我使用的数据集是 Luotuo 作者翻译的 Alpaca 数据集,训练代码主要来自 Alpaca-LoRA

wget https://github.com/LC1332/Chinese-alpaca-lora/blob/main/data/trans_chinese_alpaca_data.json

Alpach-LoRA 目录中也包含fine-tune的English数据集:

除此之外,可参考GPT-4-LLM项目,该项目还提供了使用Alpaca的Prompt翻译成中文使用 GPT4 生成了 5.2 万条指令跟随数据。

一 环境搭建

基础环境配置如下:

  • 操作系统: CentOS 7
  • CPUs: 单个节点具有 1TB 内存的 Intel CPU,物理CPU个数为64,每颗CPU核数为16
  • GPUs: 4 卡 A100 80GB GPU
  • Docker Image: pytorch:1.13.0-cuda11.6-cudnn8-devel

1.在 Alpaca-LoRA 项目中,作者提到,他们使用了 Hugging Face 的 PEFT。PEFT 是一个库(LoRA 是其支持的技术之一,除此之外还有Prefix Tuning、P-Tuning、Prompt Tuning),可以让你使用各种基于 Transformer 结构的语言模型进行高效微调。下面安装PEFT。

  1. #安装peft
  2. git clone https://github.com/huggingface/peft.git
  3. cd peft/
  4. pip install .

2.  bitsandbytes是对CUDA自定义函数的轻量级封装

 特别是针对8位优化器、矩阵乘法(LLM.int8())和量化函数。

  1. #安装bitsandbytes。
  2. git clone git@github.com:TimDettmers/bitsandbytes.git
  3. cd bitsandbytes
  4. CUDA_VERSION=116 make cuda11x
  5. python setup.py install
如果安装 bitsandbytes出现如下错误:
/usr/bin/ld: cannot find -lcudart

请行执行如下命令

  1. cd /usr/lib
  2. ln -s /usr/local/cuda/lib64/libcudart.so libcudart.so

3.Alpaca-Lora微调代码

  1. #下载alpaca-lora
  2. git clone git@github.com:tloen/alpaca-lora.git
  3. cd alpaca-lora
  4. pip install -r requirements.txt

requirements.txt文件具体的内容如下:

  1. accelerate
  2. appdirs
  3. loralib
  4. bitsandbytes
  5. black
  6. black[jupyter]
  7. datasets
  8. fire
  9. git+https://github.com/huggingface/peft.git
  10. transformers>=4.28.0
  11. sentencepiece
  12. gradio

二 模型格式转换

将LLaMA原始权重文件转换为Transformers库对应的模型文件格式。可以直接从Hugging Face下载转换好的模型如下:

下载方法可以参考:[NLP]Huggingface模型/数据文件下载方法

decapoda-research/llama-7b-hf · Hugging Face

decapoda-research/llama-13b-hf · Hugging Face

三 模型微调

Alpaca Lora 作者采用了 Hugging Face 的轻量化微调库(Parameter Efficient Fine-Tuning,PEFT)中所支持的 LoRA 方法。LoRA 方法的两项配置会直接影响需要训练的参数量:

1)LoRA 目标模块(lora_target_modules),用于指定要对哪些模块的参数进行微调。比如我们可以对 Q, K, V, O 都进行微调;也可以只对 Q、V 进行微调。不同的设定会影响需要微调的参数量,也会影响训练过程中的计算量。比如当我们设定只对 Q、V 进行微调时,需要训练的参数量(trainable parameters)只占整个模型参数总量的 6% 左右。

2)LoRA 的秩(lora_r)也是影响训练参数量的一个重要因素。客观来说,使用 LoRA 这样的方法训练得到的模型,在效果上必然会和直接在原始大模型基础上进行训练的效果有一定差异。因此,可以结合所拥有的机器配置、可以容忍的最大训练时长等因素,来灵活地配置 LoRA 的使用方法。

1. 这是微调时的默认参数如下:

  1. batch_size: 128
  2. micro_batch_size: 4
  3. num_epochs: 3
  4. learning_rate: 0.0003
  5. cutoff_len: 256
  6. val_set_size: 2000
  7. lora_r: 8
  8. lora_alpha: 16
  9. lora_dropout: 0.05
  10. lora_target_modules: ['q_proj', 'v_proj']
  11. train_on_inputs: True
  12. group_by_length: False
  13. wandb_project:
  14. wandb_run_name:
  15. wandb_watch:
  16. wandb_log_model:
  17. resume_from_checkpoint: False
  18. prompt template: alpaca

2. 使用单块GPU运行如下:

  1. nohup python finetune.py \
  2. --base_model '/home/llama-7b' \
  3. --data_path '../alpaca_data_cleaned.json' \
  4. --output_dir './lora-alpaca-7b-1gpu' \
  5. > torchrun-7b-1gpu.log 2>&1 &
  1. +-----------------------------------------------------------------------------+
  2. | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
  3. |-------------------------------+----------------------+----------------------+
  4. | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
  5. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  6. | | | MIG M. |
  7. |===============================+======================+======================|
  8. | 0 NVIDIA A100-SXM... On | 00000000:10:00.0 Off | 0 |
  9. | N/A 00C P0 293W / 400W | 10813MiB / 81920MiB | 94% Default |
  10. | | | Disabled |
  11. +-------------------------------+----------------------+----------------------+

 

3 使用4块GPU运行如下:

  1. nohup torchrun --nproc_per_node=4 --master_port=1234 finetune.py \
  2. --base_model '/home/llama-7b' \
  3. --data_path '../alpaca_data_cleaned.json' \
  4. --output_dir './lora-alpaca-7b-4gpu' \
  5. --num_epochs 1 \
  6. > torchrun-7b-4gpu.log 2>&1 &
  1. +-----------------------------------------------------------------------------+
  2. | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |
  3. |-------------------------------+----------------------+----------------------+
  4. | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
  5. | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
  6. | | | MIG M. |
  7. |===============================+======================+======================|
  8. | 0 NVIDIA A100-SXM... On | 00000000:16:00.0 Off | 0 |
  9. | N/A 11C P0 282W / 400W | 17055MiB / 81920MiB | 93% Default |
  10. | | | Disabled |
  11. +-------------------------------+----------------------+----------------------+
  12. | 1 NVIDIA A100-SXM... On | 00000000:47:00.0 Off | 0 |
  13. | N/A 12C P0 339W / 400W | 14275MiB / 81920MiB | 93% Default |
  14. | | | Disabled |
  15. +-------------------------------+----------------------+----------------------+
  16. | 2 NVIDIA A100-SXM... On | 00000000:4B:00.0 Off | 0 |
  17. | N/A 13C P0 324W / 400W | 14773MiB / 81920MiB | 94% Default |
  18. | | | Disabled |
  19. +-------------------------------+----------------------+----------------------+
  20. | 3 NVIDIA A100-SXM... On | 00000000:89:00.0 Off | 0 |
  21. | N/A 14C P0 325W / 400W | 14385MiB / 81920MiB | 94% Default |
  22. | | | Disabled |
  23. +-------------------------------+----------------------+----------------------+

4.输出如下:

  1. Training Alpaca-LoRA model with params:
  2. base_model: /disk1/llama-13b
  3. data_path: ./alpaca_data_cleaned_archive.json
  4. output_dir: ./lora-alpaca
  5. batch_size: 128
  6. micro_batch_size: 8
  7. num_epochs: 1
  8. learning_rate: 0.0003
  9. cutoff_len: 256
  10. val_set_size: 2000
  11. lora_r: 8
  12. lora_alpha: 16
  13. lora_dropout: 0.05
  14. lora_target_modules: ['q_proj', 'v_proj']
  15. train_on_inputs: True
  16. add_eos_token: False
  17. group_by_length: False
  18. wandb_project:
  19. wandb_run_name:
  20. wandb_watch:
  21. wandb_log_model:
  22. resume_from_checkpoint: False
  23. prompt template: alpaca
  24. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00, 1.06s/it]
  25. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00, 1.06s/it]
  26. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00, 1.06s/it]
  27. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:43<00:00, 1.06s/it]
  28. The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
  29. The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
  30. The class this function is called from is 'LlamaTokenizer'.
  31. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
  32. /opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  33. warnings.warn(
  34. The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
  35. The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
  36. The class this function is called from is 'LlamaTokenizer'.
  37. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
  38. The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
  39. The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
  40. The class this function is called from is 'LlamaTokenizer'.
  41. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
  42. /opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  43. warnings.warn(
  44. /opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  45. warnings.warn(
  46. The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
  47. The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
  48. The class this function is called from is 'LlamaTokenizer'.
  49. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
  50. /opt/conda/lib/python3.9/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.
  51. warnings.warn(
  52. trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
  53. Map: 3%|███▊ | 1330/49759 [00:01<00:39, 1216.23 examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
  54. Map: 0%| | 0/49759 [00:00<?, ? examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
  55. Map: 1%|▊ | 272/49759 [00:00<00:36, 1350.21 examples/s]trainable params: 6,553,600 || all params: 13,022,417,920 || trainable%: 0.05032552357220002
  56. Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1294.31 examples/s]
  57. Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1284.04 examples/s]
  58. Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:38<00:00, 1283.95 examples/s]
  59. Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1221.03 examples/s]
  60. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  61. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  62. Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49759/49759 [00:39<00:00, 1274.42 examples/s]
  63. Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1285.16 examples/s]
  64. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  65. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  66. Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1281.27 examples/s]
  67. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  68. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  69. Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:01<00:00, 1290.31 examples/s]
  70. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  71. [W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [localhost]:29005 (errno: 97 - Address family not supported by protocol).
  72. 0%| | 0/388 [00:00<?, ?it/s]/opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  73. warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  74. /opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  75. warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  76. /opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  77. warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  78. /opt/conda/lib/python3.9/site-packages/bitsandbytes-0.41.0-py3.9.egg/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization
  79. warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
  80. {'loss': 2.249, 'learning_rate': 2.9999999999999997e-05, 'epoch': 0.03}
  81. {'loss': 2.1927, 'learning_rate': 5.6999999999999996e-05, 'epoch': 0.05}
  82. {'loss': 2.0813, 'learning_rate': 7.8e-05, 'epoch': 0.08}
  83. {'loss': 1.7206, 'learning_rate': 0.00010799999999999998, 'epoch': 0.1}
  84. 11%|████████████████▋ 11%|███████████▋ | 42/388 [10:50<1:27:2

四  合并模型

1.导出为 HuggingFace 格式:

可以下载: Angainor/alpaca-lora-13b · Hugging Face   的lora_weights

修改export_hf_checkpoint.py文件:

  1. import os
  2. import torch
  3. import transformers
  4. from peft import PeftModel
  5. from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
  6. BASE_MODEL = os.environ.get("BASE_MODEL", "/disk1/llama-13b")
  7. LORA_MODEL = os.environ.get("LORA_MODEL", "./alpaca-lora-13b")
  8. HF_CHECKPOINT = os.environ.get("HF_CHECKPOINT", "./hf_ckpt")
  9. tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
  10. base_model = LlamaForCausalLM.from_pretrained(
  11. BASE_MODEL,
  12. load_in_8bit=False,
  13. torch_dtype=torch.float16,
  14. device_map={"": "cpu"},
  15. )
  16. first_weight = base_model.model.layers[0].self_attn.q_proj.weight
  17. first_weight_old = first_weight.clone()
  18. lora_model = PeftModel.from_pretrained(
  19. base_model,
  20. LORA_MODEL,
  21. device_map={"": "cpu"},
  22. torch_dtype=torch.float16,
  23. )
  24. lora_weight = lora_model.base_model.model.model.layers[
  25. 0
  26. ].self_attn.q_proj.weight
  27. assert torch.allclose(first_weight_old, first_weight)
  28. # merge weights - new merging method from peft
  29. lora_model = lora_model.merge_and_unload()
  30. lora_model.train(False)
  31. # did we do anything?
  32. assert not torch.allclose(first_weight_old, first_weight)
  33. lora_model_sd = lora_model.state_dict()
  34. deloreanized_sd = {
  35. k.replace("base_model.model.", ""): v
  36. for k, v in lora_model_sd.items()
  37. if "lora" not in k
  38. }
  39. LlamaForCausalLM.save_pretrained(
  40. base_model, HF_CHECKPOINT, state_dict=deloreanized_sd, max_shard_size="400MB"
  41. )

python export_hf_checkpoint.py

  1. The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization.
  2. The tokenizer class you load from this checkpoint is 'LLaMATokenizer'.
  3. The class this function is called from is 'LlamaTokenizer'.
  4. You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
  5. Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [00:26<00:00, 1.56it/s]

查看模型输出文件:

  1. hf_ckpt/
  2. ├── config.json
  3. ├── generation_config.json
  4. ├── pytorch_model-00001-of-00082.bin
  5. ├── pytorch_model-00002-of-00082.bin
  6. ├── pytorch_model-00003-of-00082.bin
  7. ├── pytorch_model-00004-of-00082.bin
  8. ├── pytorch_model-00005-of-00082.bin
  9. ├── pytorch_model-00006-of-00082.bin
  10. ├── pytorch_model-00007-of-00082.bin
  11. ├── pytorch_model-00008-of-00082.bin
  12. ├── pytorch_model-00009-of-00082.bin
  13. ├── pytorch_model-00010-of-00082.bin
  14. ├── pytorch_model-00011-of-00082.bin
  15. ├── pytorch_model-00012-of-00082.bin
  16. ├── pytorch_model-00013-of-00082.bin
  17. ├── pytorch_model-00014-of-00082.bin
  18. ├── pytorch_model-00015-of-00082.bin
  19. ├── pytorch_model-00016-of-00082.bin
  20. ├── pytorch_model-00017-of-00082.bin
  21. ├── pytorch_model-00018-of-00082.bin
  22. ├── pytorch_model-00019-of-00082.bin
  23. ├── pytorch_model-00020-of-00082.bin
  24. ├── pytorch_model-00021-of-00082.bin
  25. ├── pytorch_model-00022-of-00082.bin
  26. ├── pytorch_model-00023-of-00082.bin
  27. ├── pytorch_model-00024-of-00082.bin
  28. ├── pytorch_model-00025-of-00082.bin
  29. ├── pytorch_model-00026-of-00082.bin
  30. ├── pytorch_model-00027-of-00082.bin
  31. ├── pytorch_model-00028-of-00082.bin
  32. ├── pytorch_model-00029-of-00082.bin
  33. ├── pytorch_model-00030-of-00082.bin
  34. ├── pytorch_model-00031-of-00082.bin
  35. ├── pytorch_model-00032-of-00082.bin
  36. ├── pytorch_model-00033-of-00082.bin
  37. ├── pytorch_model-00034-of-00082.bin
  38. ├── pytorch_model-00035-of-00082.bin
  39. ├── pytorch_model-00036-of-00082.bin
  40. ├── pytorch_model-00037-of-00082.bin
  41. ├── pytorch_model-00038-of-00082.bin
  42. ├── pytorch_model-00039-of-00082.bin
  43. ├── pytorch_model-00040-of-00082.bin
  44. ├── pytorch_model-00041-of-00082.bin
  45. ├── pytorch_model-00042-of-00082.bin
  46. ├── pytorch_model-00043-of-00082.bin
  47. ├── pytorch_model-00044-of-00082.bin
  48. ├── pytorch_model-00045-of-00082.bin
  49. ├── pytorch_model-00046-of-00082.bin
  50. ├── pytorch_model-00047-of-00082.bin
  51. ├── pytorch_model-00048-of-00082.bin
  52. ├── pytorch_model-00049-of-00082.bin
  53. ├── pytorch_model-00050-of-00082.bin
  54. ├── pytorch_model-00051-of-00082.bin
  55. ├── pytorch_model-00052-of-00082.bin
  56. ├── pytorch_model-00053-of-00082.bin
  57. ├── pytorch_model-00054-of-00082.bin
  58. ├── pytorch_model-00055-of-00082.bin
  59. ├── pytorch_model-00056-of-00082.bin
  60. ├── pytorch_model-00057-of-00082.bin
  61. ├── pytorch_model-00058-of-00082.bin
  62. ├── pytorch_model-00059-of-00082.bin
  63. ├── pytorch_model-00060-of-00082.bin
  64. ├── pytorch_model-00061-of-00082.bin
  65. ├── pytorch_model-00062-of-00082.bin
  66. ├── pytorch_model-00063-of-00082.bin
  67. ├── pytorch_model-00064-of-00082.bin
  68. ├── pytorch_model-00065-of-00082.bin
  69. ├── pytorch_model-00066-of-00082.bin
  70. ├── pytorch_model-00067-of-00082.bin
  71. ├── pytorch_model-00068-of-00082.bin
  72. ├── pytorch_model-00069-of-00082.bin
  73. ├── pytorch_model-00070-of-00082.bin
  74. ├── pytorch_model-00071-of-00082.bin
  75. ├── pytorch_model-00072-of-00082.bin
  76. ├── pytorch_model-00073-of-00082.bin
  77. ├── pytorch_model-00074-of-00082.bin
  78. ├── pytorch_model-00075-of-00082.bin
  79. ├── pytorch_model-00076-of-00082.bin
  80. ├── pytorch_model-00077-of-00082.bin
  81. ├── pytorch_model-00078-of-00082.bin
  82. ├── pytorch_model-00079-of-00082.bin
  83. ├── pytorch_model-00080-of-00082.bin
  84. ├── pytorch_model-00081-of-00082.bin
  85. ├── pytorch_model-00082-of-00082.bin
  86. └── pytorch_model.bin.index.json
  87. 0 directories, 85 files

2 导出为PyTorch state_dicts

同理修改export_state_dict_checkpoint.py文件:

第五步:quantization(可选)

最后,Quantization 可以帮助我们加速模型推理,并减少推理所需内存。这方面也有开源的工具可以直接使用。

第六步:相关问题

保存检查点(checkpoint model)时出现显存溢出OOM(Out Of Memory)

调优过程中,遇到保存检查点model(checkpoint model)时出现显存溢出OOM(Out Of Memory)的问题,经过查看issue-CUDA out of memory中的讨论,发现是 bitsandbytes 的新版0.38.1存在bug,需要将版本退回0.37.2,问题解决。

调优结束后adapter_model.bin 没有参数(大小为443)

这个问题主要是由于alpaca-lora和peft库之间的兼容性问题,根据 fix issues to be compatible with latest peft #359 中的讨论来看,目前最简单的做法是修改 finetune.py文件,具体如下:

  1. model.save_pretrained(output_dir) # 原来275行的代码
  2. model.save_pretrained(output_dir,state_dict=old_state_dict()) # 修改后的275行的代码

参考文档

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/364573
推荐阅读
相关标签