当前位置:   article > 正文

基于星火大模型的群聊对话分角色要素提取挑战赛|#AI夏令营#Datawhale# baseline解读_from sparkai.llm.llm import chatsparkllm, chunkpri

from sparkai.llm.llm import chatsparkllm, chunkprinthandler

1.环境配置 

  1. from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
  2. from sparkai.core.messages import ChatMessage
  3. import numpy as np
  4. from tqdm import tqdm
  5. def chatbot(prompt):
  6. #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
  7. SPARKAI_URL = ''
  8. #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
  9. SPARKAI_APP_ID = ''
  10. SPARKAI_API_SECRET = ''
  11. SPARKAI_API_KEY = ''
  12. #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
  13. SPARKAI_DOMAIN = 'generalv3.5'
  14. spark = ChatSparkLLM(
  15. spark_api_url=SPARKAI_URL,
  16. spark_app_id=SPARKAI_APP_ID,
  17. spark_api_key=SPARKAI_API_KEY,
  18. spark_api_secret=SPARKAI_API_SECRET,
  19. spark_llm_domain=SPARKAI_DOMAIN,
  20. streaming=False,
  21. )
  22. messages = [ChatMessage(
  23. role="user",
  24. content=prompt
  25. )]
  26. handler = ChunkPrintHandler()
  27. a = spark.generate([messages], callbacks=[handler])
  28. return a.generations[0][0].message.content

chatbot 函数接受一个 prompt 参数,该参数是用户输入的聊天消息。ChatSparkLLM 类实例化一个大模型对象,通过提供的配置参数进行初始化。streaming=False 指定不使用流式传输。

2.数据处理Prompt

  1. content = ''
  2. prompt = f'''
  3. 你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
  4. ****群聊对话****
  5. {content}
  6. ****分析数据****
  7. 客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
  8. 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
  9. 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
  10. 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
  11. ****注意****
  12. 1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
  13. 2.不要输出分析内容
  14. 3.输出内容格式为md格式
  15. '''

这里我们对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。

一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。

过长的上下文也会导致训练时长和费用倍增。

3.训练数据集制作

jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个jsonl文件(每行是一个json数据的文件),格式如下

jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}

这段代码旨在创建一个用于训练的 JSONL 文件(traindata.jsonl),其中每条数据包含对群聊对话的分析结果。由于数据量不足,通过重复数据集12次来满足最低训练集行数要求。主要步骤解读:1.打开文件用于写入 2.遍历数据列表 3. 提取输入输出数据 4.构建 Prompt 并调用 chatbot 函数 5.构建要写入文件的数据格式 6. 重复写入数据以扩充训练集

  1. with open('traindata.jsonl', 'w', encoding='utf-8') as file:
  2. # 训练集行数(130)不符合要求,范围:1500~90000000
  3. # 遍历数据列表,并将每一行写入文件
  4. # 这里为了满足微调需求我们重复12次数据集 130*12=1560
  5. for line_data in tqdm(data):
  6. line_input = line_data["chat_text"]
  7. line_output = line_data["infos"]
  8. content = line_input
  9. prompt = f'''
  10. 你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
  11. ****群聊对话****
  12. {content}
  13. ****分析数据****
  14. 客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
  15. 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
  16. 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
  17. 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
  18. ****注意****
  19. 1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
  20. 2.不要输出分析内容
  21. 3.输出内容格式为md格式
  22. '''
  23. res = chatbot(prompt=prompt)
  24. # print(res)
  25. line_write = {
  26. "instruction":jsonl_data["instruction"],
  27. "input":json.dumps(res, ensure_ascii=False),
  28. "output":json.dumps(line_output, ensure_ascii=False)
  29. }
  30. # 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
  31. for time in range(12):
  32. file.write(json.dumps(line_write, ensure_ascii=False) + '\n') # '\n' 用于在每行末尾添加换行符

4.测试集制作

  1. # 验证集制作(提交版本)
  2. # input,target
  3. import json
  4. # 打开并读取JSON文件
  5. with open('test_data.json', 'r', encoding='utf-8') as file:
  6. data_test = json.load(file)
  7. import csv
  8. # 打开一个文件用于写入CSV数据
  9. with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
  10. # 创建一个csv writer对象
  11. csvwriter = csv.writer(csvfile)
  12. csvwriter.writerow(["input","target"])
  13. # 遍历数据列表,并将每一行写入CSV文件
  14. for line_data in tqdm(data_test):
  15. content = line_data["chat_text"]
  16. prompt = f'''
  17. 你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
  18. ****群聊对话****
  19. {content}
  20. ****分析数据****
  21. 客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
  22. 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
  23. 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
  24. 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
  25. ****注意****
  26. 1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
  27. 2.不要输出分析内容
  28. 3.输出内容格式为md格式
  29. '''
  30. res = chatbot(prompt=prompt)
  31. # print(line_data["chat_text"])
  32. ## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
  33. line_list = [res, "-"]
  34. csvwriter.writerow(line_list)
  35. # break

5.测试集与训练集制作区别

文件格式

  • 第一段代码将数据写入 JSONL 文件(traindata.jsonl)。
  • 第二段代码将数据写入 CSV 文件(test.csv)。

写入文件的内容格式

  • 第一段代码中的每行数据包含 instructioninputoutput 字段,且为了满足训练集的行数要求,每行数据被重复写入 12 次。
  • 第二段代码中的每行数据包含 inputtarget 字段,且没有重复写入数据。

文件写入方式

  • 第一段代码使用 json.dumps 将字典对象转换为 JSON 字符串,然后写入文件。
  • 第二段代码使用 csv.writer 将数据列表写入 CSV 文件。

字符数量限制提示

  • 第二段代码在注释中指出了字符数量限制的问题,并提到当前行字符数量超出 8000 字符的限制。

 6.微调推理

  1. # 定义写入函数
  2. def write_json(json_file_path, data):
  3. #"""写入json文件"""
  4. with open(json_file_path, 'w') as f:
  5. json.dump(data, f, ensure_ascii=False, indent=4)
  6. import SparkApi
  7. import json
  8. #以下密钥信息从控制台获取
  9. appid = "" #填写控制台中获取的 APPID 信息
  10. api_secret = "" #填写控制台中获取的 APISecret 信息
  11. api_key ="" #填写控制台中获取的 APIKey 信息
  12. #调用微调大模型时,设置为“patch”
  13. domain = "patchv3"
  14. #云端环境的服务地址
  15. # Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat" # 微调v1.5环境的地址
  16. Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat" # 微调v3.0环境的地址
  17. text =[]
  18. # length = 0
  19. def getText(role,content):
  20. jsoncon = {}
  21. jsoncon["role"] = role
  22. jsoncon["content"] = content
  23. text.append(jsoncon)
  24. return text
  25. def getlength(text):
  26. length = 0
  27. for content in text:
  28. temp = content["content"]
  29. leng = len(temp)
  30. length += leng
  31. return length
  32. def checklen(text):
  33. while (getlength(text) > 8000):
  34. del text[0]
  35. return text
  36. def core_run(text,prompt):
  37. # print('prompt',prompt)
  38. text.clear
  39. Input = prompt
  40. question = checklen(getText("user",Input))
  41. SparkApi.answer =""
  42. # print("星火:",end = "")
  43. SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
  44. getText("assistant",SparkApi.answer)
  45. # print(text)
  46. return text[-1]['content']
  47. text = []
  48. res = core_run(text,'你好吗?')

详细解读:

  • 写入 JSON 文件的函数

    • write_json 函数接受文件路径和数据,并将数据写入到指定的 JSON 文件中。
  • 密钥信息和 API 地址

    • 设置与 SparkApi 通信所需的密钥信息(appid, api_secret, api_key)和 API 地址(Spark_url)。
  • 获取文本长度的函数

    • getText 函数将对话角色和内容添加到 text 列表中。
    • getlength 函数计算 text 列表中所有内容的总长度。
    • checklen 函数确保 text 列表中的总长度不超过 8000 字符,超出时删除最早的内容。
  • 核心运行函数

    • core_run 函数执行对话,构建用户输入并调用 SparkApi 进行处理,然后将返回的结果添加到 text 列表中,并返回最新的对话内容。
  • 运行示例

    • 创建一个空的 text 列表,并调用 core_run 函数进行对话示例,传入输入内容 '你好吗?'。

 7.获取结果

这段代码通过读取 CSV 文件中的数据,逐行处理并生成相应的 JSON 数据,将处理后的结果存储到一个新的列表中。代码还包括处理 JSON 解析错误的机制,并在解析失败时使用一个预定义的空字典作为默认值。data_dict_empty 是一个包含所有可能字段的空字典,用于在解析失败时作为默认值。submit_data 用于存储处理后的数据,每个元素是一个包含 infos(解析后的数据字典)和 index(行号)的字典。

  1. import pandas as pd
  2. import re
  3. # 读取Excel文件
  4. df_test = pd.read_csv('test.csv',)
  5. data_dict_empty = {
  6. "基本信息-姓名": "",
  7. "基本信息-手机号码": "",
  8. "基本信息-邮箱": "",
  9. "基本信息-地区": "",
  10. "基本信息-详细地址": "",
  11. "基本信息-性别": "",
  12. "基本信息-年龄": "",
  13. "基本信息-生日": "",
  14. "咨询类型": [],
  15. "意向产品": [],
  16. "购买异议点": [],
  17. "客户预算-预算是否充足": "",
  18. "客户预算-总体预算金额": "",
  19. "客户预算-预算明细": "",
  20. "竞品信息": "",
  21. "客户是否有意向": "",
  22. "客户是否有卡点": "",
  23. "客户购买阶段": "",
  24. "下一步跟进计划-参与人": [],
  25. "下一步跟进计划-时间点": "",
  26. "下一步跟进计划-具体事项": ""
  27. }
  28. submit_data = []
  29. for id,line_data in tqdm(enumerate(df_test['input'])):
  30. # print(line_data)
  31. content = line_data
  32. text = []
  33. prompt = json.dumps(content,ensure_ascii=False)
  34. # print(json.dumps(content,ensure_ascii=False))
  35. res = core_run(text,prompt)
  36. try:
  37. data_dict = json.loads(res)
  38. except json.JSONDecodeError as e:
  39. data_dict = data_dict_empty
  40. submit_data.append({"infos":data_dict,"index":id+1})
  41. # 预计执行8min
  • 对每一行的 input 数据进行处理,构建 prompt 并调用 core_run 函数。
  • 尝试将 core_run 返回的结果解析为 JSON,如果解析失败,则使用 data_dict_empty 作为默认值。
  • 将处理后的数据(infosindex)添加到 submit_data 列表中。

注:微调模型

 模型微调在平台进行,采用零代码微调,微调方法为LoRA

LoRA在固定预训练大模型本身的参数基础上,在保留自注意力模块中原始权重矩阵的基础上,对权重矩阵进行低秩分解,训练过程中只更新低秩部分的参数。

再此不进行详细解释。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/994740
推荐阅读
相关标签
  

闽ICP备14008679号