当前位置:   article > 正文

DataWhale AI夏令营3-Task 1

DataWhale AI夏令营3-Task 1

1.baseline代码解析

  1. def api_retry(MODEL_NAME, query):
  2. max_retries = 5
  3. retry_delay = 60 # in seconds
  4. attempts = 0
  5. while attempts < max_retries:
  6. try:
  7. return call_qwen_api(MODEL_NAME, query)
  8. except Exception as e:
  9. attempts += 1
  10. if attempts < max_retries:
  11. logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
  12. time.sleep(retry_delay)
  13. else:
  14. logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
  15. raise
  16. def call_qwen_api(MODEL_NAME, query):
  17. # 这里采用dashscope的api调用模型推理,通过http传输的json封装返回结果
  18. messages = [
  19. {'role': 'user', 'content': query}]
  20. response = dashscope.Generation.call(
  21. MODEL_NAME,
  22. messages=messages,
  23. result_format='message', # set the result is message format.
  24. )
  25. if response.status_code == HTTPStatus.OK:
  26. # print(response)
  27. return response['output']['choices'][0]['message']['content']
  28. else:
  29. print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
  30. response.request_id, response.status_code,
  31. response.code, response.message
  32. ))
  33. raise Exception()
  • `api_retry(MODEL_NAME, query)`:重试API调用,如果失败则记录日志并在达到最大尝试次数时引发异常。
  • 特点 重试机制:函数使用了 @retry(delay=60, tries=5) 装饰器,这意味着在调用 API 失败时,函数会自动重试最多 5 次,每次重试间隔 60 秒。 消息格式:在调用 API 时,设置了 result_format=‘message’,表示期望的响应格式是消息格式。 错误处理:在 API 调用失败时,函数会打印详细的错误信息,包括请求 ID、状态码、错误代码和错误消息,并抛出异常,以便上层调用者能够捕获并处理这些错误。
  • `call_qwen_api(MODEL_NAME, query)`:调用 dashscope API 生成文本,适用于需要动态生成内容的场景,如聊天机器人、内容创作辅助等。该函数通过传递一个模型名称 (MODEL_NAME) 和一个查询 (query) 来调用 dashscope.Generation.call 方法,生成相应的文本。 处理 API 响应:函数会检查 API 的响应状态码,如果状态码为 HTTPStatus.OK,则提取并返回生成的文本内容。如果状态码不是 HTTPStatus.OK,则打印错误信息并抛出异常。
  1. # 这里定义了prompt推理模版
  2. def get_prompt(problem, question, options):
  3. options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))
  4. prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
  5. ### 题目:
  6. {problem}
  7. ### 问题:
  8. {question}
  9. {options}
  10. """
  11. # print(prompt)
  12. return prompt
  13. # 这里使用extract抽取模板获得抽取的结果
  14. def extract(input_text):
  15. ans_pattern = re.compile(r"答案是:(.)", re.S)
  16. problems = ans_pattern.findall(input_text)
  17. # print(problems)
  18. if(problems == ''):
  19. return 'A'
  20. return problems[0]
  21. def process_datas(datas,MODEL_NAME):
  22. results = []
  23. with ThreadPoolExecutor(max_workers=16) as executor:
  24. future_data = {}
  25. lasttask = ''
  26. lastmark = 0
  27. lens = 0
  28. for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
  29. problem = data['problem']
  30. for id,question in enumerate(data['questions']):
  31. prompt = get_prompt(problem,
  32. question['question'],
  33. question['options'],
  34. )
  35. future = executor.submit(api_retry, MODEL_NAME, prompt)
  36. future_data[future] = (data,id)
  37. time.sleep(0.6) # 控制每0.5秒提交一个任务
  38. lens += 1
  39. for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
  40. # print('data',data)
  41. data = future_data[future][0]
  42. problem_id = future_data[future][1]
  43. try:
  44. res = future.result()
  45. extract_response = extract(res)
  46. # print('res',extract_response)
  47. data['questions'][problem_id]['answer'] = extract_response
  48. results.append(data)
  49. # print('data',data)
  50. except Exception as e:
  51. logger.error(f"Failed to process text: {data}. Error: {e}")
  52. return results
  53. def main(ifn, ofn):
  54. if os.path.exists(ofn):
  55. pass
  56. data = []
  57. # 按行读取数据
  58. with open(ifn) as reader:
  59. for line in reader:
  60. sample = json.loads(line)
  61. data.append(sample)
  62. datas = data
  63. # print(data)
  64. # 均匀地分成多个数据集
  65. return_list = process_datas(datas,MODEL_NAME)
  66. print(len(return_list))
  67. print("All tasks finished!")
  68. return return_list
  69. def evaluate(ofn):
  70. data = []
  71. with open(ofn) as reader:
  72. for line in reader:
  73. sample = json.loads(line)
  74. data.append(sample)
  75. pse = 0
  76. cnt = 0
  77. tot = 0
  78. for task in data:
  79. for question in task['questions']:
  80. if MODEL_NAME in question:
  81. tot += 1
  82. cnt += question[MODEL_NAME] == question['answer']
  83. else:
  84. pse += 1
  85. print(cnt, tot, cnt/tot, pse)
  • `get_prompt(problem, question, options)`:生成逻辑推理题目的提示文本。该函数通过接受问题描述、问题和选项作为输入,然后将它们格式化为一段提示文本。
  • `extract(input_text)`:从文本中提取答案。该函数使用正则表达式来匹配文本中的答案,并返回匹配到的答案。
  • `process_datas(datas, MODEL_NAME)`:处理数据集的核心函数。该函数遍历数据集中的问题,为每个问题生成提示文本,并使用`api_retry`函数调用API进行推理,然后提取答案。使用线程池(ThreadPoolExecutor)来并行处理数据集中的问题。
  • `main(ifn, ofn)`:主函数,用于读取数据、处理数据并将处理后的结果保存到输出文件。该函数通过调用`process_datas`函数处理数据,并将处理后的结果返回。
  • `evaluate(ofn)`:评估函数,用于评估模型在数据集上的表现。该函数读取处理后的数据集,计算模型预测答案的准确率。
  1. if __name__ == '__main__':
  2. a = extract("""根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下:
  3. 1. 判断6和7是否相等:不相等。
  4. 2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。
  5. 3. 现在计算6和1的最大公约数。
  6. 4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。
  7. 5. 再计算5和1的最大公约数。
  8. 6. 5 > 1,用5减去1得到结果4。
  9. 7. 再计算4和1的最大公约数。
  10. 8. 4 > 1,用4减去1得到结果3。
  11. 9. 再计算3和1的最大公约数。
  12. 10. 3 > 1,用3减去1得到结果2。
  13. 11. 再计算2和1的最大公约数。
  14. 12. 2 > 1,用2减去1得到结果1。
  15. 13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。
  16. 因此,6和7的最大公约数是1。
  17. 答案是:C.""")
  18. print(a)
  19. return_list = main('round1_test_data.jsonl', 'upload.jsonl')
  20. def has_complete_answer(questions):
  21. # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键
  22. for question in questions:
  23. if 'answer' not in question:
  24. return False
  25. return True
  26. def filter_problems(data):
  27. result = []
  28. problem_set = set()
  29. for item in data:
  30. # print('处理的item' ,item)
  31. problem = item['problem']
  32. if problem in problem_set:
  33. # 找到已存在的字典
  34. for existing_item in result:
  35. if existing_item['problem'] == problem:
  36. # 如果当前字典有完整答案,替换已存在的字典
  37. if has_complete_answer(item['questions']):
  38. existing_item['questions'] = item['questions']
  39. existing_item['id'] = item['id']
  40. break
  41. else:
  42. # 如果当前字典有完整答案,添加到结果列表
  43. if has_complete_answer(item['questions']):
  44. result.append(item)
  45. problem_set.add(problem)
  46. return result
  47. return_list
  48. return_list = filter_problems(return_list)
  49. sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:]))
  50. print(sorted_data)
  51. def find_missing_ids(dict_list):
  52. # 提取所有序号
  53. extracted_ids = {int(d['id'][-3:]) for d in dict_list}
  54. # 创建0-500的序号集合
  55. all_ids = set(range(500))
  56. # 找出缺失的序号
  57. missing_ids = all_ids - extracted_ids
  58. return sorted(missing_ids)
  59. # 示例字典列表
  60. dict_list = sorted_data
  61. # 找出缺失的序号
  62. missing_ids = find_missing_ids(dict_list)
  63. print("缺失的序号:", missing_ids)
  64. data = []
  65. with open('round1_test_data.jsonl') as reader:
  66. for id,line in enumerate(reader):
  67. if(id in missing_ids):
  68. sample = json.loads(line)
  69. for question in sample['questions']:
  70. question['answer'] = 'A'
  71. sorted_data.append(sample)
  72. sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))
  73. with open('upload.jsonl', 'w') as writer:
  74. for sample in sorted_data:
  75. writer.write(json.dumps(sample, ensure_ascii=False))
  76. writer.write('\n')
  • `has_complete_answer(questions)`:检查问题列表中的每个问题是否有完整的答案。如果有任何一个问题缺少答案,则返回`False`,否则返回`True`。
  • `filter_problems(data)`:过滤数据,确保每个问题集合中的问题都有完整的答案。该函数遍历数据集,对于相同问题的数据,保留具有完整答案的数据,并将其添加到结果列表中。
  •  `find_missing_ids(dict_list)`:找出缺失的数据序号。该函数提取给定字典列表中已存在的序号,然后生成0到500的序号集合,找出缺失的序号并返回。
  • 主程序:使用`filter_problems`函数对返回的数据进行过滤,确保每个问题集合中都有完整的答案。找出缺失的序号并打印输出。从原始数据中添加缺失的问题,并将整理后的数据按照序号排序。将整理后的数据写入到输出文件中。

2.baseline流程

  1. 首先数据集中的问题被提取并转换为提示文本,然后通过API进行推理,提取答案并保存结果,最后评估模型的表现。
  2. 对处理后的数据进行进一步处理,确保数据完整性,并处理可能缺失的数据,最终将整理后的数据写入新的文件中。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/963647
推荐阅读
相关标签
  

闽ICP备14008679号