当前位置:   article > 正文

使用 Llama3 模型进行关系提取

使用 Llama3 模型进行关系提取

原文地址:relation-extraction-with-llama3-models

通过使用 Llama3–70B 创建的合成数据集微调 Llama3–8B 来增强关系提取

2024 年 4 月 26 日

介绍

关系提取(RE)是从非结构化文本中提取关系以识别各种命名实体之间联系的任务。它与命名实体识别(NER)一起进行,是自然语言处理管道中必不可少的一步。随着大语言模型(LLM)的兴起,传统的监督方法(包括标记实体跨度和分类它们之间的关系(如果有的话))得到了加强,或完全被基于 LLM 的方法所取代。

Llama3 是生成式人工智能(GenerativeAI)领域的最新主要版本。基础模型有 8B 和 70B 两种大小,预计不久将发布 400B 模型。这些模型可在 HuggingFace 平台上使用;70B 型号为 Meta 的新聊天网站 Meta.ai 提供动力,其性能可与 ChatGPT 相媲美。8B 模型是同类产品中性能最好的。Llama3 的架构与 Llama2 相似,性能的提升主要归功于数据升级。该模型配备了升级的标记器和扩展的上下文窗口。虽然只发布了一小部分数据,但它被标为开源。

Llama3-70B 可以产生令人惊叹的结果,但由于其体积庞大,在本地系统中不实用、过于昂贵且难以使用。因此,为了充分利用它的能力,我们让 Llama3-70B 教较小的 Llama3-8B 从非结构化文本中提取关系。

具体来说,在 Llama3-70B 的帮助下,我们建立了一个针对关系提取的监督微调数据集。然后,我们利用该数据集对 Llama3-8B 进行微调,以增强其关系提取能力。

执行

在这个项目中,我使用了配备 A100 GPU 和高内存设置的 Google Colab Pro。

我们首先安装所有需要的库:

  1. !pip install -q groq
  2. !pip install -U accelerate bitsandbytes datasets evaluate 
  3. !pip install -U peft transformers trl 

我很高兴地注意到,尽管模型很新颖,但整个设置从一开始就能正常工作,没有任何依赖性问题,也不需要从源代码中安装转换器。

我们还需要让 Goggle Colab 能够访问驱动器和文件,并设置工作目录:

  1. # For Google Colab settings
  2. from google.colab import userdata, drive
  3. # This will prompt for authorization
  4. drive.mount('/content/drive')
  5. # Set the working directory
  6. %cd '/content/drive/MyDrive/postedBlogs/llama3RE'

对于希望将模型上传到 HuggingFace Hub 的用户,我们需要上传 Hub 凭据。在我的情况下,这些凭证存储在 Google Colab secrets 中,可以通过左侧的按键按钮访问。这一步是可选的。

  1. # For Hugging Face Hub setting
  2. from huggingface_hub import login
  3. # Upload the HuggingFace token (should have WRITE access) from Colab secrets
  4. HF = userdata.get('HF')
  5. # This is needed to upload the model to HuggingFace
  6. login(token=HF,add_to_git_credential=True)

我还添加了一些路径变量,以简化文件访问:

  1. # Create a path variable for the data folder
  2. data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'
  3. # Full fine-tuning dataset
  4. sft_dataset_file = f'{data_path}sft_train_data.json'
  5. # Data collected from the the mini-test
  6. mini_data_path = f'{data_path}mini_data.json'
  7. # Test data containing all three outputs
  8. all_tests_data = f'{data_path}all_tests.json'
  9. # The adjusted training dataset
  10. train_data_path = f'{data_path}sft_train_data.json'
  11. # Create a path variable for the SFT model to be saved locally
  12. sft_model_path = '/content/drive/MyDrive/llama3RE/Llama3_RE/'

现在我们的工作区已经建立,可以开始第一步,即为关系提取任务建立一个合成数据集。

用 Llama3-70B 创建一个用于关系提取的合成数据集

目前有多个关系提取数据集,其中最著名的是 CoNLL04 数据集。此外,还有 HuggingFace 上的 web_nlg 和 AllenAI 开发的 SciREX 等优秀数据集。不过,这些数据集大多有限制性许可证。

受 web_nlg 数据集格式的启发,我们将建立自己的数据集。如果我们计划对在我们的数据集上训练的模型进行微调,这种方法将特别有用。首先,我们需要一个短句集来完成关系提取任务。我们可以通过多种方式来编译这个语料库。

收集句子集

我们将使用 databricks-dolly-15k 这个由 Databricks 员工于 2023 年生成的开源数据集。这个数据集是为监督微调而设计的,包含四个特征:指令、语境、反应和类别。在分析了八个类别后,我决定保留信息提取类别中上下文的第一句话。数据解析步骤概述如下:

  1. from datasets import load_dataset
  2. # Load the dataset
  3. dataset = load_dataset("databricks/databricks-dolly-15k")
  4. # Choose the desired category from the dataset
  5. ie_category = [e for e in dataset["train"] if e["category"]=="information_extraction"]
  6. # Retain only the context from each instance
  7. ie_context = [e["context"] for e in ie_category]
  8. # Split the text into sentences (at the period) and keep the first sentence
  9. reduced_context = [text.split('.')[0] + '.' for text in ie_context]
  10. # Retain sequences of specified lengths only (use character length)
  11. sampler = [e for e in reduced_context if 30 < len(e) < 170]

筛选过程产生了一个包含 1,041 个句子的数据集。鉴于这是一个小型项目,我没有亲自挑选句子,因此有些样本可能并不非常适合我们的任务。在制作项目中,我会仔细挑选最合适的句子。不过,就本项目而言,这个数据集已经足够了。

格式化数据

我们首先需要创建一条系统消息,用于定义输入提示并指示模型如何生成答案:

  1. system_message = """You are an experienced annontator. """You are an experienced annontator. 
  2. Extract all entities and the relations between them from the following text. 
  3. Write the answer as a triple entity1|relationship|entitity2. 
  4. Do not add anything else.
  5. Example Text: Alice is from France.
  6. Answer: Alice|is from|France.
  7. """

由于现在是实验阶段,我对模型的要求降到了最低。我确实测试了其他几个提示,包括一些要求以 CoNLL 格式输出实体分类的提示,模型的表现相当不错。不过,为了简单起见,我们现在还是坚持最基本的要求。

我们还需要将数据转换为对话格式:

  1. messages = [[
  2.     {"role": "system","content": f"{system_message}"},"role": "system","content": f"{system_message}"},
  3.     {"role": "user", "content": e}] for e in sampler]

Groq 客户端和API

Llama3 几天前刚刚发布,API 的可用性仍然有限。虽然 Llama3-70B 有聊天界面,但本项目需要的 API 只需几行代码就能处理我的 1000 句话。

提醒一下:你需要登录 GroqCloud 网站并获取免费 API 密钥。我的 API 密钥已保存在 Google Colab secrets 中。我们首先初始化 Groq 客户端:

  1. import os
  2. from groq import Groq
  3. gclient = Groq(
  4.     api_key=userdata.get("GROQ"),
  5. )

接下来,我们需要定义几个辅助函数,以便与 Meta.ai 聊天界面进行有效交互:

  1. import time
  2. from tqdm import tqdm
  3. def process_data(prompt):
  4.     """Send one request and retrieve model's generation."""
  5.     chat_completion = gclient.chat.completions.create(
  6.         messages=prompt, # input prompt to send to the model
  7.         model="llama3-70b-8192", # according to GroqCloud labeling
  8.         temperature=0.5, # controls diversity
  9.         max_tokens=128, # max number tokens to generate
  10.         top_p=1, # proportion of likelihood weighted options to consider
  11.         stop=None, # string that signals to stop generating
  12.         stream=False, # if set partial messages are sent
  13.     )
  14.     return chat_completion.choices[0].message.content
  15. def send_messages(messages):
  16.     """Process messages in batches with a pause between batches."""
  17.    
  18.    batch_size = 10
  19.     answers = []
  20.     for i in tqdm(range(0, len(messages), batch_size)): # batches of size 10
  21.         batch = messages[i:i+10]  # get the next batch of messages
  22.         for message in batch:
  23.             output = process_data(message)
  24.             answers.append(output)
  25.         if i + 10 < len(messages):  # check if there are batches left
  26.             time.sleep(10)  # wait for 10 seconds
  27.     return answers

第一个函数 process_data() 是 Groq 客户端聊天完成函数的封装程序。第二个函数 send_messages()会分批处理数据。如果你点击 Groq playground 页面上的 “设置 ”链接,就会找到 “限制 ”链接,其中详细说明了我们可以使用免费 API 的条件,包括请求数和生成令牌数的上限。为了避免超过这些限制,我在每批 10 条信息后添加了 10 秒延迟,尽管在我的情况下这并不是绝对必要的。你可以尝试使用这些设置。

现在剩下的工作就是生成关系提取数据,并将其与初始数据集整合:

  1. # Data generation with Llama3-70B
  2. answers = send_messages(messages)
  3. # Combine input data with the generated dataset
  4. combined_dataset = [{'text': user, 'gold_re': output} for user, output in zip(sampler, answers)]

评估用于关系提取的 Llama3-8B

在对模型进行微调之前,重要的是要评估其在多个样本上的性能,以确定是否真的有必要进行微调。

建立测试数据集

我们将从刚刚构建的数据集中选取 20 个样本进行测试。数据集的其余部分将用于微调。

  1. import random
  2. random.seed(17)
  3. # Select 20 random entries
  4. mini_data = random.sample(combined_dataset, 20)
  5. # Build conversational format
  6. parsed_mini_data = [[{'role': 'system', 'content': system_message},
  7.                      {'role': 'user', 'content': e['text']}] for e in mini_data]
  8. # Create the training set
  9. train_data = [item for item in combined_dataset if item not in mini_data]

我们将使用 GroqCloud API 和上文定义的实用程序,指定 model=llama3-8b-8192,而函数的其他部分保持不变。在这种情况下,我们可以直接处理我们的小型数据集,而不必担心超出 API 的限制。

下面是一个输出示例,提供了原始文本、以 gold_re 表示的 Llama3-70B 生成和以 test_re 表示的 Llama3-8B hgeneration。

  1. {'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.','text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
  2.  'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity',
  3.  'test_re': 'electric fish|were aware of|shocks'}

从这个例子中我们可以清楚地看到,Llama3-8B 的关系抽取能力还需要进一步改进。让我们努力提高它的能力。

对 Llama3-8B 进行有监督的微调

A100 GPU 支持 Flash Attention 和 bfloat16,拥有约 40GB 的内存,足以满足我们的微调需求。

准备 SFT 数据集

我们首先将数据集解析为对话格式,包括系统消息、输入文本和所需答案,这些都是我们从 Llama3-70B 生成的。然后,我们将其保存为 HuggingFace 数据集:

  1. def create_conversation(sample):
  2.     return {
  3.         "messages": [
  4.             {"role": "system","content": system_message},
  5.             {"role": "user", "content": sample["text"]},
  6.             {"role": "assistant", "content": sample["gold_re"]}
  7.         ]
  8.     }
  9. from datasets import load_dataset, Dataset
  10. train_dataset = Dataset.from_list(train_data)
  11. # Transform to conversational format
  12. train_dataset = train_dataset.map(create_conversation,
  13.                       remove_columns=train_dataset.features,
  14.                       batched=False)

选择型号

model_id  =  "meta-llama/Meta-Llama-3-8B""meta-llama/Meta-Llama-3-8B"

加载标记符

  1. from transformers import AutoTokenizer
  2. # Tokenizer
  3. tokenizer = AutoTokenizer.from_pretrained(model_id,
  4.                                           use_fast=True,
  5.                                           trust_remote_code=True)
  6. tokenizer.pad_token = tokenizer.eos_token
  7. tokenizer.pad_token_id =  tokenizer.eos_token_id
  8. tokenizer.padding_side = 'left'
  9. # Set a maximum length
  10. tokenizer.model_max_length = 512

选择量化参数

  1. from transformers import BitsAndBytesConfig
  2. bnb_config = BitsAndBytesConfig(
  3.     load_in_4bit=True,
  4.     bnb_4bit_use_double_quant=True,
  5.     bnb_4bit_quant_type="nf4",
  6.     bnb_4bit_compute_dtype=torch.bfloat16
  7. )

加载模型

  1. from transformers import AutoModelForCausalLM
  2. from peft import prepare_model_for_kbit_training
  3. from trl import setup_chat_format
  4. device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
  5. model = AutoModelForCausalLM.from_pretrained(
  6.     model_id,
  7.     device_map=device_map,
  8.     attn_implementation="flash_attention_2",
  9.     quantization_config=bnb_config
  10. )
  11. model, tokenizer = setup_chat_format(model, tokenizer)
  12. model = prepare_model_for_kbit_training(model)

LoRA 配置

  1. from peft import LoraConfig
  2. # According to Sebastian Raschka findings
  3. peft_config = LoraConfig(
  4.         lora_alpha=128, #32
  5.         lora_dropout=0.05,
  6.         r=256,  #16
  7.         bias="none",
  8.         target_modules=["q_proj", "o_proj", "gate_proj", "up_proj"
  9.           "down_proj", "k_proj", "v_proj"],
  10.         task_type="CAUSAL_LM",
  11. )

当针对所有线性层时,可以获得最佳效果。如果考虑到内存限制,可以选择更标准的值,如 alpha=32 和 rank=16,因为这些设置会大大减少参数。

训练参数

  1. from transformers import TrainingArguments
  2. # Adapted from  Phil Schmid blogpost
  3. args = TrainingArguments(
  4.     output_dir=sft_model_path,              # directory to save the model and repository id
  5.     num_train_epochs=2,                     # number of training epochs
  6.     per_device_train_batch_size=4,          # batch size per device during training
  7.     gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
  8.     gradient_checkpointing=True,            # use gradient checkpointing to save memory, use in distributed training
  9.     optim="adamw_8bit",                     # choose paged_adamw_8bit if not enough memory
  10.     logging_steps=10,                       # log every 10 steps
  11.     save_strategy="epoch",                  # save checkpoint every epoch
  12.     learning_rate=2e-4,                     # learning rate, based on QLoRA paper
  13.     bf16=True,                              # use bfloat16 precision
  14.     tf32=True,                              # use tf32 precision
  15.     max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
  16.     warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
  17.     lr_scheduler_type="constant",           # use constant learning rate scheduler
  18.     push_to_hub=True,                       # push model to Hugging Face hub
  19.     hub_model_id="llama3-8b-sft-qlora-re",
  20.     report_to="tensorboard",               # report metrics to tensorboard
  21.     )

如果选择本地保存模型,则可以省略后三个参数。你可能还需要调整 per_device_batch_size 和 gradient_accumulation_steps 以防止出现内存不足 (OOM) 错误。

初始化训练器并训练模型

  1. from trl import SFTTrainer
  2. trainer = SFTTrainer(
  3.     model=model,
  4.     args=args,
  5.     train_dataset=sft_dataset,
  6.     peft_config=peft_config,
  7.     max_seq_length=512,
  8.     tokenizer=tokenizer,
  9.     packing=False, # True if the dataset is large
  10.     dataset_kwargs={
  11.         "add_special_tokens": False,  # the template adds the special tokens
  12.         "append_concat_token": False, # no need to add additional separator token
  13.     }
  14. )
  15. trainer.train()
  16. trainer.save_model()

包括保存模型在内的训练耗时约 10 分钟。

让我们清空内存,为推理测试做好准备。如果你使用的 GPU 内存较少,并且遇到 CUDA 内存不足(OOM)错误,你可能需要重新启动运行时。

  1. import torch
  2. import gc
  3. del model
  4. del tokenizer
  5. gc.collect()
  6. torch.cuda.empty_cache()

使用 SFT 模型进行推理

在最后一步,我们将以半精度加载基础模型和 Peft 适配器。在本次测试中,我选择不将模型与适配器合并。

  1. from peft import AutoPeftModelForCausalLM
  2. from transformers import AutoTokenizer, pipeline
  3. import torch
  4. # HF model
  5. peft_model_id = "solanaO/llama3-8b-sft-qlora-re"
  6. # Load Model with PEFT adapter
  7. model = AutoPeftModelForCausalLM.from_pretrained(
  8.   peft_model_id,
  9.   device_map="auto",
  10.   torch_dtype=torch.float16,
  11.   offload_buffers=True
  12. )

接下来,我们加载标记符:

  1. okenizer = AutoTokenizer.from_pretrained(peft_model_id)
  2. tokenizer.pad_token = tokenizer.eos_token
  3. tokenizer.pad_token_id =  tokenizer.eos_token_id

然后我们建立文本生成管道:

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)"text-generation", model=model, tokenizer=tokenizer)

我们加载的测试数据集由之前预留的 20 个样本组成,数据格式为对话式。不过,这次我们省略了助手信息,将其格式化为 “Hugging Face ”数据集:

  1. def create_input_prompt(sample):
  2.     return {
  3.         "messages": [
  4.             {"role": "system","content": system_message},
  5.             {"role": "user", "content": sample["text"]},
  6.         ]
  7.     }
  8.     
  9. from datasets import Dataset
  10. test_dataset = Dataset.from_list(mini_data)
  11. # Transform to conversational format
  12. test_dataset = test_dataset.map(create_input_prompt,
  13.                       remove_columns=test_dataset.features,
  14.                       batched=False)

一个测试样本

让我们使用 SFT Llama3-8B 生成关系提取输出,并在单个实例上将其与前两个输出进行比较:

  1.  Generate the input prompt
  2. prompt = pipe.tokenizer.apply_chat_template(test_dataset[2]["messages"][:2],
  3.                                             tokenize=False,
  4.                                             add_generation_prompt=True)
  5. # Generate the output
  6. outputs = pipe(prompt,
  7.               max_new_tokens=128,
  8.               do_sample=False,
  9.               temperature=0.1,
  10.               top_k=50,
  11.               top_p=0.1,
  12.               )
  13. # Display the results
  14. print(f"Question: {test_dataset[2]['messages'][1]['content']}\n")
  15. print(f"Gold-RE: {test_sampler[2]['gold_re']}\n")
  16. print(f"LLama3-8B-RE: {test_sampler[2]['test_re']}\n")
  17. print(f"SFT-Llama3-8B-RE: {outputs[0]['generated_text'][len(prompt):].strip()}")

结果如下

  1. Question: Long before any knowledge of electricity existed, people were aware of shocks from electric fish.
  2. Gold-RE: people|were aware of|shocks
  3.     shocks|from|electric fish
  4.     electric fish|had|electricity
  5. LLama3-8B-RE: electric fish|were aware of|shocks
  6. SFT-Llama3-8B-RE: people|were aware of|shocks
  7.          shocks|from|electric fish

在这个例子中,我们观察到通过微调,Llama3-8B 的关系提取能力有了显著提高。尽管微调后的数据集既不是非常干净,也不是特别大,但结果却令人印象深刻。

结论

总之,通过利用 Llama3-70B 和可用数据集,我们成功创建了一个合成数据集,然后利用该数据集针对特定任务对 Llama3-8B 进行微调。这一过程不仅让我们熟悉了 Llama3,还让我们能够直接应用 “Hugging Face ”中的技术。据我们观察,使用 Llama3 的过程与使用 Llama2 的过程非常相似,显著的改进是提高了输出质量和令牌生成器的效率。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/533293
推荐阅读
相关标签
  

闽ICP备14008679号