解析 ‘LangChain Hooks’:如何在 Chain 的每一个生命周期节点(Start/End/Error)注入自定义埋点?

LangChain Hooks:在 Chain 生命周期节点注入自定义埋点

随着大型语言模型(LLM)应用的日益普及,构建基于LLM的复杂系统已成为常态。LangChain作为这些系统的强大编排框架,通过将LLM、工具、检索器等组件组合成链(Chain)或代理(Agent),极大地简化了开发过程。然而,仅仅构建出功能正常的应用是不够的;为了确保应用的稳定性、性能、成本效益以及用户体验,深入的监控和可观测性至关重要。

我们作为编程专家,深知在一个生产系统中,了解“发生了什么”、“何时发生”、“为什么发生”以及“花费了多少”是进行调试、优化和决策的基础。在LangChain的世界里,这意味着我们需要在Chain、LLM、工具等组件的每一次调用中,捕获关键的运行时信息。

LangChain为此提供了一套强大而灵活的机制:回调(Callbacks)。这些回调可以被视为“钩子(Hooks)”,允许我们在LangChain组件执行的特定生命周期节点(例如开始、结束、错误)注入自定义逻辑。本文将深入探讨LangChain的Callbacks机制,特别是如何利用它们在Chain的每一个生命周期节点(Start/End/Error)注入自定义埋点,以实现强大的监控和追踪能力。

1. 为什么需要 LangChain Hooks (Callbacks)?

在一个典型的LangChain应用中,一个用户请求可能触发一个复杂的Chain,这个Chain可能依次调用多个LLM、使用多个外部工具,甚至进行多次数据库查询。如果没有适当的监控,我们可能会面临以下挑战:

  • 性能瓶颈识别困难:哪个步骤耗时最长?是LLM调用慢,还是外部API响应慢?
  • 成本控制缺乏依据:一次用户请求究竟消耗了多少LLM tokens?哪些Chain的成本最高?
  • 错误排查复杂:当Chain执行失败时,是哪个环节出了问题?错误消息是什么?
  • 用户体验分析缺失:用户请求的平均响应时间是多少?不同类型的请求性能如何?
  • 合规性与审计:谁在何时调用了什么Chain,输入是什么,输出是什么?
  • A/B测试与实验追踪:在不同的Chain配置或LLM参数下,性能指标如何变化?

LangChain的回调机制正是为了解决这些问题而设计的。它提供了一系列预定义的生命周期事件,让我们能够像监听器一样,在这些事件发生时执行自定义代码,从而捕获并处理所需的埋点数据。

2. LangChain Callbacks 核心机制:BaseCallbackHandler

LangChain中的所有回调处理器都必须继承自 langchain.callbacks.base.BaseCallbackHandlerlangchain.callbacks.base.AsyncCallbackHandler。这些基类定义了一系列可重写(Override)的方法,对应着LangChain组件的不同生命周期事件。

我们将重点关注与Chain生命周期相关的核心方法,但为了完整性,也会简要提及其他重要方法。

2.1 BaseCallbackHandler 的核心方法概览

BaseCallbackHandler 定义了以下主要回调方法:

| 方法名称 | 触发时机 | 参数 LangChain Hooks 是什么?
LangChain 中的回调机制,让我们可以捕 在LangChain中,回调(Callbacks)是一种强大的机制,允许开发者在LangChain组件(如Chain, LLM, Tool, Retriever等)执行的各个生命周期阶段插入自定义逻辑。这些自定义逻辑可以用于实现日志记录、性能监控、错误追踪、数据审计、用户行为分析等多种埋点需求。

我们可以将这些回调机制理解为“LangChain Hooks”。如同其他编程框架中的钩子,它们提供了一个在特定事件发生时执行我们代码的机会,而无需修改LangChain核心组件的内部实现。

2.2 回调方法的参数解析

理解回调方法接收的参数是正确实现埋点的关键。尽管不同的方法参数略有差异,但一些核心参数是通用的:

  • run_id (UUID): 每次LangChain组件调用都会生成一个唯一的运行ID。这是追踪单个操作或一系列相关操作(通过parent_run_id)的核心标识符。
  • parent_run_id (Optional[UUID]): 如果当前组件调用是另一个组件调用的子过程,那么此参数会包含父级运行的ID。这对于构建完整的调用链(Trace)至关重要。
  • tags (Optional[List[str]]): 开发者可以在调用组件时传入的自定义标签列表,用于对运行进行分类或过滤。
  • metadata (Optional[Dict[str, Any]]): 开发者可以传入的自定义元数据字典,包含更详细的键值对信息。
  • serialized (Dict[str, Any]): 组件的序列化表示,通常包含组件的类型、配置等信息。
  • inputs (Dict[str, Any]): 在 on_chain_start 等方法中,表示Chain接收到的输入。
  • outputs (Dict[str, Any]): 在 on_chain_end 等方法中,表示Chain执行完成后的输出。
  • error (Exception): 在 on_chain_error 等方法中,表示捕获到的异常对象。
  • **kwargs: 捕获所有其他未明确定义的关键字参数,以备未来扩展。

2.3 专注于 Chain 生命周期埋点

根据主题要求,我们将重点聚焦于 on_chain_starton_chain_endon_chain_error 这三个方法,它们分别对应Chain执行的开始、结束和错误发生时机。

  • *`on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], , run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, kwargs: Any) -> Any`
    • 触发时机:当一个Chain(或Runnable)的 invoke()stream() 方法被调用,开始执行时。
    • 用途:记录Chain的启动时间、输入参数、Chain类型、run_idparent_run_id 等,用于启动计时器或创建追踪上下文。
  • *`on_chain_end(self, outputs: Dict[str, Any], , run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, kwargs: Any) -> Any`
    • 触发时机:当一个Chain成功执行并返回结果时。
    • 用途:记录Chain的结束时间、输出结果、计算执行时长,并完成追踪上下文。
  • *`on_chain_error(self, error: Exception, , run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, kwargs: Any) -> Any`
    • 触发时机:当一个Chain在执行过程中抛出未捕获的异常时。
    • 用途:记录错误信息、堆栈追踪、Chain的失败状态,并将错误信息附加到追踪上下文。

3. 实现一个基础的自定义日志埋点器

最简单的埋点需求是将Chain的生命周期事件输出到控制台或日志文件。这有助于我们在开发和调试阶段快速了解Chain的执行流程。

我们将创建一个名为 SimpleConsoleLogger 的回调处理器,它将打印Chain的开始、结束和错误信息。

import time
import json
from uuid import UUID
from typing import Any, Dict, List, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.llms import FakeListLLM
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser

# 1. 定义一个简单的控制台日志回调处理器
class SimpleConsoleLogger(BaseCallbackHandler):
    """
    一个简单的LangChain回调处理器,用于将Chain的生命周期事件输出到控制台。
    它会记录Chain的启动、结束和错误信息,并计算执行时长。
    """

    def __init__(self):
        # 用于存储每个run_id的开始时间,以便在结束时计算时长
        self._start_times: Dict[UUID, float] = {}

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """
        在Chain开始执行时调用。
        记录Chain的类型、输入、run_id,并记录当前时间作为开始时间。
        """
        chain_name = serialized.get("lc_kwargs", {}).get("name", serialized.get("lc_id", ["Unknown"])[0])
        print(f"n--- [Chain Start] ---")
        print(f"  Run ID: {run_id}")
        print(f"  Parent Run ID: {parent_run_id}")
        print(f"  Chain Name: {chain_name}")
        print(f"  Inputs: {json.dumps(inputs, indent=2, ensure_ascii=False)}")
        if tags:
            print(f"  Tags: {tags}")
        if metadata:
            print(f"  Metadata: {metadata}")
        self._start_times[run_id] = time.time()
        print(f"----------------------")

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """
        在Chain成功执行完成时调用。
        记录Chain的输出、run_id,并计算并打印执行时长。
        """
        end_time = time.time()
        start_time = self._start_times.pop(run_id, None)
        duration_seconds = f"{end_time - start_time:.4f}" if start_time else "N/A"

        print(f"n--- [Chain End] ---")
        print(f"  Run ID: {run_id}")
        print(f"  Outputs: {json.dumps(outputs, indent=2, ensure_ascii=False)}")
        print(f"  Duration: {duration_seconds} seconds")
        print(f"---------------------")

    def on_chain_error(
        self,
        error: Exception,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """
        在Chain执行过程中发生错误时调用。
        记录错误信息、run_id,并计算并打印执行时长。
        """
        end_time = time.time()
        start_time = self._start_times.pop(run_id, None)
        duration_seconds = f"{end_time - start_time:.4f}" if start_time else "N/A"

        print(f"n--- [Chain Error] ---")
        print(f"  Run ID: {run_id}")
        print(f"  Error Type: {type(error).__name__}")
        print(f"  Error Message: {error}")
        print(f"  Duration: {duration_seconds} seconds")
        print(f"-----------------------")

# 2. 设置一个模拟的LLM和Chain进行演示
# FakeListLLM是一个虚拟的LLM,它会从预设的响应列表中返回结果
llm = FakeListLLM(responses=["创新科技公司", "智慧未来有限公司", "AI驱动解决方案"])
prompt = PromptTemplate.from_template("请为一家生产{product}的公司起一个好名字。")
demo_chain = LLMChain(llm=llm, prompt=prompt)

# 3. 运行Chain并集成回调处理器
print("--- 演示1: 正常执行的Chain ---")
# 通过 config 参数传入 callbacks 列表
demo_chain.invoke(
    {"product": "智能家居设备"},
    config={
        "callbacks": [SimpleConsoleLogger()],
        "tags": ["命名服务", "产品营销"],
        "metadata": {"user_id": "user_123", "request_type": "company_name"},
    },
)

# 4. 演示Chain错误的回调
# 为了演示错误,我们创建一个特殊的Runnable,它会在特定条件下抛出异常
def failing_function(input_data: Dict[str, Any]) -> Dict[str, Any]:
    """一个模拟会失败的函数。"""
    if "fail_test" in input_data.get("product", "").lower():
        raise ValueError("模拟Chain执行失败:产品名称包含'fail_test'。")
    return {"output": f"成功处理产品:{input_data['product']}"}

# 使用RunnableLambda将Python函数包装成LangChain的Runnable
# RunnablePassthrough用于传递输入,RunnableLambda执行函数,StrOutputParser处理输出
failing_chain = RunnablePassthrough.assign(
    output=RunnableLambda(failing_function) | StrOutputParser()
)

print("n--- 演示2: 模拟Chain执行错误 ---")
try:
    failing_chain.invoke(
        {"product": "这是一个fail_test产品"},
        config={"callbacks": [SimpleConsoleLogger()], "tags": ["错误测试"]},
    )
except ValueError as e:
    print(f"n捕获到预期错误:{e}")

# 5. 演示一个嵌套的Chain
from langchain.chains import SimpleSequentialChain

llm_2 = FakeListLLM(responses=["智能家居设备", "物联网解决方案", "绿色能源系统"])
prompt_2 = PromptTemplate.from_template("生成三个与{topic}相关的流行产品类别。")
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)

llm_3 = FakeListLLM(responses=["这些类别很有前景。", "市场潜力巨大。", "值得深入研究。"])
prompt_3 = PromptTemplate.from_template("分析以下产品类别:{categories}。")
chain_3 = LLMChain(llm=llm_3, prompt=prompt_3)

# 这是一个顺序链,其中一个链的输出是另一个链的输入
sequential_chain = SimpleSequentialChain(chains=[chain_2, chain_3], verbose=False)

print("n--- 演示3: 嵌套Chain的日志 (父子run_id) ---")
# 传入一个回调处理器,观察父子run_id的关联
sequential_chain.invoke(
    {"topic": "未来科技"},
    config={"callbacks": [SimpleConsoleLogger()], "tags": ["嵌套Chain"]},
)

代码解析:

  1. SimpleConsoleLogger

    • 继承自 BaseCallbackHandler
    • _start_times 字典用于存储每个 run_id 对应的 Chain 启动时间,以便在 on_chain_endon_chain_error 中计算耗时。
    • on_chain_start:在 Chain 开始时被调用。它打印 Chain 的 run_id (唯一标识符)、parent_run_id (如果存在,用于追踪嵌套调用)、Chain 名称(从 serialized 中提取)、输入参数 inputs 以及自定义的 tagsmetadata。同时,记录当前时间。
    • on_chain_end:在 Chain 成功完成时被调用。它打印 Chain 的 run_id、输出结果 outputs,并计算从 _start_times 中获取的开始时间到当前结束时间的持续时长。
    • on_chain_error:在 Chain 抛出异常时被调用。它打印 Chain 的 run_id、错误类型和错误消息,同样计算并打印持续时长。
  2. Chain 设置

    • FakeListLLM 是一个模拟的 LLM,用于在不实际调用外部服务的情况下测试 Chain。
    • PromptTemplate 定义了 LLM 的输入格式。
    • LLMChain 是一个简单的链,将 Prompt 和 LLM 组合起来。
  3. 集成回调

    • 在调用 chain.invoke()chain.run() 时,可以通过 config 参数传入一个 callbacks 列表。列表中可以包含一个或多个 BaseCallbackHandler 实例。
    • tagsmetadata 也可以通过 config 参数传递,它们会在回调方法中被接收到,用于为埋点数据添加更多上下文信息。
  4. 错误演示

    • 为了演示 on_chain_error,我们创建了一个 failing_function,它会在输入包含特定关键词时故意抛出 ValueError
    • RunnableLambdaRunnablePassthrough 是 LangChain 表达式语言(LCEL)的一部分,它们允许我们将普通 Python 函数集成到 Chain 中,并捕获其错误。
  5. 嵌套 Chain 演示

    • SimpleSequentialChain 演示了如何将多个 Chain 串联起来。
    • 在嵌套 Chain 的场景中,外层 Chain 的 run_id 将会作为内层 Chain 的 parent_run_id 传递,这对于构建完整的追踪链路至关重要。

通过这个基础的日志埋点器,我们已经能够在 Chain 的关键生命周期节点获取到丰富的运行时信息。然而,对于生产环境,我们通常需要更结构化、更健壮的埋点和可观测性解决方案。

4. 高级埋点:捕获关键指标与结构化数据

在生产环境中,我们不仅需要日志,还需要能够聚合和分析的结构化数据。这通常包括:

  • 指标(Metrics):例如执行时长、成功率、错误率、LLM Token 消耗量。
  • 追踪(Traces):将一系列相关的操作(包括父子 Chain、LLM 调用、工具使用等)关联起来,形成一个完整的请求链路。
  • 事件(Events):记录特定时间点发生的离散事件,例如错误详情。

我们将构建一个 TelemetryCollectorCallback,它能够收集更全面的数据,并以结构化的方式存储,以便后续发送到外部监控系统。

import time
import json
from uuid import UUID
from collections import defaultdict
from typing import Any, Dict, List, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.llms import FakeListLLM
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import BaseMessage # 用于LLM回调的响应

# 1. 定义一个用于收集结构化遥测数据的回调处理器
class TelemetryCollectorCallback(BaseCallbackHandler):
    """
    一个用于收集LangChain Chain运行时遥测数据的回调处理器。
    它会收集每个Chain运行的详细信息,包括时间、状态、输入、输出、错误等。
    """
    def __init__(self):
        # 使用defaultdict存储每个run_id的详细数据
        self.runs_data: Dict[UUID, Dict[str, Any]] = defaultdict(dict)

    def _update_run_data(self, run_id: UUID, key: str, value: Any):
        """辅助方法,用于更新特定run_id的数据。"""
        self.runs_data[run_id][key] = value

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain开始执行时调用,记录初始信息。"""
        chain_name = serialized.get("lc_kwargs", {}).get("name", serialized.get("lc_id", ["Unknown"])[0])
        self._update_run_data(run_id, 'start_time_utc', time.time())
        self._update_run_data(run_id, 'run_id', str(run_id))
        self._update_run_data(run_id, 'parent_run_id', str(parent_run_id) if parent_run_id else None)
        self._update_run_data(run_id, 'chain_name', chain_name)
        self._update_run_data(run_id, 'chain_class', serialized.get('lc_id', ['','Unknown'])[1])
        self._update_run_data(run_id, 'inputs', inputs)
        self._update_run_data(run_id, 'tags', tags if tags else [])
        self._update_run_data(run_id, 'metadata', metadata if metadata else {})
        self._update_run_data(run_id, 'status', 'started')
        # print(f"Telemetry Collector: Chain '{chain_name}' ({run_id}) started.")

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain成功执行完成时调用,记录输出和完成状态。"""
        end_time = time.time()
        start_time = self.runs_data[run_id].get('start_time_utc')
        duration = end_time - start_time if start_time else None

        self._update_run_data(run_id, 'end_time_utc', end_time)
        self._update_run_data(run_id, 'duration_seconds', duration)
        self._update_run_data(run_id, 'outputs', outputs)
        self._update_run_data(run_id, 'status', 'completed')
        # print(f"Telemetry Collector: Chain ({run_id}) completed in {duration:.4f}s.")

    def on_chain_error(
        self,
        error: Exception,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain执行过程中发生错误时调用,记录错误信息。"""
        end_time = time.time()
        start_time = self.runs_data[run_id].get('start_time_utc')
        duration = end_time - start_time if start_time else None

        self._update_run_data(run_id, 'end_time_utc', end_time)
        self._update_run_data(run_id, 'duration_seconds', duration)
        self._update_run_data(run_id, 'status', 'failed')
        self._update_run_data(run_id, 'error_type', type(error).__name__)
        self._update_run_data(run_id, 'error_message', str(error))
        # print(f"Telemetry Collector: Chain ({run_id}) failed with {type(error).__name__} in {duration:.4f}s.")

    # 除了Chain的回调,我们也可以收集LLM的埋点,以获取更细粒度的信息,例如token使用量
    def on_llm_end(
        self,
        response: Any, # response可以是LLMResult或BaseMessage
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在LLM调用结束时捕获token使用量等信息。"""
        # LLMResult 包含 llm_output,其中可能包含 token_usage
        token_usage = None
        if hasattr(response, 'llm_output') and response.llm_output and 'token_usage' in response.llm_output:
            token_usage = response.llm_output['token_usage']
        elif isinstance(response, BaseMessage) and hasattr(response, 'usage_metadata'): # For some newer models/integrations
             token_usage = response.usage_metadata

        if token_usage:
            # 找到对应的父Chain或自身的LLM运行记录,并更新token_usage
            # 由于LLM通常是Chain的子组件,这里可能需要根据parent_run_id来关联
            # 更精确的做法是维护一个active_llm_runs字典
            # 为了简化演示,我们直接将token_usage记录到父run_id(如果存在)的llm_calls列表中
            # 实际应用中,LLM调用本身也有run_id,可以独立追踪
            if parent_run_id and parent_run_id in self.runs_data:
                if 'llm_calls' not in self.runs_data[parent_run_id]:
                    self.runs_data[parent_run_id]['llm_calls'] = []
                self.runs_data[parent_run_id]['llm_calls'].append({
                    "llm_run_id": str(run_id),
                    "token_usage": token_usage
                })
            else: # 如果没有父run_id,或者父Chain数据不存在,则单独记录
                self._update_run_data(run_id, 'token_usage', token_usage)
                self._update_run_data(run_id, 'type', 'llm_call')

    def get_all_runs_data(self) -> Dict[UUID, Dict[str, Any]]:
        """获取所有已收集的运行数据。"""
        # 返回一个深拷贝,防止外部修改
        return json.loads(json.dumps(self.runs_data, default=str)) # default=str处理UUID

    def get_run_data(self, run_id: UUID) -> Optional[Dict[str, Any]]:
        """获取特定run_id的运行数据。"""
        data = self.runs_data.get(run_id)
        return json.loads(json.dumps(data, default=str)) if data else None

# 2. 设置一个模拟的LLM和Chain进行演示
llm = FakeListLLM(
    responses=["最佳名称是'创新智联'。", "可以考虑'未来算法'。", "我们建议'智能脉冲'。"],
    # 模拟LLM响应中包含token_usage信息
    llm_output={"token_usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30}}
)
prompt = PromptTemplate.from_template("为一家专注于{product}的公司起个名字。")
demo_chain = LLMChain(llm=llm, prompt=prompt)

telemetry_collector = TelemetryCollectorCallback()

print("--- 演示1: 正常执行的Chain,收集详细遥测数据 ---")
# 传入 telemetry_collector 实例
demo_chain.invoke(
    {"product": "下一代AI芯片"},
    config={
        "callbacks": [telemetry_collector],
        "tags": ["AI", "芯片", "产品命名"],
        "metadata": {"user_session": "sess_abc", "priority": "high"},
    },
)

print("n--- 演示2: 模拟Chain执行错误,收集错误遥测数据 ---")
def failing_function_telemetry(input_data: Dict[str, Any]) -> Dict[str, Any]:
    if "error_trigger" in input_data.get("product", "").lower():
        raise RuntimeError("模拟运行时错误:数据处理失败。")
    return {"output": f"处理成功:{input_data['product']}"}

failing_chain_telemetry = RunnablePassthrough.assign(
    output=RunnableLambda(failing_function_telemetry) | StrOutputParser()
)

try:
    failing_chain_telemetry.invoke(
        {"product": "需要error_trigger的数据"},
        config={"callbacks": [telemetry_collector], "tags": ["错误测试", "Telemetry"]},
    )
except RuntimeError as e:
    print(f"n捕获到预期错误:{e}")

print("n--- 演示3: 嵌套Chain,查看父子关系和LLM埋点 ---")
llm_nested_1 = FakeListLLM(
    responses=["AI助手", "数据分析平台", "云计算服务"],
    llm_output={"token_usage": {"completion_tokens": 8, "prompt_tokens": 15, "total_tokens": 23}}
)
prompt_nested_1 = PromptTemplate.from_template("请列举3个与'{industry}'相关的热门技术领域。")
chain_nested_1 = LLMChain(llm=llm_nested_1, prompt=prompt_nested_1)

llm_nested_2 = FakeListLLM(
    responses=["AI助手市场正在快速增长。", "数据分析是企业决策的关键。", "云计算是现代基础设施的基石。"],
    llm_output={"token_usage": {"completion_tokens": 12, "prompt_tokens": 25, "total_tokens": 37}}
)
prompt_nested_2 = PromptTemplate.from_template("对以下技术领域进行简要市场分析:{tech_fields}")
chain_nested_2 = LLMChain(llm=llm_nested_2, prompt=prompt_nested_2)

# SimpleSequentialChain 将 chain_nested_1 的输出作为 chain_nested_2 的输入
nested_sequential_chain = SimpleSequentialChain(chains=[chain_nested_1, chain_nested_2], verbose=False)

nested_sequential_chain.invoke(
    {"industry": "科技"},
    config={
        "callbacks": [telemetry_collector],
        "tags": ["嵌套", "市场分析"],
        "metadata": {"client": "big_corp"},
    },
)

print("n--- 收集到的所有遥测数据 ---")
all_telemetry = telemetry_collector.get_all_runs_data()
for run_id_str, data in all_telemetry.items():
    print(f"--- Run ID: {run_id_str} ---")
    print(json.dumps(data, indent=2, ensure_ascii=False))
    print("-" * 40)

代码解析:

  1. TelemetryCollectorCallback

    • runs_data 字典 (使用 defaultdict(dict)) 用于按 run_id 存储每个运行的详细信息。这样,我们可以将一个 run_id 下的所有相关事件(开始、结束、错误、LLM 调用等)都收集到同一个字典中。
    • _update_run_data 辅助方法简化了数据的更新操作。
    • on_chain_start:记录了 Chain 的启动时间 (start_time_utc)、run_idparent_run_idchain_namechain_classinputstagsmetadata 和初始状态 started
    • on_chain_end:记录结束时间 (end_time_utc)、计算 duration_seconds、记录 outputs 和最终状态 completed
    • on_chain_error:记录结束时间、计算 duration_seconds、记录状态 failed、错误类型 error_type 和错误消息 error_message
    • on_llm_end:这是一个额外的 LLM 回调示例,用于演示如何捕获 LLM 的 Token 使用量 (token_usage)。这对于成本监控和性能分析至关重要。这里我们将 LLM 的 Token 使用信息关联到其父 Chain 的 llm_calls 列表中,这是一种将子组件信息聚合到父组件上的常见模式。
    • get_all_runs_dataget_run_data:这些方法允许外部代码查询和获取收集到的结构化数据。我们使用 json.dumps(..., default=str) 来确保 UUID 对象能够正确地序列化为字符串。
  2. LLM 模拟 Token 使用

    • FakeListLLM 被扩展,通过 llm_output 属性模拟了真实的 LLM 响应中可能包含的 Token 使用信息。这使得 on_llm_end 可以实际捕获并记录这些数据。
  3. 演示与数据输出

    • 运行 Chain 后,我们通过 telemetry_collector.get_all_runs_data() 获取所有收集到的数据。
    • 使用 json.dumps(..., indent=2, ensure_ascii=False) 将数据美观地打印出来,便于观察其结构和内容。可以看到,每个 run_id 下都包含了该 Chain 调用的完整生命周期信息,包括输入、输出、时长、状态,甚至嵌套的 LLM 调用的 Token 使用量。

这种结构化的数据收集方式,使得我们能够方便地将数据发送到各种外部监控系统,例如:

  • 日志聚合系统:如 ELK Stack (Elasticsearch, Logstash, Kibana) 或 Splunk,将 JSON 格式的埋点数据作为日志事件 ingest。
  • 指标监控系统:如 Prometheus/Grafana,通过定制的 exporter 将执行时长、Token 消耗等数据转化为时间序列指标。
  • 分布式追踪系统:如 Jaeger, Zipkin, OpenTelemetry,将 run_idparent_run_id 关联起来,构建完整的请求追踪链路。

5. 整合外部监控系统:以 OpenTelemetry 为例

在现代微服务架构中,分布式追踪(Distributed Tracing)是实现可观测性的核心组件。OpenTelemetry (OTel) 是一个跨语言的开放标准和工具集,用于生成、收集和导出遥测数据(Metrics, Logs, Traces)。

将 LangChain 的回调与 OpenTelemetry 集成,可以让我们无缝地将 LangChain 应用的内部操作暴露到标准的分布式追踪系统中。

我们将创建一个模拟的 OpenTelemetry 追踪器 (MockTracer) 和一个 OpenTelemetryTracingCallback,演示如何将 LangChain 的 Chain 生命周期事件转换为 OTel 的 Span。


import time
import json
import traceback
from uuid import UUID
from collections import defaultdict
from typing import Any, Dict, List, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.llms import FakeListLLM
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import BaseMessage # 用于LLM回调的响应

# --- 模拟 OpenTelemetry 核心组件 ---

class MockSpan:
    """
    模拟 OpenTelemetry 中的 Span 对象。
    一个Span代表一个操作的逻辑单元,包含名称、ID、时间戳、属性、事件等。
    """
    def __init__(self, name: str, run_id: UUID, parent_span_id: Optional[UUID] = None):
        self.name = name
        self.span_id = UUID(int=uuid.getnode()).hex[:16] # 模拟一个16字符的Span ID
        self.trace_id = str(run_id)  # 使用LangChain的run_id作为Trace ID
        self.parent_span_id = str(parent_span_id) if parent_span_id else None
        self.start_time_ns = time.time_ns()  # 纳秒级时间戳
        self.end_time_ns: Optional[int] = None
        self.attributes: Dict[str, Any] = {
            "langchain.run_id": str(run_id),
            "langchain.component.type": "unknown",
            "langchain.component.name": name,
        }
        self.events: List[Dict[str, Any]] = [] # 用于记录异常或其他重要事件
        self.status: Dict[str, Any] = {"code": "UNSET"} # OTel Span Status

        print(f"[Mock OTel] Span '{self.name}' (ID: {self.span_id}, Trace: {self.trace_id[:8]}...) started.")

    def set_attribute(self, key: str, value: Any):
        """设置Span的属性。"""
        self.attributes[key] = value

    def add_event(self, name: str, attributes: Optional[Dict[str, Any]] = None):
        """添加一个事件到Span。"""
        event_data = {
            "name": name,
            "timestamp_ns": time.time_ns(),
            "attributes": attributes if attributes else {},
        }
        self.events.append(event_data)
        print(f"[Mock OTel] Span '{self.name}' added event: {name}")

    def set_status(self, code: str, description: Optional[str] = None):
        """设置Span的状态 (e.g., OK, ERROR)。"""
        self.status["code"] = code
        if description:
            self.status["description"] = description

    def end(self):
        """结束Span,计算持续时间。"""
        self.end_time_ns = time.time_ns()
        duration_ms = (self.end_time_ns - self.start_time_ns) / 1_000_000 if self.end_time_ns else 0
        self.set_attribute("duration_ms", duration_ms)
        print(f"[Mock OTel] Span '{self.name}' (ID: {self.span_id}, Trace: {self.trace_id[:8]}...) ended. Duration: {duration_ms:.2f}ms")

    def to_dict(self) -> Dict[str, Any]:
        """将Span转换为字典格式,便于打印或导出。"""
        return {
            "name": self.name,
            "span_id": self.span_id,
            "trace_id": self.trace_id,
            "parent_span_id": self.parent_span_id,
            "start_time_ns": self.start_time_ns,
            "end_time_ns": self.end_time_ns,
            "duration_ms": self.attributes.get("duration_ms"),
            "attributes": self.attributes,
            "events": self.events,
            "status": self.status,
        }

class MockTracer:
    """
    模拟 OpenTelemetry 中的 Tracer 对象,用于管理 Span 的生命周期。
    """
    def __init__(self):
        # 存储当前活跃的Span,key是LangChain的run_id,value是MockSpan对象
        self.active_spans: Dict[UUID, MockSpan] = {}
        # 存储已完成的Span,用于后续导出
        self.finished_spans: List[MockSpan] = []

    def start_span(self, name: str, run_id: UUID, parent_run_id: Optional[UUID] = None) -> MockSpan:
        """
        开始一个新的Span。
        如果存在parent_run_id,则尝试将其对应的Span ID作为当前Span的父ID。
        """
        parent_span_id: Optional[UUID] = None
        if parent_run_id and parent_run_id in self.active_spans:
            parent_span_id = self.active_spans[parent_run_id].span_id

        span = MockSpan(name, run_id, parent_span_id)
        self.active_spans[run_id] = span
        return span

    def get_span(self, run_id: UUID) -> Optional[MockSpan]:
        """根据run_id获取活跃的Span。"""
        return self.active_spans.get(run_id)

    def end_span(self, run_id: UUID):
        """结束指定run_id的Span,并将其从活跃列表移到完成列表。"""
        span = self.active_spans.pop(run_id, None)
        if span:
            span.end()
            self.finished_spans.append(span)

    def export_finished_spans(self):
        """导出所有已完成的Span,模拟发送到OTel收集器。"""
        print("n--- 导出所有已完成的 Mock OTel Spans ---")
        for span in self.finished_spans:
            print(json.dumps(span.to_dict(), indent=2, ensure_ascii=False, default=str)) # default=str处理UUID
            print("-" * 50)
        self.finished_spans.clear() # 清空已导出的Span

# 创建一个全局的模拟Tracer实例
mock_otel_tracer = MockTracer()

# --- LangChain 回调处理器,用于集成 OpenTelemetry ---

class OpenTelemetryTracingCallback(BaseCallbackHandler):
    """
    LangChain回调处理器,将LangChain的运行事件转换为OpenTelemetry的Span。
    """
    def __init__(self, tracer: MockTracer):
        self.tracer = tracer

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain开始时创建并开始一个OpenTelemetry Span。"""
        chain_name = serialized.get("lc_kwargs", {}).get("name", serialized.get("lc_id", ["Unknown"])[0])
        span = self.tracer.start_span(f"chain.{chain_name}", run_id, parent_run_id)
        span.set_attribute("langchain.component.type", "chain")
        span.set_attribute("langchain.chain.class", serialized.get('lc_id', ['','Unknown'])[1])
        span.set_attribute("langchain.inputs", json.dumps(inputs, ensure_ascii=False))
        if tags:
            span.set_attribute("langchain.tags", json.dumps(tags, ensure_ascii=False))
        if metadata:
            span.set_attribute("langchain.metadata", json.dumps(metadata, ensure_ascii=False))

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain结束时设置Span的输出和状态,并结束Span。"""
        span = self.tracer.get_span(run_id)
        if span:
            span.set_attribute("langchain.outputs", json.dumps(outputs, ensure_ascii=False))
            span.set_status("OK") # 成功完成
            self.tracer.end_span(run_id)

    def on_chain_error(
        self,
        error: Exception,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        """在Chain出错时设置Span的错误信息和状态,并结束Span。"""
        span = self.tracer.get_span(run_id)
        if span:
            span.set_attribute("error", True)
            span.add_event(
                "exception",
                attributes={
                    "exception.type": type(error).__name__,
                    "exception.message": str(error),
                    "exception.stacktrace": traceback.format_exc(), # 捕获完整的堆栈追踪
                },
            )
            span.set_status("ERROR", description=str(error)) # 错误状态
            self.tracer.end_span(run_id)

    # 也可以为LLM添加追踪,形成更完整的调用链
    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注