当前位置:   article > 正文

绕开算力限制,如何用单GPU微调 LLM?这是一份「梯度累积」算法教程

os.environ["tokenizers_parallelism"]

05edfd7e6429407e5d9331277fbc6d82.png

  1. 来源:机器之心
  2. 本文约5500字,建议阅读10分钟本文介绍了“梯度积累”的算法教程。

算力资源用到极致,是每一位开发者的必修课。

自从大模型变成热门趋势之后,GPU 就成了紧俏的物资。很多企业的储备都不一定充足,更不用说个人开发者了。有没有什么方法可以更高效的利用算力训练模型?

在最近的一篇博客,Sebastian Raschka 介绍了「梯度累积」的方法,能够在 GPU 内存受限时使用更大 batch size 训练模型,绕开硬件限制。

41a0ed96440c7cda61ad1d5fa7db64d8.png

在此之前,Sebastian Raschka 也分享过一篇运用多 GPU 训练策略加速大型语言模型微调的文章,包括模型或 tensor sharding 等机制,这些机制将模型权重和计算分布在不同的设备上,以解决 GPU 的内存限制。

微调 BLOOM 模型进行分类

假设我们有兴趣采用近期预训练的大型语言模型来处理文本分类等下游任务。那么,我们可能会选择使用 GPT-3 的开源替代品 BLOOM 模型,特别是「仅有」 5.6 亿个参数的 BLOOM 版本 —— 它应该可以毫无问题地融入至传统 GPU 的 RAM 中(Google Colab 免费版本拥有 15 Gb RAM 的 GPU)。

一旦开始,就很可能遇到问题:内存会在训练或微调期间迅速增加。训练这个模型的唯一方法是使批大小为 1(batch size=1)。

5572be924888c7b358c45ac693048875.png

使用批大小为 1(batch size=1)为目标分类任务微调 BLOOM 的代码如下所示。

你也可以在 GitHub 项目页面下载完整代码:

https://github.com/rasbt/gradient-accumulation-blog/blob/main/src/1_batchsize-1.py

你可以将此代码直接复制并粘贴到 Google Colab 中,但还必须将随附的 local_dataset_utilities.py 文件拖放到从该文件导入了一些数据集实用程序的同一文件夹中。

  1. # pip install torch lightning matplotlib pandas torchmetrics watermark transformers datasets -U
  2. import os
  3. import os.path as op
  4. import time
  5. from datasets import load_dataset
  6. from lightning import Fabric
  7. import torch
  8. from torch.utils.data import DataLoader
  9. import torchmetrics
  10. from transformers import AutoTokenizer
  11. from transformers import AutoModelForSequenceClassification
  12. from watermark import watermark
  13. from local_dataset_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset
  14. from local_dataset_utilities import IMDBDataset
  15. def tokenize_text (batch):
  16. return tokenizer (batch ["text"], truncation=True, padding=True, max_length=1024)
  17. def train (num_epochs, model, optimizer, train_loader, val_loader, fabric):
  18. for epoch in range (num_epochs):
  19. train_acc = torchmetrics.Accuracy (
  20. task="multiclass", num_classes=2).to (fabric.device)
  21. for batch_idx, batch in enumerate (train_loader):
  22. model.train ()
  23. ### FORWARD AND BACK PROP
  24. outputs = model (
  25. batch ["input_ids"],
  26. attention_mask=batch ["attention_mask"],
  27. labels=batch ["label"]
  28. )
  29. fabric.backward (outputs ["loss"])
  30. ### UPDATE MODEL PARAMETERS
  31. optimizer.step ()
  32. optimizer.zero_grad ()
  33. ### LOGGING
  34. if not batch_idx % 300:
  35. print (f"Epoch: {epoch+1:04d}/{num_epochs:04d}"
  36. f"| Batch {batch_idx:04d}/{len (train_loader):04d}"
  37. f"| Loss: {outputs ['loss']:.4f}")
  38. model.eval ()
  39. with torch.no_grad ():
  40. predicted_labels = torch.argmax (outputs ["logits"], 1)
  41. train_acc.update (predicted_labels, batch ["label"])
  42. ### MORE LOGGING
  43. model.eval ()
  44. with torch.no_grad ():
  45. val_acc = torchmetrics.Accuracy (task="multiclass", num_classes=2).to (fabric.device)
  46. for batch in val_loader:
  47. outputs = model (
  48. batch ["input_ids"],
  49. attention_mask=batch ["attention_mask"],
  50. labels=batch ["label"]
  51. )
  52. predicted_labels = torch.argmax (outputs ["logits"], 1)
  53. val_acc.update (predicted_labels, batch ["label"])
  54. print (f"Epoch: {epoch+1:04d}/{num_epochs:04d}"
  55. f"| Train acc.: {train_acc.compute ()*100:.2f}%"
  56. f"| Val acc.: {val_acc.compute ()*100:.2f}%"
  57. )
  58. train_acc.reset (), val_acc.reset ()
  59. if __name__ == "__main__":
  60. print (watermark (packages="torch,lightning,transformers", python=True))
  61. print ("Torch CUDA available?", torch.cuda.is_available ())
  62. device = "cuda" if torch.cuda.is_available () else "cpu"
  63. torch.manual_seed (123)
  64. # torch.use_deterministic_algorithms (True)
  65. ##########################
  66. ### 1 Loading the Dataset
  67. ##########################
  68. download_dataset ()
  69. df = load_dataset_into_to_dataframe ()
  70. if not (op.exists ("train.csv") and op.exists ("val.csv") and op.exists ("test.csv")):
  71. partition_dataset (df)
  72. imdb_dataset = load_dataset (
  73. "csv",
  74. data_files={
  75. "train": "train.csv",
  76. "validation": "val.csv",
  77. "test": "test.csv",
  78. },
  79. )
  80. #########################################
  81. ### 2 Tokenization and Numericalization
  82. #########################################
  83. tokenizer = AutoTokenizer.from_pretrained ("bigscience/bloom-560m", max_length=1024)
  84. print ("Tokenizer input max length:", tokenizer.model_max_length, flush=True)
  85. print ("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True)
  86. print ("Tokenizing ...", flush=True)
  87. imdb_tokenized = imdb_dataset.map (tokenize_text, batched=True, batch_size=None)
  88. del imdb_dataset
  89. imdb_tokenized.set_format ("torch", columns=["input_ids", "attention_mask", "label"])
  90. os.environ ["TOKENIZERS_PARALLELISM"] = "false"
  91. #########################################
  92. ### 3 Set Up DataLoaders
  93. #########################################
  94. train_dataset = IMDBDataset (imdb_tokenized, partition_key="train")
  95. val_dataset = IMDBDataset (imdb_tokenized, partition_key="validation")
  96. test_dataset = IMDBDataset (imdb_tokenized, partition_key="test")
  97. train_loader = DataLoader (
  98. dataset=train_dataset,
  99. batch_size=1,
  100. shuffle=True,
  101. num_workers=4,
  102. drop_last=True,
  103. )
  104. val_loader = DataLoader (
  105. dataset=val_dataset,
  106. batch_size=1,
  107. num_workers=4,
  108. drop_last=True,
  109. )
  110. test_loader = DataLoader (
  111. dataset=test_dataset,
  112. batch_size=1,
  113. num_workers=2,
  114. drop_last=True,
  115. )
  116. #########################################
  117. ### 4 Initializing the Model
  118. #########################################
  119. fabric = Fabric (accelerator="cuda", devices=1, precision="16-mixed")
  120. fabric.launch ()
  121. model = AutoModelForSequenceClassification.from_pretrained (
  122. "bigscience/bloom-560m", num_labels=2)
  123. optimizer = torch.optim.Adam (model.parameters (), lr=5e-5)
  124. model, optimizer = fabric.setup (model, optimizer)
  125. train_loader, val_loader, test_loader = fabric.setup_dataloaders (
  126. train_loader, val_loader, test_loader)
  127. #########################################
  128. ### 5 Finetuning
  129. #########################################
  130. start = time.time ()
  131. train (
  132. num_epochs=1,
  133. model=model,
  134. optimizer=optimizer,
  135. train_loader=train_loader,
  136. val_loader=val_loader,
  137. fabric=fabric,
  138. )
  139. end = time.time ()
  140. elapsed = end-start
  141. print (f"Time elapsed {elapsed/60:.2f} min")
  142. with torch.no_grad ():
  143. model.eval ()
  144. test_acc = torchmetrics.Accuracy (task="multiclass", num_classes=2).to (fabric.device)
  145. for batch in test_loader:
  146. outputs = model (
  147. batch ["input_ids"],
  148. attention_mask=batch ["attention_mask"],
  149. labels=batch ["label"]
  150. )
  151. predicted_labels = torch.argmax (outputs ["logits"], 1)
  152. test_acc.update (predicted_labels, batch ["label"])
  153. print (f"Test accuracy {test_acc.compute ()*100:.2f}%")

作者使用了 Lightning Fabric,因为它可以让开发者在不同硬件上运行此代码时灵活地改变 GPU 数量和多 GPU 训练策略。它还允许仅通过调整查准率 flag 来启用混合精度训练(mixed-precision training)。在这种情况下,混合精度训练可以将训练速度提高三倍,并将内存需求降低约 25%。

上面展示的主要代码都是在主函数(if __name__ == "__main__" 的 context)中执行的,即使只使用单个 GPU,也推荐使用 PyTorch 运行环境执行多 GPU 训练。而后,包含在 if __name__ == "__main__" 中的以下三个代码部分负责数据加载:

# 1 加载数据集

# 2 token 化和数值化

# 3 设置数据加载器

第 4 节是初始化模型(Initializing the Model)中,然后在第 5 节 微调(Finetuning)中,调用 train 函数,这是开始让事情变得有趣的地方。在 train (...) 函数中,实现了标准的 PyTorch 循环。核心训练循环的注释版本如下所示:

d52eac64e55d3099c9f4dd381af042d3.png

批大小为 1(Batch size=1)的问题是梯度更新将会变得非常混乱和困难,正如下述训练模型时基于波动的训练损失和糟糕的测试集性能所看到的:

  1. ...
  2. torch : 2.0.0
  3. lightning : 2.0.0
  4. transformers: 4.27.2
  5. Torch CUDA available? True
  6. ...
  7. Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0969
  8. Epoch: 0001/0001 | Batch 24000/35000 | Loss: 1.9902
  9. Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0395
  10. Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.2546
  11. Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.1128
  12. Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.2661
  13. Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0044
  14. Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0067
  15. Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0468
  16. Epoch: 0001/0001 | Batch 26400/35000 | Loss: 1.7139
  17. Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.9570
  18. Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.1857
  19. Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0090
  20. Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.9790
  21. Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0503
  22. Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.2625
  23. Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.1010
  24. Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0035
  25. Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0009
  26. Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0234
  27. Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.8394
  28. Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.9497
  29. Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.1437
  30. Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.1317
  31. Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0112
  32. Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0073
  33. Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.7393
  34. Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0512
  35. Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.1337
  36. Epoch: 0001/0001 | Batch 32400/35000 | Loss: 1.1875
  37. Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.2727
  38. Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.1545
  39. Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0022
  40. Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.2681
  41. Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.2467
  42. Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0620
  43. Epoch: 0001/0001 | Batch 34500/35000 | Loss: 2.5039
  44. Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0131
  45. Epoch: 0001/0001 | Train acc.: 75.11% | Val acc.: 78.62%
  46. Time elapsed 69.97 min
  47. Test accuracy 78.53%

由于没有多的 GPU 可用于张量分片(tensor sharding),又能做些什么来训练具有更大批大小(batch size)的模型呢?

其中一种解决方法就是梯度累积,可以通过它来修改前面提到的训练循环。

什么是梯度积累?

梯度累积是一种在训练期间虚拟增加批大小(batch size)的方法,当可用的 GPU 内存不足以容纳所需的批大小时,这非常有用。在梯度累积中,梯度是针对较小的批次计算的,并在多次迭代中累积(通常是求和或平均),而不是在每一批次之后更新模型权重。一旦累积梯度达到目标「虚拟」批大小,模型权重就会使用累积梯度进行更新。

参考下面更新的 PyTorch 训练循环:

bc9adaeef50202d4d857d752d74cf10e.png

如果将 accumulation_steps 设置为 2,那么 zero_grad () 和 optimizer.step () 将只会每隔一秒调用一次。因此,使用 accumulation_steps=2 运行修改后的训练循环与将批大小(batch size)加倍具有相同的效果。

例如,如果想使用 256 的批大小,但只能将 64 的批大小放入 GPU 内存中,就可以对大小为 64 的四个批执行梯度累积。(处理完所有四个批次后,将获得相当于单个批大小为 256 的累积梯度。)这样能够有效地模拟更大的批大小,而无需更大的 GPU 内存或跨不同设备的张量分片。

虽然梯度累积可以帮助我们训练具有更大批量大小的模型,但它不会减少所需的总计算量。实际上,它有时会导致训练过程略慢一些,因为权重更新的执行频率较低。尽管如此,它却能帮我们解决限制问题,即批大小非常小时导致的更新频繁且混乱。

例如,现在让我们运行上面的代码,批大小为 1,需要 16 个累积步骤(accumulation steps)来模拟批大小等于 16。

输出如下:

  1. ...
  2. torch : 2.0.0
  3. lightning : 2.0.0
  4. transformers: 4.27.2
  5. Torch CUDA available? True
  6. ...
  7. Epoch: 0001/0001 | Batch 23700/35000 | Loss: 0.0168
  8. Epoch: 0001/0001 | Batch 24000/35000 | Loss: 0.0006
  9. Epoch: 0001/0001 | Batch 24300/35000 | Loss: 0.0152
  10. Epoch: 0001/0001 | Batch 24600/35000 | Loss: 0.0003
  11. Epoch: 0001/0001 | Batch 24900/35000 | Loss: 0.0623
  12. Epoch: 0001/0001 | Batch 25200/35000 | Loss: 0.0010
  13. Epoch: 0001/0001 | Batch 25500/35000 | Loss: 0.0001
  14. Epoch: 0001/0001 | Batch 25800/35000 | Loss: 0.0047
  15. Epoch: 0001/0001 | Batch 26100/35000 | Loss: 0.0004
  16. Epoch: 0001/0001 | Batch 26400/35000 | Loss: 0.1016
  17. Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.0021
  18. Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.0015
  19. Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0008
  20. Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.0060
  21. Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0001
  22. Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.0426
  23. Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.0012
  24. Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0025
  25. Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0025
  26. Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0000
  27. Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.0495
  28. Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.0164
  29. Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.0067
  30. Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.0037
  31. Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0005
  32. Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0013
  33. Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.0112
  34. Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0053
  35. Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.0012
  36. Epoch: 0001/0001 | Batch 32400/35000 | Loss: 0.1365
  37. Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.0210
  38. Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.0374
  39. Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0007
  40. Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.0341
  41. Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.0259
  42. Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0005
  43. Epoch: 0001/0001 | Batch 34500/35000 | Loss: 0.4792
  44. Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0003
  45. Epoch: 0001/0001 | Train acc.: 78.67% | Val acc.: 87.28%
  46. Time elapsed 51.37 min
  47. Test accuracy 87.37%

根据上面的结果,损失的波动比以前小了。此外,测试集性能提升了 10%。由于只迭代了训练集一次,因此每个训练样本只会遇到一次。训练用于 multiple epochs 的模型可以进一步提高预测性能。

你可能还会注意到,这段代码的执行速度也比之前使用的批大小为 1 的代码快。如果使用梯度累积将虚拟批大小增加到 8,仍然会有相同数量的前向传播(forward passes)。然而,由于每八个 epoch 只更新一次模型,因此反向传播(backward passes)会很少,这样可更快地在一个 epoch(训练轮数)内迭代样本。

结论

梯度累积是一种在执行权重更新之前通过累积多个小的批梯度来模拟更大的批大小的技术。该技术在可用内存有限且内存中可容纳批大小较小的情况下提供帮助。

但是,首先请思考一种你可以运行批大小的场景,这意味着可用内存大到足以容纳所需的批大小。在那种情况下,梯度累积可能不是必需的。事实上,运行更大的批大小可能更有效,因为它允许更多的并行性且能减少训练模型所需的权重更新次数。

总之,梯度累积是一种实用的技术,可以用于降低小批大小干扰信息对梯度更新准确性的影响。这是迄今一种简单而有效的技术,可以让我们绕过硬件的限制。

PS:可以让这个运行得更快吗?

没问题。可以使用 PyTorch 2.0 中引入的 torch.compile 使其运行得更快。只需要添加一些 model = torch.compile,如下图所示:

c62279564052b5a27ea65bb460619236.png

GitHub 上提供了完整的脚本。

在这种情况下,torch.compile 在不影响建模性能的情况下又减少了十分钟的训练时间:

  1. poch: 0001/0001 | Batch 26400/35000 | Loss: 0.0320
  2. Epoch: 0001/0001 | Batch 26700/35000 | Loss: 0.0010
  3. Epoch: 0001/0001 | Batch 27000/35000 | Loss: 0.0006
  4. Epoch: 0001/0001 | Batch 27300/35000 | Loss: 0.0015
  5. Epoch: 0001/0001 | Batch 27600/35000 | Loss: 0.0157
  6. Epoch: 0001/0001 | Batch 27900/35000 | Loss: 0.0015
  7. Epoch: 0001/0001 | Batch 28200/35000 | Loss: 0.0540
  8. Epoch: 0001/0001 | Batch 28500/35000 | Loss: 0.0035
  9. Epoch: 0001/0001 | Batch 28800/35000 | Loss: 0.0016
  10. Epoch: 0001/0001 | Batch 29100/35000 | Loss: 0.0015
  11. Epoch: 0001/0001 | Batch 29400/35000 | Loss: 0.0008
  12. Epoch: 0001/0001 | Batch 29700/35000 | Loss: 0.0877
  13. Epoch: 0001/0001 | Batch 30000/35000 | Loss: 0.0232
  14. Epoch: 0001/0001 | Batch 30300/35000 | Loss: 0.0014
  15. Epoch: 0001/0001 | Batch 30600/35000 | Loss: 0.0032
  16. Epoch: 0001/0001 | Batch 30900/35000 | Loss: 0.0004
  17. Epoch: 0001/0001 | Batch 31200/35000 | Loss: 0.0062
  18. Epoch: 0001/0001 | Batch 31500/35000 | Loss: 0.0032
  19. Epoch: 0001/0001 | Batch 31800/35000 | Loss: 0.0066
  20. Epoch: 0001/0001 | Batch 32100/35000 | Loss: 0.0017
  21. Epoch: 0001/0001 | Batch 32400/35000 | Loss: 0.1485
  22. Epoch: 0001/0001 | Batch 32700/35000 | Loss: 0.0324
  23. Epoch: 0001/0001 | Batch 33000/35000 | Loss: 0.0155
  24. Epoch: 0001/0001 | Batch 33300/35000 | Loss: 0.0007
  25. Epoch: 0001/0001 | Batch 33600/35000 | Loss: 0.0049
  26. Epoch: 0001/0001 | Batch 33900/35000 | Loss: 0.1170
  27. Epoch: 0001/0001 | Batch 34200/35000 | Loss: 0.0002
  28. Epoch: 0001/0001 | Batch 34500/35000 | Loss: 0.4201
  29. Epoch: 0001/0001 | Batch 34800/35000 | Loss: 0.0018
  30. Epoch: 0001/0001 | Train acc.: 78.39% | Val acc.: 86.84%
  31. Time elapsed 43.33 min
  32. Test accuracy 87.91%

请注意,与之前相比准确率略有提高很可能是由于随机性。

fab02f5a8ae57f240105e299d033e901.png

原文链接:

https://lightning.ai/pages/blog/gradient-accumulation/

编辑:王菁

校对:程安乐

8717b6cb4e7dad6a6d29696c31c3de45.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/277454?site
推荐阅读
相关标签
  

闽ICP备14008679号