当前位置:   article > 正文

基于星火大模型的群聊对话分角色要素提取挑战赛Task1笔记_from sparkai.llm.llm import chatsparkllm, chunkpri

from sparkai.llm.llm import chatsparkllm, chunkprinthandler from sparkai.cor

基于星火大模型的群聊对话分角色要素提取挑战赛Task1笔记

跑通baseline

1、安装依赖

下载相应的数据库

!pip install --upgrade -q spark_ai_python
  • 1

2、配置导入

导入必要的包。

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json

  • 1
  • 2
  • 3
  • 4

配置设置相关参数。

#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = '2e001699'
SPARKAI_API_SECRET = 'ZmU2YTliYmU1YjViODlkMDYwOWZlOTc4'
SPARKAI_API_KEY = '52a07d74ef95aead407a958f448d4464'
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

申领大模型API来自:https://console.xfyun.cn/app/myapp

在这里插入图片描述

3、模型测试

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
get_completions(text)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

4、数据读取

def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

5、Prompt设计

 4. 加载决策树模型进行训练
model = LGBMClassifier(verbosity=-1)
model.fit(train.iloc[:, 2:].values, train['Label'])
pred = model.predict(test.iloc[:, 1:].values, )
  • 1
  • 2
  • 3
  • 4

6、主函数启动

import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data

# Example usage:
json_data = '''
[
    {
        "基本信息-姓名": "张三",
        "基本信息-手机号码": "12345678901",
        "基本信息-邮箱": "zhangsan@example.com",
        "基本信息-地区": "北京市",
        "基本信息-详细地址": "朝阳区某街道",
        "基本信息-性别": "男",
        "基本信息-年龄": "30",
        "基本信息-生日": "1990-01-01",
        "咨询类型": ["询价"],
        "意向产品": ["产品A"],
        "购买异议点": ["价格高"],
        "客户预算-预算是否充足": "充足",
        "客户预算-总体预算金额": "10000",
        "客户预算-预算明细": "详细预算内容",
        "竞品信息": "竞争对手B",
        "客户是否有意向": "有意向",
        "客户是否有卡点": "无卡点",
        "客户购买阶段": "合同中",
        "下一步跟进计划-参与人": ["客服A"],
        "下一步跟进计划-时间点": "2024-07-01",
        "下一步跟进计划-具体事项": "沟通具体事项"
    }
]
'''

try:
    data = json.loads(json_data)
    completed_data = check_and_complete_json_format(data)
    print("Completed JSON:", json.dumps(completed_data, ensure_ascii=False, indent=4))
except JsonFormatError as e:
    print(f"JSON format error: {e.message}")# 5. 保存结果文件到本地
pd.DataFrame(
    {
        'uuid': test['uuid'],
        'Label': pred
    }
).to_csv('submit.csv', index=None)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89

7、生成提交文件

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

8、保存输出

write_json("output.json", result)
  • 1
print("index:", index, ", error:", e)
        continue
if not is_success:
    data["index"] = index
    error_data.append(data)
  • 1
  • 2
  • 3
  • 4
  • 5

## 8、保存输出

```python
write_json("output.json", result)
  • 1
  • 2
  • 3
  • 4
  • 5
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Li_阴宅/article/detail/994738
推荐阅读
相关标签
  

闽ICP备14008679号