当前位置:   article > 正文

HumanEval是如何进行代码评估的:从数据构成、评估逻辑到pass@k指标计算

humaneval

来自:老刘说NLP

快来!NLP论文投稿、LLM交流、论文直播群

HumanEval: Hand-Written Evaluation Set,是工作《Evaluating Large Language Models Trained on Code》(https://arxiv.org/abs/2107.03374)中提到的一个代码评测基准。

最近在做代码方面的评估,走了许多弯路,在评估逻辑上有些误解,重新整理了下,供大家一起参考。尤其是针对pass@k的理解、如何做的单元测试等。

一、HumanEval的数据构成

HumanEval评测数据集,一共包括164条样本,还是很少量的,可以用json进行更为直观的理解,地址https://github.com/abacaj/code-eval/blob/main/human-eval/data/HumanEval.jsonl.gz:

  1. {
  2.     "task_id":"HumanEval/0",
  3.     "prompt":"from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    \"\"\"\n",
  4.     "entry_point":"has_close_elements",
  5.     "canonical_solution":"    for idx, elem in enumerate(numbers):\n        for idx2, elem2 in enumerate(numbers):\n            if idx != idx2:\n                distance = abs(elem - elem2)\n                if distance < threshold:\n                    return True\n\n    return False\n",
  6.     "test":"\n\nMETADATA = {\n    'author': 'jt',\n    'dataset': 'test'\n}\n\n\ndef check(candidate):\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n    assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n    assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n    assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n\n"
  7. }

如下所示:

task_id表示任务的ID,prompt表示题目(通常直接请求大模型获取答案),entry_point是唯一标记,canonica_solution是参考答案,test是测试单元。c810606953e3c6323cd85ca0aaadad33.png

二、HumanEval的评估逻辑

每一个测试问题重复实验n次,然后通过单元测试,计算平均通过率。我们可以在源码地址:https://github.com/abacaj/code-eval/tree/main/human-eval中看到起执行逻辑

1、读取每个样本,请求模型获得结果

如下所示,generate_one_completion为请求大模型生成结果的函数,输入每道题的prompt,然后得到结果。

而由于题目太少,测试结果会有偏,大模型的结果具备多样性(如有top_p, top_k)等,所以,num_samples_per_task用来控制每道题生成多少个结果(代码中设置为200次),从而计算通过率。completion作为补全结果的存储字段。

因此,整体会有32800条样本。

  1. from human_eval.data import write_jsonl, read_problems
  2. problems = read_problems()
  3. num_samples_per_task = 200
  4. samples = [
  5.     dict(task_id=task_id, completion=generate_one_completion(problems[task_id]["prompt"]))
  6.     for task_id in problems
  7.     for _ in range(num_samples_per_task)
  8. ]
  9. write_jsonl("samples.jsonl", samples)

当然,这一块,需要做一个代码的后处理,因为模型会生成其他多余的代码片段,例如https://github.com/abacaj/code-eval/blob/main/core/evaluation.py中所述:

  1. # reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35
  2. def filter_code(completion: str) -> str:
  3.     # The program tends to overwrite, we only take the first function
  4.     completion = completion.lstrip("\n")
  5.     return completion.split("\n\n")[0]

将后处理得到的结果作为最终代码补全结果。

2、获得模型的结果,进行单元测试

这块的逻辑的在于,针对得到的补全结果,通过构造一个完整的测试用例,送入单元测试中进行测试。

其中,如下代码所示:

  1. def check_correctness(problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None) -> 
  2.     def unsafe_execute():
  3.         with create_tempdir():
  4.             # These system calls are needed when cleaning up tempdir.
  5.             import os
  6.             import shutil
  7.             rmtree = shutil.rmtree
  8.             rmdir = os.rmdir
  9.             chdir = os.chdir
  10.             # Disable functionalities that can make destructive changes to the test.
  11.             reliability_guard()
  12.             # Construct the check program and run it.
  13.             print(completion)
  14.             check_program = (
  15.                 problem["prompt"] + completion + "\n" +
  16.                 problem["test"] + "\n" +
  17.                 f"check({problem['entry_point']})"
  18.             )
  19.             try:
  20.                 exec_globals = {}
  21.                 with swallow_io():
  22.                     with time_limit(timeout):
  23.                         exec(check_program, exec_globals)
  24.                 result.append("passed")
  25.             except TimeoutException:
  26.                 result.append("timed out")
  27.             except BaseException as e:
  28.                 result.append(f"failed: {e}")
  29.             # Needed for cleaning up.
  30.             shutil.rmtree = rmtree
  31.             os.rmdir = rmdir
  32.             os.chdir = chdir
  33.     manager = multiprocessing.Manager()
  34.     result = manager.list()
  35.     p = multiprocessing.Process(target=unsafe_execute)
  36.     p.start()
  37.     p.join(timeout=timeout + 1)
  38.     if p.is_alive():
  39.         p.kill()
  40.     if not result:
  41.         result.append("timed out")
  42.     return dict(
  43.         task_id=problem["task_id"],
  44.         passed=result[0] == "passed",
  45.         result=result[0],
  46.         completion_id=completion_id,
  47.     )

里面对于测试样例的构造,是将题目的prompt、模型预测的内容completion、题目的test的按照换行符进行拼接。

  1. # Construct the check program and run it.
  2.   print(completion)
  3.   check_program = (
  4.       problem["prompt"] + completion + "\n" +
  5.       problem["test"] + "\n" +
  6.       f"check({problem['entry_point']})"
  7.   )

然后进行单元测试,直接使用python内置的exec函数进行校验,逻辑在于,给定超时timeout时间,如果测试通过,则标记为passed,如果不是,则不通过【比如说出现代码语法错误】。

  1. try:
  2.       exec_globals = {}
  3.       with swallow_io():
  4.           with time_limit(timeout):
  5.               exec(check_program, exec_globals)
  6.       result.append("passed")
  7.   except TimeoutException:
  8.       result.append("timed out")
  9.   except BaseException as e:
  10.       result.append(f"failed: {e}")

经过这个测试之后,就可以得到每条样本的预测情况。

三、再看代码模型评估中的pass@k指标计算

代码生成模型的主要基准是将样本与参考解进行匹配,匹配可以是精确的,也可以是模糊的(如BLEU分数)。

例如:

EM(Exact Match)是指生成的代码与真实代码具有相同的token序列的百分比;

BLUE机器翻译结果越接近专业人工翻译的结果,则越好,本质在判断两个句子的相似程度,相似度越高得分越高。

CodeBLEU是BLEU变体。在BLEU在n-gram匹配上的基础上,进一步通过抽象语法树(AST)融入代码语法,通过数据流融入代码语义;

但是,基于匹配的代码衡量标准存在缺陷。例如,BLEU在捕捉代码特有的语义特征方面存在问题。

因此,Kulal等人(2019年)使用pass@k指标评估功能正确性,每个问题生成k个代码样本,如果任何样本通过单元测试,则认为问题已解决,并报告总分数。

是一次实验随机性太大,需要多次实验求平均值。pass@k需要对每一个测试问题重复实验t次,并且每次都生成k个代码,最后计算平均通过率。假如重复实验100次来估计pass@100,就需要生成 100*100=10000个代码,这样的计算量是难以接受的。而t越小,估计的pass@k就越不准(方差越大)。

因此,为了评估pass@k,该工作会为每个任务生成n≥k个样本(本文中使用n=200,k≤100),计算通过单元测试的正确样本c≤n的数量,并计算无偏估计值。

6224aa4de2dfa5dbfed975a27fb05efc.png

其中,c是生成的n个代码中通过测试的数量。n越大估计越准确,但计算代价肯定远远小于t*k。

假设模型只能生成这n个代码,而且他们每一种被生成出来的概率是相等的,其中有c个可以通过测试。那么模型任意生成k个代码,全都不能通过测试的概率是:生成k个不能通过测试的代码的情况总和与生成k个代码的全部情况总和之比,即:

91cc7030dacb2f9d033989df39b16768.png

根据大数定理,当样本总量趋近无穷大的时候,样本的平均值无限接近数学期望。因此只要求出其的均值,即得到了对pass@k的无偏估计。

具体代码实现:

  1. def estimate_pass_at_k(
  2.     num_samples: Union[int, List[int], np.ndarray],
  3.     num_correct: Union[List[int], np.ndarray],
  4.     k: int,
  5. ) -> np.ndarray:
  6.     """
  7.     Estimates pass@k of each problem and returns them in an array.
  8.     """
  9.     def estimator(n: int, c: int, k: int) -> float:
  10.         """
  11.         Calculates 1 - comb(n - c, k) / comb(n, k).
  12.         """
  13.         if n - c < k:
  14.             return 1.0
  15.         return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
  16.     if isinstance(num_samples, int):
  17.         num_samples_it = itertools.repeat(num_samples, len(num_correct))
  18.     else:
  19.         assert len(num_samples) == len(num_correct)
  20.         num_samples_it = iter(num_samples)
  21.     return np.array(
  22.         [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]
  23.     )

关于这块,https://zhuanlan.zhihu.com/p/653063532做了公式的推演,感兴趣的可以进一步看看。

最终,即可完成对应的指标,例如官方的脚本运行结果:

  1. $ evaluate_functional_correctness data/example_samples.jsonl --problem_file=data/example_problem.jsonl
  2. Reading samples...
  3. 6it [00:003397.11it/s]
  4. Running example suites...
  5. 100%|...| 6/6 [00:03<00:00,  1.96it/s]
  6. Writing results to data/example_samples.jsonl_results.jsonl...
  7. 100%|...| 6/6 [00:00<00:006148.50it/s]
  8. {'pass@1'0.4999999999999999}

总结

本文主要针对humaneval这一评测任务,从数据、评估逻辑以及pass@k的评估指标计算方式进行了介绍,之前一直对pass@k有误解,认为是预测K次的通过率,读完代码实现本身才有更为准确的理解。

代码评测,也是整个评测体系中十分重要的部分,感兴趣的可关注。

参考文献

1、https://github.com/abacaj/code-eval/blob/main/human-eval/

2、https://arxiv.org/abs/2107.03374

2、https://zhuanlan.zhihu.com/p/653063532

公众号后台回复aaai、acl、naacl直接进投稿群~

回复LLM进入技术交流群~

回复 nice 进入每周论文直播分享群~

8b2cb31287af3e002b85bf48e4c9667a.jpeg

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/940982
推荐阅读
相关标签
  

闽ICP备14008679号