赞
踩
This notebook will demonstrate three things:
Gemma是一个轻量级的源生成人工智能模型集合,主要供开发人员和研究人员使用。Gemma由谷歌DeepMind研究实验室创建,该实验室也开发了Gemini, Gemma有几个版本,具有2B和7B参数,如下所示:
LoRA是Low-Rank Adaptation的缩写。它是一种通过冻结大语言模型的权重并注入可训练的秩分解矩阵来微调大语言模型的方法。因此,微调过程中可训练参数的数量将大大减少。根据LoRA论文,这个数字减少了10,000倍,计算资源大小减少了3倍。
要对LoRA进行微调,我们将遵循以下步骤:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown
class Config:
seed = 42
dataset_path = "/kaggle/input/kaggle-docs/questions_answers"
preset = "gemma_2b_en" # name of pretrained Gemma
sequence_length = 512 # max size of input sequence for training
batch_size = 1 # size of the input batch in training, x 2 as two GPUs
epochs = 15 # number of epochs to train
keras.utils.set_random_seed(Config.seed)
df = pd.read_csv(f"{Config.dataset_path}/data.csv")
df.head()
为了方便起见,我们为QA创建以下模板
template = "\n\nCategory:\nkaggle-{Category}\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"
# 定义了一个文本模板,包含了一些占位符 `{Category}`, `{Question}`, `{Answer}`,分别代表类别、问题和答案。这些占位符将在后面的步骤中被实际的值替换。
df["prompt"] = df.apply(lambda row: template.format(Category=row.Category,
Question=row.Question,
Answer=row.Answer), axis=1)
# 将DataFrame中的每一行应用到一个lambda函数上。lambda函数接受行数据作为参数,并使用`template.format()`方法将该行数据填充到模板中,然后将结果存储在一个新列`prompt`中。
data = df.prompt.tolist()
def colorize_text(text):
# 使用 zip 函数将要替换的单词和相应的颜色一一对应起来
for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
# 遍历每个单词和对应的颜色,将文本中匹配到的部分替换为带颜色的 HTML 标记
text = text.replace(f"\n\n{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
return text
Specialized class to query Gemma
We define a specialized class to query Gemma.
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_causal_lm.summary()
class GemmaQA:
def __init__(self, max_length=512):
self.max_length = max_length
self.prompt = template
self.gemma_causal_lm = gemma_causal_lm
def query(self, category, question):
response = self.gemma_causal_lm.generate(
self.prompt.format(
Category=category,
Question=question,
Answer=""),
max_length=self.max_length)
display(Markdown(colorize_text(response)))
这个Gemmma预处理层接受一批字符串作为输入,并以 (x, y, sample_weight)
的格式返回输出,其中 y 标签是 x 序列中的下一个标记的标识符。
通过下面的代码,我们可以看到,经过预处理之后,数据的形状为 (num_samples, sequence_length)
。
这段代码的功能是将输入的文本序列切分为固定长度的序列,并将每个序列中的每个标记作为 x,其下一个标记作为 y。最终返回的是一组 x 序列、对应的 y 序列以及样本权重。
x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])
print(x, y)
{‘token_ids’: Array([[ 2, 109, 8606, …, 0, 0, 0],
[ 2, 109, 8606, …, 0, 0, 0]], dtype=int32), ‘padding_mask’: Array([[ True, True, True, …, False, False, False],
[ True, True, True, …, False, False, False]], dtype=bool)} [[ 109 8606 235292 … 0 0 0]
[ 109 8606 235292 … 0 0 0]]
‘Enable LoRA for the model¶
LoRA rank is setting the number of trainable parameters.
A larger rank will result in a larger number of parameters to train.
gemma_causal_lm.backbone.enable_lora(rank=4)
gemma_causal_lm.summary()
gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length
# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(learning_rate=8e-5),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)
gemma_qa = GemmaQA()
row = df.iloc[0]
gemma_qa.query(row.Category,row.Question)
row = df.iloc[15]
gemma_qa.query(row.Category,row.Question)
category = "notebook"
question = "How to run a notebook?"
gemma_qa.query(category,question)
category = "competitions"
question = "What is a code competition?"
gemma_qa.query(category,question)
以上演示了如何使用LoRA对Gemma模型进行微调。
我们还创建了一个类来运行对Gemma模型的查询,并使用来自现有训练数据的一些示例以及一些新的,未见过的问题对其进行测试。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。