赞
踩
这次该我汇报啦
许愿明天讲的顺利,问的都会
讲+提问1个小时
但是在讨论的过程中,感觉逐步抽丝挖掘到了核心原理:
之前的理解:借助代码-LLM中的编码丰富结构化代码信息
最后的理解:如果能设置一个方法,让大模型能对自己输出的有所理解,那么效果会更好。这篇论文是通过代码结构和提示来实现这个的,理论上文字也可以
CODEIE:代码生成大模型能更好的进行少样本信息提取这项工作
接下来我将从这四个板块展开介绍
这篇论文于2023年5月发布在arXiv,随后9月发表在ACL-NLP顶刊。其作者来自复旦大学和华师大学。
首先,让我们了解一些自然语言处理(NLP)的背景知识。
信息抽取
的目标是从未经结构化处理的文本中提取出结构化信息。这个领域涵盖了各种任务,包括命名实体识别(NER)和关系抽取(RE)。
NER
用于识别文本中的命名实体,例如人名、地名和组织名,
而RE
则用于提取文本中实体之间的关系信息。
序列生成模型
来解决信息抽取(IE)任务。大型语言模型
具有强大的少样本学习能力,本论文旨在充分利用它们来解决少样本IE任务,特别是NER和RE任务。复杂的解码策略
来将输出后处理为有效的结构。总结一下为什么作者选择了这一方法
。
表1总结了用于IE任务的中等规模模型、NL-LLM和Code-LLM之间的高层次差异。
此前的工作未能在一个统一的框架下充分利用大型模型进行少样本学习,特别是在处理结构化任务方面。这篇论文的模型成功弥补了这两个限制。
为什么这两个限制会对结果产生重大影响
呢?
首先,大型模型具有适应少样本数据的能力
其次,由于线性输出不够自然,通常需要更复杂的解码策略。
因此,这篇论文提出了一种全新的思路,通过使用带有结构化代码风格提示的Code-LLM,来弥合预训练和推理阶段之间的输出差异,从而实现了IE任务的统一框架并获得更出色的结果。
这一方法的核心思想是将这两个IE任务框架化为代码生成任务
,并借助代码-LLM中的编码丰富结构化代码信息
,从而使这些IE任务有更好的结果。
这是一个示例:
通过代码风格提示和Python字典的键,如“文本”和“类型”,可以组合它们成一个与NER(命名实体识别)示例等价的Python函数。
下面我们对具体方案展开介绍。
论文涵盖了NER和RE两项任务,
首先将原始IE任务转化为代码样式,其中(换PPT)Python函数名称表示任务,文档字符串说明任务目标,初始化空列表用于保存输出,描述性注释提供了提示以将命名实体放入列表中。(换PPT)这些元素被组合为“代码提示x”。
“结构化目标y”,将每个基本信息单元(NER的一对实体和RE的三元组)表示为Python字典,yc为字典列表。
对于NER任务,键包括“text”和“type”,值分别是实体跨度和实体类型。
对于RE任务,键包括实体类型和关系类型。然后,这些输入被传递给代码生成大模型,并得到输出。
左侧显示的文字是对前面图表的简明描述。
此外,GPT-3和CodeX都是OpenAI的模型,而CodeX是在GPT-3的基础上进行改进的,这两者有着相同的起源。
由于大型模型API的黑盒特性,无法对这些大型模型进行微调,因此这篇论文致力于探索上下文学习的方法,包括使用标记样本。
上下文提示学习是如何具体体现的呢?
在这篇论文中,任务被转化为代码表示,然后将它们连接在一起,构建了一个上下文演示,其中包括x1y1x2y2直到xnyn,最后有一个准备预测的xc。这个上下文被输入到模型中,生成输出yc,其格式与y1y2yn相似,通常保持了Python语法,容易还原成原始结构。
鉴于少样本训练容易受到高方差的影响,该论文为每个实验采用不同的随机种子运行了三次,并报告了度量指标的均值和标准差。
这篇论文已经开源了这项研究的代码,如果有兴趣的朋友可以前去查看。
接下来,我们对这篇文章的实验部分进行梳理
在实验结果部分,论文涵盖了七个NLP任务的数据集,采用了中等规模的预训练模型作为基线
评价指标是常规的NER和RE任务性能度量指标。
实体的偏移量和实体类型与黄金实体匹配,那么实体跨度预测就是正确的。
如果关系类型正确且其实体对应的偏移量和类型正确,则关系预测是正确的。
结果表明:
(表4)突出显示的分别是“text”和“code”提示类型,code提示效果更好。
(图3)Codex胜过GPT-3,代码提示优于文本提示。Codex比GPT3效果好,代码提示比文本提示好。
并且在1次学习设置下,CODEIE将基准上的性能提高了60%以上,表现了强大的少样本学习能力
值得注意的是,代码提示对GPT-3更有益,尽管它并没有专门针对代码数据进行训练。
作者进行了一些控制变量的对比实验,以探讨导致模型性能优越的因素。
介绍下条件困惑度
,这是一种衡量生成的文本在给定条件下生成文本的可预测性,也就是在给定上下文前缀的条件下,模型生成下一个字符的概率的准确性的度量。
较低的条件困惑度值表示生成的文本更符合所期望的条件。
图4,在7个数据集上,文本提示和代码提示的输入格式和模型之间的条件困惑度。
分为两个指标:
1、一个是结构忠实度Structure Fidelity
,顾名思义是生成文本的结构
图5:比较提示学习和不同组合LLM的结构错误率,output的形式不对
2、一个是语义忠实度Semantic Fidelity
,生成文本的语义忠诚度
表5:实验中检测到的语义错误样本,output中语义不对,比如预定义实体类型中不存在的实体类型。
结果表明,GPT-3倾向于生成自由形式的结果,Codex更忠实于上下文中提供的演示,因此对于IE任务更可预测
结果表明(a)代码提示提高了模型的查准率和查全率;
(b)与GPT- 3相比,Codex在NER任务上实现了更高的召回率和相当的精度,并在RE任务上实现了更高的精度和召回率。
最后对这篇论文进行总结。
这篇论文提出的方法相对于其他顶刊论文来说,更加简单有效。它通过领域迁移,将文本生成转化为代码生成,设计上下文提示学习以替代仅提供API的大型模型微调。
考虑设计更良好的代码格式提示。
目前是在黑盒模型GPT3和Codex上进行实验,之后可以在开源模型上进一步微调。
以及,在非英文数据集(如中文数据集)上探索本文模型的实用性。
这就是《CODEIE: Large Code Generation Models are Better Few-Shot Information Extractors》的主要内容和关键观点。感谢大家的聆听。
条件复杂度是一种用于评估模型在给定上下文下生成下一个标记的难度和质量的度量,
模型的目标是尽可能减少条件复杂度,以获得更高质量的生成结果
。
语言模型的条件复杂度和代码模型的条件复杂度通常都基于困惑度(Perplexity)来计算,但有一些细微的差异,具体取决于模型的类型和应用领域。
1. 语言模型的条件复杂度:
语言模型的条件复杂度用于评估模型在给定上下文下生成下一个单词的质量。它通常采用以下方式计算:
数学表达式如下:
P e r p l e x i t y = exp ( − 1 N ∑ i = 1 N log P ( w i ∣ w 1 , w 2 , … , w i − 1 ) ) Perplexity = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, w_2, \ldots, w_{i-1})\right) Perplexity=exp(−N1i=1∑NlogP(wi∣w1,w2,…,wi−1))
其中:
2. 代码模型的条件复杂度:
对于代码生成模型,条件复杂度的计算方式与语言模型类似,但有一些不同之处:
计算方式在代码模型中的具体实施可能因任务和模型架构而异,但基本原理与语言模型相似。
https://github.com/artpli/CodeIE
我们的代码主要是从 UIE 和 CoCoGen 代码仓库修改而来的。
我们更新了源文件的初始版本。有关数据处理和代码的更多信息将在后续更新中提供。通知:
一个不太好的消息是,Codex 模型现在已被 OpenAI 弃用,这将对复制本文产生重大影响。一些可能的解决方案包括
申请 OpenAI的研究人员访问计划
或访问 Azure OpenAI 服务上的 Codex
。
由于我们使用的是 OpenAI 的闭源API,因此我们不知道它们背后的技术细节,例如使用的特定预训练语料库。因此,在评估我们的论文时可能存在潜在的数据污染问题。如果可能的话,我们会在更多的开源模型上评估我们的方法。
Codex被弃用了,但OpenAI建议所有用户从Codex切换到GPT-3.5 Turbo,它既可以完成编码任务,又可以补充灵活的自然语言功能。
import os
from typing import Dict, Any
import openai
from src.prompt.constants import END, END_LINE
openai.api_key = os.getenv("OPENAI_API_KEY")
class OpenaiAPIWrapper:
@staticmethod
def call(prompt: str, max_tokens: int, engine: str) -> dict:
response = openai.Completion.create(
engine=engine,
prompt=prompt,
temperature=0.0,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
stop=[END, END_LINE],
# logprobs=3,
best_of=1
)
return response
@staticmethod
def parse_response(response) -> Dict[str, Any]:
text = response["choices"][0]["text"]
return text
这段代码是一个Python脚本,它包含了一个名为OpenaiAPIWrapper
的类,该类提供了与OpenAI的GPT-3 API交互的功能。这个类的主要目的是通过给定的提示(prompt)来请求GPT-3生成文本,并解析生成的文本
。
具体来说,这段代码的功能如下:
导入必要的模块和库,包括os
、openai
,以及一些类型提示(typing
模块)和其他自定义模块。
设置OpenAI API密钥,它是在环境变量中查找的,这个密钥用于身份验证并授权访问OpenAI的GPT-3服务。
定义了一个名为OpenaiAPIWrapper
的类,该类包含两静态方法。
call
方法接受三个参数:prompt
(提示文本,用于生成文本的输入)、max_tokens
(生成的文本的最大令牌数)、engine
(指定GPT-3的引擎)。它使用这些参数通过OpenAI API发送请求,并返回生成的文本作为响应。
parse_response
方法接受一个响应对象,从中提取生成的文本,并返回它作为字符串。
这个代码段的主要目的是:为了通过OpenAI的API与GPT-3交互,以便生成自然语言文本,然后将生成的文本提取出来以供后续处理或显示。
给定一个提示文件(prompt file)和一个任务文件(task file),任务文件包含以下字段:
对于每个input_prompt,运行Codex的推断,并将以下字段添加到输出文件:
文件可以包含其他元数据,但上述字段是必需的。
"""
Given a prompt file and path to a task file with the following fields:
1. input_prompt: the code used to prompt codex
2. reference_code: expected completed code
3. reference_graph: expected graph
Runs inference over codex for each input_prompt, and adds the following fields to the output file:
4. generated_code: generated code
5. generated_graph: generated graph
The file can contain other metadata, but the fields above are required.
"""
import os
import sys
sys.path.append(os.getcwd())
from datetime import datetime
import shutil
import time
import openai
import pandas as pd
from tqdm import tqdm
import logging
import os
import pickle
from src.converters.structure_converter import StructureConverter
from src.converters.get_converter import ConverterFactory
from openai_api_wrapper import OpenaiAPIWrapper
from src.prompt.constants import END
from src.utils.file_utils import load_yaml,load_schema
logging.basicConfig(level=logging.INFO)
def run(task_file_path: str,
num_tasks: int,
start_idx: int,
output_file_path: str,
prompt_path: str,
keep_writing_output: bool,
engine: str,
max_tokens:int,
max_requests_per_min: int,
schema_path:str,
map_config_path:str,
start_cut_num:int):
tasks = pd.read_json(task_file_path, orient='records', lines=True)
converter = ConverterFactory.get_converter(args.job_type,schema_folder=schema_path, map_config_path=map_config_path)
if num_tasks != -1:
tasks = tasks.iloc[start_idx: start_idx+num_tasks]
fixed_prompt_text = read_prompt(prompt_path)
results = []
cache = load_cache(output_file_path)
num_requests = 0
time_begin = time.time()
failed_list = []
max_failed_time = 10
max_failed_taskes = 10
for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks)):
is_success = False
tmp_failed_time = 0
while is_success is False and tmp_failed_time < max_failed_time:
cut_prompt_examples_list = [None, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
cut_prompt_examples_list = cut_prompt_examples_list[start_cut_num:]
for cut_prompt_examples in cut_prompt_examples_list:
try:
num_requests += 1
request_per_minute = maintain_request_per_minute(
num_requests=num_requests, time_begin=time_begin,
max_requests_per_min=max_requests_per_min, task_idx=task_idx)
logging.info("\n")
logging.info(
f"Task {task_idx} > request/minute = {request_per_minute:.2f}")
task_results = run_task(task=task, fixed_prompt_text=fixed_prompt_text,
cache=cache, converter=converter,
cut_prompt_examples=cut_prompt_examples, task_idx=task_idx,
engine=engine, max_tokens=max_tokens)
task_results['id'] = task_idx
results.append(task_results)
is_success = True
break
except openai.error.InvalidRequestError as e:
logging.info(
f"InvalidRequestError: {e}, trying with shorter prompt (cut_prompt_examples={cut_prompt_examples + 1 if cut_prompt_examples is not None else 1})")
# sleep for a bit to further avoid rate limit exceeded exceptions
if cut_prompt_examples != cut_prompt_examples_list[-1]:
time.sleep(5)
continue
else:
tmp_failed_time = max_failed_time
logging.info(f"Failed too many times: {tmp_failed_time}")
except Exception as e: # something else went wrong
logging.info(f"Task {task_idx} failed: {e}")
tmp_failed_time += 1
time.sleep(5 * tmp_failed_time)
logging.info(f"Restart task {task_idx}")
break
if is_success and keep_writing_output:
pd.DataFrame(results).to_json(
output_file_path, orient='records', lines=True)
if is_success == False:
failed_list.append(task_idx)
logging.info(f"Task {task_idx} failed {max_failed_time} times, skipped and recorded.")
if failed_list != []:
print ("failed list:\n", failed_list)
if len(failed_list) > max_failed_taskes:
print ("too many failed taskes. exit().")
exit(0)
print(
f"Ran {len(results)} out of {len(tasks)} tasks ({len(results) / len(tasks):.2%})")
pd.DataFrame(results).to_json(
output_file_path, orient='records', lines=True)
if failed_list != []:
print ("failed list:\n", failed_list)
output_path = output_file_path.rstrip('.jsonl') + '_failed_list.pkl'
with open(output_path,"w") as fout:
pickle.dump(failed_list,fout)
print ("failed list saved into: ", output_path)
def run_task(task: dict, fixed_prompt_text: str, cache: dict,
converter: StructureConverter, task_idx: int, engine: str,
max_tokens: int, cut_prompt_examples: int = None) -> dict:
"""Runs the task, and returns the results.
Args:
task (dict): The task input
fixed_prompt_text (str): Used for cases where the input prompt is fixed
cache (dict): cache of previous results
converter (GraphPythonConverter): A graph-python converter to parse results
cut_prompt_examples (int, optional): If provided, the first `cut_prompt_examples` examples are
deleted. Prevents 4096 errors. Defaults to None.
Returns:
dict: A dictionary with the results.
"""
start_time = time.time()
prompt_text = fixed_prompt_text if fixed_prompt_text is not None else task['prompt']
if cut_prompt_examples is not None:
prompt_text_parts = prompt_text.split(END)
prompt_text = END.join(prompt_text_parts[cut_prompt_examples:])
if task['input_prompt'] in cache:
logging.info(
f"Task {task_idx} > Using cached result for {task['input_prompt']}")
codex_response = cache[task['input_prompt']]["codex_response"]
else:
codex_response = query_codex(task, prompt_text, engine, max_tokens=max_tokens)
completed_code = get_completed_code(task, codex_response)
task_results = {k: v for (k, v) in task.items()}
task_results["codex_response"] = codex_response
task_results["generated_code"] = completed_code
task_results["elapsed_time"] = time.time() - start_time
return task_results
def maintain_request_per_minute(num_requests: int, time_begin: float, max_requests_per_min: int, task_idx: int) -> float:
request_per_minute = get_request_per_minute(num_requests, time_begin)
logging.info("\n")
while request_per_minute > max_requests_per_min:
logging.info(
f"Task {task_idx} > Sleeping! (Requests/minute = {request_per_minute:.2f} > {max_requests_per_min:.2f})")
time.sleep(1)
request_per_minute = get_request_per_minute(
num_requests, time_begin)
return request_per_minute
def read_prompt(prompt_path):
if prompt_path is None:
return None
with open(prompt_path, "r") as f:
prompt = f.read()
return prompt
def load_cache(output_file_path: str):
"""We don't want to query codex repeatedly for the same input. If an output file exists, this
function creates a "cache" of the results.
The cache is implemented as a hashmap keyed by `input_prompt`, and maps to the
entire output entry
Args:
output_file_path (str): _description_
"""
if not os.path.exists(output_file_path):
return {}
else:
# make a backup of the file already there
shutil.copyfile(output_file_path, output_file_path + "_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
shutil.copy(output_file_path, output_file_path + ".bak")
cache_data = pd.read_json(
output_file_path, orient='records', lines=True)
cache = {row['input_prompt']: row.to_dict()
for _, row in cache_data.iterrows()}
return cache
def query_codex(task: dict, prompt_text: str, engine: str, max_tokens: int):
prompt = f"{prompt_text} {task['input_prompt']}"
response = OpenaiAPIWrapper.call(
prompt=prompt, max_tokens=max_tokens, engine=engine)
return response
def get_completed_code(task: dict, codex_response: dict) -> str:
completed_code = OpenaiAPIWrapper.parse_response(codex_response)
all_code = f"{task['input_prompt']}{completed_code}"
# NOTE: space is already taken care of, no need to add it again, otherwise indentation will be off
return all_code
def get_request_per_minute(num_request: int, begin_time: float) -> float:
elapsed_time = time.time() - begin_time
request_per_minute = (num_request / elapsed_time) * 60
return request_per_minute
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--task_file_path", type=str, required=True)
parser.add_argument("--num_tasks", type=int, required=True)
parser.add_argument("--start_idx", type=int, required=True)
parser.add_argument("--output_file_path", type=str, required=True)
parser.add_argument("--prompt_path", type=str,
required=False, default=None)
parser.add_argument("--job_type", type=str, required=True,
choices=ConverterFactory.supported_converters)
parser.add_argument("--keep_writing_output",
action="store_true", default=True)
parser.add_argument("--engine", type=str, required=True)
parser.add_argument("--max_requests_per_min", type=int, default=10)
parser.add_argument("--max_tokens", type=int, default=280)
parser.add_argument("--schema_path", type=str, required=True)
parser.add_argument("--map_config_path", type=str, required=True)
parser.add_argument("--start_cut_num", type=int, default=0)
args = parser.parse_args()
run(task_file_path=args.task_file_path, num_tasks=args.num_tasks,start_idx=args.start_idx,
output_file_path=args.output_file_path, prompt_path=args.prompt_path,
keep_writing_output=args.keep_writing_output, engine=args.engine,
max_requests_per_min=args.max_requests_per_min,
max_tokens=args.max_tokens,schema_path=args.schema_path,
map_config_path=args.map_config_path,start_cut_num=args.start_cut_num)
def run(task_file_path: str,
num_tasks: int,
start_idx: int,
output_file_path: str,
prompt_path: str,
keep_writing_output: bool,
engine: str,
max_tokens:int,
max_requests_per_min: int,
schema_path:str,
map_config_path:str,
start_cut_num:int):
tasks = pd.read_json(task_file_path, orient='records', lines=True)
converter = ConverterFactory.get_converter(args.job_type,schema_folder=schema_path, map_config_path=map_config_path)
这段代码定义了一个名为run
的Python函数,该函数接受多个参数,并执行一些任务。以下是每个参数的解释:
task_file_path
(str):任务文件的路径,包含任务信息。num_tasks
(int):任务的数量。start_idx
(int):任务的起始索引。output_file_path
(str):输出文件的路径,用于将结果写入。prompt_path
(str):提示文件的路径,可能包含用于代码生成的输入提示。keep_writing_output
(bool):一个布尔值,指示是否保持写入输出。可能用于控制写入的方式。engine
(str):引擎名称,用于执行任务,可能与某种代码生成引擎相关。max_tokens
(int):最大令牌数,可能用于限制生成的代码的长度。max_requests_per_min
(int):每分钟的最大请求数。schema_path
(str):模式文件的路径,可能包含任务的数据模式。map_config_path
(str):映射配置文件的路径,用于数据转换或映射。start_cut_num
(int):一个整数,可能用于指示某种切割操作的起始数量。在函数内部,它首先读取了一个任务文件并将其存储在名为tasks
的DataFrame中。然后,它使用ConverterFactory
类中的get_converter
方法来创建一个converter
对象,该对象用于根据作业类型(args.job_type
)执行某些转换任务,并使用给定的模式文件和映射配置文件。
if num_tasks != -1:
tasks = tasks.iloc[start_idx: start_idx+num_tasks]
fixed_prompt_text = read_prompt(prompt_path)
results = []
cache = load_cache(output_file_path)
num_requests = 0
time_begin = time.time()
failed_list = []
max_failed_time = 10
max_failed_taskes = 10
在给定的代码段中,如果num_tasks
参数的值不等于-1,那么它会对任务进行切片操作,保留从start_idx
到start_idx+num_tasks
的子集。
然后,代码从一个名为prompt_path
的文件中读取提示文本(这个提示文本包含用于代码生成的输入提示),并将其存储在fixed_prompt_text
变量中。
接下来,代码初始化了一些变量,包括:
results
:用于存储任务的结果的空列表。cache
:通过加载一个输出文件来初始化的,这个文件用于缓存结果。num_requests
:用于跟踪已经发出的请求数量的变量。time_begin
:用于记录时间的变量,可能用于计算运行时间。此外,还定义了以下变量:
failed_list
:用于存储失败任务的列表。max_failed_time
:最大失败时间,可能用于指示失败任务的最大允许时间。max_failed_taskes
:最大失败任务数,可能用于指示允许的最大失败任务数量。这些变量在接下来的代码中可能会用于跟踪和处理任务的执行以及失败的情况。接下来的代码段可能包括任务的执行和结果的收集,以及处理失败任务的逻辑。
for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks)):
is_success = False
tmp_failed_time = 0
while is_success is False and tmp_failed_time < max_failed_time:
cut_prompt_examples_list = [None, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
cut_prompt_examples_list = cut_prompt_examples_list[start_cut_num:]
for cut_prompt_examples in cut_prompt_examples_list:
try:
num_requests += 1
request_per_minute = maintain_request_per_minute(
num_requests=num_requests, time_begin=time_begin,
max_requests_per_min=max_requests_per_min, task_idx=task_idx)
logging.info("\n")
logging.info(
f"Task {task_idx} > request/minute = {request_per_minute:.2f}")
task_results = run_task(task=task, fixed_prompt_text=fixed_prompt_text,
cache=cache, converter=converter,
cut_prompt_examples=cut_prompt_examples, task_idx=task_idx,
engine=engine, max_tokens=max_tokens)
task_results['id'] = task_idx
results.append(task_results)
is_success = True
break
except openai.error.InvalidRequestError as e:
logging.info(
f"InvalidRequestError: {e}, trying with shorter prompt (cut_prompt_examples={cut_prompt_examples + 1 if cut_prompt_examples is not None else 1})")
# sleep for a bit to further avoid rate limit exceeded exceptions
if cut_prompt_examples != cut_prompt_examples_list[-1]:
time.sleep(5)
continue
else:
tmp_failed_time = max_failed_time
logging.info(f"Failed too many times: {tmp_failed_time}")
except Exception as e: # something else went wrong
logging.info(f"Task {task_idx} failed: {e}")
tmp_failed_time += 1
time.sleep(5 * tmp_failed_time)
logging.info(f"Restart task {task_idx}")
break
if is_success and keep_writing_output:
pd.DataFrame(results).to_json(
output_file_path, orient='records', lines=True)
if is_success == False:
failed_list.append(task_idx)
logging.info(f"Task {task_idx} failed {max_failed_time} times, skipped and recorded.")
if failed_list != []:
print ("failed list:\n", failed_list)
if len(failed_list) > max_failed_taskes:
print ("too many failed taskes. exit().")
exit(0)
这段代码是一个循环,用于处理任务列表中的每个任务。以下是代码的主要逻辑:
循环迭代任务列表中的每个任务(通过for task_idx, task in tqdm(tasks.iterrows(), total=len(tasks))
实现迭代)。
在每次迭代中,设置了一个is_success
标志,用于跟踪任务是否成功完成。还有一个tmp_failed_time
变量,用于跟踪任务失败的时间。
在一个while循环内,当is_success
为False
且tmp_failed_time
小于max_failed_time
时,会尝试执行任务。该while循环将重试直到任务成功或达到最大失败时间。
在内部循环中,任务会被分成多个部分,通过cut_prompt_examples_list
来控制。每个部分都会尝试执行。如果任务成功完成,is_success
标志将设置为True
,并跳出内部循环。如果任务失败,则会捕获异常,并根据不同的异常类型采取不同的措施,包括减少任务提示的长度、等待一段时间,以及记录失败的任务。
如果is_success
为True
,将任务结果添加到results
列表中,然后写入到输出文件中。
如果is_success
为False
,表示任务在允许的失败时间内无法成功执行,将任务索引添加到failed_list
列表中,并记录失败的任务。如果failed_list
不为空,它会将失败的任务索引打印出来。如果失败的任务数超过了max_failed_taskes
,则终止程序。
总之,这段代码负责循环处理任务列表中的每个任务,监视任务的成功与失败,并根据情况采取相应的措施,包括重试任务、记录失败的任务以及终止程序。这个循环是任务执行和处理失败任务的核心部分。
print(
f"Ran {len(results)} out of {len(tasks)} tasks ({len(results) / len(tasks):.2%})")
pd.DataFrame(results).to_json(
output_file_path, orient='records', lines=True)
if failed_list != []:
print ("failed list:\n", failed_list)
output_path = output_file_path.rstrip('.jsonl') + '_failed_list.pkl'
with open(output_path,"w") as fout:
pickle.dump(failed_list,fout)
print ("failed list saved into: ", output_path)
这段代码是在处理所有任务后执行的一些收尾工作:
首先,它打印出已经运行了多少个任务,以及总共有多少任务,并以百分比形式表示已完成任务的比例。
然后,它将任务结果存储为JSON文件(output_file_path),使用pd.DataFrame(results).to_json
方法将results
列表的内容写入文件。这个文件将包含每个任务的生成结果。
接下来,如果存在失败的任务(failed_list
不为空),它会打印出失败任务的索引,并将这个失败列表保存为一个.pkl
文件。保存的文件名是在原始output_file_path
的基础上,去掉扩展名.jsonl
并添加_failed_list.pkl
后缀。
最后,它会打印出失败任务列表已保存的文件路径。
这段代码用于记录任务的执行情况,保存任务的结果,以及对于失败任务的记录和保存。它提供了任务执行的总结信息,并保存了任务的结果和失败列表。
def run_task(task: dict, fixed_prompt_text: str, cache: dict,
converter: StructureConverter, task_idx: int, engine: str,
max_tokens: int, cut_prompt_examples: int = None) -> dict:
"""Runs the task, and returns the results.
Args:
task (dict): The task input
fixed_prompt_text (str): Used for cases where the input prompt is fixed
cache (dict): cache of previous results
converter (GraphPythonConverter): A graph-python converter to parse results
cut_prompt_examples (int, optional): If provided, the first `cut_prompt_examples` examples are
deleted. Prevents 4096 errors. Defaults to None.
Returns:
dict: A dictionary with the results.
"""
start_time = time.time()
prompt_text = fixed_prompt_text if fixed_prompt_text is not None else task['prompt']
if cut_prompt_examples is not None:
prompt_text_parts = prompt_text.split(END)
prompt_text = END.join(prompt_text_parts[cut_prompt_examples:])
if task['input_prompt'] in cache:
logging.info(
f"Task {task_idx} > Using cached result for {task['input_prompt']}")
codex_response = cache[task['input_prompt']]["codex_response"]
else:
codex_response = query_codex(task, prompt_text, engine, max_tokens=max_tokens)
completed_code = get_completed_code(task, codex_response)
task_results = {k: v for (k, v) in task.items()}
task_results["codex_response"] = codex_response
task_results["generated_code"] = completed_code
task_results["elapsed_time"] = time.time() - start_time
return task_results
这段代码定义了一个名为run_task
的函数,它用于运行一个任务并返回结果。以下是函数的主要参数和功能:
task
(dict):任务的输入,是一个包含任务信息的字典。fixed_prompt_text
(str):用于固定输入提示文本的字符串,如果不为None
,则会使用它而不是任务字典中的prompt
字段。cache
(dict):包含先前结果的缓存字典。converter
(StructureConverter):用于解析结果的转换器对象。cut_prompt_examples
(int,可选):一个整数,如果提供,将删除提示文本的前cut_prompt_examples
个示例以防止4096错误。默认为None
。函数的主要功能如下:
计时开始,记录开始时间。
根据输入参数,确定要使用的提示文本,如果fixed_prompt_text
不为None
,则使用它,否则使用任务字典中的prompt
字段。如果提供了cut_prompt_examples
,则截取提示文本的一部分以防止4096错误。
检查缓存中是否已经有了任务的结果,如果有,则从缓存中获取Codex的响应,否则使用query_codex
函数查询Codex以获取响应。
获取Codex响应后,通过get_completed_code
函数获取生成的代码。
将任务结果存储在一个字典中,包括任务的所有输入信息、Codex的响应、生成的代码和任务执行所花费的时间。
最后,返回包含任务结果的字典。
这个函数的主要目的是执行一个给定任务,与Codex进行交互,获取生成的代码,记录执行时间,并返回任务的结果。
def maintain_request_per_minute(num_requests: int, time_begin: float, max_requests_per_min: int, task_idx: int) -> float:
request_per_minute = get_request_per_minute(num_requests, time_begin)
logging.info("\n")
while request_per_minute > max_requests_per_min:
logging.info(
f"Task {task_idx} > Sleeping! (Requests/minute = {request_per_minute:.2f} > {max_requests_per_min:.2f})")
time.sleep(1)
request_per_minute = get_request_per_minute(
num_requests, time_begin)
return request_per_minute
这段代码定义了一个名为maintain_request_per_minute
的函数,用于控制每分钟的请求数以避免超出最大请求数。以下是函数的主要参数和功能:
num_requests
(int):已经发出的请求数。time_begin
(float):开始计时的时间。max_requests_per_min
(int):每分钟的最大请求数限制。task_idx
(int):任务的索引,用于记录日志。函数的主要功能如下:
计算当前每分钟的请求速率,通过调用get_request_per_minute
函数,传递已发出的请求数和开始计时的时间。
如果当前的请求速率超过了最大请求数限制,就会进入一个while循环。在循环中,它会记录日志,指示正在等待以减少请求速率。
在每次循环中,它会等待1秒,然后再次计算请求速率。这样,它会一直等待,直到请求速率不再超过最大请求数限制。
一旦请求速率在允许的范围内,它会返回当前请求速率。
这个函数的目的是确保不会超出每分钟的最大请求数限制,以遵守请求速率的规则。如果请求速率太高,它会等待一段时间,直到速率降到允许的水平。
def read_prompt(prompt_path):
if prompt_path is None:
return None
with open(prompt_path, "r") as f:
prompt = f.read()
return prompt
这段代码定义了一个名为read_prompt
的函数,用于从文件中读取提示文本。以下是函数的主要参数和功能:
prompt_path
:提示文本文件的路径。函数的主要功能如下:
首先,它检查prompt_path
是否为None
。如果prompt_path
为None
,则函数直接返回None
,表示没有可用的提示文本。
如果prompt_path
不为None
,则使用with
语句打开文件,读取文件中的文本内容,并将其存储在prompt
变量中。
最后,函数返回读取的提示文本。
这个函数的目的是从指定的文件中读取提示文本,并将其返回,以供后续任务使用。如果文件路径为None
,则返回None
表示没有可用的提示文本。
def load_cache(output_file_path: str):
"""We don't want to query codex repeatedly for the same input. If an output file exists, this
function creates a "cache" of the results.
The cache is implemented as a hashmap keyed by `input_prompt`, and maps to the
entire output entry
Args:
output_file_path (str): _description_
"""
if not os.path.exists(output_file_path):
return {}
else:
# make a backup of the file already there
shutil.copyfile(output_file_path, output_file_path + "_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
shutil.copy(output_file_path, output_file_path + ".bak")
cache_data = pd.read_json(
output_file_path, orient='records', lines=True)
cache = {row['input_prompt']: row.to_dict()
for _, row in cache_data.iterrows()}
return cache
这段代码定义了一个名为load_cache
的函数,用于创建一个缓存以存储查询结果,以避免重复查询相同的输入。以下是函数的主要参数和功能:
output_file_path
:用于存储缓存的文件路径。函数的主要功能如下:
首先,它检查是否存在指定的output_file_path
文件。如果文件不存在,它会返回一个空的缓存字典(空的哈希映射)。
如果文件存在,它会执行以下操作:
output_file_path.bak
的文件,作为原始文件的备份副本。output_file_path
文件中读取数据,以解析缓存的内容。cache
,其中每个缓存条目的键是input_prompt
,值是整个输出条目的字典表示。最后,函数返回创建的缓存字典。
这个函数的目的是在指定的文件中创建一个缓存,用于存储查询的结果。如果文件不存在,它返回一个空的缓存字典。如果文件存在,它会创建一个包含缓存数据的字典,以便在以后的查询中可以快速查找和检索已缓存的结果。同时,它也会对原始文件进行备份,以便在需要时可以还原。
def query_codex(task: dict, prompt_text: str, engine: str, max_tokens: int):
prompt = f"{prompt_text} {task['input_prompt']}"
response = OpenaiAPIWrapper.call(
prompt=prompt, max_tokens=max_tokens, engine=engine)
return response
这段代码定义了一个名为query_codex
的函数,用于向Codex(可能是GPT-3或其他类似的AI引擎)发出查询以获取生成的代码。以下是函数的主要参数和功能:
task
(dict):任务的输入,是一个包含任务信息的字典。prompt_text
(str):用于查询的提示文本。engine
(str):用于执行任务的引擎名称。max_tokens
(int):生成的代码的最大令牌数。函数的主要功能如下:
构建完整的提示文本,将输入提示(task['input_prompt']
)附加到传入的提示文本(prompt_text
)后面。
使用OpenaiAPIWrapper
中的call
方法,向Codex引擎发送查询请求,传递生成的提示文本、最大令牌数(max_tokens
)和引擎名称(engine
)。
函数将Codex的响应作为结果返回。
这个函数的主要目的是执行查询,以获取与给定任务和提示相关的生成代码。它使用提供的引擎和参数来发出查询请求,并将Codex的响应返回供后续处理。
def get_completed_code(task: dict, codex_response: dict) -> str:
completed_code = OpenaiAPIWrapper.parse_response(codex_response)
all_code = f"{task['input_prompt']}{completed_code}"
# NOTE: space is already taken care of, no need to add it again, otherwise indentation will be off
return all_code
这段代码定义了一个名为get_completed_code
的函数,用于从Codex的响应中提取生成的代码。以下是函数的主要参数和功能:
task
(dict):任务的输入,是一个包含任务信息的字典。codex_response
(dict):Codex的响应,可能包含生成的代码。函数的主要功能如下:
使用OpenaiAPIWrapper
中的parse_response
方法,从Codex的响应(codex_response
)中提取生成的代码,并将其存储在completed_code
变量中。
将生成的代码与任务的输入提示(task['input_prompt']
)合并,以获得完整的代码。这是通过将输入提示和生成的代码连接在一起来实现的,不需要额外的空格或缩进。
返回包含完整代码的字符串(all_code
)。
这个函数的目的是从Codex的响应中提取生成的代码,并将其与任务的输入提示合并,以获得完整的生成代码。这个生成的代码可以在后续任务中使用或记录下来。
def get_request_per_minute(num_request: int, begin_time: float) -> float:
elapsed_time = time.time() - begin_time
request_per_minute = (num_request / elapsed_time) * 60
return request_per_minute
这段代码定义了一个名为get_request_per_minute
的函数,用于计算每分钟的请求速率。以下是函数的主要参数和功能:
num_request
(int):已经发出的请求数。begin_time
(float):开始计时的时间。函数的主要功能如下:
计算从begin_time
开始到当前时间的经过的时间(elapsed_time
),使用time.time()
函数来获取当前时间戳。
使用已发出的请求数(num_request
)除以经过的时间(elapsed_time
),然后乘以60,以得到每分钟的请求速率。
返回计算出的请求速率(request_per_minute
)。
这个函数的目的是根据已经发出的请求数和经过的时间来计算每分钟的请求速率,以便在维持请求速率时使用。
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--task_file_path", type=str, required=True)
parser.add_argument("--num_tasks", type=int, required=True)
parser.add_argument("--start_idx", type=int, required=True)
parser.add_argument("--output_file_path", type=str, required=True)
parser.add_argument("--prompt_path", type=str,
required=False, default=None)
parser.add_argument("--job_type", type=str, required=True,
choices=ConverterFactory.supported_converters)
parser.add_argument("--keep_writing_output",
action="store_true", default=True)
parser.add_argument("--engine", type=str, required=True)
parser.add_argument("--max_requests_per_min", type=int, default=10)
parser.add_argument("--max_tokens", type=int, default=280)
parser.add_argument("--schema_path", type=str, required=True)
parser.add_argument("--map_config_path", type=str, required=True)
parser.add_argument("--start_cut_num", type=int, default=0)
args = parser.parse_args()
run(task_file_path=args.task_file_path, num_tasks=args.num_tasks,start_idx=args.start_idx,
output_file_path=args.output_file_path, prompt_path=args.prompt_path,
keep_writing_output=args.keep_writing_output, engine=args.engine,
max_requests_per_min=args.max_requests_per_min,
max_tokens=args.max_tokens,schema_path=args.schema_path,
map_config_path=args.map_config_path,start_cut_num=args.start_cut_num)
这段代码是一个主程序入口,用于执行任务。它使用argparse
模块来解析命令行参数,根据用户提供的参数来调用run
函数,执行任务。以下是它的主要功能和参数:
使用argparse
创建一个命令行解析器(parser
)。
添加一系列命令行参数,包括:
task_file_path
:任务文件的路径。num_tasks
:任务的数量。start_idx
:任务的起始索引。output_file_path
:输出文件的路径。prompt_path
:提示文件的路径(可选参数,默认为None
)。job_type
:作业类型,从支持的作业类型中选择。keep_writing_output
:一个布尔标志,指示是否保持写入输出(默认为True
)。engine
:用于执行任务的引擎名称。max_requests_per_min
:每分钟的最大请求数(默认为10)。max_tokens
:生成的代码的最大令牌数(默认为280)。schema_path
:模式文件的路径。map_config_path
:映射配置文件的路径。start_cut_num
:一个整数,用于指示从输入提示中删除的示例数量(默认为0)。使用parser.parse_args()
解析命令行参数,并将结果存储在args
变量中。
调用run
函数,传递解析后的参数,以执行任务。
这个主程序入口的目的是允许用户从命令行传递参数来配置任务的执行,然后根据这些参数来执行任务。这是一个通用的模板,可以根据不同的需求和任务来定制。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
from src.converters.structure_converter import StructureConverter
from src.converters.record import EntityRecord, RelationRecord
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
class NLSELPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"""
text = input_dict['text']
record = input_dict['record']
prompt = []
input = ['The text is : ',
"\"" + text + "\". ",
"The named entities in the text: "
]
prompt.extend(input)
if prompt_part_only:
return ''.join(prompt)
return ''.join(prompt) + '\n'
def output_to_structure(self, input_dict, output_str):
"""
sample:
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'tokens': ['CRICKET', '-', 'LEICESTERSHIRE', 'TAKE', 'OVER', 'AT', 'TOP', 'AFTER', 'INNINGS', 'VICTORY', '.'],
'record': '<extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'relation': [],
'event': [],
'spot': ['organization'],
'asoc': [],
'spot_asoc': [{'span': 'LEICESTERSHIRE', 'label': 'organization', 'asoc': []}]}
code:
The text is : "CRICKET -
LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .".
Find named entities such as organization, person,
miscellaneous, location in the text. The organization
"LEICESTERSHIRE" exist in the text.
:param sample:
:param code:
:return:
"""
text = input_dict['text']
tokens = input_dict['tokens']
sel2record = SEL2Record(
schema_dict=self.schema_dict,
decoding_schema=self.decoding,
map_config=self.map_config,
)
pattern = re.compile(r"The named entities in the text:\s*(.*)")
pred = re.search(pattern, output_str).group(1)
# print ("pred: ")
# print (pred)
pred_record = sel2record.sel2record(pred, text, tokens)
return pred_record
if __name__ == "__main__":
schema_folder = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line = fin.readline()
line = eval(line.strip())
data = line
# print ("dev data:\n", data)
converter = NLSELPromptCreator(schema_folder=schema_folder,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
# print ("prompt:\n", prompt)
# we have to provide the init state to the sample
# prompt = converter.generate_sample_head(data)
# print("sample head: ", prompt)
code = """
The text is : "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .". The named entities in the text: <extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>
"""
data = {"text":"Enterprises from domestic coastal provinces and cities increased , and there are altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting .","tokens":["Enterprises","from","domestic","coastal","provinces","and","cities","increased",",","and","there","are","altogether","30","enterprise","representatives","from","30","provinces",",","cities","and","autonomous","regions","coming","to","this","meeting","."],"entity":[{"type":"geographical social political","offset":[2,3,4,5,6],"text":"domestic coastal provinces and cities"},{"type":"geographical social political","offset":[17,18,19,20,21,22,23],"text":"30 provinces , cities and autonomous regions"},{"type":"organization","offset":[0,1,2,3,4,5,6],"text":"Enterprises from domestic coastal provinces and cities"},{"type":"person","offset":[12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],"text":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting"}],"relation":[],"event":[],"spot":["person","organization","geographical social political"],"asoc":[],"spot_asoc":[{"span":"Enterprises from domestic coastal provinces and cities","label":"organization","asoc":[]},{"span":"domestic coastal provinces and cities","label":"geographical social political","asoc":[]},{"span":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting","label":"person","asoc":[]},{"span":"30 provinces , cities and autonomous regions","label":"geographical social political","asoc":[]}]}
code = r'The text is : \"Enterprises from domestic coastal provinces and cities increased , and there are altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting .\". The named entities in the text: <extra_id_0> <extra_id_0> organization <extra_id_5> Enterprises from domestic coastal provinces and cities <extra_id_1> <extra_id_0> geographical social political <extra_id_5> domestic coastal provinces and cities <extra_id_1> <extra_id_0> person <extra_id_5> altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions <extra_id_1> <extra_id_0> geographical social political <extra_id_5> provinces , cities and autonomous regions <extra_id_1> <extra_id_0> geographical social political <extra_id_5> provinces <extra_id_1> <extra_id_0> geographical social political <extra_id_5> cities <extra_id_1> <extra_id_0> geographical social political <extra_id_5> autonomous regions <extra_id_1> <extra_id_0> person <extra_id_5> this meeting <extra_id_1> <extra_id_1>\n'
# conver the prediction to the answers
predictions = converter.output_to_structure(data, code)
print (predictions)
这段代码主要用于将文本结构化成适用于自然语言处理任务的输入数据,然后将模型的输出结果转换回结构化的数据。代码的核心部分包括以下功能:
导入必要的库和模块,包括json
、re
、collections
、typing
等。还导入了一些自定义的模块,例如StructureConverter
和其他用于处理数据的模块。
定义了一个名为NLSELPromptCreator
的类,该类继承自StructureConverter
,用于将文本数据结构化为适用于自然语言处理任务的输入,并将模型的输出结果还原为结构化的数据。
类中的__init__
方法用于初始化类的属性,包括加载模式文件、解码模式、实体模式、关系模式和其他配置信息。
structure_to_input
方法用于将输入数据结构化为适用于模型的输入格式。它接受输入字典和一个布尔参数prompt_part_only
,根据输入数据生成用于模型的提示文本。
output_to_structure
方法用于将模型的输出结果还原为结构化的数据。它接受输入字典和模型的输出字符串,并使用自定义的SEL2Record
类进行解析,以还原结构化数据。
在if __name__ == "__main__":
部分,脚本展示了如何使用NLSELPromptCreator
类来处理输入数据并还原模型的输出结果。具体来说,它加载了模式文件、配置文件和示例数据,然后调用structure_to_input
方法生成模型输入的提示文本,接着将模型的输出结果传递给output_to_structure
方法,以还原结构化数据。
总的来说,这段代码是用于数据处理和转换的工具,特别适用于将自然语言文本转化为适用于特定任务的输入格式,以及将模型的输出结果还原为结构化数据。这对于自然语言处理任务中的数据预处理和后处理非常有用。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
import numpy as np
from src.converters.structure_converter import StructureConverter
from src.converters.record import EntityRecord, RelationRecord
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
class NLSELPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"""
text = input_dict['text']
record = input_dict['record']
prompt = []
input = ['The text is : ',
"\"" + text + "\". ",
"The named entities in the text: "
]
prompt.extend(input)
if prompt_part_only:
return ''.join(prompt)
record = record.replace('extra_id_','')
prompt.append(record)
return ''.join(prompt) + '\n'
def existing_nested(self, entity_dict_list):
entity_offset = []
for ent in entity_dict_list:
tmp_offset = ent['offset']
entity_offset.append(tmp_offset)
sorted_offset = sorted(entity_offset)
start = -1
end = -1
for so in sorted_offset:
temp_s, temp_e = so[0],so[-1]
if temp_s <= end:
return True
start = temp_s
end = temp_e
return False
def output_to_structure(self, input_dict, output_str):
"""
sample:
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'tokens': ['CRICKET', '-', 'LEICESTERSHIRE', 'TAKE', 'OVER', 'AT', 'TOP', 'AFTER', 'INNINGS', 'VICTORY', '.'],
'record': '<extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'relation': [],
'event': [],
'spot': ['organization'],
'asoc': [],
'spot_asoc': [{'span': 'LEICESTERSHIRE', 'label': 'organization', 'asoc': []}]}
code:
The text is : "CRICKET -
LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .".
Find named entities such as organization, person,
miscellaneous, location in the text. The organization
"LEICESTERSHIRE" exist in the text.
:param sample:
:param code:
:return:
"""
text = input_dict['text']
tokens = input_dict['tokens']
entity = input_dict['entity']
exist_nested = self.existing_nested(entity)
sel2record = SEL2Record(
schema_dict=self.schema_dict,
decoding_schema=self.decoding,
map_config=self.map_config,
)
pattern = re.compile(r"The named entities in the text:\s*(.*)")
pred = re.search(pattern, output_str).group(1)
pred = pred.strip()
#
# print ("text: ", text)
# print ("output_str: ", output_str)
# print ("pred: ", pred)
pred_record = sel2record.sel2record(pred, text, tokens)
pred_record['statistic']['complex'] = exist_nested
return pred_record
if __name__ == "__main__":
schema_folder = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line = fin.readline()
line = eval(line.strip())
data = line
# print ("dev data:\n", data)
converter = NLSELPromptCreator(schema_folder=schema_folder,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
print ("prompt:\n", prompt)
# we have to provide the init state to the sample
# prompt = converter.generate_sample_head(data)
# print("sample head: ", prompt)
# code = """
# The text is : "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .". The named entities in the text: <extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>
# """
# data = {"text":"Enterprises from domestic coastal provinces and cities increased , and there are altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting .","tokens":["Enterprises","from","domestic","coastal","provinces","and","cities","increased",",","and","there","are","altogether","30","enterprise","representatives","from","30","provinces",",","cities","and","autonomous","regions","coming","to","this","meeting","."],"entity":[{"type":"geographical social political","offset":[2,3,4,5,6],"text":"domestic coastal provinces and cities"},{"type":"geographical social political","offset":[17,18,19,20,21,22,23],"text":"30 provinces , cities and autonomous regions"},{"type":"organization","offset":[0,1,2,3,4,5,6],"text":"Enterprises from domestic coastal provinces and cities"},{"type":"person","offset":[12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],"text":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting"}],"relation":[],"event":[],"spot":["person","organization","geographical social political"],"asoc":[],"spot_asoc":[{"span":"Enterprises from domestic coastal provinces and cities","label":"organization","asoc":[]},{"span":"domestic coastal provinces and cities","label":"geographical social political","asoc":[]},{"span":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting","label":"person","asoc":[]},{"span":"30 provinces , cities and autonomous regions","label":"geographical social political","asoc":[]}]}
# code = r'The text is : "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .". The named entities in the text: <0> <0> organization <5> LEICESTERSHIRE <1> <1>\n'
code = repr(prompt)
# conver the prediction to the answers
predictions = converter.output_to_structure(data, code)
print (predictions)
这段代码是一个Python脚本,与前一个代码段非常相似,也是用于将文本结构化为适用于自然语言处理任务的输入数据,并将模型的输出结果还原为结构化数据。不过在这个代码段中,有一些新增的功能和修改:
导入了一些额外的库和模块,如numpy
,用于在代码中进行一些数学运算。
NLSELPromptCreator
类的构造函数中,引入了一个新的方法existing_nested
,用于检测输入文本中是否存在嵌套的实体。
structure_to_input
方法中,生成的prompt
中还包括了从输入数据中提取的record
信息。
output_to_structure
方法中,通过新的existing_nested
方法检测输入数据中是否存在嵌套实体,并将结果存储在生成的结构化数据中。
在if __name__ == "__main__":
部分,脚本加载了模式文件、配置文件和示例数据,然后调用NLSELPromptCreator
类的方法,将输入数据结构化为适用于模型的提示文本,并将模型的输出结果还原为结构化数据。此外,还对输入数据进行了一些修改,以测试新功能。
总的来说,这段代码与前一个代码段非常相似,但在一些细节上进行了一些修改和新增功能。它仍然是用于处理和转换文本数据的工具,特别适用于自然语言处理任务中的数据预处理和后处理。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
from src.converters.structure_converter import StructureConverter
from src.converters.record import EntityRecord, RelationRecord
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
class NLSELPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"""
text = input_dict['text']
record = input_dict['record']
prompt = []
input = ['The text is : ',
"\"" + text + "\". ",
"The named entities in the text: "
]
prompt.extend(input)
if prompt_part_only:
return ''.join(prompt)
record = record.replace('extra_id_','')
record = record.lstrip('<0>').rstrip('<1>').strip()
record = record.split('<1>')
record = [rec.strip().lstrip('<0>').strip() for rec in record]
record_new = []
for rec in record:
if rec != '':
temp_str = rec
temp_tuple = temp_str.split('<5>')
assert len(temp_tuple) == 2
temp_tuple = [tt.strip() for tt in temp_tuple]
new_str = f'"{temp_tuple[1]}" is "{temp_tuple[0]}" .'
record_new.append(new_str)
record = ' '.join(record_new)
prompt.append(record)
return ''.join(prompt) + '\n'
def output_to_structure(self, input_dict, output_str):
"""
sample:
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'tokens': ['CRICKET', '-', 'LEICESTERSHIRE', 'TAKE', 'OVER', 'AT', 'TOP', 'AFTER', 'INNINGS', 'VICTORY', '.'],
'record': '<extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'relation': [],
'event': [],
'spot': ['organization'],
'asoc': [],
'spot_asoc': [{'span': 'LEICESTERSHIRE', 'label': 'organization', 'asoc': []}]}
code:
The text is : "CRICKET -
LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .".
Find named entities such as organization, person,
miscellaneous, location in the text. The organization
"LEICESTERSHIRE" exist in the text.
:param sample:
:param code:
:return:
"""
text = input_dict['text']
tokens = input_dict['tokens']
sel2record = SEL2Record(
schema_dict=self.schema_dict,
decoding_schema=self.decoding,
map_config=self.map_config,
)
pattern = re.compile(r"The named entities in the text:\s*(.*)")
pred = re.search(pattern, output_str).group(1)
pattern = re.compile(r"\"(.*?)\"\sis\s\"(.*?)\"\s.")
pred = pattern.findall(pred)
pred = [(p[1],p[0]) for p in pred]
pred = [' <5> '.join(p) for p in pred]
pred = ['<0> ' + p + ' <1>' for p in pred]
pred = ' '.join(pred)
pred = '<0> ' + pred + ' <1>'
pred_record = sel2record.sel2record(pred, text, tokens)
return pred_record
if __name__ == "__main__":
schema_folder = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line = fin.readline()
line = fin.readline()
line = eval(line.strip())
data = line
# print ("dev data:\n", data)
converter = NLSELPromptCreator(schema_folder=schema_folder,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
print ("prompt:\n", prompt)
# we have to provide the init state to the sample
# prompt = converter.generate_sample_head(data)
# print("sample head: ", prompt)
# code = """
# The text is : "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .". The named entities in the text: <extra_id_0> <extra_id_0> organization <extra_id_5> LEICESTERSHIRE <extra_id_1> <extra_id_1>
# """
# data = {"text":"Enterprises from domestic coastal provinces and cities increased , and there are altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting .","tokens":["Enterprises","from","domestic","coastal","provinces","and","cities","increased",",","and","there","are","altogether","30","enterprise","representatives","from","30","provinces",",","cities","and","autonomous","regions","coming","to","this","meeting","."],"entity":[{"type":"geographical social political","offset":[2,3,4,5,6],"text":"domestic coastal provinces and cities"},{"type":"geographical social political","offset":[17,18,19,20,21,22,23],"text":"30 provinces , cities and autonomous regions"},{"type":"organization","offset":[0,1,2,3,4,5,6],"text":"Enterprises from domestic coastal provinces and cities"},{"type":"person","offset":[12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],"text":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting"}],"relation":[],"event":[],"spot":["person","organization","geographical social political"],"asoc":[],"spot_asoc":[{"span":"Enterprises from domestic coastal provinces and cities","label":"organization","asoc":[]},{"span":"domestic coastal provinces and cities","label":"geographical social political","asoc":[]},{"span":"altogether 30 enterprise representatives from 30 provinces , cities and autonomous regions coming to this meeting","label":"person","asoc":[]},{"span":"30 provinces , cities and autonomous regions","label":"geographical social political","asoc":[]}]}
# code = r'The text is : "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .". The named entities in the text: <0> <0> organization <5> LEICESTERSHIRE <1> <1>\n'
code = repr(prompt)
# conver the prediction to the answers
predictions = converter.output_to_structure(data, code)
print (predictions)
这段代码也是与前面的代码段非常相似,它仍然是用于将文本结构化为适用于自然语言处理任务的输入数据,并将模型的输出结果还原为结构化数据。在这个代码段中,主要的改变包括:
structure_to_input
方法现在可以正确地处理record
信息,将其从模型的输出中提取并格式化为更易读的文本。
output_to_structure
方法在处理模型输出时,根据新的格式化规则,将模型的输出解析为结构化数据。这包括解析文本并将其还原为实体。
if __name__ == "__main__":
部分加载了模式文件、配置文件和示例数据,然后调用NLSELPromptCreator
类的方法,将输入数据结构化为适用于模型的提示文本,并将模型的输出结果还原为结构化数据。此外,还对输入数据进行了一些修改,以测试新功能。
总的来说,这段代码仍然是用于文本数据的处理和转换工具,特别适用于自然语言处理任务中的数据预处理和后处理。它提供了更复杂的处理能力,可以正确处理record
信息,并按新的格式规则生成输出。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
from src.converters.structure_converter import StructureConverter
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import EntityRecord, RelationRecord
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
class PLFuncPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"""
text = input_dict['text']
entity_list = input_dict['entity']
spot_list = input_dict['spot']
prompt = []
goal = 'named entity extraction'
func_head = self.to_function_head(self.to_function_name(goal),input='input_text')
prompt.append(func_head)
docstring = '\t""" extract named entities from the input_text . """'
prompt.append(docstring)
input_text = f'\tinput_text = "{text}"'
prompt.append(input_text)
inline_annotation = '\t# extracted named entity list'
prompt.append(inline_annotation)
if prompt_part_only:
return self.list_to_str(prompt)
for spot in spot_list:
entity_list_name = self.to_function_name(spot) + '_list'
tmp_entity_text = []
for ent in entity_list:
if ent['type'] == spot:
ent_text = ent['text']
tmp_entity_text.append(f'"{ent_text}"')
prompt.append(f'\t{entity_list_name} = [' + ', '.join(tmp_entity_text) + ']')
prompt = self.list_to_str(prompt)
return prompt + '\n'
def output_to_structure(self, input_dict, output_str):
"""
input_dict:
{'text': 'West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .',
'tokens': ['West', 'Indian', 'all-rounder', 'Phil', 'Simmons', 'took', 'four', 'for', '38', 'on', 'Friday', 'as', 'Leicestershire', 'beat', 'Somerset', 'by', 'an', 'innings', 'and', '39', 'runs', 'in', 'two', 'days', 'to', 'take', 'over', 'at', 'the', 'head', 'of', 'the', 'county', 'championship', '.'],
'record': '<extra_id_0> <extra_id_0> miscellaneous <extra_id_5> West Indian <extra_id_1> <extra_id_0> person <extra_id_5> Phil Simmons <extra_id_1> <extra_id_0> organization <extra_id_5> Leicestershire <extra_id_1> <extra_id_0> organization <extra_id_5> Somerset <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [12], 'text': 'Leicestershire'}, {'type': 'person', 'offset': [3, 4], 'text': 'Phil Simmons'}, {'type': 'organization', 'offset': [14], 'text': 'Somerset'}, {'type': 'miscellaneous', 'offset': [0, 1], 'text': 'West Indian'}],
'relation': [], 'event': [], 'spot': ['person', 'organization', 'miscellaneous'], 'asoc': [],
'spot_asoc': [{'span': 'West Indian', 'label': 'miscellaneous', 'asoc': []}, {'span': 'Phil Simmons', 'label': 'person', 'asoc': []}, {'span': 'Leicestershire', 'label': 'organization', 'asoc': []}, {'span': 'Somerset', 'label': 'organization', 'asoc': []}]}
output_str:
def extract_named_entity(input_text):
# extract named entities from the input_text.
input_text = "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship ."
# extracted named entity list
person_list = ["Phil Simmons"]
organization_list = ["Leicestershire", "Somerset"]
miscellaneous_list = ["West Indian"]
:return:
"""
tokens = input_dict['tokens']
sent_records = {}
sent_records['entity'] = []
for entity_s in self.entity_schema:
temp_entities = re.findall(f'{self.to_function_name(entity_s)}_list' + r' = \[(.*?)\]', output_str)
if len(temp_entities) != 0:
temp_entities = temp_entities[0].split(", ")
temp_entity_list = [
{'text': e.strip(r'\"'), 'type': entity_s} for e in temp_entities
]
sent_records['entity'].extend(temp_entity_list)
offset_records = {}
record_map = EntityRecord(map_config=self.map_config)
offset_records['offset'] = record_map.to_offset(
instance=sent_records.get('entity', []),
tokens=tokens,
)
offset_records['string'] = record_map.to_string(
sent_records.get('entity', []),
)
"""
{'offset': [('opinion', (10,)), ('aspect', (11, 12)), ('opinion', (32,)), ('aspect', (34,))],
'string': [('opinion', 'soft'), ('aspect', 'rubber enclosure'), ('opinion', 'break'), ('aspect', 'seal')]}
"""
return {"entity": offset_records,"relation": {"offset": [], "string": []},"event": {"offset": [], "string": []}}
if __name__ == "__main__":
schema_path = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line0 = fin.readline()
line1 = fin.readline()
line = fin.readline()
line = eval(line.strip())
data = line
converter = PLFuncPromptCreator(schema_folder=schema_path,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
# convert the whole sample
# prompt = converter.structure_to_input(data, prompt_part_only=True)
# print ("prompt:\n", prompt)
data = {"text":"Two goals from defensive errors in the last six minutes allowed Japan to come from behind and collect all three points from their opening meeting against Syria .","tokens":["Two","goals","from","defensive","errors","in","the","last","six","minutes","allowed","Japan","to","come","from","behind","and","collect","all","three","points","from","their","opening","meeting","against","Syria","."],"entity":[{"type":"location","offset":[26],"text":"Syria"},{"type":"location","offset":[11],"text":"Japan"}],"relation":[],"event":[],"spot":["location"],"asoc":[],"spot_asoc":[{"span":"Japan","label":"location","asoc":[]},{"span":"Syria","label":"location","asoc":[]}]}
code = r'def named_entity_extraction(input_text):\n\t\"\"\" extract named entities from the input_text . \"\"\"\n\tinput_text = \"Two goals from defensive errors in the last six minutes allowed Japan to come from behind and collect all three points from their opening meeting against Syria .\"\n\t# extracted named entity list\n\tlocation_list = [\"Syria\"]\n'
print (data)
print (code)
# conver the prediction to the answers
predictions = converter.output_to_structure(data, code)
print ("output: \n")
print (predictions)
这段代码看起来是一个结构转换器,用于将文本数据转换为适合训练和输入到模型的格式,以及将模型的输出结果还原为结构化数据。在这个特定的示例中,它是为了从文本中提取命名实体而设计的。
以下是代码的一些关键部分:
structure_to_input
方法:这个方法将输入文本数据转化为一个Python函数的形式,函数用于从输入文本中提取命名实体。它生成一个Python函数头,包括函数名和参数(input_text
),然后生成一个文档字符串,指明函数的目的。接下来,它提供了输入文本,然后为每个命名实体类型生成一个Python列表,以存储从文本中提取的实体。最后,它将生成的代码合并成一个完整的Python函数。
output_to_structure
方法:这个方法用于将模型生成的代码输出还原为结构化数据。它从输出字符串中提取每个命名实体类型的实体列表,并将它们映射回原始文本的偏移位置。最后,它返回包含命名实体信息的字典,其中包括偏移位置和字符串表示。
if __name__ == "__main__":
部分加载了模式文件、配置文件和示例数据。然后,它使用 PLFuncPromptCreator
类将示例数据转化为模型输入的提示文本,并将模型的输出代码还原为结构化数据。这里还包括了一个硬编码的示例数据和相应的输出代码。
总的来说,这段代码是一个通用的结构转换器,可以用于将文本数据转化为适合输入到模型的格式,以及将模型的输出结果还原为结构化数据。这在自然语言处理任务中非常有用,特别是在需要处理命名实体提取的任务中。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
from src.converters.structure_converter import StructureConverter
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import EntityRecord, RelationRecord
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
"""
def extract_named_entity(input_text):
# extract named entities from the input_text .
input_text = "Steve became CEO of Apple in 1998"
# extracted named entities
person = ["Steve"]
organization = ["Apple"]
person = ["Steve"]
organization = ["Apple"]
"""
class PLFuncPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"spot_asoc":[{"span":"Japan","label":"location","asoc":[]},{"span":"Syria","label":"location","asoc":[]}]
"""
text = input_dict['text']
spot_asoc_list = input_dict['spot_asoc']
prompt = []
goal = 'named entity extraction'
func_head = self.to_function_head(self.to_function_name(goal),input='input_text')
prompt.append(func_head)
docstring = '\t""" extract named entities from the input_text . """'
prompt.append(docstring)
input_text = f'\tinput_text = "{text}"'
prompt.append(input_text)
inline_annotation = '\t# extracted named entities'
prompt.append(inline_annotation)
if prompt_part_only:
return self.list_to_str(prompt)
for sc in spot_asoc_list:
entity_text = sc['span']
entity_type = self.to_function_name(sc['label'])
prompt.append(f'\t{entity_type} = [ {entity_text} ]')
prompt = self.list_to_str(prompt)
return prompt + '\n'
def output_to_structure(self, input_dict, output_str):
"""
input_dict:
{'text': 'West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .',
'tokens': ['West', 'Indian', 'all-rounder', 'Phil', 'Simmons', 'took', 'four', 'for', '38', 'on', 'Friday', 'as', 'Leicestershire', 'beat', 'Somerset', 'by', 'an', 'innings', 'and', '39', 'runs', 'in', 'two', 'days', 'to', 'take', 'over', 'at', 'the', 'head', 'of', 'the', 'county', 'championship', '.'],
'record': '<extra_id_0> <extra_id_0> miscellaneous <extra_id_5> West Indian <extra_id_1> <extra_id_0> person <extra_id_5> Phil Simmons <extra_id_1> <extra_id_0> organization <extra_id_5> Leicestershire <extra_id_1> <extra_id_0> organization <extra_id_5> Somerset <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [12], 'text': 'Leicestershire'}, {'type': 'person', 'offset': [3, 4], 'text': 'Phil Simmons'}, {'type': 'organization', 'offset': [14], 'text': 'Somerset'}, {'type': 'miscellaneous', 'offset': [0, 1], 'text': 'West Indian'}],
'relation': [], 'event': [], 'spot': ['person', 'organization', 'miscellaneous'], 'asoc': [],
'spot_asoc': [{'span': 'West Indian', 'label': 'miscellaneous', 'asoc': []}, {'span': 'Phil Simmons', 'label': 'person', 'asoc': []}, {'span': 'Leicestershire', 'label': 'organization', 'asoc': []}, {'span': 'Somerset', 'label': 'organization', 'asoc': []}]}
output_str:
def extract_named_entity(input_text):
# extract named entities from the input_text.
input_text = "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship ."
# extracted named entity list
person_list = ["Phil Simmons"]
organization_list = ["Leicestershire", "Somerset"]
miscellaneous_list = ["West Indian"]
:return:
"""
tokens = input_dict['tokens']
sent_records = {}
sent_records['entity'] = []
for entity_s in self.entity_schema:
temp_entities = re.findall(f'{self.to_function_name(entity_s)}' + r' = \[(.*?)\]', output_str)
if len(temp_entities) != 0:
temp_entity_list = [
{'text': e.strip(), 'type': entity_s} for e in temp_entities
]
sent_records['entity'].extend(temp_entity_list)
offset_records = {}
record_map = EntityRecord(map_config=self.map_config)
offset_records['offset'] = record_map.to_offset(
instance=sent_records.get('entity', []),
tokens=tokens,
)
offset_records['string'] = record_map.to_string(
sent_records.get('entity', []),
)
"""
{'offset': [('opinion', (10,)), ('aspect', (11, 12)), ('opinion', (32,)), ('aspect', (34,))],
'string': [('opinion', 'soft'), ('aspect', 'rubber enclosure'), ('opinion', 'break'), ('aspect', 'seal')]}
"""
return {"entity": offset_records,"relation": {"offset": [], "string": []},"event": {"offset": [], "string": []}}
if __name__ == "__main__":
schema_path = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line0 = fin.readline()
line1 = fin.readline()
line = fin.readline()
line = eval(line.strip())
data = line
print ('data: ', data)
print ('data keys: ', data.keys())
converter = PLFuncPromptCreator(schema_folder=schema_path,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
print ("prompt:\n", prompt)
code = repr(prompt)
# conver the prediction to the answers
predictions = converter.output_to_structure(data, code)
print ("output: \n")
print (predictions)
这段代码是为了执行两个主要任务:
将结构化数据(如文本和实体)转化为 Python 函数的形式,该函数用于从输入文本中提取命名实体。该函数被创建为 extract_named_entity(input_text)
,其中 input_text
是函数的输入参数,用于传入要分析的文本。在函数内部,命名实体被提取并存储在不同的列表中(如 person
和 organization
)。此函数还包含有关如何从输入文本中提取命名实体的注释。
将生成的代码字符串(Python 函数)反转回结构化数据。这是通过分析生成的 Python 代码字符串以提取从文本中提取的命名实体完成的。函数的返回值是包含偏移位置和字符串表示的字典,其中包括命名实体信息。
代码的执行如下:
它首先加载了模式文件和配置文件,然后从示例数据中获取一个示例。
然后,使用 PLFuncPromptCreator
类将示例数据转化为模型输入的提示文本,生成的 Python 函数包括示例中提到的命名实体类型和实体。
代码还执行了逆操作,将生成的代码字符串从函数中提取出结构化数据。
总之,这段代码是一个示例,演示了如何将结构化数据转化为模型的输入提示文本,并如何从生成的代码中提取出结构化数据。这在自然语言处理任务中特别有用,尤其是在需要进行信息提取的场景中。
import json
import re
from collections import OrderedDict
from typing import List, Union, Dict, Tuple
from src.converters.structure_converter import StructureConverter
from src.utils.file_utils import load_yaml,load_schema
from uie.sel2record.record import EntityRecord, RelationRecord
from uie.sel2record.record import MapConfig
from uie.sel2record.sel2record import SEL2Record
"""
def extract_named_entity(input_text):
# extract named entities from the input_text .
input_text = "Steve became CEO of Apple in 1998"
# extracted named entities
person = ["Steve"]
organization = ["Apple"]
person = ["Steve"]
organization = ["Apple"]
"""
class PLFuncPromptCreator(StructureConverter):
def __init__(self, schema_folder=None, map_config_path=None):
self.schema_dict = SEL2Record.load_schema_dict(schema_folder)
self.decoding = 'spotasoc'
record_schema = self.schema_dict['record']
self.entity_schema = record_schema.type_list
self.relation_schema = record_schema.role_list
self.spot_asoc = record_schema.type_role_dict
self.map_config = MapConfig.load_from_yaml(map_config_path)
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False):
"""
{'text': 'CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .',
'entity': [{'type': 'organization', 'offset': [2], 'text': 'LEICESTERSHIRE'}],
'spot': ['organization'],
"spot_asoc":[{"span":"Japan","label":"location","asoc":[]},{"span":"Syria","label":"location","asoc":[]}]
"""
text = input_dict['text']
spot_asoc_list = input_dict['spot_asoc']
prompt = []
goal = 'named entity extraction'
func_head = self.to_function_head(self.to_function_name(goal),input='input_text')
prompt.append(func_head)
docstring = '\t""" extract named entities from the input_text . """'
prompt.append(docstring)
input_text = f'\tinput_text = "{text}"'
prompt.append(input_text)
inline_annotation = '\t# extracted named entities'
prompt.append(inline_annotation)
if prompt_part_only:
return self.list_to_str(prompt)
for sc in spot_asoc_list:
entity_text = sc['span']
entity_type = self.to_function_name(sc['label'])
prompt.append(f'\t{entity_type} = [ "{entity_text}" ]')
prompt = self.list_to_str(prompt)
return prompt + '\n'
def output_to_structure(self, input_dict, output_str):
"""
input_dict:
{'text': 'West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .',
'tokens': ['West', 'Indian', 'all-rounder', 'Phil', 'Simmons', 'took', 'four', 'for', '38', 'on', 'Friday', 'as', 'Leicestershire', 'beat', 'Somerset', 'by', 'an', 'innings', 'and', '39', 'runs', 'in', 'two', 'days', 'to', 'take', 'over', 'at', 'the', 'head', 'of', 'the', 'county', 'championship', '.'],
'record': '<extra_id_0> <extra_id_0> miscellaneous <extra_id_5> West Indian <extra_id_1> <extra_id_0> person <extra_id_5> Phil Simmons <extra_id_1> <extra_id_0> organization <extra_id_5> Leicestershire <extra_id_1> <extra_id_0> organization <extra_id_5> Somerset <extra_id_1> <extra_id_1>',
'entity': [{'type': 'organization', 'offset': [12], 'text': 'Leicestershire'}, {'type': 'person', 'offset': [3, 4], 'text': 'Phil Simmons'}, {'type': 'organization', 'offset': [14], 'text': 'Somerset'}, {'type': 'miscellaneous', 'offset': [0, 1], 'text': 'West Indian'}],
'relation': [], 'event': [], 'spot': ['person', 'organization', 'miscellaneous'], 'asoc': [],
'spot_asoc': [{'span': 'West Indian', 'label': 'miscellaneous', 'asoc': []}, {'span': 'Phil Simmons', 'label': 'person', 'asoc': []}, {'span': 'Leicestershire', 'label': 'organization', 'asoc': []}, {'span': 'Somerset', 'label': 'organization', 'asoc': []}]}
output_str:
def extract_named_entity(input_text):
# extract named entities from the input_text.
input_text = "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship ."
# extracted named entity list
person_list = ["Phil Simmons"]
organization_list = ["Leicestershire", "Somerset"]
miscellaneous_list = ["West Indian"]
:return:
"""
tokens = input_dict['tokens']
sent_records = {}
sent_records['entity'] = []
for entity_s in self.entity_schema:
temp_entities = re.findall(f'{self.to_function_name(entity_s)}' + r' = \[(.*?)\]', output_str)
if len(temp_entities) != 0:
temp_entity_list = [
{'text': e.strip().strip(r'\"') , 'type': entity_s} for e in temp_entities
]
sent_records['entity'].extend(temp_entity_list)
offset_records = {}
record_map = EntityRecord(map_config=self.map_config)
offset_records['offset'] = record_map.to_offset(
instance=sent_records.get('entity', []),
tokens=tokens,
)
offset_records['string'] = record_map.to_string(
sent_records.get('entity', []),
)
"""
{'offset': [('opinion', (10,)), ('aspect', (11, 12)), ('opinion', (32,)), ('aspect', (34,))],
'string': [('opinion', 'soft'), ('aspect', 'rubber enclosure'), ('opinion', 'break'), ('aspect', 'seal')]}
"""
return {"entity": offset_records,"relation": {"offset": [], "string": []},"event": {"offset": [], "string": []}}
if __name__ == "__main__":
schema_path = 'data/conll03'
map_config_path = 'config/offset_map/first_offset_en.yaml'
val_path = 'data/conll03/val.json'
with open(val_path) as fin:
line0 = fin.readline()
line1 = fin.readline()
line = fin.readline()
line = eval(line.strip())
data = line
# print ('data: ', data)
# print ('data keys: ', data.keys())
converter = PLFuncPromptCreator(schema_folder=schema_path,
map_config_path=map_config_path)
# convert the whole sample
prompt = converter.structure_to_input(data, prompt_part_only=False)
# print ("prompt:\n", prompt)
code = repr(prompt)
data = {"text":"China controlled most of the match and saw several chances missed until the 78th minute when Uzbek striker Igor Shkvyrin took advantage of a misdirected defensive header to lob the ball over the advancing Chinese keeper and into an empty net .","tokens":["China","controlled","most","of","the","match","and","saw","several","chances","missed","until","the","78th","minute","when","Uzbek","striker","Igor","Shkvyrin","took","advantage","of","a","misdirected","defensive","header","to","lob","the","ball","over","the","advancing","Chinese","keeper","and","into","an","empty","net","."],"entity":[{"type":"miscellaneous","offset":[16],"text":"Uzbek"},{"type":"miscellaneous","offset":[34],"text":"Chinese"},{"type":"person","offset":[18,19],"text":"Igor Shkvyrin"},{"type":"location","offset":[0],"text":"China"}],"relation":[],"event":[],"spot":["person","miscellaneous","location"],"asoc":[],"spot_asoc":[{"span":"China","label":"location","asoc":[]},{"span":"Uzbek","label":"miscellaneous","asoc":[]},{"span":"Igor Shkvyrin","label":"person","asoc":[]},{"span":"Chinese","label":"miscellaneous","asoc":[]}],"input_idx":5,"input_prompt":"def named_entity_extraction(input_text):\n\t\"\"\" extract named entities from the input_text . \"\"\"\n\tinput_text = \"China controlled most of the match and saw several chances missed until the 78th minute when Uzbek striker Igor Shkvyrin took advantage of a misdirected defensive header to lob the ball over the advancing Chinese keeper and into an empty net .\"\n\t# extracted named entities","reference_output":"def named_entity_extraction(input_text):\n\t\"\"\" extract named entities from the input_text . \"\"\"\n\tinput_text = \"China controlled most of the match and saw several chances missed until the 78th minute when Uzbek striker Igor Shkvyrin took advantage of a misdirected defensive header to lob the ball over the advancing Chinese keeper and into an empty net .\"\n\t# extracted named entities\n\tlocation = [ \"China\" ]\n\tmiscellaneous = [ \"Uzbek\" ]\n\tperson = [ \"Igor Shkvyrin\" ]\n\tmiscellaneous = [ \"Chinese\" ]\n"}
code = r'def named_entity_extraction(input_text):\n\t\"\"\" extract named entities from the input_text . \"\"\"\n\tinput_text = \"China controlled most of the match and saw several chances missed until the 78th minute when Uzbek striker Igor Shkvyrin took advantage of a misdirected defensive header to lob the ball over the advancing Chinese keeper and into an empty net .\"\n\t# extracted named entities\n\tlocation = [ \"China\" ]\n\tperson = [ \"Igor Shkvyrin\" ]\n\tlocation = [ \"Uzbek\" ]\n'
# conver the prediction to the answers
print (data)
print (code)
predictions = converter.output_to_structure(data, code)
print ("output: \n")
print (predictions)
这段代码是一个用于将结构化数据转换为Python函数及反向操作的示例。这个示例的主要目的是为了将输入的文本数据中提取命名实体,并将提取的结果转化为Python函数,以及反向操作,从Python函数中提取出命名实体的结构化数据。
首先,PLFuncPromptCreator
类的构造函数中加载了模式和配置文件,这些文件用于定义命名实体和其它实体之间的关系。
然后,在 structure_to_input
方法中,输入的结构化数据包括文本和相关的命名实体信息。该方法将这些信息转化为Python函数的形式,函数名为 extract_named_entity(input_text)
,其中 input_text
是输入文本,用于传递要提取命名实体的文本。函数内部包括了有关如何从文本中提取命名实体的注释,以及命名实体的提取结果。这些结果被以Python变量的形式嵌入到生成的代码中,如 person = ["Steve"]
和 organization = ["Apple"]
。
接下来,output_to_structure
方法执行反向操作,从生成的Python函数中提取出命名实体的结构化数据。它分析了生成的Python代码字符串以查找变量和命名实体,然后将它们整理成一个结构化的数据格式。这个数据格式包括命名实体的类型、文本和偏移位置。
最后,代码示例提供了一个示例数据,以及使用前述方法生成的Python代码。然后,使用 output_to_structure
方法将生成的Python代码转化为结构化数据,以验证反向操作是否能够成功提取出命名实体的信息。
总的来说,这个示例是一个简单的端到端示例,演示了如何将结构化数据转化为Python函数形式,以及如何从Python函数中提取出结构化数据。这对于自然语言处理和信息提取任务非常有用。
# NER tasks
from src.converters.ner.structure2pl_func_v5 import PLFuncPromptCreator as NERPLFuncPromptCreator
from src.converters.ner.structure2nl_sel_v2 import NLSELPromptCreator as NERNLSELPromptCreator
# RE tasks
from src.converters.re.structure2pl_func_v5 import PLFuncPromptCreator as REPLFuncPromptCreator
from src.converters.re.structure2nl_sel_v2 import NLSELPromptCreator as RENLSELPromptCreator
class ConverterFactory:
converter_to_class = {
# ner
"ner-pl-func": NERPLFuncPromptCreator,
"ner-nl-sel": NERNLSELPromptCreator,
# re
"re-pl-func": REPLFuncPromptCreator,
"re-nl-sel": RENLSELPromptCreator
}
supported_converters = list(converter_to_class.keys())
@staticmethod
def get_converter(job_type: str, **kwargs):
if job_type not in ConverterFactory.supported_converters:
raise ValueError(f"Unsupported job type: {job_type}")
return ConverterFactory.converter_to_class[job_type](**kwargs)
这段代码是一个名为 ConverterFactory
的类,它用于创建不同类型的文本结构到Python函数的转换器。这些转换器是为命名实体识别(NER)和关系抽取(RE)等自然语言处理任务设计的。以下是一些关键细节:
ConverterFactory
类包括以下子类转换器,每个子类用于不同的任务:
NERPLFuncPromptCreator
用于生成Python函数的NER任务。NERNLSELPromptCreator
用于生成NL-Sel(自然语言到结构化查询语言)的NER任务。REPLFuncPromptCreator
用于生成Python函数的RE任务。RENLSELPromptCreator
用于生成NL-Sel的RE任务。ConverterFactory
类维护了一个名为 converter_to_class
的字典,该字典将任务类型(例如:“ner-pl-func”)映射到相应的转换器类。这使得根据任务类型选择正确的转换器变得非常方便。
supported_converters
列表包含了所有受支持的任务类型。通过检查任务类型是否包含在此列表中,您可以验证是否支持所请求的任务类型。
get_converter
方法是工厂的核心方法,用于根据任务类型返回适当的转换器实例。如果请求的任务类型不受支持,它会引发一个值错误(ValueError
)。
总之,ConverterFactory
类提供了一个通用的接口,用于根据任务类型选择适当的转换器。这使得在不同的自然语言处理任务中,使用不同类型的文本结构转换变得更加方便和模块化。
这段代码定义了一些用于文本转换和映射的类和函数。具体来说,它包括以下部分:
MapConfig
类:这个类用于配置映射策略,包括map_strategy
(映射策略)、de_duplicate
(是否去重)、span_to_token
(用于将文本转换为标记的策略)等。
Record
类:这是一个基类,其他特定记录类(EntityRecord
、RelationRecord
、EventRecord
)都继承自它。它包含一个span_to_token
方法,用于将文本转换为标记。
EntityRecord
类:用于将生成的字符串转换为包含实体信息(类型和范围)的记录。这个类包括方法,可以将生成的记录列表转换为实体信息,还可以将这些实体信息映射到文本中的标记位置。
RelationRecord
类:用于将生成的字符串转换为包含关系信息(关系类型、参数1类型、参数1范围、参数2类型、参数2范围)的记录。这个类包括方法,可以将生成的记录列表转换为关系信息,还可以将这些关系信息映射到文本中的标记位置。
EventRecord
类:用于将生成的字符串转换为包含事件信息(事件类型、触发词范围、角色信息)的记录。这个类包括方法,可以将生成的记录列表转换为事件信息,还可以将这些事件信息映射到文本中的标记位置。
总之,这段代码主要用于处理自然语言文本中的信息抽取、关系抽取和事件抽取任务,提供了不同的映射策略和工具函数以便进行这些任务。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from asyncio.log import logger
import numpy
from src.converters.utils import span_to_token, match_sublist, check_overlap, get_index_tuple
import logging
logger = logging.getLogger("__main__")
class MapConfig:
def __init__(self,
map_strategy: str = 'first',
de_duplicate: bool = True,
span_to_token: str = 'space') -> None:
self.map_strategy = map_strategy
self.de_duplicate = de_duplicate
self.span_to_token = span_to_token
def __repr__(self) -> str:
repr_list = [
f"map_strategy: {self.map_strategy}",
f"de_duplicate: {self.de_duplicate}",
f"span_to_token: {self.span_to_token}",
]
return ', '.join(repr_list)
@staticmethod
def load_from_yaml(config_file):
import yaml
with open(config_file) as fin:
config = yaml.load(fin, Loader=yaml.FullLoader)
return MapConfig(
map_strategy=config['map_strategy'],
de_duplicate=config['de_duplicate'],
span_to_token=config['span_to_token'],
)
class Record:
def __init__(self, map_config) -> None:
self._map_config = map_config
def span_to_token(self, text):
return span_to_token(text, span_to_token_strategy=self._map_config['span_to_token'])
class EntityRecord(Record):
""" Record for converting generated string to information record <type, span>
"""
@staticmethod
def to_string(pred_record_list):
entity_list = list()
for pred_record in pred_record_list:
record_type, record_text = pred_record['type'], pred_record['text']
if record_text == "":
logger.warning(f"Empty Extraction {pred_record}")
continue
entity_list += [(record_type, record_text)]
return entity_list
def to_offset(self, instance, tokens):
map_strategy_dict = {
'first': self.record_to_offset_first_role,
'closest': self.record_to_offset_closest_role,
'longer_first': self.record_to_offset_longer_first,
}
if self._map_config['map_strategy'] in map_strategy_dict:
map_function = map_strategy_dict[self._map_config['map_strategy']]
return map_function(
instance=instance,
token_list=tokens, )
else:
raise NotImplementedError(
f"The map strategy {self._map_config.map_strategy} in {self.__class__} is not implemented."
)
def record_to_offset_closest_role(
self,
instance,
token_list, ):
"""
Find Role's offset using closest matched with trigger word.
:param instance:
:return:
"""
return self.record_to_offset_first_role(instance, token_list=token_list)
def record_to_offset_first_role(self, instance, token_list):
"""
Find Entity's offset using first matched in the sentence.
:param instance:
:return:
"""
entity_list = list()
entity_matched_set = set()
for pred_record in instance:
record_type, record_text = pred_record['type'], pred_record['text']
if record_text == "":
logger.warning(f"Empty Extraction {pred_record}")
continue
matched_list = match_sublist(token_list,
self.span_to_token(record_text))
for matched in matched_list:
if (record_type, matched) not in entity_matched_set:
entity_list += [(record_type,
tuple(range(matched[0], matched[1] + 1)))]
entity_matched_set.add((record_type, matched))
break
return entity_list
def record_to_offset_longer_first(self, instance, token_list):
"""
Find Entity's offset using first matched in the sentence.
:param instance:
:return:
"""
entity_list = list()
entity_matched_set = set()
for x in instance:
x['length'] = len(x['text'])
instance.sort(reverse=True, key=lambda x: x['length'])
for pred_record in instance:
record_type, record_text = pred_record['type'], pred_record['text']
if record_text == "":
logger.warning(f"Empty Extraction {pred_record}")
continue
matched_list = match_sublist(token_list,
self.span_to_token(record_text))
for matched in matched_list:
flag = False
for _, g in entity_matched_set:
if check_overlap(g, matched):
flag = True
if flag:
continue
if (record_type, matched) not in entity_matched_set:
entity_list += [(record_type,
tuple(range(matched[0], matched[1] + 1)))]
entity_matched_set.add((record_type, matched))
break
return entity_list
class RelationRecord(Record):
""" Record for converting generated string to information record
<type, arg1_type, arg1_span, arg2_type, arg2_span>
"""
def to_offset(self, instance, tokens):
map_strategy_dict = {
'first': self.record_to_offset_first_role,
'closest': self.record_to_offset_closest_role,
'longer_first': self.record_to_offset_closest_role,
}
if self._map_config['map_strategy'] in map_strategy_dict:
map_function = map_strategy_dict[self._map_config['map_strategy']]
return map_function(
instance=instance,
token_list=tokens, )
else:
raise NotImplementedError(
f"The map strategy {self._map_config['map_strategy']} in {self.__class__} is not implemented."
)
@staticmethod
def to_string(instance):
relation_list = list()
for record in instance:
relation_type = record['type']
relation = [relation_type]
if len(record['roles']) < 2:
continue
for role_type, text_str in record['roles'][:2]:
relation += [role_type, text_str]
relation_list += [tuple(relation)]
return relation_list
def record_to_offset_first_role(self, instance, token_list):
"""
Find Role's offset using first matched in the sentence.
:param instance:
:return:
"""
relation_list = list()
for record in instance:
relation_type = record['type']
if len(record['roles']) < 2:
continue
relation = [relation_type]
for role_type, text_str in record['roles'][:2]:
matched_list = match_sublist(token_list,
self.span_to_token(text_str))
if len(matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(text_str, token_list))
break
relation += [role_type, get_index_tuple(matched_list[0])]
if len(relation) != 5 or (self._map_config.de_duplicate and
tuple(relation) in relation_list):
continue
relation_list += [tuple(relation)]
return relation_list
def record_to_offset_closest_role(self, instance, token_list):
"""
Find Role's offset using closest matched with trigger word.
:param instance:
:return:
"""
relation_list = list()
for record in instance:
relation_type = record['type']
if len(record['roles']) < 2:
continue
arg1_type, arg1_text = record['roles'][0]
arg2_type, arg2_text = record['roles'][1]
arg1_matched_list = match_sublist(token_list,
self.span_to_token(arg1_text))
if len(arg1_matched_list) == 0:
logger.warning("[Retry]: %s %s\n" %
(arg1_text, token_list))
arg1_matched_list = match_sublist(token_list,
self.span_to_token(arg1_text + '.'))
arg2_matched_list = match_sublist(token_list,
self.span_to_token(arg2_text))
if len(arg2_matched_list) == 0:
logger.warning("[Retry]: %s %s\n" %
(arg2_text, token_list))
arg2_matched_list = match_sublist(token_list,
self.span_to_token(arg2_text + '.'))
if len(arg1_matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(arg1_text, token_list))
break
if len(arg2_matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(arg2_text, token_list))
break
distance_tuple = list()
for arg1_match in arg1_matched_list:
for arg2_match in arg2_matched_list:
distance = abs(arg1_match[0] - arg2_match[0])
distance_tuple += [(distance, arg1_match, arg2_match)]
distance_tuple.sort()
relation = [
relation_type,
arg1_type,
get_index_tuple(distance_tuple[0][1]),
arg2_type,
get_index_tuple(distance_tuple[0][2]),
]
if self._map_config['de_duplicate'] and tuple(
relation) in relation_list:
continue
relation_list += [tuple(relation)]
return relation_list
class EventRecord(Record):
""" Record for converting generated string to information record in predicate-arguments
{
type: pred_type,
trigger: predicate_span,
args: [(arg_type, arg_span), ...]
}
"""
def to_offset(self, instance, tokens):
map_strategy_dict = {
'first': self.record_to_offset_first_role,
'closest': self.record_to_offset_closest_role,
'longer_first': self.record_to_offset_closest_role,
}
if self._map_config.map_strategy in map_strategy_dict:
map_function = map_strategy_dict[self._map_config.map_strategy]
return map_function(
instance=instance,
token_list=tokens, )
else:
raise NotImplementedError(
f"The map strategy {self._map_config.map_strategy} in {self.__class__} is not implemented."
)
@staticmethod
def to_string(instance):
"""
{'type': 'Justice:Appeal',
'trigger': 'appeal',
'roles': [
('Adjudicator', 'court'),
('Plaintiff', 'Anwar')
], }
"""
return instance
def record_to_offset_first_role(self, instance, token_list):
"""
Find Role's offset using first matched in the sentence.
"""
record_list = list()
trigger_matched_set = set()
for record in instance:
event_type = record['type']
trigger = record['trigger']
matched_list = match_sublist(token_list,
self.span_to_token(trigger))
if len(matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(trigger, token_list))
continue
trigger_offset = None
for matched in matched_list:
if matched not in trigger_matched_set:
trigger_offset = get_index_tuple(matched)
trigger_matched_set.add(matched)
break
# No trigger word, skip the record
if trigger_offset is None:
break
pred_record = {
'type': event_type,
'roles': [],
'trigger': trigger_offset
}
for role_type, text_str in record['roles']:
matched_list = match_sublist(token_list,
self.span_to_token(text_str))
if len(matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(text_str, token_list))
continue
pred_record['roles'] += [(role_type,
get_index_tuple(matched_list[0]))]
record_list += [pred_record]
return record_list
def record_to_offset_closest_role(self, instance, token_list):
"""
Find Role's offset using closest matched with trigger word.
"""
record_list = list()
trigger_matched_set = set()
for record in instance:
event_type = record['type']
trigger = record['trigger']
matched_list = match_sublist(token_list,
self.span_to_token(trigger))
if len(matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(trigger, token_list))
continue
trigger_offset = None
for matched in matched_list:
if matched not in trigger_matched_set:
trigger_offset = get_index_tuple(matched)
trigger_matched_set.add(matched)
break
# No trigger word, skip the record
if trigger_offset is None or len(trigger_offset) == 0:
break
pred_record = {
'type': event_type,
'roles': [],
'trigger': trigger_offset
}
for role_type, text_str in record['roles']:
matched_list = match_sublist(token_list,
self.span_to_token(text_str))
if len(matched_list) == 0:
logger.warning("[Cannot reconstruct]: %s %s\n" %
(text_str, token_list))
else:
abs_distances = [
abs(match[0] - trigger_offset[0])
for match in matched_list
]
closest_index = numpy.argmin(abs_distances)
pred_record['roles'] += [(
role_type,
get_index_tuple(matched_list[closest_index]))]
record_list += [pred_record]
return record_list
class StructureConverter(object):
def structure_to_input(self, input_dict: dict, prompt_part_only: bool = False) -> str:
raise NotImplementedError()
def output_to_structure(self, input_dict: dict, output_str: str):
raise NotImplementedError()
@staticmethod
def to_function_head(s,input=''):
return f'def {s}({input}):'
@staticmethod
def to_function_name(s):
s = s.replace(".", "").replace(",", "")
# remove DT
tok = s.lower().split()
tok = [x for x in tok if x not in ['the', 'a', 'an']]
return '_'.join(tok)
@staticmethod
def list_to_str(l):
# remove \n
l = [x.replace("\n", " ") if x != '\n' else '' for x in l]
l = '\n'.join(l)
return l
这是一个名为 StructureConverter
的Python类,它是一个抽象基类(Abstract Base Class)用于定义结构转换器的接口。这个类包含以下几个方法和静态方法:
structure_to_input(self, input_dict: dict, prompt_part_only: bool = False) -> str
:这是一个抽象方法,用于将输入数据结构转换为文本输入。派生类应该实现这个方法来定义特定任务的输入结构到文本输入的转换逻辑。
output_to_structure(self, input_dict: dict, output_str: str)
:这是另一个抽象方法,用于将输出文本转换回数据结构。与 structure_to_input
类似,它的实现应该由派生类来完成。
to_function_head(s, input='')
:这是一个静态方法,用于生成Python函数的头部,其中 s
是函数的名称,input
是函数的参数。它返回一个字符串,表示函数定义。
to_function_name(s)
:这是另一个静态方法,用于生成合法的Python函数名。给定一个字符串 s
,它将字符串处理为合适的Python函数名称格式,删除标点符号并将单词连接成下划线分隔的格式。
list_to_str(l)
:这也是一个静态方法,用于将字符串列表 l
转换为一个字符串,并处理换行符 \n
。它将列表中的每个字符串连接成一个字符串,换行符替换为空格,返回一个多行的字符串。
StructureConverter
类本身是一个抽象基类,无法直接实例化。相反,它提供了接口和一些实用方法,供派生类实现和使用,以定义特定任务的结构转换逻辑。这个类可以作为其他具体结构转换器类的基类,以提供通用的接口和方法。
import re
import sys
from typing import Tuple
def match_sublist(the_list, to_match):
"""
:param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5]
:param to_match: [1, 2]
:return:
[(0, 1), (6, 7)]
"""
len_to_match = len(to_match)
matched_list = list()
for index in range(len(the_list) - len_to_match + 1):
if to_match == the_list[index:index + len_to_match]:
matched_list += [(index, index + len_to_match - 1)]
return matched_list
def check_overlap(x, y):
if x[0] > y[1] or y[0] > x[1]:
return False
else:
return True
def get_index_tuple(matched: Tuple[int, int]):
return tuple(range(matched[0], matched[1] + 1))
def span_to_token(text, span_to_token_strategy='space'):
if span_to_token_strategy == 'space':
return text.split(' ')
elif span_to_token_strategy == 'list':
return list(text)
else:
raise NotImplementedError(
f"The span to token strategy {span_to_token_strategy} is not implemented.")
def to_camel_case(title: str) -> str:
"""Converts a proscript title to a camel case string.
Example:
title: travel to the theme park
camel_case: TravelToThemePark
"""
if "." == title[-1]:
title = title[:-1]
title_tokens = title.split(" ")
title_camel_case = ""
for token in title_tokens:
title_camel_case += token.capitalize()
return title_camel_case
def to_snake_case(name):
# replace all space and punctuation with underscore
if name[-1] == ".":
name = name[:-1]
name = re.sub(r'[\s\W]', '_', name)
return name.lower().strip()
def from_snake_to_normal_str(snake_str: str) -> str:
"""Converts a snake case string to a normal string.
Example:
snake_str: travel_to_the_theme_park
normal_str: travel to the theme park
"""
return " ".join(snake_str.split("_"))
def compile_code_get_object(py_code_str: str):
"""Given python code as a string, compiles it
and returns an object of the class contained in the string.
Args:
code (str): _description_
"""
# compile the code
try:
py_code = compile(py_code_str, "<string>", "exec")
except SyntaxError:
# try without the last k lines in py_code_str: usually the last line is incomplete
for k in range(1, 3):
try:
lines = py_code_str.split("\n")
lines = "\n".join(lines[:-k])
py_code = compile(lines, "<string>", "exec")
except SyntaxError as e:
print(f"Error compiling python code:\n{py_code_str}")
raise e
# instantiate the class
py_code_dict = {}
exec(py_code, py_code_dict)
# the newly instantiated class will be last in the scope
py_code_class = py_code_dict[list(py_code_dict.keys())[-1]]()
return py_code_class
这些函数是一组用于字符串处理、列表操作和代码编译的实用工具函数。以下是这些函数的简要描述:
match_sublist(the_list, to_match)
:查找列表 the_list
中所有与 to_match
匹配的子列表,并返回它们的索引范围。例如,the_list
是 [1, 2, 3, 4, 5, 6, 1, 2, 4, 5]
,to_match
是 [1, 2]
,则返回 [(0, 1), (6, 7)]
,表示匹配的子列表的起始和结束索引。
check_overlap(x, y)
:检查两个索引范围 x
和 y
是否有重叠,如果有重叠返回 True
,否则返回 False
。
get_index_tuple(matched: Tuple[int, int])
:将匹配的索引范围 matched
转换为一个元组,包含范围内的所有索引。
span_to_token(text, span_to_token_strategy='space')
:将文本 text
按照指定策略 span_to_token_strategy
分割为单词或字符列表。默认使用空格分割策略,可以选择 'list'
策略以将文本拆分为字符列表。
to_camel_case(title: str)
:将一个标题字符串 title
转换为驼峰命名法格式。例如,将 “travel to the theme park” 转换为 “TravelToThemePark”。
to_snake_case(name)
:将字符串 name
转换为蛇形命名法(snake_case)格式,用下划线分隔单词。
from_snake_to_normal_str(snake_str: str)
:将蛇形命名法字符串 snake_str
转换为普通字符串,通过删除下划线并用空格分隔单词。
compile_code_get_object(py_code_str: str)
:编译给定的 Python 代码字符串 py_code_str
并返回该代码中包含的类的对象。这个函数首先编译代码,然后实例化类并返回其对象。
这些函数提供了一些用于文本处理、字符串格式转换和代码执行的常见功能,可在不同的上下文中使用。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import argparse
import json
import os
import sys
import numpy as np
from pprint import pprint
from src.eval.scorer import EntityScorer, RelationScorer, EventScorer
def read_file(file_name):
return [line for line in open(file_name).readlines()]
def write_to_file(result, output_filename, prefix=None):
with open(output_filename, 'w') as output:
for key, value in result.items():
if prefix:
key = '%s_%s' % (prefix, key)
output.write("%s=%s\n" % (key, value))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-g', dest='gold_folder', help="Golden Dataset folder")
parser.add_argument('-gf', dest='gold_file', help="Golden Dataset File")
parser.add_argument('-p', dest='pred_folder', nargs='+', help="Predicted model folder")
parser.add_argument('-pf', dest='pred_file', help="Predicted model file")
parser.add_argument('-sf', dest='saved_file', help="Saved result file")
parser.add_argument('-v', dest='verbose', action='store_true', help='Show more information during running')
parser.add_argument('-w', dest='write_to_file', action='store_true', help="Write evaluation results to predicted folder")
parser.add_argument('-m', dest='match_mode', default='normal', choices=['set', 'normal', 'multimatch'])
parser.add_argument('-case', dest='case', action='store_true', help='Show case study')
options = parser.parse_args()
data_dict = {
'test': [options.pred_file, options.gold_file],
}
task_dict = {
'entity': EntityScorer,
'relation': RelationScorer,
'event': EventScorer,
}
result_list = {'eval': list(), 'test': list()}
for pred_folder in options.pred_folder:
gold_folder = options.gold_folder
for data_key, (generation, gold_file) in data_dict.items():
gold_filename = os.path.join(gold_folder, gold_file)
pred_filename = os.path.join(pred_folder, generation)
if not os.path.exists(pred_filename):
sys.stderr.write("%s not found.\n" % pred_filename)
continue
print("pred:", pred_filename)
print("gold:", gold_filename)
if options.case:
for pred_line, gold_line in zip(read_file(pred_filename), read_file(gold_filename)):
gold_instance = json.loads(gold_line)
pred_instance = json.loads(pred_line)
print('=========================')
print(gold_instance['text'])
for task in task_dict:
scorer = task_dict[task]
gold = scorer.load_gold_list([gold_instance[task]])[0]
pred = scorer.load_pred_list([pred_instance[task]])[0]
min_length = max(
len(gold['string']),
len(pred['string']),
len(gold.get('string_trigger', [])),
len(pred.get('string_trigger', [])),
len(gold.get('string_role', [])),
len(pred.get('string_role', [])),
)
if min_length == 0:
continue
if task == 'entity':
print("Entity Gold:", sorted(gold['string']))
print("Entity Pred:", sorted(pred['string']))
if task == 'relation':
print("Relation Gold:", sorted(gold['string']))
print("Relation Pred:", sorted(pred['string']))
if task == 'event':
print("Event Gold Trigger:", sorted(gold['string_trigger']))
print("Event Pred Trigger:", sorted(pred['string_trigger']))
print("Event Gold Role :", sorted(gold['string_role']))
print("Event Pred Role :", sorted(pred['string_role']))
results = dict()
for task in task_dict:
if task not in json.loads(read_file(pred_filename)[0]):
continue
scorer = task_dict[task]
gold_list = [json.loads(line)[task] for line in read_file(gold_filename)]
pred_list = [json.loads(line)[task] for line in read_file(pred_filename)]
########## 23-01-07
ill_formed = [json.loads(line)['statistic']['ill-formed'] for line in read_file(pred_filename)]
assert len(pred_list) == len(gold_list)
gold_instance_list = scorer.load_gold_list(gold_list)
pred_instance_list = scorer.load_pred_list(pred_list)
assert len(pred_instance_list) == len(gold_instance_list)
sub_results = scorer.eval_instance_list(
gold_instance_list=gold_instance_list,
pred_instance_list=pred_instance_list,
verbose=options.verbose,
match_mode=options.match_mode,
)
results.update(sub_results)
result_list[data_key] += [results]
if options.write_to_file:
output_filename = "%s/%s" % (pred_folder, options.saved_file)
write_to_file(
result=results,
output_filename=output_filename,
prefix=data_key,
)
if __name__ == "__main__":
main()
这段代码是一个命令行工具,用于评估NLP任务的结果与金标准之间的差异。具体来说,它有以下功能:
从命令行参数中获取要评估的模型输出(predicted results)和金标准数据(golden dataset)。
支持对不同任务(entity、relation、event)的评估。
支持设置不同的评估模式,包括"set"、"normal"和"multimatch"等。
可以在评估结果中输出更多信息,以便进行更详细的分析。
可以将评估结果写入文件,保存到预测结果文件夹中。
主要的功能包括:
读取模型的预测结果和金标准数据文件。
对模型输出和金标准数据进行解析,提取相应任务(entity、relation、event)的信息。
调用评估工具类(EntityScorer
、RelationScorer
、EventScorer
)对预测结果和金标准数据进行评估。
计算各项评估指标,包括精确度、召回率、F1分数等,并将结果输出到控制台或保存到文件中。
如果设置了"case"选项,还会输出详细的案例分析信息,包括模型的预测结果和金标准数据。
该工具主要用于评估NLP任务的结果,帮助研究人员和从业者了解模型的性能和改进方向。
import argparse
import json
import random
from tqdm import tqdm
import subprocess
from src.converters.get_converter import ConverterFactory
import pandas as pd
def eval(src_file, pred_file, save_file, job_type,
schema_path, map_config_path, pred_key='generated_code'):
src_d = pd.read_json(src_file, orient='records', lines=True)
with open(pred_file, 'r') as f:
pred_d = []
for line in f:
data = json.loads(line.strip())
pred_d.append(data)
converter = ConverterFactory.get_converter(job_type=job_type, schema_folder=schema_path, map_config_path=map_config_path)
prediction_list = []
ill_formed = 0
invalid_label = 0
invalid_text_span = 0
invalid_label_asoc = 0
invalid_text_span_asoc = 0
for qid, src_data in tqdm(src_d.iterrows(), total=len(src_d)):
pred = pred_d[qid]
predictions = converter.output_to_structure(src_data, pred[pred_key])
if predictions['statistic']['ill-formed'] == True:
ill_formed += 1
if predictions['statistic']['Invalid-Label'] == True:
invalid_label += 1
if predictions['statistic']['Invalid-Text-Span'] == True:
invalid_text_span += 1
if predictions['statistic']['Invalid-Label-asoc'] == True:
invalid_label_asoc += 1
if predictions['statistic']['Invalid-Text-Span-asoc'] == True:
invalid_text_span_asoc += 1
prediction_list.append(predictions)
pd.DataFrame(prediction_list).to_json(save_file, orient='records', lines=True)
print ("ill_formed number: ", ill_formed)
print ("Invalid-Label number: ", invalid_label)
print ("Invalid-Text-Span number: ", invalid_text_span)
print ("Invalid-Label-asoc number: ", invalid_label_asoc)
print ("Invalid-Text-Span-asoc number: ", invalid_text_span_asoc)
def config():
parser = argparse.ArgumentParser()
parser.add_argument('--raw_output_file', type=str)
parser.add_argument('--output_file', type=str)
parser.add_argument('--src_file', type=str)
parser.add_argument('--job_type', type=str)
parser.add_argument("--schema_path", type=str, required=True)
parser.add_argument("--map_config_path", type=str, required=True)
parser.add_argument("--pred_key", type=str, default='generated_code')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = config()
save_file = args.output_file
eval(args.src_file, args.raw_output_file,
save_file, args.job_type,
args.schema_path, args.map_config_path, args.pred_key)
这段代码主要用于评估NLP任务中的模型预测结果,并将评估后的结果存储到指定的输出文件中。以下是其主要功能:
从命令行参数中读取各种输入文件和配置参数,包括原始输出文件(raw_output_file
)、输出文件(output_file
)、数据源文件(src_file
)、任务类型(job_type
)、模式配置路径(schema_path
)、映射配置路径(map_config_path
)、预测键名(pred_key
)等。
通过Converter
工厂类创建适当的转换器(如NER或RE),用于将原始模型预测结果转换为结构化的评估结果。
遍历数据源文件中的每个示例,对每个示例的原始模型预测结果进行结构化转换和评估。
计算各种评估指标,包括识别标签不合法(Invalid-Label
)、文本跨度不合法(Invalid-Text-Span
)、标签关联不合法(Invalid-Label-asoc
)、文本跨度关联不合法(Invalid-Text-Span-asoc
)等。
统计不合法的预测结果数量,包括不合法的标签、文本跨度和它们的关联。
最后,将评估后的结果以JSON格式保存到输出文件中,并在控制台输出统计信息。
这段代码用于自动化评估模型预测结果,帮助研究人员和从业者了解模型性能,以便进行进一步的改进和分析。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# adapted from https://raw.githubusercontent.com/universal-ie/UIE/main/uie/extraction/scorer.py
from collections import defaultdict
from copy import deepcopy
from typing import Dict, List
import sys
def tuple_offset(offset):
if isinstance(offset, tuple):
return offset
else:
return tuple(offset)
class Metric:
""" Tuple Metric """
def __init__(self, verbose=False, match_mode='normal'):
self.tp = 0.
self.gold_num = 0.
self.pred_num = 0.
self.verbose = verbose
self.match_mode = match_mode
assert self.match_mode in {'set', 'normal', 'multimatch'}
def __repr__(self) -> str:
return f"tp: {self.tp}, gold: {self.gold_num}, pred: {self.pred_num}"
@staticmethod
def safe_div(a, b):
if b == 0.:
return 0.
else:
return a / b
def compute_f1(self, prefix=''):
tp = self.tp
pred_num = self.pred_num
gold_num = self.gold_num
p, r = self.safe_div(tp, pred_num), self.safe_div(tp, gold_num)
return {prefix + 'tp': tp,
prefix + 'gold': gold_num,
prefix + 'pred': pred_num,
prefix + 'P': p * 100,
prefix + 'R': r * 100,
prefix + 'F1': self.safe_div(2 * p * r, p + r) * 100
}
def count_instance(self, gold_list, pred_list):
if self.match_mode == 'set':
gold_list = set(gold_list)
pred_list = set(pred_list)
if self.verbose:
print("Gold:", gold_list)
print("Pred:", pred_list)
self.gold_num += len(gold_list)
self.pred_num += len(pred_list)
self.tp += len(gold_list & pred_list)
else:
if self.verbose:
print("Gold:", gold_list)
print("Pred:", pred_list)
self.gold_num += len(gold_list)
self.pred_num += len(pred_list)
if len(gold_list) > 0 and len(pred_list) > 0:
# guarantee length same
assert len(gold_list[0]) == len(pred_list[0])
dup_gold_list = deepcopy(gold_list)
for pred in pred_list:
if pred in dup_gold_list:
self.tp += 1
if self.match_mode == 'normal':
# Each Gold Instance can be matched one time
dup_gold_list.remove(pred)
def count_batch_instance(self, batch_gold_list, batch_pred_list):
for gold_list, pred_list in zip(batch_gold_list, batch_pred_list):
self.count_instance(gold_list=gold_list, pred_list=pred_list)
class RecordMetric(Metric):
""" 不考虑不同 Role 之间的顺序,例如事件论元"""
@staticmethod
def is_equal(gold, pred):
if gold['type'] != pred['type']:
return False
if gold['spot'] != pred['spot']:
return False
if len(gold['asocs']) != len(pred['asocs']):
return False
for gold_role, pred_role in zip(sorted(gold['asocs']), sorted(pred['asocs'])):
if gold_role != pred_role:
return False
return True
def count_instance(self, gold_list, pred_list):
if self.match_mode == 'set':
raise NotImplementedError(f'{self.__class__.__name__} do not support the match model `set`')
if self.verbose:
print("Gold:", gold_list)
print("Pred:", pred_list)
self.gold_num += len(gold_list)
self.pred_num += len(pred_list)
gold_indexes = list(range(len(gold_list)))
non_found = [True] * len(gold_list)
for pred in pred_list:
for gold_index in gold_indexes:
if non_found[gold_index] and self.is_equal(gold_list[gold_index], pred):
self.tp += 1
non_found[gold_index] = False
if self.match_mode == 'normal':
break
class OrderedRecordMetric(RecordMetric):
""" 考虑不同 Role 之间的顺序,例如关系 """
@staticmethod
def is_equal(gold, pred):
if gold['type'] != pred['type']:
return False
if gold['spot'] != pred['spot']:
return False
if len(gold['asocs']) != len(pred['asocs']):
return False
for gold_role, pred_role in zip(gold['asocs'], pred['asocs']):
if gold_role != pred_role:
return False
return True
def warning_tp_increment(gold, pred, prefix):
sys.stderr.write(f"{prefix} TP Increment Warning, Gold Offset: {gold['offset']}\n")
sys.stderr.write(f"{prefix} TP Increment Warning, Pred Offset: {pred['offset']}\n")
sys.stderr.write(f"{prefix} TP Increment Warning, Gold String: {gold['string']}\n")
sys.stderr.write(f"{prefix} TP Increment Warning, Pred String: {pred['string']}\n")
sys.stderr.write(f"===============\n")
class Scorer:
@staticmethod
def load_gold_list(gold_list, offset_key=None):
raise NotImplementedError
@staticmethod
def load_pred_list(pred_list):
raise NotImplementedError
@staticmethod
def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'):
raise NotImplementedError
class EntityScorer(Scorer):
@staticmethod
def load_gold_list(gold_list: List[List[Dict]]):
""" Load gold instance to `string` and `offset`
Args:
gold_list (List[List[Dict]]): [description]
[
[
{'type': 'Geo-political', 'offset': [7], 'text': 'seattle'},
{'type': 'Location', 'offset': [11], 'text': 'lot'},
{'type': 'Geo-political', 'offset': [14], 'text': 'city'}
],
[...]
]
Returns:
List[Dict]: each instance has `offset` and `string`
[
{
'offset': [('Geo-political', (7,)), ('Location', (11,)), ('Geo-political', (14,))],
'string': [('Geo-political', 'seattle'), ('Location', 'lot'), ('Geo-political', 'city')]
},
{...}, ...
]
"""
gold_instance_list = []
for gold in gold_list:
gold_offset = list()
gold_string = list()
for span in gold:
span_label = span['type']
span_offset = span['offset']
span_text = span['text']
gold_offset += [(span_label, tuple_offset(span_offset))]
gold_string += [(span_label, span_text)]
gold_instance = {
'offset': gold_offset,
'string': gold_string,
}
gold_instance_list += [gold_instance]
return gold_instance_list
@staticmethod
def load_pred_list(pred_list: List[Dict]):
"""[summary]
Args:
pred_list (List[Dict]): [description]
[
{
'offset': [['Geo-political', [7]], ['Geo-political', [14]]],
'string': [['Geo-political', 'seattle'], ['Geo-political', 'city']]
},
{...},
]
Returns:
List[Dict] : each relation instance has `offset` and `string`
[
{
'offset': [('Geo-political', (7,)), ('Geo-political', (14,))],
'string': [('Geo-political', 'seattle'), ('Geo-political', 'city')]
}
]
"""
pred_instance_list = list()
for pred in pred_list:
for offset_pred in pred['offset']:
if not isinstance(offset_pred[1], tuple):
offset_pred[1] = tuple_offset(offset_pred[1])
pred['offset'] = [tuple_offset(p) for p in pred['offset']]
pred['string'] = [tuple_offset(p) for p in pred['string']]
pred_instance_list += [pred]
return pred_instance_list
@staticmethod
def eval_instance_list(gold_instance_list: List[Dict], pred_instance_list: List[Dict], verbose=False, match_mode='normal'):
"""[summary]
Args:
gold_instance_list (List[Dict]): [description]
[
{
'offset': [('Geo-political', (7,)), ('Location', (11,)), ('Geo-political', (14,))],
'string': [('Geo-political', 'seattle'), ('Location', 'lot'), ('Geo-political', 'city')]
},
{...}, ...
]
pred_instance_list (List[Dict]): [description]
[
{
'offset': [('Geo-political', (7,)), ('Geo-political', (14,))],
'string': [('Geo-political', 'seattle'), ('Geo-political', 'city')]
}
]
verbose (bool, optional): [description]. Defaults to False.
match_mode (string, optional): [description]. Defaults to `normal` .
Returns:
Dict: Result of Evaluation
(offset, string) X (gold, pred, tp, P, R, F1)
"""
metrics = {
'string': Metric(verbose=verbose, match_mode=match_mode),
'offset': Metric(verbose=verbose, match_mode=match_mode),
}
for pred, gold in zip(pred_instance_list, gold_instance_list):
pre_string_tp, pre_offset_tp = metrics['string'].tp, metrics['offset'].tp
for eval_key in metrics:
metrics[eval_key].count_instance(
gold_list=gold.get(eval_key, []),
pred_list=pred.get(eval_key, [])
)
post_string_tp, post_offset_tp = metrics['string'].tp, metrics['offset'].tp
if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp:
warning_tp_increment(gold=gold, pred=pred, prefix='Entity')
results = dict()
for eval_key in metrics:
results.update(metrics[eval_key].compute_f1(prefix=eval_key + '-ent-'))
return results
class RelationScorer(Scorer):
@staticmethod
def load_gold_list(gold_list: List[List[Dict]]):
"""[summary]
Args:
gold_list (List[List[Dict]]): List of Sentece, each sentence contains a List of Relation Dict
[
[
{
'type': 'Part-whole',
'args': [{'type': 'Location', 'offset': [11], 'text': 'lot'}, {'type': 'Geo-political', 'offset': [14], 'text': 'city'}]
}, ...
],
[...],
]
Returns:
List[Dict]: List of Sentece, each sentence contains two List (offset, string) of Relation Tuple
[
{
'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,)), ... ],
'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan'), ...]
}
]
"""
gold_instance_list = []
for gold in gold_list:
gold_instance = defaultdict(list)
for record in gold:
assert len(record['args']) == 2
gold_instance['offset'] += [(
record['type'],
record['args'][0]['type'],
tuple_offset(record['args'][0]['offset']),
record['args'][1]['type'],
tuple_offset(record['args'][1]['offset']),
)]
gold_instance['string'] += [(
record['type'],
record['args'][0]['type'],
record['args'][0]['text'],
record['args'][1]['type'],
record['args'][1]['text'],
)]
gold_instance_list += [gold_instance]
return gold_instance_list
@staticmethod
def load_pred_list(pred_list):
"""[summary]
Args:
pred_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Relation List
[
{
'offset': [['Part-whole', 'Geo-political', [0], 'Geo-political', [2]]],
'string': [['Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan']],
}, ...
]
Returns:
List[Dict]: List of Sentece, each sentence contains two List (offset, string) of Relation Tuple
[
{
'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,))],
'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan')]
}, ...
]
"""
pred_instance_list = list()
for pred in pred_list:
for offset_pred in pred['offset']:
if not isinstance(offset_pred[2], tuple):
offset_pred[2] = tuple_offset(offset_pred[2])
if not isinstance(offset_pred[4], tuple):
offset_pred[4] = tuple_offset(offset_pred[4])
pred['offset'] = [tuple_offset(p) for p in pred['offset']]
pred['string'] = [tuple_offset(p) for p in pred['string']]
pred_instance_list += [pred]
return pred_instance_list
@staticmethod
def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'):
"""[summary]
Args:
gold_instance_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Relation Tuple
[
{
'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,)), ... ],
'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan'), ...]
}
]
pred_instance_list ([type]): List of Sentece, each sentence contains two List (offset, string) of Relation Tuple
[
{
'offset': [('Part-whole', 'Geo-political', (0,), 'Geo-political', (2,))],
'string': [('Part-whole', 'Geo-political', 'MULTAN', 'Geo-political', 'Pakistan')]
}, ...
]
verbose (bool, optional): Defaults to False.
match_mode (string, optional): [description]. Defaults to `normal` .
Returns:
Dict: Result of Evaluation
(offset, string) X (boundary, strict) X (gold, pred, tp, P, R, F1)
"""
# Span Boundary and Type
metrics = {
'offset': Metric(verbose=verbose, match_mode=match_mode),
'string': Metric(verbose=verbose, match_mode=match_mode),
}
# Span Boundary Only
boundary_metrics = {
'offset': Metric(verbose=verbose, match_mode=match_mode),
'string': Metric(verbose=verbose, match_mode=match_mode),
}
for pred, gold in zip(pred_instance_list, gold_instance_list):
pre_string_tp, pre_offset_tp = metrics['string'].tp, metrics['offset'].tp
for eval_key in metrics:
# Span Boundary and Type
metrics[eval_key].count_instance(
gold_list=gold.get(eval_key, []),
pred_list=pred.get(eval_key, []),
)
post_string_tp, post_offset_tp = metrics['string'].tp, metrics['offset'].tp
if verbose and (post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp):
warning_tp_increment(gold=gold, pred=pred, prefix='Relation Strict')
pre_string_tp, pre_offset_tp = boundary_metrics['string'].tp, boundary_metrics['offset'].tp
for eval_key in boundary_metrics:
# Span Boundary Only
boundary_metrics[eval_key].count_instance(
gold_list=[(x[0], x[2], x[4]) for x in gold.get(eval_key, [])],
pred_list=[(x[0], x[2], x[4]) for x in pred.get(eval_key, [])],
)
post_string_tp, post_offset_tp = boundary_metrics['string'].tp, boundary_metrics['offset'].tp
if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp:
warning_tp_increment(gold=gold, pred=pred, prefix='Relation Boundary')
results = dict()
for eval_key in metrics:
results.update(metrics[eval_key].compute_f1(prefix=eval_key + '-rel-strict-'))
for eval_key in boundary_metrics:
results.update(boundary_metrics[eval_key].compute_f1(prefix=eval_key + '-rel-boundary-'))
return results
class EventScorer(Scorer):
@staticmethod
def load_gold_list(gold_list):
"""[summary]
Args:
gold_list (List[List[Dict]]): List of Sentece, each sentence contains a List of Event Dict
[
[ # Sentance
{ # Event Record
'type': 'Die',
'offset': [16],
'text': 'shot',
'args': [
{'type': 'Victim', 'offset': [17], 'text': 'himself'},
{'type': 'Agent', 'offset': [5, 6], 'text': 'John Joseph'},
{'type': 'Place', 'offset': [23], 'text': 'court'}
]
},
]
]
Returns:
List[Dict]: List of Sentece, each sentence contains Four List of Event Tuple
[
{
'offset_trigger': [('Die', (16,)), ('Convict', (30,))],
'string_trigger': [('Die', 'shot'), ('Convict', 'convicted')],
'offset_role': [('Die', 'Victim', (17,)), ('Die', 'Agent', (5, 6)), ('Die', 'Place', (23,))],
'string_role': [('Die', 'Victim', 'himself'), ('Die', 'Agent', 'John Joseph'), ('Die', 'Place', 'court')]
},
...
]
"""
gold_instance_list = []
for gold in gold_list:
gold_instance = defaultdict(list)
for record in gold:
gold_instance['offset_trigger'] += [(record['type'], tuple_offset(record['offset']))]
gold_instance['string_trigger'] += [(record['type'], record['text'])]
for arg in record['args']:
gold_instance['offset_role'] += [(record['type'], arg['type'], tuple_offset(arg['offset']))]
gold_instance['string_role'] += [(record['type'], arg['type'], arg['text'])]
gold_instance_list += [gold_instance]
return gold_instance_list
@staticmethod
def load_pred_list(pred_list):
"""[summary]
Args:
pred_list (List[Dict]): List of Sentece, each sentence contains two List (offset, string) of Event List
[
{
'offset': [{'type': 'Attack', 'roles': [['Attacker', [5, 6]], ['Place', [23]], ['Target', [17]]], 'trigger': [16]}],
'string': [{'roles': [['Attacker', 'John Joseph'], ['Place', 'court'], ['Target', 'himself']], 'type': 'Attack', 'trigger': 'shot'}],
},
...
]
Returns:
List[Dict]: List of Sentece, each sentence contains four List (offset, string) X (trigger, role) of Event List
[
{
'offset_trigger': [('Attack', (16,))],
'offset_role': [('Attack', 'Attacker', (5, 6)), ('Attack', 'Place', (23,)), ('Attack', 'Target', (17,))],
'string_trigger': [('Attack', 'shot')],
'string_role': [('Attack', 'Attacker', 'John Joseph'), ('Attack', 'Place', 'court'), ('Attack', 'Target', 'himself')],
},
...
]
"""
pred_instance_list = list()
for pred in pred_list:
pred_instance = defaultdict(list)
for offset_pred in pred['offset']:
event_type, trigger_offset = offset_pred['type'], tuple_offset(offset_pred['trigger'])
pred_instance['offset_trigger'] += [(event_type, trigger_offset)]
for role_type, role_offset in offset_pred['roles']:
pred_instance['offset_role'] += [(event_type, role_type, tuple_offset(role_offset))]
for string_pred in pred['string']:
event_type, trigger_string = string_pred['type'], string_pred['trigger']
pred_instance['string_trigger'] += [(event_type, trigger_string)]
for role_type, role_string in string_pred['roles']:
pred_instance['string_role'] += [(event_type, role_type, role_string)]
pred_instance_list += [pred_instance]
return pred_instance_list
@staticmethod
def eval_instance_list(gold_instance_list, pred_instance_list, verbose=False, match_mode='normal'):
"""[summary]
Args:
gold_instance_list (List[Dict]): List of Sentece, each sentence contains Four List of Event Tuple
[
{
'offset_trigger': [('Die', (16,)), ('Convict', (30,))],
'string_trigger': [('Die', 'shot'), ('Convict', 'convicted')],
'offset_role': [('Die', 'Victim', (17,)), ('Die', 'Agent', (5, 6)), ('Die', 'Place', (23,))],
'string_role': [('Die', 'Victim', 'himself'), ('Die', 'Agent', 'John Joseph'), ('Die', 'Place', 'court')]
},
...
]
pred_instance_list (List[Dict]): List of Sentece, each sentence contains four List (offset, string) X (trigger, role) of Event List
[
{
'offset_trigger': [('Attack', (16,))],
'offset_role': [('Attack', 'Attacker', (5, 6)), ('Attack', 'Place', (23,)), ('Attack', 'Target', (17,))],
'string_trigger': [('Attack', 'shot')],
'string_role': [('Attack', 'Attacker', 'John Joseph'), ('Attack', 'Place', 'court'), ('Attack', 'Target', 'himself')],
},
...
]
verbose (bool, optional): [description]. Defaults to False.
match_mode (string, optional): [description]. Defaults to `normal`.
Returns:
Dict: Result of Evaluation
(offset, string) X (trigger, role) X (gold, pred, tp, P, R, F1)
"""
trigger_metrics = {
'offset': Metric(verbose=verbose, match_mode=match_mode),
'string': Metric(verbose=verbose, match_mode=match_mode),
}
role_metrics = {
'offset': Metric(verbose=verbose, match_mode=match_mode),
'string': Metric(verbose=verbose, match_mode=match_mode),
}
for pred, gold in zip(pred_instance_list, gold_instance_list):
pre_string_tp, pre_offset_tp = trigger_metrics['string'].tp, trigger_metrics['offset'].tp
for eval_key in trigger_metrics:
trigger_metrics[eval_key].count_instance(
gold_list=gold.get(eval_key + '_trigger', []),
pred_list=pred.get(eval_key + '_trigger', [])
)
post_string_tp, post_offset_tp = trigger_metrics['string'].tp, trigger_metrics['offset'].tp
if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp:
warning_tp_increment(gold=gold, pred=pred, prefix='Trigger')
pre_string_tp, pre_offset_tp = role_metrics['string'].tp, role_metrics['offset'].tp
for eval_key in role_metrics:
role_metrics[eval_key].count_instance(
gold_list=gold.get(eval_key + '_role', []),
pred_list=pred.get(eval_key + '_role', [])
)
post_string_tp, post_offset_tp = role_metrics['string'].tp, role_metrics['offset'].tp
if verbose and post_offset_tp - pre_offset_tp != post_string_tp - pre_string_tp:
warning_tp_increment(gold=gold, pred=pred, prefix='Role')
results = dict()
for eval_key in trigger_metrics:
results.update(trigger_metrics[eval_key].compute_f1(prefix=f'{eval_key}-evt-trigger-'))
for eval_key in role_metrics:
results.update(role_metrics[eval_key].compute_f1(prefix=f'{eval_key}-evt-role-'))
return results
END = "# END"
END_LINE = "\n----------------------------------------"
这段代码定义了一个评估(scoring)模块,用于计算不同任务(如实体识别、关系抽取、事件抽取)的评估指标。下面是代码的主要组成部分:
Metric
类:这是一个通用的评估指标类,用于计算真阳性(true positives,tp)、金标样本数(gold_num)、预测样本数(pred_num)以及相关的评估指标,如精确度(P)、召回率(R)和 F1 分数。safe_div
方法用于进行除法计算,以避免除以零的情况。compute_f1
方法计算 F1 分数。count_instance
方法用于计算指标的真阳性、金标样本数和预测样本数。
RecordMetric
类:这是继承自 Metric
类的一个子类,用于处理不考虑不同角色之间的顺序的任务(例如事件论元抽取)。它包括了一个额外的 is_equal
方法,用于判断金标和预测是否相等。
OrderedRecordMetric
类:这是继承自 RecordMetric
类的子类,用于处理考虑不同角色之间的顺序的任务(例如关系抽取)。它重写了 is_equal
方法,以考虑不同角色之间的顺序。
Scorer
类:这是一个基本的评估类,定义了三个静态方法 load_gold_list
、load_pred_list
和 eval_instance_list
,分别用于加载金标样本、预测样本,以及计算评估指标。子类可以继承这个类并实现这些方法来适应不同的任务。
EntityScorer
、RelationScorer
和 EventScorer
类:这些类分别用于实体识别、关系抽取和事件抽取任务。它们通过继承 Scorer
类,实现了加载金标样本和预测样本的方法,以及计算相应任务的评估指标的方法。
这段代码的目的是提供一个通用的评估框架,使用户能够方便地计算不同任务的评估指标,例如 F1 分数。根据任务的特点,可以选择适当的评估类来加载数据并计算指标。
import pandas as pd
import os
import shutil
import random
import argparse
from collections import defaultdict
import json
import sys
from src.prompt.constants import END
from src.utils.record_schema import RecordSchema
def make_prompt(file_path: str, out_file, n_examples, seed: int = 0):
random.seed(seed)
data = [json.loads(line.strip()) for line in open(file_path)]
if n_examples != -1:
samples = random.sample(data, n_examples)
else:
samples = data
random.shuffle(samples)
prompt = ""
for sample in samples:
prompt += sample["reference_output"]
prompt += f"{END}\n\n"
with open(out_file,'w',encoding='utf-8') as fout:
fout.write(prompt)
print ("saved prompt to ", out_file)
return 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-src', help='Source File Name', required=True)
parser.add_argument('-tgt', help='Target File Name, n shot sampled',
required=True)
parser.add_argument('-schema_file', help='schema_file', required=True)
parser.add_argument('-task', help='N-Shot Task name', required=True,
choices=['entity', 'relation', 'event'])
parser.add_argument('-n_examples', help='n_examples',type=int)
parser.add_argument('-seed', help='Default is None, no random')
parser.add_argument('-min_len', dest='min_len', help='Default is None', type=int)
options = parser.parse_args()
source_file = options.src
target_file = options.tgt
make_prompt(file_path=source_file, out_file=target_file, n_examples=options.n_examples)
if __name__ == "__main__":
main()
这段代码的主要功能是生成提示文本(prompt)用于模型的训练。以下是代码的主要步骤和功能:
代码通过命令行参数接受输入文件的路径、输出文件的路径、生成的提示文本的样本数(n_examples
)、种子值(seed
)、以及其他必要参数。
从输入文件中读取数据,数据以JSON格式存储。可以选择从中随机采样一定数量的样本,也可以使用所有的样本。
随机打乱数据的顺序。
创建一个空的提示文本字符串。
对每个样本,将样本的reference_output
字段添加到提示文本字符串中,同时在每个样本之间添加特定标记(END
)以分隔不同的样本。
将生成的提示文本字符串写入输出文件中。
输出提示文本的保存路径,以便用户查看。
这段代码通常用于生成用于训练NLP模型的提示文本,可以从数据中随机选择一些样本,并按一定格式组织成提示文本,以供后续模型训练使用。
import pandas as pd
from tqdm import tqdm
from src.converters.get_converter import ConverterFactory
from src.utils.file_utils import load_yaml, load_schema, read_data
def make_task_file(args):
data = read_data(args.inpath)
converter = ConverterFactory.get_converter(args.job_type,schema_folder=args.schema_path, map_config_path=args.map_config_path)
res = []
for i, row in tqdm(data.iterrows(), total=len(data)):
try:
struct_input = converter.structure_to_input(row, prompt_part_only=False)
if struct_input is None:
continue
tmp = {k: v for (k, v) in row.items() if k not in ['record']}
tmp["input_idx"] = i
tmp["input_prompt"] = converter.structure_to_input(row, prompt_part_only=True)
tmp["reference_output"] = struct_input
except Exception as e:
raise e
res.append(tmp)
# successfully converted
conversion_rate = len(res) / len(data)
pd.DataFrame(res).to_json(args.outpath, orient='records', lines=True)
print(f"Converted {len(res)} out of {len(data)} rows ({conversion_rate:.2%})")
print ("Saved to ", args.outpath)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--inpath", type=str, required=True)
parser.add_argument("--outpath", type=str, required=True)
parser.add_argument("--job_type", type=str, required=True)
parser.add_argument("--schema_path", type=str, required=True)
parser.add_argument("--map_config_path", type=str, required=True)
args = parser.parse_args()
make_task_file(args)
这段代码的主要功能是将给定的数据转换成适用于模型训练的任务文件。以下是代码的主要步骤和功能:
代码通过命令行参数接受输入文件的路径、输出文件的路径、任务类型(job_type
)、模式配置文件的路径(schema_path
)、映射配置文件的路径(map_config_path
),以及其他必要参数。
从输入文件中读取数据,数据通常以DataFrame格式存储。
利用ConverterFactory
从给定的任务类型和配置文件加载适当的转换器(例如,NER或RE转换器)。
对于数据中的每一行,调用转换器的structure_to_input
方法将原始数据转换为适用于模型训练的输入。此外,还提取一些其他字段,如索引、输入提示等,用于创建任务文件的元信息。
如果数据的转换成功,将转换后的数据添加到结果列表中。
最后,计算成功转换的数据比例,并将结果保存为JSON格式的任务文件,以供后续模型训练使用。
总的来说,这段代码用于将原始数据转换为可用于不同NLP任务(如NER、RE等)的任务文件格式,以便模型能够理解和学习这些任务。生成的任务文件包含了输入、输出和其他元信息,以便于后续的训练和评估。
import yaml
import pandas as pd
def load_schema(schema_path):
with open(schema_path,encoding='utf8') as fin:
entity_line = fin.readline().strip()
relation_line = fin.readline().strip()
spot_asoc_line = fin.readline().strip()
return {'entity_schema': eval(entity_line),
'relation_schema': eval(relation_line),
'spot_asoc_schema': eval(spot_asoc_line)}
def load_yaml(yaml_path):
with open(yaml_path,'r') as fin:
map_config = yaml.load(fin.read(), Loader=yaml.FullLoader)
return map_config
def read_data(inpath):
if "json" in inpath:
data = pd.read_json(inpath, orient='records', lines=True)
else:
raise ValueError(f"Unknown input format: {inpath}")
return data
这段代码是一组用于加载数据和配置文件的辅助函数,包括:
load_schema(schema_path)
函数用于加载给定路径的模式(schema)文件。该文件通常包含了实体、关系和关联标签的信息。函数读取文件的前三行,分别包含了实体模式、关系模式和关联标签模式的定义,然后返回这些模式的字典。
load_yaml(yaml_path)
函数用于加载给定路径的YAML格式配置文件。它打开指定路径的文件,使用PyYAML库加载文件内容,并返回加载后的配置字典。
read_data(inpath)
函数用于加载数据文件,支持JSON格式的数据文件。它检查文件类型,如果文件类型为JSON,它使用Pandas库的pd.read_json
函数加载文件内容为DataFrame,并以行的方式解释为记录。函数返回加载后的数据作为DataFrame。
这些函数用于在主要代码中加载数据和配置信息,以便进行任务的转换和生成。例如,load_schema
函数加载了任务所需的模式信息,而load_yaml
函数加载了映射配置文件。 read_data
函数用于加载任务的输入数据。这些辅助函数有助于使主要代码更模块化和易于维护。
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import json
from collections import defaultdict
from typing import List
class RecordSchema:
def __init__(self, type_list, role_list, type_role_dict):
self.type_list = type_list
self.role_list = role_list
self.type_role_dict = type_role_dict
@staticmethod
def read_from_file(filename):
lines = open(filename).readlines()
type_list = json.loads(lines[0])
role_list = json.loads(lines[1])
type_role_dict = json.loads(lines[2])
return RecordSchema(type_list, role_list, type_role_dict)
def write_to_file(self, filename):
with open(filename, 'w') as output:
output.write(json.dumps(self.type_list, ensure_ascii=False) + '\n')
output.write(json.dumps(self.role_list, ensure_ascii=False) + '\n')
output.write(json.dumps(self.type_role_dict, ensure_ascii=False) + '\n')
def merge_schema(schema_list: List[RecordSchema]):
type_set = set()
role_set = set()
type_role_dict = defaultdict(list)
for schema in schema_list:
for type_name in schema.type_list:
type_set.add(type_name)
for role_name in schema.role_list:
role_set.add(role_name)
for type_name in schema.type_role_dict:
type_role_dict[type_name] += schema.type_role_dict[type_name]
for type_name in type_role_dict:
type_role_dict[type_name] = list(set(type_role_dict[type_name]))
return RecordSchema(type_list=list(type_set),
role_list=list(role_set),
type_role_dict=type_role_dict,
)
这段代码定义了一个名为 RecordSchema
的类,该类用于管理任务的记录模式(schema)。记录模式通常包括实体类型(type)、关系角色(role)和类型-角色字典(type_role_dict)。
__init__
方法初始化 RecordSchema
类的实例,需要传入类型列表(type_list
)、角色列表(role_list
)和类型-角色字典(type_role_dict
)。read_from_file
方法从文件中读取记录模式的定义。它读取文件的前三行,分别包含类型列表、角色列表和类型-角色字典的定义,并使用这些信息创建 RecordSchema
的实例。write_to_file
方法将记录模式的定义写入文件。它将类型列表、角色列表和类型-角色字典的信息写入文件的三行中。此外,代码还定义了一个名为 merge_schema
的函数,用于合并多个记录模式。它接受一个记录模式列表 schema_list
,并合并这些记录模式的类型、角色和类型-角色字典信息,最终返回一个合并后的 RecordSchema
实例。
这些工具函数和类用于管理任务的记录模式信息,以便在任务处理过程中进行合并和存储。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。