当前位置:   article > 正文

LangChain - 回调函数_langchain回调函数

langchain回调函数

LangChain - 回调函数


回调函数概览

LangChain提供了一个回调函数系统,允许您在LLM应用程序的 各个阶段进行钩子操作

这对于日志记录、监控、流式传输和其他任务非常有用。

您可以使用 API 中的 callbacks 参数订阅这些事件。

该参数是一个处理程序对象列表,这些对象应该详细实现下面描述的一个或多个方法。

回调处理程序

CallbackHandlers 是实现 CallbackHandler 接口的对象,每个事件都可以订阅一个方法。

当触发事件时,CallbackManager 会调用每个处理程序上的适当方法。

—python class BaseCallbackHandler: “”“Base callback handler that can be used to handle callbacks from langchain.”“”

def on_llm_start(
    self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
    """Run when LLM starts running."""

def on_chat_model_start(
    self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any
) -> Any:
    """Run when Chat Model starts running."""

def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
    """Run on new LLM token. Only available when streaming is enabled."""

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
    """Run when LLM ends running."""

def on_llm_error(
    self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
    """Run when LLM errors."""

def on_chain_start(
    self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
    """Run when chain starts running."""

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
    """Run when chain ends running."""

def on_chain_error(
    self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
    """Run when chain errors."""

def on_tool_start(
    self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
    """Run when tool starts running."""

def on_tool_end(self, output: str, **kwargs: Any) -> Any:
    """Run when tool ends running."""

def on_tool_error(
    self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
    """Run when tool errors."""

def on_text(self, text: str, **kwargs: Any) -> Any:
    """Run on arbitrary text."""

def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
    """Run on agent action."""

def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
    """Run on agent end."""
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

基本使用 StdOutCallbackHandler

LangChain 提供了一些内置的处理程序,您可以使用它们进行入门。
这些处理程序在 langchain/callbacks 模块中可用。

最基本的处理程序是 StdOutCallbackHandler,它只是将所有事件记录到 stdout

注意 当对象上的 verbose 标志设置为 true 时,即使没有显式传递,StdOutCallbackHandler 也会被调用。


from langchain.callbacks import StdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

handler = StdOutCallbackHandler()
llm = OpenAI()
prompt = PromptTemplate.from_template("1 + {number} = ")

# Constructor callback: First, let's explicitly set the StdOutCallbackHandler when initializing our chain
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
chain.run(number=2)

# Use verbose flag: Then, let's use the `verbose` flag to achieve the same result
chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
chain.run(number=2)

# Request callbacks: Finally, let's use the request `callbacks` to achieve the same result
chain = LLMChain(llm=llm, prompt=prompt)
chain.run(number=2, callbacks=[handler])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

    > Entering new LLMChain chain...
    Prompt after formatting:
    1 + 2 = 
    
    > Finished chain.
    
    
    > Entering new LLMChain chain...
    Prompt after formatting:
    1 + 2 = 
    
    > Finished chain.
    
    
    > Entering new LLMChain chain...
    Prompt after formatting:
    1 + 2 = 
    
    > Finished chain.


    '\n\n3'
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

在哪里传递回调

callbacks 参数在 API 的大多数对象(Chains、Models、Tools、Agents 等)中都可用,有两个不同的位置:

  • 构造函数回调:在构造函数中定义
    例如 LLMChain(callbacks=[handler], tags=['a-tag']),它将用于该对象上的所有调用,并仅限于该对象的范围。
    例如,如果您将处理程序传递给 LLMChain 构造函数,它将不会被附加到该链上的模型使用。
  • 请求回调:在发出请求的 call()/run()/apply() 方法中定义
    例如 chain.call(inputs, callbacks=[handler]),它仅用于该特定请求以及它包含的所有子请求
    (例如,对 LLMChain 的调用触发对模型的调用,模型使用在 call() 方法中传递的相同处理程序)。

verbose 参数在 API 的大多数对象(Chains、Models、Tools、Agents 等)中都可用作构造函数参数。
例如 LLMChain(verbose=True),它等效于将 ConsoleCallbackHandler 传递给该对象及其所有子对象的 callbacks 参数。
这对于调试非常有用,因为它会将所有事件记录到控制台。


在什么情况下使用这些选项?
  • 构造函数回调最适用于记录、监视等与单个请求无关的用例。
    例如,如果您想记录对 LLMChain 的所有请求,您可以将处理程序传递给构造函数。
  • 请求回调最适用于流式传输等用例,其中您希望将单个请求的输出流式传输到特定的 WebSocket 连接或其他类似用例。
    例如,如果您想将单个请求的输出流式传输到 WebSocket,您可以将处理程序传递给 call() 方法。

异步回调 (AsyncCallbackHandler

如果您计划使用异步API,则建议使用AsyncCallbackHandler以避免阻塞运行循环。

高级如果您在运行llm/chain/tool/agent时使用同步CallbackHandler同时使用异步方法,它仍然可以工作。
但是,在底层,它将使用run_in_executor调用,如果您的CallbackHandler不是线程安全的,则可能会引发问题。

import asyncio
from typing import Any, Dict, List

from langchain.chat_models import ChatOpenAI
from langchain.schema import LLMResult, HumanMessage
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler


class MyCustomSyncHandler(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        print(f"Sync handler being called in a `thread_pool_executor`: token: {token}")


class MyCustomAsyncHandler(AsyncCallbackHandler):
    """用于处理来自langchain的回调的异步回调处理程序。"""

    async def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """当链条开始运行时运行。"""
        print("zzzz....")
        await asyncio.sleep(0.3)
        class_name = serialized["name"]
        print("嗨!我刚醒来。您的llm正在启动")

    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """当链条结束运行时运行。"""
        print("zzzz....")
        await asyncio.sleep(0.3)
        print("嗨!我刚醒来。您的llm正在结束")
 
# 为了启用流式传输,我们在ChatModel构造函数中传入`streaming=True`
# 此外,我们还传入一个包含自定义处理程序的列表
chat = ChatOpenAI(
    max_tokens=25,
    streaming=True,
    callbacks=[MyCustomSyncHandler(), 
    		MyCustomAsyncHandler()],
)

await chat.agenerate([[HumanMessage(content="给我讲个笑话")]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

自定义回调处理程序

您还可以创建一个自定义处理程序并将其设置在对象上。
在下面的示例中,我们将使用自定义处理程序实现流处理。

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage


class MyCustomHandler(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        print(f"我的自定义处理程序,token: {token}")

chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomHandler()])

chat([HumanMessage(content="给我讲一个笑话")])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

自定义链的回调函数

当您创建自定义链时,可以轻松地设置它使用与所有内置链相同的回调系统。
Chains / LLMs / Chat Models / Agents / Tools 上的 _call_generate_run和相应的异步方法现在接收第二个参数 run_manager,它绑定到该运行,并包含可以被该对象使用的日志记录方法(即 on_llm_new_token)。
在构建自定义链时,这非常有用。有关如何创建自定义链并在其中使用回调的更多信息请参阅此指南。


将日志记录到文件

此例显示了如何将日志记录到文件。
它展示了如何使用 FileCallbackHandler,它与 StdOutCallbackHandler 做的事情相同,但是将输出写入文件。
它还使用 loguru 库来记录处理程序未捕获的其他输出。

from loguru import logger

from langchain.callbacks import FileCallbackHandler
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate

logfile = "output.log"

logger.add(logfile, colorize=True, enqueue=True)
handler = FileCallbackHandler(logfile)

llm = OpenAI()
prompt = PromptTemplate.from_template("1 + {number} = ")

# this chain will both print to stdout (because verbose=True) and write to 'output.log'
# if verbose=False, the FileCallbackHandler will still write to 'output.log'
chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler], verbose=True)
answer = chain.run(number=2)
logger.info(answer)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

Now we can open the file output.log to see that the output has been captured.

!pip install ansi2html > /dev/null
  • 1
from IPython.display import display, HTML
from ansi2html import Ansi2HTMLConverter

with open("output.log", "r") as f:
    content = f.read()

conv = Ansi2HTMLConverter()
html = conv.convert(content, full=True)

display(HTML(html))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

多个回调处理程序

在前面的示例中,我们通过在创建对象时使用 callbacks= 来传递回调处理程序。
在这种情况下,回调处理程序将仅适用于该特定对象。

然而,在许多情况下,当运行对象时传递处理程序会更有优势。
当使用 callbacks 关键字参数通过 CallbackHandlers 传递时,这些回调处理程序将被所有参与执行的嵌套对象使用。
例如,当将处理程序传递给 Agent 时,它将用于与代理相关的所有回调以及代理执行中涉及的所有对象,例如 ToolsLLMChainLLM

这样,我们就不必手动将处理程序附加到每个嵌套对象上。

from typing import Dict, Union, Any, List

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.llms import OpenAI
 
# First, define custom callback handler implementations
class MyCustomHandlerOne(BaseCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        print(f"on_llm_start {serialized['name']}")

    def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
        print(f"on_new_token {token}")

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> Any:
        """Run when LLM errors."""

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> Any:
        print(f"on_chain_start {serialized['name']}")

    def on_tool_start(
        self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
    ) -> Any:
        print(f"on_tool_start {serialized['name']}")

    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        print(f"on_agent_action {action}")


class MyCustomHandlerTwo(BaseCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        print(f"on_llm_start (I'm the second handler!!) {serialized['name']}")

# Instantiate the handlers
handler1 = MyCustomHandlerOne()
handler2 = MyCustomHandlerTwo()

# Setup the agent. Only the `llm` will issue callbacks for handler2
llm = OpenAI(temperature=0, streaming=True, callbacks=[handler2])

tools = load_tools(["llm-math"], llm=llm)

agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION)

# Callbacks for handler1 will be issued by every object involved in the
# Agent execution (llm, llmchain, tool, agent executor)
agent.run("What is 2 raised to the 0.235 power?", callbacks=[handler1])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

标签 tags

你可以通过将 tags 参数传递给 call()/run()/apply() 方法来为回调函数添加标签。
这对于过滤日志非常有用,例如,如果您想记录对特定 LLMChain 所做的所有请求,您可以添加一个标签,然后通过该标签对日志进行筛选。
您可以将标签传递给构造函数和请求回调函数,详见上面的示例。
这些标签然后传递给 “start” 回调方法的 tags 参数,即 on_llm_starton_chat_model_starton_chain_starton_tool_start


Token counting

LangChain提供了一个上下文管理器,允许您计算标记数量。

import asyncio
from langchain.callbacks import get_openai_callback
from langchain.llms import OpenAI

llm = OpenAI(temperature=0)

with get_openai_callback() as cb:
    llm("What is the square root of 4?")

total_tokens = cb.total_tokens

assert total_tokens > 0

with get_openai_callback() as cb:
    llm("What is the square root of 4?")
    llm("What is the square root of 4?")

assert cb.total_tokens == total_tokens * 2

# You can kick off concurrent runs from within the context manager
with get_openai_callback() as cb:
    await asyncio.gather(
        *[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)]
    )

assert cb.total_tokens == total_tokens * 3

# The context manager is concurrency safe
task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"]))
with get_openai_callback() as cb:
    await llm.agenerate(["What is the square root of 4?"])

await task
assert cb.total_tokens == total_tokens
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

跟踪

有两种推荐的方法来跟踪您的 LangChains:

  1. LANGCHAIN_TRACING 环境变量设置为 "true"
  2. 使用上下文管理器 with tracing_enabled() 来跟踪特定的代码块。

注意:如果设置了环境变量,无论代码是否在上下文管理器中,所有代码都将被跟踪。

import os

from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.llms import OpenAI

# To run the code, make sure to set OPENAI_API_KEY and SERPAPI_API_KEY
llm = OpenAI(temperature=0)

tools = load_tools(["llm-math", "serpapi"], llm=llm)

agent = initialize_agent(
    tools, llm, 
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 
    verbose=True
)

questions = [
    "Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?",
    "Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?",
    "Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?",
    "Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?",
    "Who is Beyonce's husband? What is his age raised to the 0.19 power?",
]
 
os.environ["LANGCHAIN_TRACING"] = "true"

# Both of the agent runs will be traced because the environment variable is set
agent.run(questions[0])
with tracing_enabled() as session:
    assert session
    agent.run(questions[1])
 
# Now, we unset the environment variable and use a context manager.

if "LANGCHAIN_TRACING" in os.environ:
    del os.environ["LANGCHAIN_TRACING"]

# here, we are writing traces to "my_test_session"
with tracing_enabled("my_test_session") as session:
    assert session
    agent.run(questions[0])  # this should be traced

agent.run(questions[1])  # this should not be traced
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

import asyncio

# The context manager is concurrency safe:
if "LANGCHAIN_TRACING" in os.environ:
    del os.environ["LANGCHAIN_TRACING"]

# start a background task
task = asyncio.create_task(agent.arun(questions[0]))  # this should not be traced
with tracing_enabled() as session:
    assert session
    tasks = [agent.arun(q) for q in questions[1:3]]  # these should be traced
    await asyncio.gather(*tasks)

await task
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

集成

Argilla
https://python.langchain.com.cn/docs/modules/callbacks/integrations/argilla


2024-04-10(三) 日晕

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/853634
推荐阅读
相关标签
  

闽ICP备14008679号