当前位置:   article > 正文

llama3-8b-instruct-262k微调过程的问题笔记(场景为llama论文审稿)_bitsandbytes>=0.43.1

bitsandbytes>=0.43.1

目录

一、环境配置

  1.1、模型

  1.2、微调环境

  1.3、微调数据

二、发现的问题

  2.1、过拟合问题

  2.2、Qlora zero3 保存模型时OOM问题(已解决)


一、环境配置

  1.1、模型

llama3-8b-instruct-262k (英文)

  1.2、微调环境

  1. Package Version
  2. ----------------------------- -----------
  3. absl-py 2.1.0
  4. accelerate 0.31.0.dev0
  5. aiohttp 3.9.5
  6. aiosignal 1.3.1
  7. annotated-types 0.7.0
  8. anyio 4.3.0
  9. async-timeout 4.0.3
  10. attrs 23.2.0
  11. bitsandbytes 0.43.1
  12. certifi 2024.2.2
  13. cffi 1.16.0
  14. charset-normalizer 3.3.2
  15. click 8.1.7
  16. contourpy 1.2.1
  17. cryptography 42.0.7
  18. cycler 0.12.1
  19. datasets 2.19.1
  20. datatrove 0.2.0
  21. deepspeed 0.14.0
  22. Deprecated 1.2.14
  23. dill 0.3.8
  24. docker-pycreds 0.4.0
  25. docstring_parser 0.16
  26. einops 0.8.0
  27. et-xmlfile 1.1.0
  28. evaluate 0.4.2
  29. exceptiongroup 1.2.1
  30. filelock 3.14.0
  31. flash-attn 2.5.7
  32. fonttools 4.51.0
  33. frozenlist 1.4.1
  34. fsspec 2024.3.1
  35. gitdb 4.0.11
  36. GitPython 3.1.43
  37. grpcio 1.64.0
  38. h11 0.14.0
  39. hf_transfer 0.1.6
  40. hjson 3.1.0
  41. httpcore 1.0.5
  42. httpx 0.27.0
  43. huggingface-hub 0.23.1
  44. humanize 4.9.0
  45. idna 3.7
  46. Jinja2 3.1.4
  47. joblib 1.4.2
  48. kiwisolver 1.4.5
  49. loguru 0.7.2
  50. Markdown 3.6
  51. markdown-it-py 3.0.0
  52. MarkupSafe 2.1.5
  53. matplotlib 3.9.0
  54. mdurl 0.1.2
  55. mpmath 1.3.0
  56. multidict 6.0.5
  57. multiprocess 0.70.16
  58. networkx 3.3
  59. ninja 1.11.1.1
  60. nltk 3.8.1
  61. numpy 1.26.4
  62. nvidia-cublas-cu12 12.1.3.1
  63. nvidia-cuda-cupti-cu12 12.1.105
  64. nvidia-cuda-nvrtc-cu12 12.1.105
  65. nvidia-cuda-runtime-cu12 12.1.105
  66. nvidia-cudnn-cu12 8.9.2.26
  67. nvidia-cufft-cu12 11.0.2.54
  68. nvidia-curand-cu12 10.3.2.106
  69. nvidia-cusolver-cu12 11.4.5.107
  70. nvidia-cusparse-cu12 12.1.0.106
  71. nvidia-nccl-cu12 2.19.3
  72. nvidia-nvjitlink-cu12 12.5.40
  73. nvidia-nvtx-cu12 12.1.105
  74. openpyxl 3.1.2
  75. packaging 24.0
  76. pandas 2.2.2
  77. peft 0.11.2.dev0
  78. pillow 10.3.0
  79. pip 24.0
  80. platformdirs 4.2.2
  81. protobuf 3.20.3
  82. psutil 5.9.8
  83. py-cpuinfo 9.0.0
  84. pyarrow 16.1.0
  85. pyarrow-hotfix 0.6
  86. pycparser 2.22
  87. pydantic 2.7.1
  88. pydantic_core 2.18.2
  89. PyGithub 2.3.0
  90. Pygments 2.18.0
  91. PyJWT 2.8.0
  92. PyNaCl 1.5.0
  93. pynvml 11.5.0
  94. pyparsing 3.1.2
  95. python-dateutil 2.9.0.post0
  96. pytz 2024.1
  97. PyYAML 6.0.1
  98. regex 2024.5.15
  99. requests 2.32.2
  100. rich 13.7.1
  101. safetensors 0.4.3
  102. scikit-learn 1.5.0
  103. scipy 1.13.1
  104. sentencepiece 0.2.0
  105. sentry-sdk 2.3.1
  106. setproctitle 1.3.3
  107. setuptools 69.5.1
  108. shtab 1.7.1
  109. six 1.16.0
  110. smmap 5.0.1
  111. sniffio 1.3.1
  112. sympy 1.12
  113. tensorboard 2.16.2
  114. tensorboard-data-server 0.7.2
  115. threadpoolctl 3.5.0
  116. tiktoken 0.7.0
  117. tokenizers 0.19.1
  118. torch 2.2.1
  119. tqdm 4.66.4
  120. transformers 4.42.0.dev0
  121. transformers-stream-generator 0.0.5
  122. triton 2.2.0
  123. trl 0.8.7.dev0
  124. typing_extensions 4.12.0
  125. tyro 0.8.4
  126. tzdata 2024.1
  127. unsloth 2024.5
  128. urllib3 2.2.1
  129. wandb 0.17.0
  130. Werkzeug 3.0.3
  131. wheel 0.43.0
  132. wrapt 1.16.0
  133. xformers 0.0.25
  134. xxhash 3.4.1
  135. yarl 1.9.4

  1.3、微调数据

  • 数量:1.5k
  • 格式:jsonl,字典的key:input: paper, output: review

二、发现的问题

  2.1、过拟合问题

问题简述:

整个微调的过程中没有使用合适的验证集验证最佳模型保存时机,一是因为数据量太少,使用少量的验证集验证不具有可信度,二是选择什么样的方式进行验证。由于没有相关验证集验证的过程,模型训练epoch过高过拟合反而推理会效果会变差,下面是推理效果比较(yarn那篇论文,除了迭代次数140的模型仅推理一次,其他迭代次数推理都是用了多次推理取较好的结果)

引申一些问题:

1. early stop:不同的数据最佳模型的迭代次数不一样,怎么精准判断最佳模型的迭代次数,保存最佳模型(仅通过loss判断可能有待商榷,因为模型推理的语言风格也是比较重要的考量方式,差别可以看下面的截图实例)

2. 验证集的验证方法选择什么样的方式来判断最佳模型

  • 迭代批次为140的(仅推理一次),1.4 左右epoch

  • 迭代批次为260的(推理多次取了最优的效果),2.7左右epoch

  • 迭代批次为280的(推理多次取了最优的效果),2.9左右epoch

  2.2、Qlora zero3 保存模型时OOM问题(已解决)

问题简述:

我使用longqlora zero3模型微调 llama3-8b-instruct-262k,开启了shift short attention + flash attention v2,训练的过程中一切正常,loss正常下降,使用的设备为 A6000 (48G),占用的显存为30G左右,但在trainer保存模型时(模型 + zero3 优化器状态),显存的占用会出现短暂的暴涨为58G,模型保存后显存暂用恢复至30G左右。

我使用A100尝试关闭shift short attention,仅使用flash attention v2训练,依然在模型保存时显存占用增加,但A100为80G显存,训练便正常进行了

疑问❓:为何仅仅在模型保存的时候显存会出现爆发式增加呢?

  • 正常的训练的显存占用

  • 保存model时显存瞬间占用

(图:略)

  •  排查问题与解决方式:per_device_eval_batch_size设置太大了,模型保存时会进行验证集验证过程,per_device_eval_batch_size 设置小一些降低显存溢出的可能性。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/731570
推荐阅读
相关标签
  

闽ICP备14008679号