当前位置:   article > 正文

学术论文GPT的源码解读与二次开发:从ChatPaper到gpt_academic_看论文的gpt

看论文的gpt

前言

本文的前两个部分最早是属于此旧文的《学术论文GPT的源码解读与微调:从ChatPaper到七月论文审稿GPT第1版》,但为了每一篇文章各自的内容更好的呈现,于是我今天做了以下三个改动

  1. 原来属于mamba第五部分的「Mamba近似工作之线性Transformer:从TransnormerLLM到RWKV」,改放到此文中:学术论文GPT的源码解读与微调:从ChatPaper到七月论文审稿GPT第1版
  2. 把旧文「学术论文GPT的源码解读与微调」中关于chatpaper相关的部分独立抽取出来成本文:学术论文GPT的源码解读与二次开发:从ChatPaper到gpt_academic
  3. 故旧文「学术论文GPT的源码解读与微调」的标题就改成了:七月论文审稿GPT第1版:通过3万多篇paper和10多万的review数据微调RWKV

如此,mamba那篇解读可以专注mamba的解读,不把过多篇幅放在mamba之外的RWKV上,且原来论文审稿第一版本身微调的RWKV,故刚好需要介绍下RWKV

且对于学术论文GPT的源码解读与微调本来就还得解读下gpt_academic,故把ChatPaper和gpt_academic这两个开源系统独立成本文,也更好

本文


第一部分 ChatPaper:论文对话、总结、翻译

ChatPaper的自身定位是全流程加速科研:论文总结+专业级翻译+润色+审稿+审稿回复,因为论文更多是PDF的格式,故针对PDF的对话、总结、翻译,便不可避免的涉及到PDF的解析

1.1 论文审稿:ChatPaper/ChatReviewerAndResponse

1.1.1 对PDF的解析:ChatReviewerAndResponse/get_paper.py

// 待更

1.1.2 论文审查:ChatReviewerAndResponse/chat_reviewer.py

使用OpenAI的GPT模型进行论文审查的脚本。它首先定义了一个Reviewer类来处理审查工作,然后在if __name__ == '__main__':语句下使用argparse处理命令行参数,并调用chat_reviewer_main函数来开始审查过程

  • 导入模块:比如jieba、tenacity等
  • 命名元组定义:用于保存与论文审稿相关的参数
    1. ReviewerParams = namedtuple(
    2. "ReviewerParams",
    3. [
    4. "paper_path",
    5. "file_format",
    6. "research_fields",
    7. "language"
    8. ],
    9. )
  • 判断文本中是否包含中文:
    1. def contains_chinese(text):
    2. for ch in text:
    3. if u'\u4e00' <= ch <= u'\u9fff':
    4. return True
    5. return False
  • 插入句子到文本
    主要功能是在给定文本的每隔一定数量的单词或中文字符后插入一个指定的句子。如果文本行包含中文字符,则使用jieba分词工具来切分中文,否则使用空格来切分:
    1. def insert_sentence(text, sentence, interval):
    2. # 将输入文本按换行符分割成行
    3. lines = text.split('\n')
    4. # 初始化一个新的行列表
    5. new_lines = []
    6. # 遍历每一行
    7. for line in lines:
    8. # 检查行中是否包含中文字符
    9. if contains_chinese(line):
    10. # 如果是中文,使用jieba分词工具进行分词
    11. words = list(jieba.cut(line))
    12. # 定义分隔符为空字符(对于中文分词)
    13. separator = ''
    14. else:
    15. # 如果不包含中文,按空格分割行
    16. words = line.split()
    17. # 定义分隔符为空格(对于英文或其他非中文语言)
    18. separator = ' '
    19. # 初始化一个新的单词列表
    20. new_words = []
    21. # 初始化一个计数器
    22. count = 0
    23. # 遍历当前行的每一个单词
    24. for word in words:
    25. # 将当前单词添加到新的单词列表
    26. new_words.append(word)
    27. # 计数器增加
    28. count += 1
    29. # 检查是否达到了插入句子的间隔
    30. if count % interval == 0:
    31. # 在达到指定间隔时,将要插入的句子添加到新的单词列表
    32. new_words.append(sentence)
    33. # 将新的单词列表连接起来,并添加到新的行列表
    34. new_lines.append(separator.join(new_words))
    35. # 将新的行列表连接起来,返回结果
    36. return '\n'.join(new_lines)
  • 论文审稿类:定义了一个Reviewer类,包含以下功能:
    \rightarrow  第一阶段审稿:先是基于论文标题和摘要,选择要审稿的部分
    1. # 定义Reviewer类
    2. class Reviewer:
    3. # 初始化方法,设置属性
    4. def __init__(self, args=None):
    5. if args.language == 'en':
    6. self.language = 'English'
    7. elif args.language == 'zh':
    8. self.language = 'Chinese'
    9. else:
    10. self.language = 'Chinese'
    11. # 创建一个ConfigParser对象
    12. self.config = configparser.ConfigParser()
    13. # 读取配置文件
    14. self.config.read('apikey.ini')
    15. # 获取某个键对应的值
    16. self.chat_api_list = self.config.get('OpenAI', 'OPENAI_API_KEYS')[1:-1].replace('\'', '').split(',')
    17. self.chat_api_list = [api.strip() for api in self.chat_api_list if len(api) > 5]
    18. self.cur_api = 0
    19. self.file_format = args.file_format
    20. self.max_token_num = 4096
    21. self.encoding = tiktoken.get_encoding("gpt2")
    22. def validateTitle(self, title):
    23. # 修正论文的路径格式
    24. rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
    25. new_title = re.sub(rstr, "_", title) # 替换为下划线
    26. return new_title
    然后分别实现两个函数
    一个stage_1,主要功能是为了与GPT-3模型进行对话,获取模型对于文章的两个最关键部分的选择意见
    1. def stage_1(self, paper):
    2. # 初始化一个空列表,用于存储生成的HTML内容
    3. htmls = []
    4. # 初始化一个空字符串,用于存储文章的标题和摘要
    5. text = ''
    6. # 添加文章的标题
    7. text += 'Title: ' + paper.title + '. '
    8. # 添加文章的摘要
    9. text += 'Abstract: ' + paper.section_texts['Abstract']
    10. # 计算文本的token数量
    11. text_token = len(self.encoding.encode(text))
    12. # 判断token数量是否超过最大token限制的一半减去800
    13. if text_token > self.max_token_num/2 - 800:
    14. input_text_index = int(len(text)*((self.max_token_num/2)-800)/text_token)
    15. # 如果超出,则截取文本以满足长度要求
    16. text = text[:input_text_index]
    17. # 设置OpenAI API的密钥
    18. openai.api_key = self.chat_api_list[self.cur_api]
    19. # 更新当前使用的API索引
    20. self.cur_api += 1
    21. # 如果当前API索引超过API列表的长度,则重置为0
    22. self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api
    23. # 创建与GPT-3的对话消息
    24. messages = [
    25. {"role": "system",
    26. "content": f"You are a professional reviewer in the field of {args.research_fields}. "
    27. f"I will give you a paper. You need to review this paper and discuss the novelty and originality of ideas, correctness, clarity, the significance of results, potential impact and quality of the presentation. "
    28. f"Due to the length limitations, I am only allowed to provide you the abstract, introduction, conclusion and at most two sections of this paper."
    29. f"Now I will give you the title and abstract and the headings of potential sections. "
    30. f"You need to reply at most two headings. Then I will further provide you the full information, includes aforementioned sections and at most two sections you called for.\n\n"
    31. f"Title: {paper.title}\n\n"
    32. f"Abstract: {paper.section_texts['Abstract']}\n\n"
    33. f"Potential Sections: {paper.section_names[2:-1]}\n\n"
    34. f"Follow the following format to output your choice of sections:"
    35. f"{{chosen section 1}}, {{chosen section 2}}\n\n"},
    36. {"role": "user", "content": text},
    37. ]
    38. # 调用OpenAI API与GPT-3进行对话
    39. response = openai.ChatCompletion.create(
    40. model="gpt-3.5-turbo",
    41. messages=messages,
    42. )
    43. # 初始化一个空字符串,用于存储模型的回复
    44. result = ''
    45. # 遍历模型的回复,将其添加到结果字符串中
    46. for choice in response.choices:
    47. result += choice.message.content
    48. # 打印模型的回复
    49. print(result)
    50. # 返回模型的回复,将其分割为多个部分
    51. return result.split(',')
    一个chat_review,主要功能是调用GPT-3模型进行论文审稿,对输入的文章文本进行审查,并按照预定格式生成审稿意见
    1. def chat_review(self, text):
    2. # 设置OpenAI API的密钥
    3. openai.api_key = self.chat_api_list[self.cur_api]
    4. # 更新当前使用的API密钥索引
    5. self.cur_api += 1
    6. # 如果当前API密钥索引超过API密钥列表的长度,则将其重置为0
    7. self.cur_api = 0 if self.cur_api >= len(self.chat_api_list)-1 else self.cur_api
    8. # 定义用于审稿提示的token数量
    9. review_prompt_token = 1000
    10. # 计算输入文本的token数量
    11. text_token = len(self.encoding.encode(text))
    12. # 计算输入文本的截取位置
    13. input_text_index = int(len(text)*(self.max_token_num-review_prompt_token)/text_token)
    14. # 截取文本并添加前缀
    15. input_text = "This is the paper for your review:" + text[:input_text_index]
    16. # 从'ReviewFormat.txt'文件中读取审稿格式
    17. with open('ReviewFormat.txt', 'r') as file:
    18. review_format = file.read()
    19. # 创建与GPT-3的对话消息
    20. messages=[
    21. {"role": "system",
    22. "content": "You are a professional reviewer in the field of "+args.research_fields+". Now I will give you a paper. You need to give a complete review opinion according to the following requirements and format:"+ review_format +" Please answer in {}.".format(self.language)},
    23. {"role": "user", "content": input_text},
    24. ]
    25. # 调用OpenAI API与GPT-3进行对话
    26. response = openai.ChatCompletion.create(
    27. model="gpt-3.5-turbo",
    28. messages=messages,
    29. )
    30. # 初始化一个空字符串,用于存储模型的回复
    31. result = ''
    32. # 遍历模型的回复,将其添加到结果字符串中
    33. for choice in response.choices:
    34. result += choice.message.content
    35. # 在结果中插入特定的句子,警告不允许复制
    36. result = insert_sentence(result, '**Generated by ChatGPT, no copying allowed!**', 15)
    37. # 追加伦理声明
    38. result += "\n\n⚠伦理声明/Ethics statement:\n--禁止直接复制生成的评论用于任何论文审稿工作!\n--Direct copying of generated comments for any paper review work is prohibited!"
    39. # 打印分隔符和结果
    40. print("********"*10)
    41. print(result)
    42. print("********"*10)
    43. # 打印相关的token使用信息和响应时间
    44. print("prompt_token_used:", response.usage.prompt_tokens)
    45. print("completion_token_used:", response.usage.completion_tokens)
    46. print("total_token_used:", response.usage.total_tokens)
    47. print("response_time:", response.response_ms/1000.0, 's')
    48. # 返回模型生成的审稿意见
    49. return result
    \rightarrow  使用ChatGPT进行审稿,且有tenacity重试机制和更多的功能,其中review_by_chatgpt 调用了上面所示的两个函数,一个stage_1,一个chat_review
    1. def review_by_chatgpt(self, paper_list):
    2. # 创建一个空列表用于存储每篇文章审稿后的HTML格式内容
    3. htmls = []
    4. # 遍历paper_list中的每一篇文章
    5. for paper_index, paper in enumerate(paper_list):
    6. # 使用第一阶段审稿方法选择文章的关键部分
    7. sections_of_interest = self.stage_1(paper)
    8. # 初始化一个空字符串用于提取文章的主要部分
    9. text = ''
    10. # 添加文章的标题
    11. text += 'Title:' + paper.title + '. '
    12. # 添加文章的摘要
    13. text += 'Abstract: ' + paper.section_texts['Abstract']
    14. # 查找并添加“Introduction”部分
    15. intro_title = next((item for item in paper.section_names if 'ntroduction' in item.lower()), None)
    16. if intro_title is not None:
    17. text += 'Introduction: ' + paper.section_texts[intro_title]
    18. # 同样地,查找并添加“Conclusion”部分
    19. conclusion_title = next((item for item in paper.section_names if 'onclusion' in item), None)
    20. if conclusion_title is not None:
    21. text += 'Conclusion: ' + paper.section_texts[conclusion_title]
    22. # 遍历sections_of_interest,添加其他感兴趣的部分
    23. for heading in sections_of_interest:
    24. if heading in paper.section_names:
    25. text += heading + ': ' + paper.section_texts[heading]
    26. # 使用ChatGPT进行审稿,并得到审稿内容
    27. chat_review_text = self.chat_review(text=text)
    28. # 将审稿的文章编号和内容添加到htmls列表中
    29. htmls.append('## Paper:' + str(paper_index+1))
    30. htmls.append('\n\n\n')
    31. htmls.append(chat_review_text)
    32. # 获取当前日期和时间,并转换为字符串格式
    33. date_str = str(datetime.datetime.now())[:13].replace(' ', '-')
    34. try:
    35. # 创建输出文件夹
    36. export_path = os.path.join('./', 'output_file')
    37. os.makedirs(export_path)
    38. except:
    39. # 如果文件夹已存在,则不执行任何操作
    40. pass
    41. # 如果是第一篇文章,则写模式为'w',否则为'a'
    42. mode = 'w' if paper_index == 0 else 'a'
    43. # 根据文章标题和日期生成文件名
    44. file_name = os.path.join(export_path, date_str+'-'+self.validateTitle(paper.title)+"."+self.file_format)
    45. # 将审稿内容导出为Markdown格式并保存
    46. self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode)
    47. # 清空htmls列表,为下一篇文章做准备
    48. htmls = []
  • 主程序部分:
    定义了一个chat_reviewer_main 函数,该函数创建了一个Reviewer对象,并对指定路径中的PDF文件进行审稿
    1. def chat_reviewer_main(args):
    2. reviewer1 = Reviewer(args=args)
    3. # 开始判断是路径还是文件:
    4. paper_list = []
    5. if args.paper_path.endswith(".pdf"):
    6. paper_list.append(Paper(path=args.paper_path))
    7. else:
    8. for root, dirs, files in os.walk(args.paper_path):
    9. print("root:", root, "dirs:", dirs, 'files:', files) #当前目录路径
    10. for filename in files:
    11. # 如果找到PDF文件,则将其复制到目标文件夹中
    12. if filename.endswith(".pdf"):
    13. paper_list.append(Paper(path=os.path.join(root, filename)))
    14. print("------------------paper_num: {}------------------".format(len(paper_list)))
    15. [print(paper_index, paper_name.path.split('\\')[-1]) for paper_index, paper_name in enumerate(paper_list)]
    16. reviewer1.review_by_chatgpt(paper_list=paper_list)
    主程序中定义了命令行参数解析,并调用了chat_reviewer_main 函数
    在主程序中增加了审稿时间的计算功能
    1. if __name__ == '__main__':
    2. parser = argparse.ArgumentParser()
    3. parser.add_argument("--paper_path", type=str, default='', help="path of papers")
    4. parser.add_argument("--file_format", type=str, default='txt', help="output file format")
    5. parser.add_argument("--research_fields", type=str, default='computer science, artificial intelligence and reinforcement learning', help="the research fields of paper")
    6. parser.add_argument("--language", type=str, default='en', help="output lauguage, en or zh")
    7. reviewer_args = ReviewerParams(**vars(parser.parse_args()))
    8. start_time = time.time()
    9. chat_reviewer_main(args=reviewer_args)
    10. print("review time:", time.time() - start_time)

当然,这个项目的论文审稿部分更多是用的ChatGPT的API审稿,我司在API的基础上进一步做了微调的工作,比如如何通过论文审阅语料微调出一个论文审稿GPT(甚至通过10万量级的paper+review语料微调/训练),详见本文的第三部分或我司的「大模型项目开发线下营

1.2 PDF解析:ChatPaper/scipdf_parser-master/

通过这个项目文件:ChatPaper/scipdf_parser-master/scipdf/pdf/parse_pdf.py可以看到以下内容

1.2.1 必要的库、常量、PDF路径

  • 导入必要的库
    re: 正则表达式库,用于匹配和处理字符串
    os 和 os.path: 操作文件和路径的库
    glob: 搜索文件的库
    urllib: 用于处理和获取 URL
    subprocess: 执行外部命令和程序的库
    requests: 用于发送 HTTP 请求的库
    BeautifulSoup 和 NavigableString: 从 bs4 导入,用于解析和操作 XML/HTML 内容
    tqdm 和 tqdm_notebook: 提供进度条功能
  • 定义常量
    GROBID_URL: GROBID 是一个开源软件,可以从 PDF 文件中提取和解析学术出版物的结构化信息
    PDF_FIGURES_JAR_PATH: 这是指向某个 jar 文件的路径,但这段代码中并没有用到这个常量
  • 函数 list_pdf_paths: 返回给定文件夹中所有 PDF 文件的路径
  • 函数 validate_url: 通过正则表达式验证给定的路径是否为有效的 URL
    1. def validate_url(path: str):
    2. """
    3. 验证给定的``path``是否为URL
    4. """
    5. # 定义正则表达式以匹配URL
    6. # 下面的正则表达式主要匹配了以下几部分:
    7. # 1. http:// 或 https:// 开头
    8. # 2. 域名 (例如:example.com)
    9. # 3. localhost (本地主机)
    10. # 4. IP地址 (例如:192.168.1.1)
    11. # 5. 可选的端口号 (例如::80)
    12. # 6. 路径或者查询字符串
    13. regex = re.compile(
    14. r"^(?:http|ftp)s?://" # http:// or https:// 开头
    15. # 域名部分
    16. r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|"
    17. r"localhost|" # localhost 部分
    18. r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP地址部分
    19. r"(?::\d+)?" # 可选的端口号部分
    20. r"(?:/?|[/?]\S+)$", # 路径或查询字符串部分
    21. re.IGNORECASE, # 忽略大小写
    22. )
    23. # 使用上述正则表达式匹配给定的path,如果匹配成功则返回True,否则返回False
    24. return re.match(regex, path) is not None

1.2.2 parse_pdf:对PDF的解析

这是代码中的核心功能,用 GROBID 服务从 PDF 文档中解析 XML 或 BeautifulSoup 格式的信息
如果 fulltext 参数为 True,则解析整篇文章;否则,只解析标题
可以从本地或云端的 GROBID 服务中获取数据

  1. def parse_pdf(
  2. pdf_path: str,
  3. fulltext: bool = True,
  4. soup: bool = False,
  5. return_coordinates: bool = True,
  6. grobid_url: str = GROBID_URL,
  7. ):
  8. """
  9. 使用GROBID工具将PDF解析为XML或BeautifulSoup
  10. 可以查看http://grobid.readthedocs.io/en/latest/Install-Grobid/了解如何本地运行GROBID
  11. 加载GROBID zip文件后,可以使用以下方法运行GROBID
  12. >> ./gradlew run
  13. 参数
  14. ==========
  15. pdf_path: str 或 bytes,出版物、文章的路径、URL或PDF的字节字符串
  16. fulltext: bool, 解析选项,如果为True,解析文章的全部文本
  17. 如果为False,只解析头部
  18. grobid_url: str, GROBID解析器的url,默认为'http://localhost:8070'
  19. 可以更改为"https://cloud.science-miner.com/grobid/"使用云服务
  20. soup: bool, 如果为True,返回文章的BeautifulSoup
  21. 输出
  22. ======
  23. parsed_article: 如果soup为False,则返回文本格式的解析后的XML,
  24. 否则返回XML的BeautifulSoup
  25. 示例
  26. =======
  27. >> parsed_article = parse_pdf(pdf_path, fulltext=True, soup=True)
  28. """
  29. # GROBID的URL
  30. if fulltext:
  31. url = "%s/api/processFulltextDocument" % grobid_url # 完整文本处理URL
  32. else:
  33. url = "%s/api/processHeaderDocument" % grobid_url # 仅处理头部的URL
  34. files = []
  35. if return_coordinates: # 如果需要返回坐标
  36. files += [
  37. ("teiCoordinates", (None, "persName")),
  38. ("teiCoordinates", (None, "figure")),
  39. ("teiCoordinates", (None, "ref")),
  40. ("teiCoordinates", (None, "formula")),
  41. ("teiCoordinates", (None, "biblStruct")),
  42. ]
  43. if isinstance(pdf_path, str): # 如果pdf_path是字符串
  44. if validate_url(pdf_path) and op.splitext(pdf_path)[-1].lower() != ".pdf":
  45. print("输入的URL必须以``.pdf``结尾")
  46. parsed_article = None
  47. elif validate_url(pdf_path) and op.splitext(pdf_path)[-1] == ".pdf":
  48. page = urllib.request.urlopen(pdf_path).read() # 从URL下载PDF
  49. parsed_article = requests.post(url, files={"input": page}).text # 通过GROBID处理下载的PDF
  50. elif op.exists(pdf_path): # 如果pdf_path是文件路径
  51. parsed_article = requests.post(
  52. url, files={"input": open(pdf_path, "rb")}
  53. ).text # 通过GROBID处理文件
  54. else:
  55. parsed_article = None
  56. elif isinstance(pdf_path, bytes): # 如果pdf_path是字节
  57. # 假设传入的是字节字符串
  58. parsed_article = requests.post(url, files={"input": pdf_path}).text # 通过GROBID处理字节
  59. else:
  60. parsed_article = None
  61. if soup and parsed_article is not None: # 如果需要返回BeautifulSoup对象
  62. parsed_article = BeautifulSoup(parsed_article, "lxml")
  63. return parsed_article

1.2.3 提取作者信息/parse_authors、出版日期/parse_date、摘要/parse_abstract、段落/parse_sections

  • 函数parse_authors从 BeautifulSoup 文章对象中提取作者信息
    1. def parse_authors(article):
    2. """
    3. Parse authors from a given BeautifulSoup of an article
    4. """
    5. # 从文章的 BeautifulSoup 对象中查找包含作者信息的 "sourcedesc" 标签,然后找到其中所有的 "persname" 标签
    6. author_names = article.find("sourcedesc").findAll("persname")
    7. # 创建一个空列表,用于保存解析的作者名字
    8. authors = []
    9. # 遍历每个作者标签
    10. for author in author_names:
    11. # 查找作者的名字,并进行处理,如果不存在则返回空字符串
    12. firstname = author.find("forename", {"type": "first"})
    13. firstname = firstname.text.strip() if firstname is not None else ""
    14. # 查找作者的中间名,并进行处理,如果不存在则返回空字符串
    15. middlename = author.find("forename", {"type": "middle"})
    16. middlename = middlename.text.strip() if middlename is not None else ""
    17. # 查找作者的姓氏,并进行处理,如果不存在则返回空字符串
    18. lastname = author.find("surname")
    19. lastname = lastname.text.strip() if lastname is not None else ""
    20. # 判断中间名是否存在,然后将名、中间名和姓组合在一起
    21. if middlename is not "":
    22. authors.append(firstname + " " + middlename + " " + lastname)
    23. else:
    24. authors.append(firstname + " " + lastname)
    25. # 使用"; "连接所有的作者名,生成一个字符串
    26. authors = "; ".join(authors)
    27. # 返回最终的作者名字符串
    28. return authors
  • 下面这个parse_date函数是提取初版日期,从 BeautifulSoup 文章对象中提取出版日期
    1. def parse_date(article):
    2. """
    3. Parse date from a given BeautifulSoup of an article
    4. """
    5. # 从文章的 BeautifulSoup 对象中查找包含出版日期信息的 "publicationstmt" 标签
    6. pub_date = article.find("publicationstmt")
    7. # 在 "publicationstmt" 标签下查找 "date" 标签
    8. year = pub_date.find("date")
    9. # 尝试获取 "date" 标签的 "when" 属性,如果标签不存在则返回空字符串
    10. year = year.attrs.get("when") if year is not None else ""
    11. # 返回解析出的年份
    12. return year
  • 而parse_abstract这个函数则是提取摘要,即从 BeautifulSoup 文章对象中提取摘要
    1. def parse_abstract(article):
    2. """
    3. Parse abstract from a given BeautifulSoup of an article
    4. """
    5. # 从文章的 BeautifulSoup 对象中查找 "abstract" 标签
    6. div = article.find("abstract")
    7. # 初始化摘要字符串为空
    8. abstract = ""
    9. # 遍历 "abstract" 标签下的所有直接子节点
    10. for p in list(div.children):
    11. # 如果子节点不是纯文本(NavigableString)且子节点的子元素数量大于0
    12. if not isinstance(p, NavigableString) and len(list(p)) > 0:
    13. # 将子节点下的所有非纯文本子元素的文本内容加入摘要字符串
    14. abstract += " ".join(
    15. [elem.text for elem in p if not isinstance(elem, NavigableString)]
    16. )
    17. # 返回解析出的摘要
    18. return abstract
  • 而parse_sections则是提取段落,从 BeautifulSoup 文章对象中提取文章的各个部分或段落,且它还计算每个部分中的引用数量
    1. def parse_sections(article, as_list: bool = False):
    2. """
    3. 从给定的BeautifulSoup文章中解析章节列表
    4. 参数
    5. ==========
    6. as_list: bool, 如果为True,则将输出文本作为段落列表,
    7. 而不是将其连接成一个单一的文本
    8. """
    9. # 找到文章中的"text"部分
    10. article_text = article.find("text")
    11. # 获取所有带有特定属性的"div"标签
    12. divs = article_text.find_all("div", attrs={"xmlns": "http://www.tei-c.org/ns/1.0"})
    13. sections = [] # 初始化章节列表
    14. for div in divs:
    15. div_list = list(div.children)
    16. if len(div_list) == 0:
    17. heading = ""
    18. text = ""
    19. elif len(div_list) == 1:
    20. # 如果只有一个子元素
    21. if isinstance(div_list[0], NavigableString):
    22. heading = str(div_list[0])
    23. text = ""
    24. else:
    25. heading = ""
    26. text = div_list[0].text
    27. else:
    28. text = []
    29. heading = div_list[0]
    30. if isinstance(heading, NavigableString):
    31. heading = str(heading)
    32. p_all = list(div.children)[1:]
    33. else:
    34. heading = ""
    35. p_all = list(div.children)
    36. for p in p_all:
    37. if p is not None:
    38. try:
    39. text.append(p.text) # 尝试添加文本
    40. except:
    41. pass
    42. if not as_list:
    43. text = "\n".join(text)
    44. # 如果标题或文本不为空
    45. if heading is not "" or text is not "":
    46. # 计算参考文献数量
    47. ref_dict = calculate_number_of_references(div)
    48. sections.append(
    49. {
    50. "heading": heading,
    51. "text": text,
    52. "n_publication_ref": ref_dict["n_publication_ref"],
    53. "n_figure_ref": ref_dict["n_figure_ref"],
    54. }
    55. )
    56. return sections

1.2.4 计算引用与解析文献引用/parse_references(article)

  • calculate_number_of_references:计算给定部分中的引用数量
    1. def calculate_number_of_references(div):
    2. """
    3. 对于给定的章节,计算章节中的参考文献数量
    4. """
    5. # 计算给定章节中的文献引用数量
    6. n_publication_ref = len(
    7. # 列表推导式查找所有type属性为"bibr""ref"标签
    8. [ref for ref in div.find_all("ref") if ref.attrs.get("type") == "bibr"]
    9. )
    10. # 计算给定章节中的图形引用数量
    11. n_figure_ref = len(
    12. # 列表推导式查找所有type属性为"figure""ref"标签
    13. [ref for ref in div.find_all("ref") if ref.attrs.get("type") == "figure"]
    14. )
    15. # 返回一个字典,包含文献引用数量和图形引用数量
    16. return {"n_publication_ref": n_publication_ref, "n_figure_ref": n_figure_ref}
  • parse_references(article):解析文献引用
    功能:从给定的BeautifulSoup对象中解析文献引用列表
    主要步骤:
    寻找包含引用的部分
    对于每个引用,提取文章标题、期刊、发布日期和作者信息
    返回包含所有引用信息的列表
    1. def parse_references(article):
    2. """
    3. 从给定的BeautifulSoup文章中解析引用列表
    4. """
    5. reference_list = [] # 初始化引用列表
    6. # 在文章中查找文本部分中的引用部分
    7. references = article.find("text").find("div", attrs={"type": "references"})
    8. # 如果存在引用,则查找所有的"biblstruct"标签,否则返回空列表
    9. references = references.find_all("biblstruct") if references is not None else []
    10. reference_list = [] # 再次初始化引用列表
    11. for reference in references:
    12. # 尝试查找引用的文章标题
    13. title = reference.find("title", attrs={"level": "a"})
    14. if title is None:
    15. title = reference.find("title", attrs={"level": "m"})
    16. title = title.text if title is not None else ""
    17. # 尝试查找引用的期刊名
    18. journal = reference.find("title", attrs={"level": "j"})
    19. journal = journal.text if journal is not None else ""
    20. if journal is "":
    21. journal = reference.find("publisher")
    22. journal = journal.text if journal is not None else ""
    23. # 查找引用的出版年份
    24. year = reference.find("date")
    25. year = year.attrs.get("when") if year is not None else ""
    26. authors = [] # 初始化作者列表
    27. # 遍历引用中的所有作者
    28. for author in reference.find_all("author"):
    29. firstname = author.find("forename", {"type": "first"})
    30. firstname = firstname.text.strip() if firstname is not None else ""
    31. middlename = author.find("forename", {"type": "middle"})
    32. middlename = middlename.text.strip() if middlename is not None else ""
    33. lastname = author.find("surname")
    34. lastname = lastname.text.strip() if lastname is not None else ""
    35. # 根据是否有中间名来组合作者的全名
    36. if middlename is not "":
    37. authors.append(firstname + " " + middlename + " " + lastname)
    38. else:
    39. authors.append(firstname + " " + lastname)
    40. authors = "; ".join(authors) # 将所有作者连接为一个字符串
    41. # 将标题、期刊、年份和作者添加到引用列表中
    42. reference_list.append(
    43. {"title": title, "journal": journal, "year": year, "authors": authors}
    44. )
    45. return reference_list # 返回引用列表

1.2.5 解析图形和表格、公式

  • parse_figure_caption(article)
    功能:从给定的BeautifulSoup对象中解析图形和表格
    主要步骤:
    搜索所有图形
    对于每个图形或表格,提取标签、类型、ID、标题和数据
    返回包含所有图形/表格信息的列表
    1. def parse_figure_caption(article):
    2. """
    3. 从给定的BeautifulSoup文章中解析图表列表
    4. """
    5. figures_list = [] # 初始化图表列表
    6. # 在文章中查找所有的"figure"标签
    7. figures = article.find_all("figure")
    8. for figure in figures:
    9. # 获取图标的类型(可能是图或表)和ID
    10. figure_type = figure.attrs.get("type") or ""
    11. figure_id = figure.attrs.get("xml:id") or ""
    12. # 获取图标的标签(如"图1"
    13. label = figure.find("label").text
    14. if figure_type == "table":
    15. # 如果图形类型为表,则获取表的标题和数据
    16. caption = figure.find("figdesc").text
    17. data = figure.table.text
    18. else:
    19. # 否则,只获取图形的标题,并将数据设置为空字符串
    20. caption = figure.text
    21. data = ""
    22. # 将标签、类型、ID、标题和数据添加到图形列表中
    23. figures_list.append(
    24. {
    25. "figure_label": label,
    26. "figure_type": figure_type,
    27. "figure_id": figure_id,
    28. "figure_caption": caption,
    29. "figure_data": data,
    30. }
    31. )
    32. return figures_list # 返回图表列表

  • parse_figures(...):
    功能:使用pdffigures2工具从给定的科学PDF中解析图形
    主要步骤:
    检查输出文件夹是否存在,如果不存在则创建它
    在输出文件夹中创建子文件夹来保存数据和图形
    使用Java运行pdffigures2工具解析图形
    打印完成消息
    1. def parse_figures(
    2. pdf_folder: str,
    3. jar_path: str = PDF_FIGURES_JAR_PATH,
    4. resolution: int = 300,
    5. output_folder: str = "figures",
    6. ):
    7. """
    8. 使用pdffigures2从给定的科学PDF中提取图形。
    9. 参数
    10. ==========
    11. pdf_folder: str, 包含PDF文件的文件夹的路径。一个文件夹必须只包含PDF文件。
    12. jar_path: str, pdffigures2-assembly-0.0.12-SNAPSHOT.jar文件的默认路径。
    13. resolution: int, 输出图形的分辨率。
    14. output_folder: str, 我们希望保存解析数据(与图形相关)和图形的文件夹的路径。
    15. 输出
    16. ======
    17. folder: 在output_folder/data和output_folder/figures中创建文件夹,分别包含解析数据和图形。
    18. """
    19. # 检查output_folder是否存在,如果不存在,则创建它。
    20. if not op.isdir(output_folder):
    21. os.makedirs(output_folder)
    22. # 在output_folder内创建“data”和“figures”子文件夹。
    23. data_path = op.join(output_folder, "data")
    24. figure_path = op.join(output_folder, "figures")
    25. if not op.exists(data_path):
    26. os.makedirs(data_path)
    27. if not op.exists(figure_path):
    28. os.makedirs(figure_path)
    29. # 如果data和figures文件夹存在,则执行pdffigures2命令。
    30. if op.isdir(data_path) and op.isdir(figure_path):
    31. args = [
    32. "java",
    33. "-jar",
    34. jar_path,
    35. pdf_folder,
    36. "-i",
    37. str(resolution),
    38. "-d",
    39. op.join(op.abspath(data_path), ""),
    40. "-m",
    41. op.join(op.abspath(figure_path), ""), # end path with "/"
    42. ]
    43. _ = subprocess.run(
    44. args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=20
    45. )
    46. print("完成从PDFs中提取图形!")
    47. else:
    48. print(
    49. "您可能需要检查output文件夹路径中的``data``和``figures``。"
    50. )
  • parse_formulas(article):解析公式
    功能:从给定的BeautifulSoup对象中解析公式
    主要步骤:
    搜索所有公式
    提取公式的ID、文本和坐标
    返回包含所有公式信息的列表
    1. def parse_formulas(article):
    2. """
    3. 从给定的BeautifulSoup文章中解析公式列表
    4. """
    5. formulas_list = [] # 初始化公式列表
    6. # 在文章中查找所有的"formula"标签
    7. formulas = article.find_all("formula")
    8. for formula in formulas:
    9. # 获取公式的ID
    10. formula_id = formula.attrs["xml:id"] or ""
    11. # 获取公式的文本内容
    12. formula_text = formula.text
    13. # 尝试获取公式的坐标
    14. formula_coordinates = formula.attrs.get("coords") or ""
    15. if formula_coordinates is not "":
    16. # 如果有坐标,将它们转换为浮点数列表
    17. formula_coordinates = [float(x) for x in formula_coordinates.split(",")]
    18. # 将ID、文本和坐标添加到公式列表中
    19. formulas_list.append(
    20. {
    21. "formula_id": formula_id,
    22. "formula_text": formula_text,
    23. "formula_coordinates": formula_coordinates,
    24. }
    25. )
    26. return formulas_list # 返回公式列表

1.2.6 把标题/作者/摘要/图形/公式等转换为JSON格式的字典

  • convert_article_soup_to_dict(article, as_list=False):
    功能:将BeautifulSoup对象转换为JSON格式的字典,类似于某些开源项目的输出
    主要步骤:
    提取文章的标题、作者、发布日期、摘要、部分、引用、图形和公式
    返回一个包含所有这些信息的字典
    1. def convert_article_soup_to_dict(article, as_list: bool = False):
    2. """
    3. 将BeautifulSoup对象转换为JSON格式的函数
    4. 与https://github.com/allenai/science-parse/ 的输出类似
    5. 参数
    6. ==========
    7. article: BeautifulSoup
    8. 输出
    9. ======
    10. article_json: dict, 给定文章的解析字典,格式如下:
    11. {
    12. 'title': ...,
    13. 'abstract': ...,
    14. 'sections': [
    15. {'heading': ..., 'text': ...},
    16. {'heading': ..., 'text': ...},
    17. ...
    18. ],
    19. 'references': [
    20. {'title': ..., 'journal': ..., 'year': ..., 'authors': ...},
    21. {'title': ..., 'journal': ..., 'year': ..., 'authors': ...},
    22. ...
    23. ],
    24. 'figures': [
    25. {'figure_label': ..., 'figure_type': ..., 'figure_id': ..., 'figure_caption': ..., 'figure_data': ...},
    26. ...
    27. ]
    28. }
    29. """
    30. article_dict = {} # 初始化文章字典
    31. if article is not None:
    32. # 从文章中获取主标题
    33. title = article.find("title", attrs={"type": "main"})
    34. title = title.text.strip() if title is not None else ""
    35. article_dict["title"] = title
    36. # 解析文章的作者
    37. article_dict["authors"] = parse_authors(article)
    38. # 解析文章的发布日期
    39. article_dict["pub_date"] = parse_date(article)
    40. # 解析文章的摘要
    41. article_dict["abstract"] = parse_abstract(article)
    42. # 解析文章的各个部分
    43. article_dict["sections"] = parse_sections(article, as_list=as_list)
    44. # 解析文章的参考文献
    45. article_dict["references"] = parse_references(article)
    46. # 解析文章的图表
    47. article_dict["figures"] = parse_figure_caption(article)
    48. # 解析文章的公式
    49. article_dict["formulas"] = parse_formulas(article)
    50. # 从文章中获取DOI
    51. doi = article.find("idno", attrs={"type": "DOI"})
    52. doi = doi.text if doi is not None else ""
    53. article_dict["doi"] = doi
    54. return article_dict
    55. else:
    56. return None # 如果文章不存在,返回None
  • parse_pdf_to_dict(...)
    功能:解析给定的PDF并返回解析后的文章的字典
    主要步骤:
    使用外部工具或服务(如GROBID)解析PDF
    将解析后的BeautifulSoup对象转换为字典格式
    返回该字典
    1. def parse_pdf_to_dict(
    2. pdf_path: str,
    3. fulltext: bool = True,
    4. soup: bool = True,
    5. as_list: bool = False,
    6. return_coordinates: bool = True,
    7. grobid_url: str = GROBID_URL,
    8. ):
    9. """
    10. 解析给定的PDF并返回解析后的文章字典
    11. 参数
    12. ==========
    13. pdf_path: str, 出版物或文章的路径
    14. fulltext: bool, 是否提取完整文本
    15. soup: bool, 是否返回BeautifulSoup
    16. as_list: bool, 是否返回部分列表
    17. return_coordinates: bool, 是否返回坐标
    18. grobid_url: str, grobid服务器的url,默认为`GROBID_URL`
    19. 可更改为 "https://cloud.science-miner.com/grobid/" 使用云服务
    20. 输出
    21. =====
    22. article_dict: dict, 文章的字典
    23. """
    24. # 使用parse_pdf函数解析PDF
    25. parsed_article = parse_pdf(
    26. pdf_path,
    27. fulltext=fulltext,
    28. soup=soup,
    29. return_coordinates=return_coordinates,
    30. grobid_url=grobid_url,
    31. )
    32. # 将BeautifulSoup对象转换为字典
    33. article_dict = convert_article_soup_to_dict(parsed_article, as_list=as_list)
    34. return article_dict # 返回解析后的文章字典
    这个函数的目的是解析给定的PDF文件,并将其转换为一个结构化的字典。首先,它使用parse_pdf函数来解析PDF,然后使用convert_article_soup_to_dict函数将解析后的BeautifulSoup对象转换为字典

1.3 论文检索:ChatPaper/auto_survey/utils

具体包含如下功能(这个基于GPT4的文献总结工具的项目auto-draft也提供类似的功能)

  • 自动搜索相关文献, 提供真实有出处的引用
  • 自动生成LaTeX格式,markdown格式的调研结果

1.3.1 /utils/knowledge_databases/ml_textbook_test

// 待更

1.3.2 /utils/embeddings.py

  1. # 导入HuggingFace的文本嵌入功能
  2. from langchain.embeddings import HuggingFaceEmbeddings
  3. # 导入操作系统相关的模块,用于获取环境变量等操作
  4. import os
  5. # 从环境变量中获取OpenAI的API密钥
  6. openai_api_key = os.getenv("OPENAI_API_KEY")
  7. # 如果获取到了OpenAI的API密钥
  8. if openai_api_key is not None:
  9. # 导入OpenAI的文本嵌入功能
  10. from langchain.embeddings.openai import OpenAIEmbeddings
  11. # 使用获取到的API密钥初始化OpenAI的文本嵌入
  12. openai_embedding = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key=openai_api_key)
  13. else:
  14. # 如果没有获取到API密钥,则将OpenAI的文本嵌入设为None
  15. openai_embedding = None
  16. # 定义HuggingFace的模型名称
  17. model_name = 'sentence-transformers/all-MiniLM-L6-v2'
  18. # 设置模型的参数,这里是将模型放在CPU上运行
  19. model_kwargs = {'device': 'cpu'}
  20. # 设置文本嵌入的参数,这里是不对嵌入进行归一化
  21. encode_kwargs = {'normalize_embeddings': False}
  22. # 使用上述参数初始化HuggingFace的文本嵌入
  23. all_minilm_l6_v2 = HuggingFaceEmbeddings(
  24. model_name=model_name,
  25. model_kwargs=model_kwargs,
  26. encode_kwargs=encode_kwargs)
  27. # 创建一个字典来存储上述两种文本嵌入,方便后续调用
  28. EMBEDDINGS = {"text-embedding-ada-002": openai_embedding, "all-MiniLM-L6-v2": all_minilm_l6_v2}

1.3.3 /utils/gpt_interaction.py

// 待更

1.3.4 /utils/knowledge.py

定义了一个Knowledge类,该类使用关键词字典从数据库中搜索相关内容,并可以将这些内容转化为提示文本或JSON格式

  1. import tiktoken # 导入tiktoken模块,用于计算tokens数量
  2. from random import shuffle # 从random模块导入shuffle函数,用于随机打乱列表
  3. # 使用`tiktoken`来计算文本中的tokens数量
  4. tokenizer_name = tiktoken.encoding_for_model('gpt-4') # 为"gpt-4"模型获取相应的编码器名称
  5. tokenizer = tiktoken.get_encoding(tokenizer_name.name) # 获取编码器实例
  6. def tiktoken_len(text):
  7. # 计算给定文本中的tokens数量
  8. tokens = tokenizer.encode(text, disallowed_special=()) # 对文本进行编码并返回tokens
  9. return len(tokens) # 返回tokens的数量
  10. class Knowledge:
  11. # 定义一个Knowledge类来处理知识数据库相关操作
  12. def __init__(self, db):
  13. self.db = db # 数据库实例
  14. self.contents = [] # 用于存放内容的列表
  15. def collect_knowledge(self, keywords_dict, max_query):
  16. """
  17. 根据给定的关键词字典,从数据库中搜索并收集相关的知识。
  18. keywords_dict:
  19. 示例: {"machine learning": 5, "language model": 2};
  20. """
  21. db = self.db
  22. if max_query > 0:
  23. for kw in keywords_dict:
  24. docs = db.similarity_search_with_score(kw, k=max_query) # 使用关键词在数据库中进行相似度搜索
  25. for i in range(max_query):
  26. content = {"content": docs[i][0].page_content.replace('\n', ' '), # 移除换行符
  27. "score": docs[i][1]} # 为每个文档添加评分
  28. self.contents.append(content) # 将内容添加到contents列表中
  29. shuffle(self.contents) # 随机打乱contents列表
  30. def to_prompts(self, max_tokens=2048):
  31. # 将收集到的知识内容转化为提示文本,且tokens总数不超过max_tokens
  32. if len(self.contents) == 0:
  33. return ""
  34. prompts = []
  35. tokens = 0
  36. for idx, content in enumerate(self.contents):
  37. prompt = "Reference {}: {}\n".format(idx, content["content"])
  38. tokens += tiktoken_len(prompt)
  39. if tokens >= max_tokens:
  40. break
  41. else:
  42. prompts.append(prompt) # 将提示文本添加到prompts列表中
  43. return "".join(prompts) # 返回连接后的提示文本
  44. def to_json(self):
  45. # 将收集到的知识内容转化为JSON格式
  46. if len(self.contents) == 0:
  47. return {}
  48. output = {}
  49. for idx, content in enumerate(self.contents):
  50. output[str(idx)] = {
  51. "content": content["content"],
  52. "score": str(content["score"])
  53. }
  54. print(output)
  55. return output

1.3.5 /utils/references.py

这个代码文件主要注意实现了以下功能

1.3.5.1 第一部分:References 类之外
  1. Reference类的说明

    • 从给定的.bib文件中读取论文,并用search_paper_abstract方法填充缺失的摘要
    • 根据一些关键词使用Semantic Scholar API查找相关论文
    • 从所选论文中生成Bibtex引用格式
    • 从所选论文中生成提示(prompts)。示例提示格式为:{"paper_id": "paper summary"}
  2. 待完成的任务(todo)

    • 加载预定义的论文;
    • 使用Semantic Scholar API查找所有相关作品;
    • 将所有引文添加到bib_papers
    • 将所有被引文添加到bib_papers
    • 使用Semantic Scholar查找它们的嵌入;
    • 将引文分组以减少tokens的数量
  3. 一些基本的工具

    • evaluate_cosine_similarity:计算两个向量的余弦相似性
      1. def evaluate_cosine_similarity(v1, v2):
      2. try:
      3. return np.dot(v1, v2)/(norm(v1)*norm(v2))
      4. except ValueError:
      5. return 0.0
    • chunks 将一个较长的列表分割为较小的批次,以便于处理;
      1. def chunks(lst, chunk_size=MAX_BATCH_SIZE):
      2. """Splits a longer list to respect batch size"""
      3. for i in range(0, len(lst), chunk_size):
      4. yield lst[i : i + chunk_size]
    • embed 通过向Semantic Scholar的API发送请求,为一组论文计算嵌入(即将论文映射到一个向量空间中)
      1. def embed(papers):
      2. embeddings_by_paper_id: Dict[str, List[float]] = {}
      3. for chunk in chunks(papers):
      4. # Allow Python requests to convert the data above to JSON
      5. response = requests.post(URL, json=chunk)
      6. if response.status_code != 200:
      7. raise RuntimeError("Sorry, something went wrong, please try later!")
      8. for paper in response.json()["preds"]:
      9. embeddings_by_paper_id[paper["paper_id"]] = paper["embedding"]
      10. return embeddings_by_paper_id
    • get_embeddings 为给定的论文标题和描述获取嵌入
      1. def get_embeddings(paper_title, paper_description):
      2. output = [{"title": paper_title, "abstract": paper_description, "paper_id": "target_paper"}]
      3. emb_vector = embed(output)["target_paper"]
      4. target_paper = output[0]
      5. target_paper["embeddings"] = emb_vector
      6. return target_paper
    • get_top_k 获取与给定论文最相关的k篇论文
      具体而言,从提供的papers_dict 中找到与给定的paper_title和paper_description最相似的前k篇论文,并返回。至于相似性是通过计算两篇论文嵌入向量的余弦相似度来确定的
      1. def get_top_k(papers_dict, paper_title, paper_description, k=None):
      2. # 获取目标论文的嵌入向量
      3. target_paper = get_embeddings(paper_title, paper_description)
      4. # 存放所有的论文信息,其中应包含嵌入向量
      5. papers = papers_dict
      6. # 如果k小于papers的数量,返回k篇最相关的论文
      7. # 如果k大于等于papers的数量或k为None,返回所有论文
      8. max_num_papers = len(papers) # 获取论文总数
      9. if k is None: # 如果k为None,设置k为论文总数
      10. k = max_num_papers
      11. num_papers = min(k, max_num_papers) # 确定需要返回的论文数量
      12. # 获取目标论文的嵌入向量
      13. target_embedding_vector = target_paper["embeddings"]
      14. # 计算每篇论文与目标论文的余弦相似度
      15. for k in papers:
      16. v = papers[k]
      17. embedding_vector = v["embeddings"] # 获取当前论文的嵌入向量
      18. cos_sim = evaluate_cosine_similarity(embedding_vector, target_embedding_vector) # 计算余弦相似度
      19. papers[k]["cos_sim"] = cos_sim # 存储余弦相似度到papers中
      20. # 返回相似度最高的前k篇论文
      21. sorted_papers = {k: v for k, v in sorted(papers.items(), key=lambda x: x[1]["cos_sim"], reverse=True)[:num_papers]}
      22. # 从返回的论文中移除嵌入向量信息
      23. for key in sorted_papers:
      24. sorted_papers[key].pop("embeddings", None)
      25. return sorted_papers
    • remove_newlines 去除摘要中的换行符,减少提示的长度
      1. def remove_newlines(serie):
      2. # This function is applied to the abstract of each paper to reduce the length of prompts.
      3. serie = serie.replace('\n', ' ')
      4. serie = serie.replace('\\n', ' ')
      5. serie = serie.replace(' ', ' ')
      6. serie = serie.replace(' ', ' ')
      7. return serie
  4. 从.bib文件加载论文信息

    • 读取.bib文件,并将其解析为一个python对象;
    • 通过load_papers_from_bibtex 函数遍历这个对象,从中提取论文的各种属性(如ID、标题、期刊、年份、作者、摘要等);
      1. def load_papers_from_bibtex(bib_file_path):
      2. with open(bib_file_path) as bibtex_file:
      3. bib_database = bibtexparser.load(bibtex_file)
      4. if len(bib_database.entries) == 0:
      5. return []
      6. else:
      7. bib_papers = []
      8. for bibitem in bib_database.entries:
      9. # Add each paper to `bib_papers`
      10. paper_id = bibitem.get("ID")
      11. title = bibitem.get("title")
      12. if title is None:
      13. continue
      14. journal = bibitem.get("journal")
      15. year = bibitem.get("year")
      16. author = bibitem.get("author")
      17. abstract = bibitem.get("abstract")
      18. if abstract is None:
      19. abstract = search_paper_abstract(title)
      20. result = {
      21. "paper_id": paper_id,
      22. "title": title,
      23. "link": "",
      24. "abstract": abstract,
      25. "authors": author,
      26. "year": year,
      27. "journal": journal
      28. }
      29. bib_papers.append(result)
      30. return bib_papers
    • 对于缺失摘要的论文,使用search_paper_abstract 函数查询摘要
      1. def search_paper_abstract(title):
      2. pg = ProxyGenerator()
      3. success = pg.FreeProxies() # pg.ScraperAPI("921b16f94d701308b9d9b4456ddde155")
      4. if success:
      5. try:
      6. scholarly.use_proxy(pg)
      7. # input the title of a paper, return its abstract
      8. search_query = scholarly.search_pubs(title)
      9. found_paper = next(search_query)
      10. except:
      11. return ""
      12. else:
      13. return ""
      14. # raise RuntimeError("ScraperAPI fails.")
      15. return remove_newlines(found_paper['bib']['abstract'])
  5. 计算文本的tokens数量

    • 使用tokenizer对象来计算给定文本的tokens的数量
      1. # `tokenizer`: used to count how many tokens
      2. tokenizer_name = tiktoken.encoding_for_model('gpt-4')
      3. tokenizer = tiktoken.get_encoding(tokenizer_name.name)
      4. def tiktoken_len(text):
      5. # evaluate how many tokens for the given text
      6. tokens = tokenizer.encode(text, disallowed_special=())
      7. return len(tokens)
  6. 使用Semantic Scholar (SS) API搜索论文

    • 使用Semantic Scholar API搜索指定关键词的论文;
    • 从API返回的数据中提取论文的各种属性
  7. parse_search_results 函数

    这部分主要关于从搜索结果中提取学术论文的相关信息:
    该函数的目的是对传入的搜索结果进行解析,并将其转换为一个论文信息列表。

    • 首先检查传入的搜索结果是否为空。
    • 逐个解析每篇论文的内容,包括作者信息、年份、标题等。
    • 对某些字段进行特殊处理,如将日志名中的&替换为\&
    • 如果存在摘要的“tldr”(即“过长不读”)版本,它会被优先使用,否则会使用原始摘要。
    • 最后,所有提取出的信息将被组合成一个字典并添加到结果列表中
      且函数下方的代码调用了一个假设的ss_search方法,然后使用上述函数处理这些搜索结果
      1. def parse_search_results(search_results_ss):
      2. # 判断搜索结果是否为空
      3. if len(search_results_ss) == 0:
      4. return []
      5. # 将搜索结果转换为论文字典的列表
      6. papers_ss = []
      7. for raw_paper in search_results_ss:
      8. # 如果论文没有摘要,跳过此论文
      9. if raw_paper["abstract"] is None:
      10. continue
      11. # 提取作者信息
      12. authors_str, last_name = extract_author_info(raw_paper['authors'])
      13. # 获取论文的发表年份
      14. year_str = str(raw_paper['year'])
      15. # 获取论文标题
      16. title = raw_paper['title']
      17. # 有些期刊的名字可能包含"&"字符;将其替换掉
      18. journal = raw_paper['venue'].replace("&", "\\&")
      19. # 如果没有提供期刊名,就默认为“arXiv preprint”
      20. if not journal:
      21. journal = "arXiv preprint"
      22. # 根据作者姓、发表年份和标题提取论文ID
      23. paper_id = extract_paper_id(last_name, year_str, title).lower()
      24. # 转换外部ID为链接
      25. link = externalIds2link(raw_paper['externalIds'])
      26. # 如果存在tldr摘要,使用tldr摘要;否则,使用原始摘要并移除其中的换行符
      27. if tldr and raw_paper['tldr'] is not None:
      28. abstract = raw_paper['tldr']['text']
      29. else:
      30. abstract = remove_newlines(raw_paper['abstract'])
      31. # 有些论文可能没有嵌入;处理这种情况
      32. embeddings_dict = raw_paper.get('embedding')
      33. if embeddings_dict is None:
      34. continue
      35. else:
      36. embeddings = raw_paper['embedding']['vector']
      37. # 组合结果
      38. result = {
      39. "paper_id": paper_id,
      40. "title": title,
      41. "abstract": abstract,
      42. "link": link,
      43. "authors": authors_str,
      44. "year": year_str,
      45. "journal": journal,
      46. "embeddings": embeddings
      47. }
      48. # 将结果添加到论文列表中
      49. papers_ss.append(result)
      50. # 返回论文列表
      51. return papers_ss
      52. # 使用关键字进行搜索
      53. raw_results = ss_search(keyword, limit=counts)
      54. # 如果获取到了原始搜索结果
      55. if raw_results is not None:
      56. # 提取搜索结果数据
      57. search_results = raw_results.get("data")
      58. # 如果搜索结果是空的,设置为空列表
      59. if search_results is None:
      60. search_results = []
      61. # 如果没有获取到原始搜索结果,设置为空列表
      62. else:
      63. search_results = []
      64. # 解析搜索结果并返回
      65. results = parse_search_results(search_results)
      66. return results
1.3.5.2 第二部分:References

该类用于管理论文引用:

  1. 初始化方法:当创建一个References对象时,可以选择为其提供标题、论文列表、关键词以及描述
  2. load_papers 方法:加载给定BibTeX格式的论文到引用类中
  3. generate_keywords_dict 方法:生成一个关键词字典,其中每个关键词都关联一个论文数量
  4. collect_papers 方法:使用给定的关键词字典收集尽可能多的论文。这个方法尝试收集给定关键词的相关论文,并添加到类的内部存储中
  5. to_bibtex 方法:将保存的论文列表转换为BibTeX格式的文件
  6. _get_papers 方法:一个内部方法,用于从内部存储中获取论文列表
  7. to_prompts 方法:将引用转换为提示格式,这可能是为了后续使用某种机器学习模型
  8. to_json 方法:将论文列表转换为JSON格式
  9. 代码的最后部分(在if __name__ == "__main__":之后)是一个简单的测试部分,用于测试上述代码的功能

//待更

1.4 ChatPaper/chat_paper.py

chat_paper.py,包含一个Paper类、Reader类和chat_paper_mian函数。该程序功能为根据读者输入的搜索查询和感兴趣的关键词,从Arxiv数据库中获取文章,并对文章进行摘要和总结。程序使用了OpenAI的GPT-3模型生成文本摘要,使用了arxiv包获取Arxiv数据库中的文章。程序会将摘要和总结以markdown文件的形式保存下来。

1.4.1 Paper类

  • Paper 类代表了一篇论文,它可以从 PDF 文件中解析出论文的元信息和内容,并提供了一些函数用于获取论文信息,如获取文章标题,获取章节名称及内容等。主要方法有:
  • parse_pdf:解析PDF文件
    其中的self._get_all_page_index() 和self._get_all_page() 这两个方法 下文很快会定义
    1. def parse_pdf(self): # 定义一个方法来解析PDF文件
    2. self.pdf = fitz.open(self.path) # 使用fitz库打开指定路径的pdf文件
    3. self.text_list = [page.get_text() for page in self.pdf] # 从每一页中提取文本并存放到列表中
    4. self.all_text = ' '.join(self.text_list) # 将每一页的文本连接成一个完整的字符串
    5. self.section_page_dict = self._get_all_page_index() # 获取段落与其对应的页码字典
    6. print("section_page_dict", self.section_page_dict) # 打印该段落与页码的对应字典
    7. self.section_text_dict = self._get_all_page() # 获取段落与其对应的内容字典
    8. self.section_text_dict.update({"title": self.title}) # 将标题添加到段落内容字典中
    9. self.section_text_dict.update({"paper_info": self.get_paper_info()}) # 获取论文的信息并添加到字典中
    10. self.pdf.close() # 关闭pdf文件
  • get_all_page_index:各个部分与页码的对应字典
    1. def _get_all_page_index(self):
    2. # 定义需要寻找的章节名称列表
    3. section_list = ["Abstract",
    4. 'Introduction', 'Related Work', 'Background',
    5. "Preliminary", "Problem Formulation",
    6. 'Methods', 'Methodology', "Method", 'Approach', 'Approaches',
    7. # exp
    8. "Materials and Methods", "Experiment Settings",
    9. 'Experiment', "Experimental Results", "Evaluation", "Experiments",
    10. "Results", 'Findings', 'Data Analysis',
    11. "Discussion", "Results and Discussion", "Conclusion",
    12. 'References']
    13. # 初始化一个字典来存储找到的章节和它们在文档中出现的页码
    14. section_page_dict = {}
    15. # 遍历每一页文档
    16. for page_index, page in enumerate(self.pdf):
    17. # 获取当前页面的文本内容
    18. cur_text = page.get_text()
    19. # 遍历需要寻找的章节名称列表
    20. for section_name in section_list:
    21. # 将章节名称转换成大写形式
    22. section_name_upper = section_name.upper()
    23. # 如果当前页面包含"Abstract"这个关键词
    24. if "Abstract" == section_name and section_name in cur_text:
    25. # 将"Abstract"和它所在的页码加入字典中
    26. section_page_dict[section_name] = page_index
    27. # 如果当前页面包含章节名称,则将章节名称和它所在的页码加入字典中
    28. else:
    29. if section_name + '\n' in cur_text:
    30. section_page_dict[section_name] = page_index
    31. elif section_name_upper + '\n' in cur_text:
    32. section_page_dict[section_name] = page_index
    33. # 返回所有找到的章节名称及它们在文档中出现的页码
    34. return section_page_dict
  • get_all_page:各个部分与内容对应的字典
    1. def _get_all_page(self):
    2. """
    3. 获取PDF文件中每个页面的文本信息,并将文本信息按照章节组织成字典返回。
    4. """
    5. text = '' # 初始化空字符串用于临时储存文本
    6. text_list = [] # 初始化列表用于储存每一页的文本
    7. section_dict = {} # 初始化章节字典
    8. text_list = [page.get_text() for page in self.pdf] # 从每一页获取文本
    9. for sec_index, sec_name in enumerate(self.section_page_dict): # 遍历章节页码字典
    10. print(sec_index, sec_name, self.section_page_dict[sec_name]) # 打印章节索引、章节名和章节起始页码
    11. if sec_index <= 0 and self.abs: # 如果是第一个章节并且存在摘要,则跳过
    12. continue
    13. else:
    14. start_page = self.section_page_dict[sec_name] # 获取章节的起始页码
    15. # 如果当前章节不是最后一个,则获取下一个章节的起始页码作为当前章节的结束页码
    16. if sec_index < len(list(self.section_page_dict.keys()))-1:
    17. end_page = self.section_page_dict[list(self.section_page_dict.keys())[sec_index+1]]
    18. else: # 否则当前章节的结束页码为PDF的最后一页
    19. end_page = len(text_list)
    20. print("start_page, end_page:", start_page, end_page) # 打印起始和结束页码
    21. cur_sec_text = '' # 初始化当前章节的文本
    22. # 如果起始页码和结束页码相同,说明章节在同一页内
    23. if end_page - start_page == 0:
    24. next_sec = list(self.section_page_dict.keys())[sec_index+1]
    25. # 下面的代码是为了确定当前章节的文本的起始和结束位置
    26. # 这部分代码处理可能存在的大小写不一致的问题
    27. start_i = text_list[start_page].find(sec_name) if text_list[start_page].find(sec_name) != -1 else text_list[start_page].find(sec_name.upper())
    28. end_i = text_list[start_page].find(next_sec) if text_list[start_page].find(next_sec) != -1 else text_list[start_page].find(next_sec.upper())
    29. cur_sec_text += text_list[start_page][start_i:end_i]
    30. else: # 否则,章节可能跨越多页
    31. for page_i in range(start_page, end_page):
    32. # 下面的代码是为了确定在每一页中章节文本的起始和结束位置
    33. if page_i == start_page:
    34. start_i = text_list[start_page].find(sec_name) if text_list[start_page].find(sec_name) != -1 else text_list[start_page].find(sec_name.upper())
    35. cur_sec_text += text_list[page_i][start_i:]
    36. elif page_i < end_page:
    37. cur_sec_text += text_list[page_i]
    38. elif page_i == end_page:
    39. next_sec = list(self.section_page_dict.keys())[sec_index+1]
    40. end_i = text_list[start_page].find(next_sec) if text_list[start_page].find(next_sec) != -1 else text_list[start_page].find(next_sec.upper())
    41. cur_sec_text += text_list[page_i][:end_i]
    42. # 在当前章节的文本中去除多余的换行符
    43. section_dict[sec_name] = cur_sec_text.replace('-\n', '').replace('\n', ' ')
    44. return section_dict # 返回章节字典
  • get_paper_info:获取论文的摘要信息
    首先尝试从self.section_text_dict 字典中获取摘要,如果没有,则使用self.abs。最后,它从标题页的文本中移除摘要的内容并返回
    1. def get_paper_info(self): # 定义一个方法获取论文的信息
    2. first_page_text = self.pdf[self.title_page].get_text() # 从PDF的标题页中提取文本
    3. if "Abstract" in self.section_text_dict.keys(): # 如果"Abstract"(摘要)在字典的关键字中
    4. abstract_text = self.section_text_dict['Abstract'] # 从字典中获取摘要的文本
    5. else: # 否则
    6. abstract_text = self.abs # 使用self.abs作为摘要的文本
    7. first_page_text = first_page_text.replace(abstract_text, "") # 从首页面文本中移除摘要内容
    8. return first_page_text # 返回处理后的首页面文本
  • get_chapter_names:根据字体大小,识别每个章节名称,并返回一个列表
  • get_title:获取论文标题
    1. def get_title(self):
    2. doc = self.pdf # 打开pdf文件
    3. max_font_size = 0 # 初始化最大字体大小为0
    4. max_string = "" # 初始化最大字体大小对应的字符串为空
    5. max_font_sizes = [0]
    6. for page_index, page in enumerate(doc): # 遍历每一页
    7. text = page.get_text("dict") # 获取页面上的文本信息
    8. blocks = text["blocks"] # 获取文本块列表
    9. for block in blocks: # 遍历每个文本块
    10. if block["type"] == 0 and len(block['lines']): # 如果是文字类型
    11. if len(block["lines"][0]["spans"]):
    12. font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
    13. max_font_sizes.append(font_size)
    14. if font_size > max_font_size: # 如果字体大小大于当前最大值
    15. max_font_size = font_size # 更新最大值
    16. max_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串
    17. max_font_sizes.sort()
    18. print("max_font_sizes", max_font_sizes[-10:])
    19. cur_title = ''
    20. for page_index, page in enumerate(doc): # 遍历每一页
    21. text = page.get_text("dict") # 获取页面上的文本信息
    22. blocks = text["blocks"] # 获取文本块列表
    23. for block in blocks: # 遍历每个文本块
    24. if block["type"] == 0 and len(block['lines']): # 如果是文字类型
    25. if len(block["lines"][0]["spans"]):
    26. cur_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串
    27. font_flags = block["lines"][0]["spans"][0]["flags"] # 获取第一行第一段文字的字体特征
    28. font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
    29. # print(font_size)
    30. if abs(font_size - max_font_sizes[-1]) < 0.3 or abs(font_size - max_font_sizes[-2]) < 0.3:
    31. # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
    32. if len(cur_string) > 4 and "arXiv" not in cur_string:
    33. # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
    34. if cur_title == '' :
    35. cur_title += cur_string
    36. else:
    37. cur_title += ' ' + cur_string
    38. self.title_page = page_index
    39. # break
    40. title = cur_title.replace('\n', ' ')
    41. return title

1.4.2 Reader类

Reader类包含了下载文章、筛选文章以及使用OpenAI的GPT-3模型生成文本摘要和总结的方法。主要方法有:

  • get_arxiv(): 使用Arxiv的API获取搜索结果
  • filter_arxiv(): 筛选文章,并返回筛选后的结果
  • download_pdf(): 从Arxiv下载筛选后的文章
  • summary_with_chat(): 对每一篇下载下来的文章进行文本摘要和总结,并将结果以markdown文件的形式保存
    该函数的实现主要分为三个部分
    首先,第一步:用title,abs和introduction进行总结
    1. # 遍历论文列表
    2. for paper_index, paper in enumerate(paper_list):
    3. # 第一步:用title,abs和introduction进行总结
    4. text = ''
    5. text += 'Title:' + paper.title
    6. text += 'Url:' + paper.url
    7. text += 'Abstract:' + paper.abs
    8. text += 'Paper_info:' + paper.section_text_dict['paper_info']
    9. # 添加introduction
    10. text += list(paper.section_text_dict.values())[0]
    11. chat_summary_text = ""
    12. # 尝试与聊天机器人对话以获取摘要
    13. try:
    14. chat_summary_text = self.chat_summary(text=text)
    15. except Exception as e: # 捕获所有异常
    16. print("summary_error:", e)
    17. import sys
    18. exc_type, exc_obj, exc_tb = sys.exc_info() # 获取异常信息
    19. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
    20. print(exc_type, fname, exc_tb.tb_lineno)
    21. if "maximum context" in str(e): # 如果错误信息中包含特定字符串
    22. current_tokens_index = str(e).find("your messages resulted in") + len(
    23. "your messages resulted in") + 1
    24. offset = int(str(e)[current_tokens_index:current_tokens_index + 4])
    25. summary_prompt_token = offset + 1000 + 150
    26. chat_summary_text = self.chat_summary(text=text, summary_prompt_token=summary_prompt_token)
    27. # 添加到html列表中
    28. htmls.append('## Paper:' + str(paper_index + 1))
    29. htmls.append('\n\n\n')
    30. htmls.append(chat_summary_text)
    其次,第二步:总结方法
    1. # 第二步:总结方法。
    2. # 由于有些文章的方法章节名是算法名,所以简单的通过关键词来筛选很难获取
    3. method_key = ''
    4. for parse_key in paper.section_text_dict.keys():
    5. if 'method' in parse_key.lower() or 'approach' in parse_key.lower():
    6. method_key = parse_key
    7. break
    8. # 如果找到方法关键词
    9. if method_key != '':
    10. text = ''
    11. method_text = ''
    12. summary_text = ''
    13. summary_text += "<summary>" + chat_summary_text
    14. method_text += paper.section_text_dict[method_key]
    15. text = summary_text + "\n\n<Methods>:\n\n" + method_text
    16. chat_method_text = ""
    17. try:
    18. chat_method_text = self.chat_method(text=text)
    19. except Exception as e:
    20. print("method_error:", e)
    21. import sys
    22. exc_type, exc_obj, exc_tb = sys.exc_info()
    23. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
    24. print(exc_type, fname, exc_tb.tb_lineno)
    25. if "maximum context" in str(e):
    26. current_tokens_index = str(e).find("your messages resulted in") + len(
    27. "your messages resulted in") + 1
    28. offset = int(str(e)[current_tokens_index:current_tokens_index + 4])
    29. method_prompt_token = offset + 800 + 150
    30. chat_method_text = self.chat_method(text=text, method_prompt_token=method_prompt_token)
    31. htmls.append(chat_method_text)
    32. else:
    33. chat_method_text = ''
    34. htmls.append("\n" * 4)
    最后,第三步:总结全文并打分
    1. # 第三步:总结全文并打分。
    2. conclusion_key = ''
    3. for parse_key in paper.section_text_dict.keys():
    4. if 'conclu' in parse_key.lower():
    5. conclusion_key = parse_key
    6. break
    7. text = ''
    8. conclusion_text = ''
    9. summary_text = ''
    10. summary_text += "<summary>" + chat_summary_text + "\n <Method summary>:\n" + chat_method_text
    11. if conclusion_key != '':
    12. conclusion_text += paper.section_text_dict[conclusion_key]
    13. text = summary_text + "\n\n<Conclusion>:\n\n" + conclusion_text
    14. else:
    15. text = summary_text
    16. chat_conclusion_text = ""
    17. try:
    18. chat_conclusion_text = self.chat_conclusion(text=text)
    19. except Exception as e:
    20. print("conclusion_error:", e)
    21. import sys
    22. exc_type, exc_obj, exc_tb = sys.exc_info()
    23. fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
    24. print(exc_type, fname, exc_tb.tb_lineno)
    25. if "maximum context" in str(e):
    26. current_tokens_index = str(e).find("your messages resulted in") + len(
    27. "your messages resulted in") + 1
    28. offset = int(str(e)[current_tokens_index:current_tokens_index + 4])
    29. conclusion_prompt_token = offset + 800 + 150
    30. chat_conclusion_text = self.chat_conclusion(text=text, conclusion_prompt_token=conclusion_prompt_token)
    31. htmls.append(chat_conclusion_text)
    32. htmls.append("\n" * 4)
    33. # 整合成一个文件并保存
    34. date_str = str(datetime.datetime.now())[:13].replace(' ', '-')
    35. export_path = os.path.join(self.root_path, 'export')
    36. if not os.path.exists(export_path):
    37. os.makedirs(export_path)
    38. mode = 'w' if paper_index == 0 else 'a'
    39. file_name = os.path.join(export_path,
    40. date_str + '-' + self.validateTitle(paper.title[:80]) + "." + self.file_format)
    41. self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode)
    42. htmls = []
  • chat_summary():第一次提取title,abs,和introduction,设定prompt通过调用API的方式得到对应的总结
    1. def chat_summary(self, text, summary_prompt_token=1100):
    2. # 设置OpenAI API密钥
    3. openai.api_key = self.chat_api_list[self.cur_api]
    4. # 更新API密钥索引,用于循环使用多个API密钥(如果有)
    5. self.cur_api += 1
    6. self.cur_api = 0 if self.cur_api >= len(self.chat_api_list) - 1 else self.cur_api
    7. # 计算输入文本的token数量
    8. text_token = len(self.encoding.encode(text))
    9. # 计算截断文本的索引,确保总的token数量不超过限制
    10. clip_text_index = int(len(text) * (self.max_token_num - summary_prompt_token) / text_token)
    11. # 获取截断后的文本
    12. clip_text = text[:clip_text_index]
    13. # 定义聊天机器人的交互消息
    14. messages = [
    15. {"role": "system",
    16. "content": "You are a researcher in the field of [" + self.key_word + "] who is good at summarizing papers using concise statements"},
    17. {"role": "assistant",
    18. "content": "This is the title, author, link, abstract and introduction of an English document. I need your help to read and summarize the following questions: " + clip_text},
    19. {"role": "user", "content": """
    20. ...(这部分是详细的指示内容,为了简洁我略过了)...
    21. """.format(self.language, self.language, self.language)},
    22. ]
    23. # 根据API类型调用相应的方法
    24. if openai.api_type == 'azure':
    25. response = openai.ChatCompletion.create(
    26. engine=self.chatgpt_model,
    27. messages=messages,
    28. )
    29. else:
    30. response = openai.ChatCompletion.create(
    31. model=self.chatgpt_model,
    32. messages=messages,
    33. )
    34. # 从响应中提取机器人的回复
    35. result = ''
    36. for choice in response.choices:
    37. result += choice.message.content
    38. # 打印结果和使用的token数量以及响应时间
    39. print("summary_result:\n", result)
    40. print("prompt_token_used:", response.usage.prompt_tokens,
    41. "completion_token_used:", response.usage.completion_tokens,
    42. "total_token_used:", response.usage.total_tokens)
    43. print("response_time:", response.response_ms / 1000.0, 's')
    44. # 返回结果
    45. return result
  • chat_method():提取上面chat_summary()得到的结果,加上method或approach部分的内容,设定prompt通过调用API的方式得到对应的总结
    1. def chat_method(self, text, method_prompt_token=800):
    2. # 设置OpenAI的API key
    3. openai.api_key = self.chat_api_list[self.cur_api]
    4. # 将当前API索引递增,以便下次使用不同的API key
    5. self.cur_api += 1
    6. # 如果当前API索引超出API key列表的长度,则将其重置为0(实现循环使用API key列表)
    7. self.cur_api = 0 if self.cur_api >= len(self.chat_api_list) - 1 else self.cur_api
    8. # 使用encoding方法计算输入文本的token数量
    9. text_token = len(self.encoding.encode(text))
    10. # 根据最大token数量和方法提示token计算需要裁剪的文本长度
    11. clip_text_index = int(len(text) * (self.max_token_num - method_prompt_token) / text_token)
    12. # 根据上面计算的索引裁剪文本
    13. clip_text = text[:clip_text_index]
    14. # 定义要发送到ChatGPT的消息列表
    15. messages = [
    16. # 定义系统角色的消息,描述用户的专业背景和能力
    17. {"role": "system", "content": "You are a researcher in the field of [" + self.key_word + "] who is good at summarizing papers using concise statements"},
    18. # 定义助手角色的消息,描述要助手完成的任务
    19. {"role": "assistant", "content": "This is the <summary> and <Method> part of an English document, where <summary> you have summarized, but the <Methods> part, I need your help to read and summarize the following questions." + clip_text},
    20. # 定义用户角色的消息,给出具体的问题和期望格式
    21. {"role": "user", "content": """
    22. 7. Describe in detail the methodological idea of this article. Be sure to use {} answers (proper nouns need to be marked in English). For example, its steps are.
    23. - (1):...
    24. - (2):...
    25. - (3):...
    26. - .......
    27. Follow the format of the output that follows:
    28. 7. Methods: \n\n
    29. - (1):xxx;\n
    30. - (2):xxx;\n
    31. - (3):xxx;\n
    32. ....... \n\n
    33. Be sure to use {} answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
    34. """.format(self.language, self.language)},
    35. ]
    36. # 根据API类型选择适当的调用方法
    37. if openai.api_type == 'azure':
    38. response = openai.ChatCompletion.create(
    39. engine=self.chatgpt_model,
    40. messages=messages,
    41. )
    42. else:
    43. response = openai.ChatCompletion.create(
    44. model=self.chatgpt_model,
    45. messages=messages,
    46. )
    47. # 从返回的答案中初始化一个空字符串用于保存结果
    48. result = ''
    49. # 遍历返回的选择,将内容添加到结果字符串中
    50. for choice in response.choices:
    51. result += choice.message.content
    52. # 打印方法的结果和相关的token使用情况
    53. print("method_result:\n", result)
    54. print("prompt_token_used:", response.usage.prompt_tokens,
    55. "completion_token_used:", response.usage.completion_tokens,
    56. "total_token_used:", response.usage.total_tokens)
    57. # 打印响应时间
    58. print("response_time:", response.response_ms / 1000.0, 's')
    59. # 返回结果字符串
    60. return result
  • chat_conclusion():提取上面两部分:chat_summary()chat_method()得到的结果(API给的回复),加上conclusion部分的内容,设定prompt通过调用API的方式得到对应的总结
    1. def chat_conclusion(self, text, conclusion_prompt_token=800):
    2. # 设置OpenAI的API密钥
    3. openai.api_key = self.chat_api_list[self.cur_api]
    4. # 使当前API索引递增,以便下次使用不同的API密钥
    5. self.cur_api += 1
    6. # 如果当前API索引超过API密钥列表的长度,将其重置为0
    7. self.cur_api = 0 if self.cur_api >= len(self.chat_api_list) - 1 else self.cur_api
    8. # 使用encoding方法计算输入文本的token数量
    9. text_token = len(self.encoding.encode(text))
    10. # 计算需要裁剪的文本长度,以适应模型的最大token限制
    11. clip_text_index = int(len(text) * (self.max_token_num - conclusion_prompt_token) / text_token)
    12. # 裁剪文本
    13. clip_text = text[:clip_text_index]
    14. # 定义要发送给ChatGPT的消息列表
    15. messages = [
    16. # 系统角色的消息,描述用户作为一个审稿人的背景
    17. {"role": "system", "content": "You are a reviewer in the field of [" + self.key_word + "] and you need to critically review this article"},
    18. # 助手角色的消息,描述要助手完成的任务
    19. {"role": "assistant", "content": "This is the <summary> and <conclusion> part of an English literature, where <summary> you have already summarized, but <conclusion> part, I need your help to summarize the following questions:" + clip_text},
    20. # 用户角色的消息,提供具体问题和预期的答案格式
    21. {"role": "user", "content": """
    22. 8. Make the following summary.Be sure to use {} answers (proper nouns need to be marked in English).
    23. - (1):What is the significance of this piece of work?
    24. - (2):Summarize the strengths and weaknesses of this article in three dimensions: innovation point, performance, and workload.
    25. .......
    26. Follow the format of the output later:
    27. 8. Conclusion: \n\n
    28. - (1):xxx;\n
    29. - (2):Innovation point: xxx; Performance: xxx; Workload: xxx;\n
    30. Be sure to use {} answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
    31. """.format(self.language, self.language)},
    32. ]
    33. # 根据API类型选择适当的方法来获取模型的答案
    34. if openai.api_type == 'azure':
    35. response = openai.ChatCompletion.create(
    36. engine=self.chatgpt_model,
    37. messages=messages,
    38. )
    39. else:
    40. response = openai.ChatCompletion.create(
    41. model=self.chatgpt_model,
    42. messages=messages,
    43. )
    44. # 初始化结果字符串
    45. result = ''
    46. # 遍历模型返回的答案,将其添加到结果字符串中
    47. for choice in response.choices:
    48. result += choice.message.content
    49. # 打印结论部分的结果和token使用情况
    50. print("conclusion_result:\n", result)
    51. print("prompt_token_used:", response.usage.prompt_tokens,
    52. "completion_token_used:", response.usage.completion_tokens,
    53. "total_token_used:", response.usage.total_tokens)
    54. # 打印响应时间
    55. print("response_time:", response.response_ms / 1000.0, 's')
    56. # 返回结果字符串
    57. return result

1.4.3 chat_paper_main

// 待更

1.5 RUN一下:ChatPaper代码整体运行后得到的部分结果

chatpaper代码运行后得到的部分结果 输出:标题、作者、单位、 关键词、相关链接及 Summary。其中

  • Summary为总结 得到的摘要

  • method_result:对论文方法(method或approach)的总结

  • Conclusion_result:对论文全文的总结(包含工作意义及创新点等)

//待更

第二部分 gpt_academic源码解读

// 待更

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/325100
推荐阅读
相关标签
  

闽ICP备14008679号