当前位置:   article > 正文

使用Openai的api进行数据扩充_openai 自定义数据

openai 自定义数据

使用背景

kaggle举办了一场单项选择题比赛(LLM),希望参赛选手使用更少参数量的模型来达到一个媲美GPT的效果,但是提供的训练数据有限,于是考虑调用GPT的结构来自己扩充数据
比赛链接:

Kaggle - LLM Science Exam
Use LLMs to answer difficult science questions(比赛链接)

在这里插入图片描述

写在前面

  • 国内受到特殊原因的影响,正常情况下无法调用API
  • 本代码举例提倡使用kaggle平台运行

kaggle api添加方法

在这里插入图片描述
在这里插入图片描述

api的收费标准(官方购买)

在这里插入图片描述
在这里插入图片描述

代码

library
!pip install openai langchain tiktoken -q
import os
import pandas as pd
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken
from tqdm.auto import tqdm

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
添加api
  • kaggle
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["OPENAI_API_KEY"] = user_secrets.get_secret("openai")
  • 1
  • 2
  • 3
  • 普通平台
# 将你的 API 密钥作为一个字符串赋值给变量
api_key = "sk-***********************PiAFqUGsGs8q8YTXU99U"

# 设置环境变量
os.environ["OPENAI_API_KEY"] = api_key

# 然后你可以使用 API 密钥进行 API 调用

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

def num_tokens_from_string(string: str, encoding_name: str="gpt-3.5-turbo") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens
def truncate_text_by_token(string: str, limit: int=3000, encoding_name: str="gpt-3.5-turbo") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.encoding_for_model(encoding_name)
    tokens = encoding.encode(string)[:limit]
    text = encoding.decode(tokens)
    return text

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

df = pd.read_parquet('/kaggle/input/wikipedia-stem-plaintext/parsed.parquet')
df['id'] = range(len(df))
df = df.groupby(['title']).sample(1,random_state=100).sample(3000, random_state=100).reset_index(drop=True)
df
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述


doc_ids = df.id.tolist()
texts = df.text.tolist()
response_keys_set = set(("prompt", "A", "B", "C", "D", "E", "answer"))
  • 1
  • 2
  • 3
  • 4

template_string = """
You will be provided with TEXT from wikipedia. \
The TEXT will be delimited with ### characters.
### TEXT begin
{text}
### TEXT end
Output a python list of {q_num} dict objects, where each object is \
a multiple choice question whom answers should be in \
the given TEXT and that has 5 choices each and has the following format:
    'prompt': <question on the TEXT>
    'A': <question answer option>
    'B': <question answer option>
    'C': <question answer option>
    'D': <question answer option>
    'E': <question answer option>
    'answer': <answer option key label>

You should tell me which one of your proposed options is right \
by assigning the corresponding option's key label in the 'answer' field.

The question and the answer options should be challenging, \
more about statements, description.

The answer do not require the given TEXT and should be LONGER enough!
"""

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

规定了必须是单选题

在这里插入图片描述

# output token max to 500模型
chat = ChatOpenAI(temperature=0.75, request_timeout=600, max_tokens=1200, max_retries=1)
#提示词
prompt_template = ChatPromptTemplate.from_template(template_string)

chain = LLMChain(llm=chat, prompt=prompt_template)
debug=True
if debug:
    texts = texts[:20] 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  1. chat = ChatOpenAI(temperature=0.75, request_timeout=600, max_tokens=1200, max_retries=1):

  2. temperature(温度):

    • 范围: 0 到 1 之间的浮点数。
    • 含义: 控制模型生成文本的随机性。较高的温度会使输出更加随机,而较低的温度会使输出更加确定和保守。
  3. request_timeout(请求超时):

    • 单位: 秒 (s)。
    • 含义: 指定向API发送请求后等待响应的最大时间。如果在此时间内未收到响应,请求将被视为超时。
  4. max_tokens(最大标记数):

    • 含义: 控制模型生成的最大标记数(tokens)。标记是文本中的单词、字符或子单元,例如一个汉字或一个英文单词。
  5. max_retries(最大重试次数):

    • 含义: 在遇到网络错误或其他临时问题时,允许重新尝试发送请求的最大次数。
  • 温度设置为 0.75,使得生成的文本具有一定的随机性。
  • 请求超时设定为 600 秒,即 10 分钟,表示等待API响应的最大时间。
  • 最大标记数设定为 1200,允许模型生成的文本长度最长为1200个标记。
  • 最大重试次数设定为 1,即在发生网络错误或其他临时问题时,允许尝试重新发送请求一次。
  1. prompt_template = ChatPromptTemplate.from_template(template_string):

    • 创建了一个对话提示模板的实例 prompt_template,这里 template_string 是一个用于引导对话的字符串模板。这个模板可能包含特定的对话角色、情境或问题,以便启动模型生成相关的对话内容。
  2. chain = LLMChain(llm=chat, prompt=prompt_template):
    • 创建了一个基于语言模型的对话链 chain
    • llm=chat 将之前定义的ChatGPT模型 chat 与这个对话链关联起来。
    • prompt=prompt_template 则将之前定义的对话提示模板 prompt_template 与这个对话链关联起来,以便在对话过程中提供引导。

综上,这段代码的作用是创建了一个可以基于指定提示模板进行对话生成的机制,并且可以控制生成结果的随机性、超时时间等参数。


# split each text to 3000 token chunks, each chunk generate 1 question
# change this if for your vary len texts
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 3000,
    chunk_overlap  = 0,
    length_function = num_tokens_from_string)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • text_splitter = RecursiveCharacterTextSplitter(chunk_size = 3000, chunk_overlap = 0, length_function = num_tokens_from_string):
    • chunk_size = 3000: 指定了每个文本块的最大长度为3000个标记。这个值可以根据实际需要进行调整。
    • chunk_overlap = 0: 指定了文本块之间的重叠部分的长度。在这里设置为0,表示不允许文本块之间有重叠部分。
    • length_function = num_tokens_from_string: 这是一个用于计算文本长度的函数。num_tokens_from_string 会返回给定文本的标记数。

综上,这段代码的作用是将较长的文本按照指定的规则切割成长度不超过3000个标记的文本块,并为每个文本块生成一个问题。如果你的文本长度发生变化,你可能需要相应地调整这些切割规则。

def is_valid_question(r):
    return (response_keys_set == set(list(r.keys()))) and all([isinstance(v, str) and len(v)>0 for k,v in r.items()]) and (r['answer'] in ["A", "B", "C", "D", "E"])
def select_correct_formatted_questions(r):
    res_q = []
    for rr in r:
        if is_valid_question(rr):
            res_q.append(rr)
    return res_q

# 文本切割器(text_splitter)、每段文本生成的问题数量(q_num,默认为1)、最大尝试次数(max_attempts,默认为1)作为输入参数。

def gather_multiple_choice_question_dataset(doc_ids, texts, text_splitter, q_num=1, max_attempts=1):
    # 初始化一个空列表,用于存储多项选择题。
    multiple_choice_questions = []
    
    # 遍历每个文档的ID和对应的文本内容。
    for doc_id, text in tqdm(zip(doc_ids, texts), total=len(texts)):
        k = 0  # 用于控制每个文本生成的问题数量,初始化为0。
        
        # 将文本使用文本切割器进行切割,得到文本块。
        for text_chunk in text_splitter.create_documents([text]):
            
            # 如果已经生成了一个问题,就跳出内层循环,处理下一个文本块。
            if k >= 1:
                break
            
            attempts_cnt = 0  # 用于记录尝试次数,初始化为0。
            
            # 不断尝试生成问题,直到达到最大尝试次数。
            while True:
                try:
                    # 使用预定义的模型链(chain)运行,传入文本块和生成的问题数量。
                    response = chain.run(text=text_chunk, q_num=q_num)
                    mcq = eval(response)  # 将模型返回的字符串结果解析为Python对象。
                    
                    # 检查生成的问题是否符合要求。
                    if not isinstance(mcq, list) or len(mcq) == 0:
                        raise Exception
                    
                    # 从生成的问题中选择符合格式要求的问题。
                    output = select_correct_formatted_questions(mcq)
                    
                    # 如果没有符合格式要求的问题,则抛出异常。
                    if len(output) == 0:
                        raise Exception
                    
                    # 为每个问题记录其所属文档的ID,并将其加入到多项选择题列表中。
                    for d in output:
                        d['doc_id'] = doc_id
                    multiple_choice_questions.extend(output)
                    
                    k += 1  # 增加已生成问题的计数。
                    break  # 成功生成问题,跳出内层循环,处理下一个文本块。
                except Exception:
                    attempts_cnt += 1  # 尝试次数加1。
                    
                    # 如果达到最大尝试次数,则跳出内层循环,处理下一个文本块。
                    if attempts_cnt >= max_attempts:
                        break
                        
    return multiple_choice_questions  # 返回生成的多项选择题列表。


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

multi_choice_questions = gather_multiple_choice_question_dataset(doc_ids=doc_ids,
                                                                 texts=texts, 
                                                                 text_splitter=text_splitter, 
                                                                 q_num=1,
                                                                 max_attempts=1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/852091
推荐阅读
相关标签
  

闽ICP备14008679号