当前位置:   article > 正文

Datawhale AI夏令营 baseline1精读分享直播

Datawhale AI夏令营 baseline1精读分享直播

目录

一、环境配置

二、数据处理

三、promot工程

四、数据抽取 

五、个人感悟


听了baseline1的精读分享直播,对程序进行简要梳理。

一、环境配置

spark_ai_python要求python版本3.8以上

!pip install --upgrade -q spark_ai_python tqdm jsonschema python-dotenv

导入包,其中dotenv是一个零依赖的模块,它的主要功能是从.env文件中加载环境变量,此处从.env文件中加载ID、SECRET和KEY变量 

  1. from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
  2. from sparkai.core.messages import ChatMessage
  3. import json
  4. from dotenv import load_dotenv
  5. # 加载.env文件中的环境变量
  6. """
  7. SPARKAI_APP_ID=""
  8. SPARKAI_API_SECRET=""
  9. SPARKAI_API_KEY=""
  10. """
  11. load_dotenv()

 在讯飞开放控制台获得大模型调用密钥信息

  1. #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
  2. SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
  3. #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
  4. SPARKAI_APP_ID = os.getenv("SPARKAI_APP_ID")
  5. SPARKAI_API_SECRET = os.getenv("SPARKAI_API_SECRET")
  6. SPARKAI_API_KEY = os.getenv("SPARKAI_API_KEY")
  7. #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
  8. SPARKAI_DOMAIN = 'generalv3.5'

 测试星火大模型是否可以正常使用

  1. def get_completions(text):
  2. messages = [ChatMessage(
  3. role="user",
  4. content=text
  5. )]
  6. spark = ChatSparkLLM(
  7. spark_api_url=SPARKAI_URL,
  8. spark_app_id=SPARKAI_APP_ID,
  9. spark_api_key=SPARKAI_API_KEY,
  10. spark_api_secret=SPARKAI_API_SECRET,
  11. spark_llm_domain=SPARKAI_DOMAIN,
  12. streaming=False,
  13. )
  14. handler = ChunkPrintHandler()
  15. a = spark.generate([messages], callbacks=[handler])
  16. return a.generations[0][0].text
  17. # 测试模型配置是否正确
  18. text = "你好"
  19. get_completions(text)

二、数据处理

目的:读取训练集和测试集的数据本身

定义两个函数用于读取json文件和写入json文件 

  1. def read_json(json_file_path):
  2. """读取json文件"""
  3. with open(json_file_path, 'r') as f:
  4. data = json.load(f)
  5. return data
  6. def write_json(json_file_path, data):
  7. """写入json文件"""
  8. with open(json_file_path, 'w') as f:
  9. json.dump(data, f, ensure_ascii=False, indent=4)
  10. # 读取数据
  11. train_data = read_json("dataset/train.json")
  12. test_data = read_json("dataset/test_data.json")

查看对话数据集,是群聊数据,有部分影响大模型结果的信息存在:图片和引用消息等

  1. # 查看对话数据
  2. print(train_data[100]['chat_text'])

格式化为json格式 

  1. # 查看对话标签
  2. def print_json_format(data):
  3. """格式化输出json格式"""
  4. print(json.dumps(data, indent=4, ensure_ascii=False))
  5. print_json_format(train_data[100]['infos'])

定义函数提取文本中的json字符串,type为list

  1. def convert_all_json_in_text_to_dict(text):
  2. """提取LLM输出文本中的json字符串"""
  3. dicts, stack = [], []
  4. for i in range(len(text)):
  5. if text[i] == '{':
  6. stack.append(i)
  7. elif text[i] == '}':
  8. begin = stack.pop()
  9. if not stack:
  10. dicts.append(json.loads(text[begin:i+1]))
  11. return dicts
  12. llm_output = """
  13. ```json
  14. [{
  15. "基本信息-姓名": "李强1",
  16. "基本信息-手机号码": "11059489858"
  17. }]
  18. ```
  19. """
  20. # 测试一下效果
  21. json_res = convert_all_json_in_text_to_dict(llm_output)
  22. print_json_format(json_res)
  23. print(type(json_res))

三、promot工程

promot编写思路:任务目标-抽取数据定义-抽取内容引入-抽取规则强调

将群聊对话输入大模型 

  1. prompt 设计
  2. PROMPT_EXTRACT = """
  3. 你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。

 运行测试

  1. content = train_data[100]['chat_text']
  2. res = get_completions(PROMPT_EXTRACT.format(content=content))
  3. json_res = convert_all_json_in_text_to_dict(res)
  4. print_json_format(json_res)

查看原格式,含有markdown标签

print(res)

 

 查看数据对应的标签

  1. # 查看训练数据对应的标签
  2. print_json_format(train_data[100]['infos'])

四、数据抽取 

检查json格式并补全,防止大模型将空字段删除以及输出格式异常

check_and_complete_json_format函数对大模型抽取的结果进行字段格式的检查以及缺少的字段进行补全

  1. import json
  2. class JsonFormatError(Exception):
  3. def __init__(self, message):
  4. self.message = message
  5. super().__init__(self.message)
  6. def check_and_complete_json_format(data):
  7. required_keys = {
  8. "基本信息-姓名": str,
  9. "基本信息-手机号码": str,
  10. "基本信息-邮箱": str,
  11. "基本信息-地区": str,
  12. "基本信息-详细地址": str,
  13. "基本信息-性别": str,
  14. "基本信息-年龄": str,
  15. "基本信息-生日": str,
  16. "咨询类型": list,
  17. "意向产品": list,
  18. "购买异议点": list,
  19. "客户预算-预算是否充足": str,
  20. "客户预算-总体预算金额": str,
  21. "客户预算-预算明细": str,
  22. "竞品信息": str,
  23. "客户是否有意向": str,
  24. "客户是否有卡点": str,
  25. "客户购买阶段": str,
  26. "下一步跟进计划-参与人": list,
  27. "下一步跟进计划-时间点": str,
  28. "下一步跟进计划-具体事项": str
  29. }
  30. if not isinstance(data, list):
  31. raise JsonFormatError("Data is not a list")
  32. for item in data:
  33. if not isinstance(item, dict):
  34. raise JsonFormatError("Item is not a dictionary")
  35. for key, value_type in required_keys.items():
  36. if key not in item:
  37. item[key] = [] if value_type == list else ""
  38. if not isinstance(item[key], value_type):
  39. raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
  40. if value_type == list and not all(isinstance(i, str) for i in item[key]):
  41. raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")
  42. return data
  43. # Example usage:
  44. json_data = '''
  45. [
  46. {
  47. "基本信息-姓名": "张三",
  48. "基本信息-邮箱": "zhangsan@example.com",
  49. "基本信息-地区": "北京市",
  50. "基本信息-详细地址": "朝阳区某街道",
  51. "基本信息-性别": "男",
  52. "基本信息-年龄": "30",
  53. "基本信息-生日": "1990-01-01",
  54. "咨询类型": "",
  55. "意向产品": ["产品A"],
  56. "购买异议点": ["价格高"],
  57. "客户预算-预算是否充足": "充足",
  58. "客户预算-总体预算金额": "10000",
  59. "客户预算-预算明细": "详细预算内容",
  60. "竞品信息": "竞争对手B",
  61. "客户是否有意向": "有意向",
  62. "客户是否有卡点": "无卡点",
  63. "客户购买阶段": "合同中",
  64. "下一步跟进计划-参与人": ["客服A"],
  65. "下一步跟进计划-时间点": "2024-07-01",
  66. "下一步跟进计划-具体事项": "沟通具体事项"
  67. }
  68. ]
  69. '''
  70. try:
  71. data = json.loads(json_data)
  72. completed_data = check_and_complete_json_format(data)
  73. print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
  74. except JsonFormatError as e:
  75. print(f"JSON format error: {e.message}")

 使用另一个jsonschema库进行更加简便的格式验证

  1. import json
  2. from jsonschema import validate, Draft7Validator
  3. from jsonschema.exceptions import ValidationError
  4. class JsonFormatError(Exception):
  5. def __init__(self, message):
  6. self.message = message
  7. super().__init__(self.message)
  8. schema = {
  9. "type": "array",
  10. "items": {
  11. "type": "object",
  12. "properties": {
  13. "基本信息-姓名": {"type": "string", "default": ""},
  14. "基本信息-手机号码": {"type": "string", "default": ""},
  15. "基本信息-邮箱": {"type": "string", "default": ""},
  16. "基本信息-地区": {"type": "string", "default": ""},
  17. "基本信息-详细地址": {"type": "string", "default": ""},
  18. "基本信息-性别": {"type": "string", "default": ""},
  19. "基本信息-年龄": {"type": "string", "default": ""},
  20. "基本信息-生日": {"type": "string", "default": ""},
  21. "咨询类型": {"type": "array", "items": {"type": "string"}, "default": []},
  22. "意向产品": {"type": "array", "items": {"type": "string"}, "default": []},
  23. "购买异议点": {"type": "array", "items": {"type": "string"}, "default": []},
  24. "客户预算-预算是否充足": {"type": "string", "enum": ["充足", "不充足", ""], "default": ""},
  25. "客户预算-总体预算金额": {"type": "string", "default": ""},
  26. "客户预算-预算明细": {"type": "string", "default": ""},
  27. "竞品信息": {"type": "string", "default": ""},
  28. "客户是否有意向": {"type": "string", "enum": ["有意向", "无意向", ""], "default": ""},
  29. "客户是否有卡点": {"type": "string", "enum": ["有卡点", "无卡点", ""], "default": ""},
  30. "客户购买阶段": {"type": "string", "default": ""},
  31. "下一步跟进计划-参与人": {"type": "array", "items": {"type": "string"}, "default": []},
  32. "下一步跟进计划-时间点": {"type": "string", "default": ""},
  33. "下一步跟进计划-具体事项": {"type": "string", "default": ""}
  34. },
  35. "required": [
  36. "基本信息-姓名", "基本信息-手机号码", "基本信息-邮箱", "基本信息-地区",
  37. "基本信息-详细地址", "基本信息-性别", "基本信息-年龄", "基本信息-生日",
  38. "咨询类型", "意向产品", "购买异议点", "客户预算-预算是否充足",
  39. "客户预算-总体预算金额", "客户预算-预算明细", "竞品信息",
  40. "客户是否有意向", "客户是否有卡点", "客户购买阶段",
  41. "下一步跟进计划-参与人", "下一步跟进计划-时间点", "下一步跟进计划-具体事项"
  42. ]
  43. }
  44. }
  45. def validate_and_complete_json(data):
  46. # Create a validator with the ability to fill in default values
  47. validator = Draft7Validator(schema)
  48. for item in data:
  49. errors = sorted(validator.iter_errors(item), key=lambda e: e.path)
  50. for error in errors:
  51. # If the property is missing and has a default, apply the default value
  52. for subschema in error.schema_path:
  53. if 'default' in error.schema:
  54. item[error.schema_path[-1]] = error.schema['default']
  55. break
  56. # Validate the completed data
  57. try:
  58. validate(instance=data, schema=schema)
  59. except ValidationError as e:
  60. raise JsonFormatError(f"JSON format error: {e.message}")
  61. return data
  62. # Example usage:
  63. json_data = '''
  64. [
  65. {
  66. "基本信息-姓名": "张三",
  67. "基本信息-手机号码": "12345678901",
  68. "基本信息-邮箱": "zhangsan@example.com",
  69. "基本信息-地区": "北京市",
  70. "基本信息-详细地址": "朝阳区某街道",
  71. "基本信息-性别": "男",
  72. "基本信息-年龄": "30",
  73. "基本信息-生日": "1990-01-01",
  74. "咨询类型": ["询价"],
  75. "意向产品": ["产品A"],
  76. "购买异议点": ["价格高"],
  77. "客户预算-预算是否充足": "充足",
  78. "客户预算-总体预算金额": "10000",
  79. "客户预算-预算明细": "详细预算内容",
  80. "竞品信息": "竞争对手B",
  81. "客户是否有意向": "有意向",
  82. "客户是否有卡点": "无卡点",
  83. "客户购买阶段": "合同中",
  84. "下一步跟进计划-参与人": ["客服A"],
  85. "下一步跟进计划-时间点": "2024-07-01",
  86. "下一步跟进计划-具体事项": "沟通具体事项"
  87. }
  88. ]
  89. '''
  90. try:
  91. data = json.loads(json_data)
  92. completed_data = validate_and_complete_json(data)
  93. print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
  94. except JsonFormatError as e:
  95. print(f"JSON format error: {e.message}")

防止数据格式异常,可重新调用API获取数据

  1. if error_data:
  2. retry_count = 10 # 重试次数
  3. error_data_temp = []
  4. while True:
  5. if error_data_temp:
  6. error_data = error_data_temp
  7. error_data_temp = []
  8. for data in tqdm(error_data):
  9. is_success = False
  10. for i in range(retry_count):
  11. try:
  12. res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
  13. infos = convert_all_json_in_text_to_dict(res)
  14. infos = check_and_complete_json_format(infos)
  15. result.append({
  16. "infos": infos,
  17. "index": data["index"]
  18. })
  19. is_success = True
  20. break
  21. except Exception as e:
  22. print("index:", index, ", error:", e)
  23. continue
  24. if not is_success:
  25. error_data_temp.append(data)
  26. if not error_data_temp:
  27. break
  28. result = sorted(result, key=lambda x: x["index"])

如果有错误数据,重新请求 

  1. from tqdm import tqdm
  2. retry_count = 5 # 重试次数
  3. result = []
  4. error_data = []
  5. for index, data in tqdm(enumerate(test_data)):
  6. index += 1
  7. is_success = False
  8. for i in range(retry_count):
  9. try:
  10. res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
  11. infos = convert_all_json_in_text_to_dict(res)
  12. infos = check_and_complete_json_format(infos)
  13. result.append({
  14. "infos": infos,
  15. "index": index
  16. })
  17. is_success = True
  18. break
  19. except Exception as e:
  20. print("index:", index, ", error:", e)
  21. continue
  22. if not is_success:
  23. data["index"] = index
  24. error_data.append(data)

写入output.json并提交 

write_json("output.json", result)

五、个人感悟

在运行过程中报错:

Did not find spark_app_id, please add an environment variable `IFLYTEK_SPARK_APP_ID` which contains it, or pass `spark_app_id` as a named parameter. (type=value_error)

经过多次尝试后发现是需要在目录下自己创建.env文件,将环境变量写进去。

不知道是不是这么做的,但是我添加了文本文件,重命名为.env,之后该文件就不见了,但是程序可以运行,猜测是因为这个名称改变了路径。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/994786
推荐阅读
相关标签
  

闽ICP备14008679号