当前位置:   article > 正文

使用 Llama3 模型进行关系提取 通过使用 Llama3–70B 创建的合成数据集微调 Llama3–8B 来增强关系提取

使用 Llama3 模型进行关系提取 通过使用 Llama3–70B 创建的合成数据集微调 Llama3–8B 来增强关系提取

前提

关系提取(RE)是从非结构化文本中提取关系以识别各种命名实体之间的联系的任务。它与命名实体识别 (NER) 结合完成,是自然语言处理管道中的重要步骤。随着大型语言模型 (LLM) 的兴起,涉及标记实体跨度和对它们之间的关系(如果有)进行分类的传统监督方法得到增强或完全被基于 LLM 的方法所取代 [ 1 ]。

Llama3 是 GenerativeAI 领域的最新主要版本 [ 2 ]。基础型号有 8B 和 70B 两种尺寸可供选择,400B 型号预计很快就会发布。这些模型在 HuggingFace 平台上可用;详细信息请参见[ 3 ]。 70B 变体为 Meta 的新聊天网站Meta.ai提供支持,并表现出与 ChatGPT 相当的性能。 8B 型号是同类产品中性能最高的型号之一。 Llama3的架构与Llama2类似,性能的提升主要来自于数据升级。该模型配备了升级的分词器和扩展的上下文窗口。尽管只发布了一小部分数据,但它被标记为开源。总的来说,这是一个优秀的模型,我迫不及待地想尝试一下。

Llama3–70B 可以产生惊人的结果,但由于其尺寸,它不切实际、昂贵且难以在本地系统上使用。因此,为了利用其功能,我们让 Llama3-70B 教较小的 Llama3-8B 从非结构化文本中提取关系的任务。

具体来说,在 Llama3-70B 的帮助下,我们构建了一个旨在关系提取的监督微调数据集。然后,我们使用该数据集对 Llama3-8B 进行微调,以增强其关系提取能力。

工作区设置

在这个项目中,我使用了配备 A100 GPU 和高 RAM 设置的 Google Colab Pro。

我们首先安装所有必需的库:

!pip install -q groq
!pip install -U accelerate bitsandbytes datasets evaluate 
!pip install -U peft transformers trl 

  • 1
  • 2
  • 3
  • 4

我很高兴地注意到整个设置从一开始就有效,没有任何依赖关系问题或需要transformers从源安装,尽管模型很新颖。

我们还需要授予 Goggle Colab 驱动器和文件的访问权限并设置工作目录:

# For Google Colab settings
from google.colab import userdata, drive

# This will prompt for authorization
drive.mount('/content/drive')

# Set the working directory
%cd '/content/drive/MyDrive/postedBlogs/llama3RE'

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

对于那些希望将模型上传到 HuggingFace Hub 的人,我们需要上传 Hub 凭据。就我而言,这些信息存储在 Google Colab 机密中,可以通过左侧的按键按钮进行访问。此步骤是可选的。

# For Hugging Face Hub setting
from huggingface_hub import login

# Upload the HuggingFace token (should have WRITE access) from Colab secrets
HF = userdata.get('HF')

# This is needed to upload the model to HuggingFace
login(token=HF,add_to_git_credential=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

我还添加了一些路径变量来简化文件访问:

# Create a path variable for the data folder
data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'

# Full fine-tuning dataset
sft_dataset_file = f'{data_path}sft_train_data.json'

# Data collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'

# Test data containing all three outputs
all_tests_data = f'{data_path}all_tests.json'

# The adjusted training dataset
train_data_path = f'{data_path}sft_train_data.json'

# Create a path variable for the SFT model to be saved locally
sft_model_path = '/content/drive/MyDrive/llama3RE/Llama3_RE/'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

现在我们的工作区已经设置完毕,我们可以进入第一步,即为关系提取任务构建一个合成数据集。

使用 Llama3–70B 创建用于关系提取的综合数据集

有多个可用的关系提取数据集,其中最著名的是CoNLL04数据集。此外,还有一些优秀的数据集,例如HuggingFace 上提供的web_nlg以及AllenAI 开发的SciREX 。然而,大多数这些数据集都带有限制性许可。

受到数据集格式的启发,web_nlg我们将构建自己的数据集。如果我们计划微调在我们的数据集上训练的模型,这种方法将特别有用。首先,我们需要一组短句子来执行关系提取任务。我们可以通过多种方式编译这个语料库。

收集句子集

我们将使用databricks-dolly-15k,这是 Databricks 员工于 2023 年生成的开源数据集。该数据集专为监督微调而设计,包括四个特征:指令、上下文、响应和类别。在分析了八个类别后,我决定保留该类别中上下文的第一句话information_extraction。数据解析步骤概述如下:

from datasets import load_dataset

# Load the dataset
dataset = load_dataset("databricks/databricks-dolly-15k")

# Choose the desired category from the dataset
ie_category = [e for e in dataset["train"] if e["category"]=="information_extraction"]

# Retain only the context from each instance
ie_context = [e["context"] for e in ie_category]

# Split the text into sentences (at the period) and keep the first sentence
reduced_context = [text.split('.')[0] + '.' for text in ie_context]

# Retain sequences of specified lengths only (use character length)
sampler = [e for e in reduced_context if 30 < len(e) < 170]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

选择过程产生包含 1,041 个句子的数据集。鉴于这是一个小型项目,我没有精心挑选句子,因此,某些示例可能不太适合我们的任务。在指定的制作项目中,我会仔细选择最合适的句子。然而,对于本项目的目的来说,这个数据集就足够了。

格式化数据

我们首先需要创建一条系统消息,用于定义输入提示并指示模型如何生成答案:

system_message = """You are an experienced annontator. 
Extract all entities and the relations between them from the following text. 
Write the answer as a triple entity1|relationship|entitity2. 
Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

由于这是一个实验阶段,我将对模型的要求保持在最低限度。我确实测试了其他几个提示,包括一些请求以 CoNLL 格式输出的提示,其中实体被分类,并且模型表现得很好。然而,为了简单起见,我们现在将坚持基础知识。

我们还需要将数据转换为会话格式:

messages = [[ 
    { "role" : "system" , "content" : f" {system_message} " }, 
    { "role" : "user" , "content" : e}] for e in Sampler]

  • 1
  • 2
  • 3
  • 4

Groq 客户端和 API

Llama3 刚刚发布几天,API 选项的可用性仍然有限。虽然 Llama3-70B 可以使用聊天界面,但该项目需要一个 API,可以通过几行代码处理我的 1,000 个句子。我发现这个精彩的YouTube 视频解释了如何免费使用 GroqCloud API。欲了解更多详情,请参阅视频。

请注意:您需要登录并从GroqCloud网站检索免费的 API 密钥。我的 API 密钥已保存在 Google Colab 密钥中。我们首先初始化 Groq 客户端:

import os
from groq import Groq

gclient = Groq(
    api_key=userdata.get("GROQ"),
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

接下来,我们需要定义几个帮助函数,使我们能够有效地与Meta.ai聊天界面进行交互

import time
from tqdm import tqdm

def process_data(prompt):

    """Send one request and retrieve model's generation."""

    chat_completion = gclient.chat.completions.create(
        messages=prompt, # input prompt to send to the model
        model="llama3-70b-8192", # according to GroqCloud labeling
        temperature=0.5, # controls diversity
        max_tokens=128, # max number tokens to generate
        top_p=1, # proportion of likelihood weighted options to consider
        stop=None, # string that signals to stop generating
        stream=False, # if set partial messages are sent
    )
    return chat_completion.choices[0].message.content


def send_messages(messages):

    """Process messages in batches with a pause between batches."""
   
   batch_size = 10
    answers = []

    for i in tqdm(range(0, len(messages), batch_size)): # batches of size 10

        batch = messages[i:i+10]  # get the next batch of messages

        for message in batch:
            output = process_data(message)
            answers.append(output)

        if i + 10 < len(messages):  # check if there are batches left
            time.sleep(10)  # wait for 10 seconds

    return answers
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

第一个函数process_data()用作 Groq 客户端的聊天完成函数的包装器。第二个功能send_messages(),小批量处理数据。如果您点击 Groq Playground 页面上的“设置”链接,您将找到一个“限制”链接,其中详细说明了我们可以使用免费 API 的条件,包括请求数量和生成令牌的上限。为了避免超出这些限制,我在每批 10 条消息后添加了 10 秒的延迟,尽管在我的情况下这并不是绝对必要的。您可能想尝试这些设置。

现在剩下的就是生成我们的关系提取数据并将其与初始数据集集成:

# Data generation with Llama3-70B
answers = send_messages(messages)

# Combine input data with the generated dataset
combined_dataset = [{'text': user, 'gold_re': output} for user, output in zip(sampler, answers)]
  • 1
  • 2
  • 3
  • 4
  • 5

评估 Llama3–8B 的关系提取

在继续微调模型之前,重要的是评估其在多个样本上的性能,以确定是否确实需要微调。

构建测试数据集

我们将从刚刚构建的数据集中选择 20 个样本,并将它们放在一边进行测试。数据集的其余部分将用于微调。

import random
random.seed(17)

# Select 20 random entries
mini_data = random.sample(combined_dataset, 20)

# Build conversational format
parsed_mini_data = [[{'role': 'system', 'content': system_message},
                     {'role': 'user', 'content': e['text']}] for e in mini_data]

# Create the training set
train_data = [item for item in combined_dataset if item not in mini_data]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

我们将使用 GroqCloud API 和上面定义的实用程序,进行指定,model=llama3-8b-8192而函数的其余部分保持不变。在这种情况下,我们可以直接处理我们的小数据集,而不必担心超出 API 限制。

下面是一个示例输出,提供原始的text、表示为 Llama3-70B 一代gold_re和标记为 的 Llama3-8B h Generation test_re。

{'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
 'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity',
 'test_re': 'electric fish|were aware of|shocks'}


  • 1
  • 2
  • 3
  • 4
  • 5

有关完整的测试数据集,请参阅Google Colab 笔记本。

仅从这个例子就可以看出,Llama3-8B 可以从其关系提取功能的一些改进中受益。让我们努力加强这一点。

https://github.com/SolanaO/Blogs_Content/blob/master/llama3_re/Llama3_RE_Inference_SFT.ipynb
  • 1

Llama3–8B 的监督微调

我们将利用全套技术来帮助我们,包括 QLoRA 和 Flash Attention。我不会在这里深入探讨选择超参数的细节,但如果您有兴趣进一步探索,请查看这些精彩的参考文献 [ 4 ] 和 [ 5 ]。

A100 GPU支持Flash Attention和bfloat16,并且拥有大约40GB的内存,足以满足我们的微调需求。

准备 SFT 数据集

我们首先将数据集解析为会话格式,包括系统消息、输入文本和所需的答案,这些答案是我们从 Llama3-70B 一代得出的。然后我们将其保存为 HuggingFace 数据集:

def create_conversation(sample):
    return {
        "messages": [
            {"role": "system","content": system_message},
            {"role": "user", "content": sample["text"]},
            {"role": "assistant", "content": sample["gold_re"]}
        ]
    }

from datasets import load_dataset, Dataset

train_dataset = Dataset.from_list(train_data)

# Transform to conversational format
train_dataset = train_dataset.map(create_conversation,
                      remove_columns=train_dataset.features,
                      batched=False)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

选择型号

model_id  =  "meta-llama/Meta-Llama-3-8B"
  • 1

加载分词器

from transformers import AutoTokenizer

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
                                          use_fast=True,
                                          trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id =  tokenizer.eos_token_id
tokenizer.padding_side = 'left'

# Set a maximum length
tokenizer.model_max_length = 512
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

Choose Quantization Parameters

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

Load the Model

from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_format

device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device_map,
    attn_implementation="flash_attention_2",
    quantization_config=bnb_config
)

model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

LoRA Configuration

from peft import LoraConfig

# According to Sebastian Raschka findings
peft_config = LoraConfig(
        lora_alpha=128, #32
        lora_dropout=0.05,
        r=256,  #16
        bias="none",
        target_modules=["q_proj", "o_proj", "gate_proj", "up_proj", 
          "down_proj", "k_proj", "v_proj"],
        task_type="CAUSAL_LM",
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

当针对所有线性层时,可以获得最佳结果。如果担心内存限制,选择更标准的值(例如 alpha=32 和rank=16)可能会有所帮助,因为这些设置会导致参数显着减少。

训练参数

from transformers import TrainingArguments

# Adapted from  Phil Schmid blogpost
args = TrainingArguments(
    output_dir=sft_model_path,              # directory to save the model and repository id
    num_train_epochs=2,                     # number of training epochs
    per_device_train_batch_size=4,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory, use in distributed training
    optim="adamw_8bit",                     # choose paged_adamw_8bit if not enough memory
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to Hugging Face hub
    hub_model_id="llama3-8b-sft-qlora-re",
    report_to="tensorboard",               # report metrics to tensorboard
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

如果选择将模型保存在本地,则可以省略最后三个参数。您可能还需要调整per_device_batch_size和gradient_accumulation_steps来防止内存不足 (OOM) 错误。

初始化训练器并训练模型

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=sft_dataset,
    peft_config=peft_config,
    max_seq_length=512,
    tokenizer=tokenizer,
    packing=False, # True if the dataset is large
    dataset_kwargs={
        "add_special_tokens": False,  # the template adds the special tokens
        "append_concat_token": False, # no need to add additional separator token
    }
)

trainer.train()
trainer.save_model()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

训练(包括模型保存)大约花费了 10 分钟。

让我们清理一下内存,为推理测试做准备。如果您使用内存较少的 GPU 并遇到 CUDA 内存不足 (OOM) 错误,则可能需要重新启动运行时。

import torch
import gc
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

使用 SFT 模型进行推理

在最后一步中,我们将以半精度加载基本模型以及 Peft 适配器。对于此测试,我选择不将模型与适配器合并。

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch

# HF model
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16,
  offload_buffers=True
)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

接下来,我们加载分词器:

okenizer = AutoTokenizer.from_pretrained(peft_model_id)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id =  tokenizer.eos_token_id

  • 1
  • 2
  • 3
  • 4
  • 5

我们构建文本生成管道:

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
  • 1

我们加载测试数据集,其中包含我们之前预留的 20 个样本,并以对话方式格式化数据。不过,这次我们省略了辅助消息,并将其格式化为 Hugging Face 数据集:

def create_input_prompt(sample):
    return {
        "messages": [
            {"role": "system","content": system_message},
            {"role": "user", "content": sample["text"]},
        ]
    }
    
from datasets import Dataset

test_dataset = Dataset.from_list(mini_data)

# Transform to conversational format
test_dataset = test_dataset.map(create_input_prompt,
                      remove_columns=test_dataset.features,
                      batched=False)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

一个样品测试

让我们使用 SFT Llama3–8B 生成关系提取输出,并将其与单个实例上的前两个输出进行比较:

 Generate the input prompt
prompt = pipe.tokenizer.apply_chat_template(test_dataset[2]["messages"][:2],
                                            tokenize=False,
                                            add_generation_prompt=True)
# Generate the output
outputs = pipe(prompt,
              max_new_tokens=128,
              do_sample=False,
              temperature=0.1,
              top_k=50,
              top_p=0.1,
              )
# Display the results
print(f"Question: {test_dataset[2]['messages'][1]['content']}\n")
print(f"Gold-RE: {test_sampler[2]['gold_re']}\n")
print(f"LLama3-8B-RE: {test_sampler[2]['test_re']}\n")
print(f"SFT-Llama3-8B-RE: {outputs[0]['generated_text'][len(prompt):].strip()}")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

我们得到以下信息:

Question: Long before any knowledge of electricity existed, people were aware of shocks from electric fish.

Gold-RE: people|were aware of|shocks
    shocks|from|electric fish
    electric fish|had|electricity

LLama3-8B-RE: electric fish|were aware of|shocks

SFT-Llama3-8B-RE: people|were aware of|shocks
         shocks|from|electric fish
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在这个例子中,我们观察到通过微调,Llama3-8B 的关系提取能力有了显着的提高。尽管微调数据集既不是很干净也不是特别大,但结果令人印象深刻。

有关 20 个样本数据集的完整结果,请参阅Google Colab 笔记本。请注意,推理测试需要更长的时间,因为我们以半精度加载模型。

结论

总之,通过利用 Llama3-70B 和可用的数据集,我们成功创建了一个合成数据集,然后用于针对特定任务微调 Llama3-8B。这个过程不仅让我们熟悉了 Llama3,还让我们能够应用 Hugging Face 中的简单技术。我们观察到,使用 Llama3 的体验与使用 Llama2 的体验非常相似,显着的改进是增强的输出质量和更有效的分词器。

对于那些有兴趣进一步突破界限的人,可以考虑使用更复杂的任务来挑战模型,例如对实体和关系进行分类,并使用这些分类来构建知识图。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/563253
推荐阅读
相关标签
  

闽ICP备14008679号