当前位置:   article > 正文

GraphRAG + Ollama 本地部署全攻略:避坑实战指南

graphrag ollama

▼最近直播超级多,预约保你有收获

1

为什么要对 GraphRAG 本地部署?

微软开源 GraphRAG 后,热度越来越高,目前 GraphRAG 只支持 OpenAI 的闭源大模型,导致部署后使用范围大大受限,本文通过 GraphRAG 源码的修改,来支持更广泛的 Embedding 模型和开源大模型,从而使得 GraphRAG 的更容易上手使用。

c81d53c1212068852bcb0242badc6f39.png

如果对 GrapRAG 还不太熟悉的同学,可以看我之前写的两篇文章 《微软重磅开源 GraphRAG:新一代 RAG 技术来了!》 和《GraphRAG 项目升级!现已支持 Ollama 本地模型接入,打造交互式 UI 体验

 2

GraphRAG 一键安装

第一步、安装 GraphRAG

511881ec4cfbcf03027f50198fac030c.png

需要 Python 3.10-3.12 环境。

第二步、创建知识数据文件夹

安装完整后,需要创建一个文件夹,用来存储你的知识数据,目前 GraphRAG 只支持 txt 和 csv 格式。

52ae342daea1d44192d89739ded0db67.png

第三步、准备一份数据放在 /ragtest/input 目录下

66642ea0e01bea61cd62fe525e6872d8.png

第四步、初始化工作区

首先,我们需要运行以下命令来初始化。

b222be3a86e65b0d36e4d0ae842decde.png

其次,我们第二步已经准备了 ragtest 目录,运行以下命令完成初始化。

70bb41fa0fd255ff65343de782165d6b.png

运行完成后,在 ragtest 目录下生成以下两个文件:.env 和settings.yaml。ragtest 目录下的结构如下:

47407ed859620be22a3e4c86b224aa85.png

.env 文件包含了运行 GraphRAG 管道所需的环境变量。如果您检查该文件,您会看到一个定义的环境变量,GRAPHRAG_API_KEY=<API_KEY>。这是 OpenAI API 或 Azure OpenAI 端点的 API 密钥。您可以用自己的 API 密钥替换它。

settings.yaml 文件包含了管道的设置。您可以修改此文件以更改管道的设置。

 3

修改配置文件支持本地部署大模型

第一步、确保已安装 Ollama 

如果你还没安装或者不会安装,可以参考我之前写的文章《Spring AI + Ollama 快速构建大模型应用程序(含源码)》。

第二步、确保已安装以下本地模型

  1. Embedding 嵌入模型
  2. quentinz/bge-large-zh-v1.5:latest
  3. LLM 大模型
  4. gemma2:9b

第三步、修改 settings.yaml 以支持以上两个本地模型,以下是修改后的文件

  1. encoding_model: cl100k_base
  2. skip_workflows: []
  3. llm:
  4. api_key: ollama
  5. type: openai_chat # or azure_openai_chat
  6. model: gemma2:9b # 你 ollama 中的本地 llm 模型,可以换成其他的,只要你安装了就可以
  7. model_supports_json: true # recommended if this is available for your model.
  8. max_tokens: 2048
  9.   api_base: http://localhost:11434/v1 # 接口注意是v1
  10. concurrent_requests: 1 # the number of parallel inflight requests that may be made
  11. parallelization:
  12. stagger: 0.3
  13. async_mode: threaded # or asyncio
  14. embeddings:
  15. async_mode: threaded # or asyncio
  16. llm:
  17. api_key: ollama
  18. type: openai_embedding # or azure_openai_embedding
  19.     model: quentinz/bge-large-zh-v1.5:latest # 你 ollama 中的本地 Embeding 模型,可以换成其他的,只要你安装了就可以
  20.     api_base: http://localhost:11434/api # 注意是 api
  21.     concurrent_requests: 1 # the number of parallel inflight requests that may be made
  22. chunks:
  23. size: 300
  24. overlap: 100
  25. group_by_columns: [id] # by default, we don't allow chunks to cross documents
  26. input:
  27. type: file # or blob
  28. file_type: text # or csv
  29. base_dir: "input"
  30. file_encoding: utf-8
  31. file_pattern: ".*\\.txt$"
  32. cache:
  33. type: file # or blob
  34.   base_dir: "cache"
  35. storage:
  36. type: file # or blob
  37.   base_dir: "output/${timestamp}/artifacts"
  38. reporting:
  39. type: file # or console, blob
  40.   base_dir: "output/${timestamp}/reports"
  41. entity_extraction:
  42. prompt: "prompts/entity_extraction.txt"
  43. entity_types: [organization,person,geo,event]
  44. max_gleanings: 0
  45. summarize_descriptions:
  46. prompt: "prompts/summarize_descriptions.txt"
  47. max_length: 500
  48. claim_extraction:
  49. prompt: "prompts/claim_extraction.txt"
  50. description: "Any claims or facts that could be relevant to information discovery."
  51. max_gleanings: 0
  52. community_report:
  53. prompt: "prompts/community_report.txt"
  54. max_length: 2000
  55. max_input_length: 8000
  56. cluster_graph:
  57. max_cluster_size: 10
  58. embed_graph:
  59.   enabled: false # if true, will generate node2vec embeddings for nodes
  60. umap:
  61. enabled: false # if true, will generate UMAP embeddings for nodes
  62. snapshots:
  63. graphml: false
  64. raw_entities: false
  65. top_level_nodes: false
  66. local_search:
  67. max_tokens: 5000
  68. global_search:
  69.   max_tokens: 5000

第四步、运行 GraphRAG 构建知识图谱索引

79ead107998a5b9f84bf437bb7b9dfb4.png

构建知识图谱的索引需要一定的时间,构建过程如下所示:

357c69003e7fc9c2369c53553e8e8c5f.png

 4

修改源码支持本地部署大模型

接下来修改源码,保证进行 local 和 global 查询时给出正确的结果。

第一步、修改成本地的 Embedding 模型

修改源代码的目录和文件:

.../Python/Python310/site-packages/graphrag/llm/openai/openai_embeddings_llm.py"

修改后的源码如下:

  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """The EmbeddingsLLM class."""
  4. from typing_extensions import Unpack
  5. from graphrag.llm.base import BaseLLM
  6. from graphrag.llm.types import (
  7. EmbeddingInput,
  8. EmbeddingOutput,
  9. LLMInput,
  10. )
  11. from .openai_configuration import OpenAIConfiguration
  12. from .types import OpenAIClientTypes
  13. import ollama
  14. class OpenAIEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]):
  15. """A text-embedding generator LLM."""
  16. _client: OpenAIClientTypes
  17. _configuration: OpenAIConfiguration
  18. def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration):
  19. self.client = client
  20. self.configuration = configuration
  21. async def _execute_llm(
  22. self, input: EmbeddingInput, **kwargs: Unpack[LLMInput]
  23. ) -> EmbeddingOutput | None:
  24. args = {
  25. "model": self.configuration.model,
  26. **(kwargs.get("model_parameters") or {}),
  27. }
  28. embedding_list = []
  29. for inp in input:
  30. embedding = ollama.embeddings(model="quentinz/bge-large-zh-v1.5:latest",prompt=inp)
  31. embedding_list.append(embedding["embedding"])
  32. return embedding_list
  33. # embedding = await self.client.embeddings.create(
  34. # input=input,
  35. # **args,
  36. # )
  37. # return [d.embedding for d in embedding.data]

第二步、继续修改 Embedding 模型

修改源代码的目录和文件:

.../Python/Python310/site-packages/graphrag/query/llm/oai/embedding.py"

修改后的源码如下:

  1. # Copyright (c) 2024 Microsoft Corporation.
  2. # Licensed under the MIT License
  3. """OpenAI Embedding model implementation."""
  4. import asyncio
  5. from collections.abc import Callable
  6. from typing import Any
  7. import numpy as np
  8. import tiktoken
  9. from tenacity import (
  10. AsyncRetrying,
  11. RetryError,
  12. Retrying,
  13. retry_if_exception_type,
  14. stop_after_attempt,
  15. wait_exponential_jitter,
  16. )
  17. from graphrag.query.llm.base import BaseTextEmbedding
  18. from graphrag.query.llm.oai.base import OpenAILLMImpl
  19. from graphrag.query.llm.oai.typing import (
  20. OPENAI_RETRY_ERROR_TYPES,
  21. OpenaiApiType,
  22. )
  23. from graphrag.query.llm.text_utils import chunk_text
  24. from graphrag.query.progress import StatusReporter
  25. from langchain_community.embeddings import OllamaEmbeddings
  26. class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
  27. """Wrapper for OpenAI Embedding models."""
  28. def __init__(
  29. self,
  30. api_key: str | None = None,
  31. azure_ad_token_provider: Callable | None = None,
  32. model: str = "text-embedding-3-small",
  33. deployment_name: str | None = None,
  34. api_base: str | None = None,
  35. api_version: str | None = None,
  36. api_type: OpenaiApiType = OpenaiApiType.OpenAI,
  37. organization: str | None = None,
  38. encoding_name: str = "cl100k_base",
  39. max_tokens: int = 8191,
  40. max_retries: int = 10,
  41. request_timeout: float = 180.0,
  42. retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
  43. reporter: StatusReporter | None = None,
  44. ):
  45. OpenAILLMImpl.__init__(
  46. self=self,
  47. api_key=api_key,
  48. azure_ad_token_provider=azure_ad_token_provider,
  49. deployment_name=deployment_name,
  50. api_base=api_base,
  51. api_version=api_version,
  52. api_type=api_type, # type: ignore
  53. organization=organization,
  54. max_retries=max_retries,
  55. request_timeout=request_timeout,
  56. reporter=reporter,
  57. )
  58. self.model = model
  59. self.encoding_name = encoding_name
  60. self.max_tokens = max_tokens
  61. self.token_encoder = tiktoken.get_encoding(self.encoding_name)
  62. self.retry_error_types = retry_error_types
  63. def embed(self, text: str, **kwargs: Any) -> list[float]:
  64. """
  65. Embed text using OpenAI Embedding's sync function.
  66. For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
  67. Please refer to: https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
  68. """
  69. token_chunks = chunk_text(
  70. text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
  71. )
  72. chunk_embeddings = []
  73. chunk_lens = []
  74. for chunk in token_chunks:
  75. try:
  76. embedding, chunk_len = self._embed_with_retry(chunk, **kwargs)
  77. chunk_embeddings.append(embedding)
  78. chunk_lens.append(chunk_len)
  79. # TODO: catch a more specific exception
  80. except Exception as e: # noqa BLE001
  81. self._reporter.error(
  82. message="Error embedding chunk",
  83. details={self.__class__.__name__: str(e)},
  84. )
  85. continue
  86. chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
  87. chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
  88. return chunk_embeddings.tolist()
  89. async def aembed(self, text: str, **kwargs: Any) -> list[float]:
  90. """
  91. Embed text using OpenAI Embedding's async function.
  92. For text longer than max_tokens, chunk texts into max_tokens, embed each chunk, then combine using weighted average.
  93. """
  94. token_chunks = chunk_text(
  95. text=text, token_encoder=self.token_encoder, max_tokens=self.max_tokens
  96. )
  97. chunk_embeddings = []
  98. chunk_lens = []
  99. embedding_results = await asyncio.gather(*[
  100. self._aembed_with_retry(chunk, **kwargs) for chunk in token_chunks
  101. ])
  102. embedding_results = [result for result in embedding_results if result[0]]
  103. chunk_embeddings = [result[0] for result in embedding_results]
  104. chunk_lens = [result[1] for result in embedding_results]
  105. chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) # type: ignore
  106. chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)
  107. return chunk_embeddings.tolist()
  108. def _embed_with_retry(
  109. self, text: str | tuple, **kwargs: Any
  110. ) -> tuple[list[float], int]:
  111. try:
  112. retryer = Retrying(
  113. stop=stop_after_attempt(self.max_retries),
  114. wait=wait_exponential_jitter(max=10),
  115. reraise=True,
  116. retry=retry_if_exception_type(self.retry_error_types),
  117. )
  118. for attempt in retryer:
  119. with attempt:
  120. embedding = (
  121. OllamaEmbeddings(
  122. model=self.model,
  123. ).embed_query(text)
  124. or []
  125. )
  126. return (embedding, len(text))
  127. except RetryError as e:
  128. self._reporter.error(
  129. message="Error at embed_with_retry()",
  130. details={self.__class__.__name__: str(e)},
  131. )
  132. return ([], 0)
  133. else:
  134. # TODO: why not just throw in this case?
  135. return ([], 0)
  136. async def _aembed_with_retry(
  137. self, text: str | tuple, **kwargs: Any
  138. ) -> tuple[list[float], int]:
  139. try:
  140. retryer = AsyncRetrying(
  141. stop=stop_after_attempt(self.max_retries),
  142. wait=wait_exponential_jitter(max=10),
  143. reraise=True,
  144. retry=retry_if_exception_type(self.retry_error_types),
  145. )
  146. async for attempt in retryer:
  147. with attempt:
  148. embedding = (
  149. await OllamaEmbeddings(
  150. model=self.model,
  151. ).embed_query(text) or [] )
  152. return (embedding, len(text))
  153. except RetryError as e:
  154. self._reporter.error(
  155. message="Error at embed_with_retry()",
  156. details={self.__class__.__name__: str(e)},
  157. )
  158. return ([], 0)
  159. else:
  160. # TODO: why not just throw in this case?
  161.             return ([], 0)

 5

GraphRAG 效果测试

第一、local 查询

ad6b5156d22a464dd72b75f842d3043c.png

f58708745187a0f824ee61dc4169e815.png

第二、global 查询

0cccdec188a87b765a85a966bdfe8496.png

69a9fff05a19b371522023215b343075.png

6

加我微信

有很多企业级落地实战案例,不方便公开发公众号,我会直接分享在朋友圈欢迎你扫码加我个人微信来看

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/894418
推荐阅读
相关标签