当前位置:   article > 正文

2024 Datawhale 夏令营 零基础入门大模型技术竞赛 Task1 + Task2_datawhale夏令营

datawhale夏令营

#AI夏令营 #Datawhale #夏令营

记录自己参加 Datawhale 夏令营期间的学习心得,欢迎交流讨论

这一版笔记主要是对 Baseline 跑通过程的记录、代码初步分析;直播讲解的观后心得。
(Task 1、Task 2)

一、跑通 Baseline

相关的 baseline 手册 https://datawhaler.feishu.cn/wiki/VIy8ws47ii2N79kOt9zcXnbXnuS
文档十分详细,小白 30 分钟内可以完成速通体验,对流程环节有一个大概的认识。

1 报名赛事、申领大模型 API

根据手册操作即可,注册账号登录,填写身份信息
赛事链接:https://challenge.xfyun.cn/h5/detail?type=role-element-extraction&ch=dw24_y0SCtd
大模型 API 链接:https://console.xfyun.cn/app/myapp

2 Baseline 项目复刻(分析)

baseline 项目链接: https://aistudio.baidu.com/projectdetail/8095619

2.1 下载相关库

!pip install --upgrade -q spark_ai_python
  • 1

下载 python 接入星火大模型的库,调用星火大模型 API

2.2 导入配置

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import json


#星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
SPARKAI_APP_ID = ''
SPARKAI_API_SECRET = ''
SPARKAI_API_KEY = ''
#星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
SPARKAI_DOMAIN = 'generalv3.5'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

填入调用模型的相关信息

2.3 模型测试

测试模型是否能够正常完成调用:

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你好"
get_completions(text)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

get_completions 函数:

  • 将文本输入 text 作为参数
  • 内部定义了 messages 结构体,调用了 ChatMessage 方法(传入身份和上下文)
  • 创建了一个 ChatSparkLLM 类实例 spark,将之前的配置导入,用以与模型交互
  • 创建了一个 ChunkPrintHandler 类实例 handler,用于处理生成的文本
  • 调用 spark.generate() 方法来生成文本 a,将用户消息传递给模型,并将生成的文本通过 handler 处理
  • 最后,返回生成的文本

2.4 数据读取

这一步是读取比赛提供相关的json文件,用以后续的处理

def read_json(json_file_path):
    """读取json文件"""
    with open(json_file_path, 'r') as f:
        data = json.load(f)
    return data

def write_json(json_file_path, data):
    """写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据
train_data = read_json("dataset/train.json")
test_data = read_json("dataset/test_data.json")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

2.5 Prompt设计

为了让大模型在接收到数据之后能够按照我们的需要进行输出,我们需要进行 prompt 设计。
Baseline 的 promt 任务大致有:

  1. 任务
  2. 表单的格式
  3. 聊天对话记录(由后续的 PROMPT_EXTRACT.format(content=data["chat_text"]) 补充完整)
  4. 输出的格式
# prompt 设计

PROMPT_EXTRACT = """

你将获得一段群聊对话记录。你的任务是根据给定的表单格式从对话记录中提取结构化信息。在提取信息时,请确保它与类型信息完全匹配,不要添加任何没有出现在下面模式中的属性。


表单格式如下:

info: Array<Dict(

    "基本信息-姓名": string | "",  // 客户的姓名。

    "基本信息-手机号码": string | "",  // 客户的手机号码。

    "基本信息-邮箱": string | "",  // 客户的电子邮箱地址。

    "基本信息-地区": string | "",  // 客户所在的地区或城市。

    "基本信息-详细地址": string | "",  // 客户的详细地址。

    "基本信息-性别": string | "",  // 客户的性别。

    "基本信息-年龄": string | "",  // 客户的年龄。

    "基本信息-生日": string | "",  // 客户的生日。

    "咨询类型": string[] | [],  // 客户的咨询类型,如询价、答疑等。

    "意向产品": string[] | [],  // 客户感兴趣的产品。

    "购买异议点": string[] | [],  // 客户在购买过程中提出的异议或问题。

    "客户预算-预算是否充足": string | "",  // 客户的预算是否充足。示例:充足, 不充足

    "客户预算-总体预算金额": string | "",  // 客户的总体预算金额。

    "客户预算-预算明细": string | "",  // 客户预算的具体明细。

    "竞品信息": string | "",  // 竞争对手的信息。

    "客户是否有意向": string | "",  // 客户是否有购买意向。示例:有意向, 无意向

    "客户是否有卡点": string | "",  // 客户在购买过程中是否遇到阻碍或卡点。示例:有卡点, 无卡点

    "客户购买阶段": string | "",  // 客户当前的购买阶段,如合同中、方案交流等。

    "下一步跟进计划-参与人": string[] | [],  // 下一步跟进计划中涉及的人员(客服人员)。

    "下一步跟进计划-时间点": string | "",  // 下一步跟进的时间点。

    "下一步跟进计划-具体事项": string | ""  // 下一步需要进行的具体事项。

)>

  

请分析以下群聊对话记录,并根据上述格式提取信息:

  

**对话记录:**

\```
{content}
\```


请将提取的信息以JSON格式输出。

不要添加任何澄清信息。

输出必须遵循上面的模式。

不要添加任何没有出现在模式中的附加字段。

不要随意删除字段。

  

**输出:**

\```

[{{

    "基本信息-姓名": "姓名",

    "基本信息-手机号码": "手机号码",

    "基本信息-邮箱": "邮箱",

    "基本信息-地区": "地区",

    "基本信息-详细地址": "详细地址",

    "基本信息-性别": "性别",

    "基本信息-年龄": "年龄",

    "基本信息-生日": "生日",

    "咨询类型": ["咨询类型"],

    "意向产品": ["意向产品"],

    "购买异议点": ["购买异议点"],

    "客户预算-预算是否充足": "充足或不充足",

    "客户预算-总体预算金额": "总体预算金额",

    "客户预算-预算明细": "预算明细",

    "竞品信息": "竞品信息",

    "客户是否有意向": "有意向或无意向",

    "客户是否有卡点": "有卡点或无卡点",

    "客户购买阶段": "购买阶段",

    "下一步跟进计划-参与人": ["跟进计划参与人"],

    "下一步跟进计划-时间点": "跟进计划时间点",

    "下一步跟进计划-具体事项": "跟进计划具体事项"

}}, ...]

\```

"""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133

2.6 启动主函数

在正式调用大模型之前,代码还定义了两个函数

  • convert_all_json_in_text_to_dict:将大模型的文本输出转化为 dict 格式
  • check_and_complete_json_format: 检查提取的 json 下的每一项的格式是否正确,不正确则会报错(raise 代码定义的 JsonFormatError)
import json

class JsonFormatError(Exception):
    def __init__(self, message):
        self.message = message
        super().__init__(self.message)

def convert_all_json_in_text_to_dict(text):
    """提取LLM输出文本中的json字符串"""
    dicts, stack = [], []
    for i in range(len(text)):
        if text[i] == '{':
            stack.append(i)
        elif text[i] == '}':
            begin = stack.pop()
            if not stack:
                dicts.append(json.loads(text[begin:i+1]))
    return dicts

# 查看对话标签
def print_json_format(data):
    """格式化输出json格式"""
    print(json.dumps(data, indent=4, ensure_ascii=False))

def check_and_complete_json_format(data):
    required_keys = {
        "基本信息-姓名": str,
        "基本信息-手机号码": str,
        "基本信息-邮箱": str,
        "基本信息-地区": str,
        "基本信息-详细地址": str,
        "基本信息-性别": str,
        "基本信息-年龄": str,
        "基本信息-生日": str,
        "咨询类型": list,
        "意向产品": list,
        "购买异议点": list,
        "客户预算-预算是否充足": str,
        "客户预算-总体预算金额": str,
        "客户预算-预算明细": str,
        "竞品信息": str,
        "客户是否有意向": str,
        "客户是否有卡点": str,
        "客户购买阶段": str,
        "下一步跟进计划-参与人": list,
        "下一步跟进计划-时间点": str,
        "下一步跟进计划-具体事项": str
    }

    if not isinstance(data, list):
        raise JsonFormatError("Data is not a list")

    for item in data:
        if not isinstance(item, dict):
            raise JsonFormatError("Item is not a dictionary")
        for key, value_type in required_keys.items():
            if key not in item:
                item[key] = [] if value_type == list else ""
            if not isinstance(item[key], value_type):
                raise JsonFormatError(f"Key '{key}' is not of type {value_type.__name__}")
            if value_type == list and not all(isinstance(i, str) for i in item[key]):
                raise JsonFormatError(f"Key '{key}' does not contain all strings in the list")

    return data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

完成了前面的准备工作,便可以正式开始调用大模型进行交互

from tqdm import tqdm

retry_count = 5 # 重试次数
result = []
error_data = []

for index, data in tqdm(enumerate(test_data)):
    index += 1
    is_success = False
    for i in range(retry_count):
        try:
            res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))
            infos = convert_all_json_in_text_to_dict(res)
            infos = check_and_complete_json_format(infos)
            result.append({
                "infos": infos,
                "index": index
            })
            is_success = True
            break
        except Exception as e:
            print("index:", index, ", error:", e)
            continue
    if not is_success:
        data["index"] = index
        error_data.append(data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

对于每一个 test_data, 代码会尝试调用模型五次。
由于模型的输出具有不确定性,因此每次都要对返回的文本进行解析然后检查,检查无误之后才会将数据保存到 result 中。

2.7 生成提交文件

最后将 result 保存为 output.json 便完成了 baseline 的代码运行!

# 保存输出
write_json("output.json", result)
  • 1
  • 2

二、项目代码分析与直播观后心得

1 赛事任务与评价指标

  • 赛事任务:从给定的<客服>与<客户>的群聊对话中, 提取出指定的字段信息,待提取的全部字段见数据说明。
  • 评价指标:测试集的每条数据同样包含共21个字段, 按照各字段难易程度划分总计满分36分。每个提取正确性的判定标准如下:
    1. 对于答案唯一字段,将使用完全匹配的方式计算提取是否正确,提取正确得到相应分数,否则为0分。
    2. 对于答案不唯一字段,将综合考虑提取完整性、语义相似度等维度判定提取的匹配分数,最终该字段得分为 “匹配分数 * 该字段难度分数”。
    3. 每条测试数据的最终得分为各字段累计得分。最终测试集上的分数为所有测试数据的平均得分。

2 总体思路方向

  • 比赛鼓励对 promt 开发优化。
  • 即:通过设计 prompt 强调抽取的数据格式和数据内容,将测试集的数据通过大语言模型抽取得到结果。
  • 受限于有限的训练量(150+),微调的性价比可能不如 promt 来的好。

3 改进方向

  • 由于初期 promt 粗略,大模型可能提取出多余信息。在后续改良中可以采用分阶段 promt针对特定问题的 promt 方式。
  • 由于评价系统每天只能提交三次,限制了改善的效果评估。因此,需要自写评价方案,便于观察细微改善。
  • 采用微调方式,线上大模型需要排队很久。讯飞大模型提供了 Lite、Pro、Max 三个版本,可以先使用轻量的 Lite 版本进行较快的微调,确定有效后再进行 Pro、Max 的微调。

#这个夏令营不简单 #AI夏令营 #Datawhale #夏令营

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/994746
推荐阅读
相关标签
  

闽ICP备14008679号