当前位置:   article > 正文

使用Kaggle Docs数据微调Gemma模型

使用Kaggle Docs数据微调Gemma模型

在这里插入图片描述

Introduction

This notebook will demonstrate three things:

  • How to fine-tune Gemma model using LoRA
  • Creation of a specialised class to query about Kaggle features
  • Some results of querying about Kaggle Docs

什么是Gemma

Gemma是一个轻量级的源生成人工智能模型集合,主要供开发人员和研究人员使用。Gemma由谷歌DeepMind研究实验室创建,该实验室也开发了Gemini, Gemma有几个版本,具有2B和7B参数,如下所示:
在这里插入图片描述

什么是LoRA?

LoRA是Low-Rank Adaptation的缩写。它是一种通过冻结大语言模型的权重并注入可训练的秩分解矩阵来微调大语言模型的方法。因此,微调过程中可训练参数的数量将大大减少。根据LoRA论文,这个数字减少了10,000倍,计算资源大小减少了3倍。

要对LoRA进行微调,我们将遵循以下步骤:

  • 1 安装库
  • 2加载并处理数据以进行微调
  • 3初始化Gemma因果语言模型(Gemma causal LM)的代码
  • 4进行微调
  • 5使用用于微调的数据中的问题和其他问题测试微调模型
Install packages
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
  • 1
  • 2
  • 3
Import packages
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
config
class Config:
    seed = 42
    dataset_path = "/kaggle/input/kaggle-docs/questions_answers"
    preset = "gemma_2b_en" # name of pretrained Gemma
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training, x 2 as two GPUs
    epochs = 15 # number of epochs to train
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
keras.utils.set_random_seed(Config.seed)
  • 1
Load data
df = pd.read_csv(f"{Config.dataset_path}/data.csv")
df.head()
  • 1
  • 2

请添加图片描述
为了方便起见,我们为QA创建以下模板

template = "\n\nCategory:\nkaggle-{Category}\n\nQuestion:\n{Question}\n\nAnswer:\n{Answer}"
# 定义了一个文本模板,包含了一些占位符 `{Category}`, `{Question}`, `{Answer}`,分别代表类别、问题和答案。这些占位符将在后面的步骤中被实际的值替换。

df["prompt"] = df.apply(lambda row: template.format(Category=row.Category,
                                                             Question=row.Question,
                                                             Answer=row.Answer), axis=1)
# 将DataFrame中的每一行应用到一个lambda函数上。lambda函数接受行数据作为参数,并使用`template.format()`方法将该行数据填充到模板中,然后将结果存储在一个新列`prompt`中。

data = df.prompt.tolist()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
def colorize_text(text):
    # 使用 zip 函数将要替换的单词和相应的颜色一一对应起来
    for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
        # 遍历每个单词和对应的颜色,将文本中匹配到的部分替换为带颜色的 HTML 标记
        text = text.replace(f"\n\n{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Specialized class to query Gemma
We define a specialized class to query Gemma.

Initialize the code for Gemma Causal LM¶
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_causal_lm.summary()
  • 1
  • 2

请添加图片描述

Define the specialized class

class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm
        
    def query(self, category, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                Category=category,
                Question=question,
                Answer=""), 
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))
        
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

Gemma preprocessor

这个Gemmma预处理层接受一批字符串作为输入,并以 (x, y, sample_weight) 的格式返回输出,其中 y 标签是 x 序列中的下一个标记的标识符。

通过下面的代码,我们可以看到,经过预处理之后,数据的形状为 (num_samples, sequence_length)

这段代码的功能是将输入的文本序列切分为固定长度的序列,并将每个序列中的每个标记作为 x,其下一个标记作为 y。最终返回的是一组 x 序列、对应的 y 序列以及样本权重。

x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])
print(x, y)
  • 1
  • 2

{‘token_ids’: Array([[ 2, 109, 8606, …, 0, 0, 0],
[ 2, 109, 8606, …, 0, 0, 0]], dtype=int32), ‘padding_mask’: Array([[ True, True, True, …, False, False, False],
[ True, True, True, …, False, False, False]], dtype=bool)} [[ 109 8606 235292 … 0 0 0]
[ 109 8606 235292 … 0 0 0]]

Perform fine-tuning with LoRA

‘Enable LoRA for the model¶
LoRA rank is setting the number of trainable parameters.
A larger rank will result in a larger number of parameters to train.

Enable LoRA for the model and set the LoRA rank to 4.

gemma_causal_lm.backbone.enable_lora(rank=4)
gemma_causal_lm.summary()
  • 1
  • 2

请添加图片描述

Run the training sequence¶

gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=8e-5),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

请添加图片描述

Test the fine-tuned model

gemma_qa = GemmaQA()
  • 1

sample1

row = df.iloc[0]
gemma_qa.query(row.Category,row.Question)
  • 1
  • 2

请添加图片描述

sample2

row = df.iloc[15]
gemma_qa.query(row.Category,row.Question)

  • 1
  • 2
  • 3

在这里插入图片描述

category = "notebook"
question = "How to run a notebook?"
gemma_qa.query(category,question)
  • 1
  • 2
  • 3

请添加图片描述

category = "competitions"
question = "What is a code competition?"
gemma_qa.query(category,question)
  • 1
  • 2
  • 3

请添加图片描述

以上演示了如何使用LoRA对Gemma模型进行微调。
我们还创建了一个类来运行对Gemma模型的查询,并使用来自现有训练数据的一些示例以及一些新的,未见过的问题对其进行测试。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/403489
推荐阅读
相关标签
  

闽ICP备14008679号