各位编程爱好者、AI应用开发者们,大家好!
今天,我们将深入探讨LangChain框架中一个极其强大且灵活的机制——Callbacks。在构建复杂的AI应用时,我们经常需要对模型的行为进行监控、日志记录、性能分析,甚至在特定事件发生时触发自定义逻辑。Callbacks机制正是为此而生,它像一系列事件监听器,让我们可以“窥探”LangChain组件(如LLMs、Chains、Agents)的内部运作,并在关键生命周期事件点插入我们自己的代码。
本次讲座的重点,将放在如何通过自定义Callback Handler,实现一个实时、精确的Token消耗统计器。这对于成本控制、性能优化以及理解模型行为至关重要。
一、LangChain Callbacks 机制概览
在LangChain中,Callbacks 提供了一种非侵入式的扩展能力。当一个 LangChain 组件(比如一个大型语言模型调用、一个链的执行、一个代理的决策过程或工具使用)开始、进展或结束时,它会触发一系列预定义的事件。Callback Handler 就是用来捕获并响应这些事件的类。
1.1 为什么需要 Callbacks?
想象一下,你正在构建一个复杂的RAG(Retrieval Augmented Generation)应用,其中包含文档加载、文本分割、嵌入、向量搜索、LLM调用、输出解析等多个步骤。如果没有Callbacks,你很难做到:
- 实时监控: 知道当前LLM调用消耗了多少Token,耗时多久。
- 日志记录: 记录每次Chain执行的输入、输出、中间步骤。
- 错误处理: 当LLM调用失败时,捕获错误并执行自定义恢复逻辑。
- 成本估算: 准确统计每次交互或整个会话的Token使用量,以便进行成本核算。
- 进度反馈: 在流式输出时,向用户展示模型正在生成内容。
- 性能分析: 识别应用中的瓶颈,优化各个组件。
Callbacks正是解决这些痛点的关键。
1.2 Callbacks 的核心组件
LangChain 的Callbacks机制主要由以下两个核心抽象组成:
BaseCallbackHandler/AsyncCallbackHandler: 这是所有自定义回调处理器的基类。它定义了一系列可以在不同事件点被调用的方法(例如on_llm_start,on_chain_end等)。BaseCallbackHandler: 用于同步操作。AsyncCallbackHandler: 用于异步操作,其方法都是async函数。
CallbackManager/AsyncCallbackManager: 这是一个容器,用于管理一个或多个Callback Handler实例。LangChain 的组件(如LLM、Chain)会接收一个CallbackManager,并通知它所有注册的 Handler。
1.3 BaseCallbackHandler 的关键方法
BaseCallbackHandler 提供了非常丰富的事件钩子,涵盖了LangChain组件的各个生命周期。以下是一些最常用的方法,我们将通过表格形式列出:
| 方法名称 | 触发时机 | 参数 | 描述 |
|---|
* `on_llm_start`: 触发于一个LLM调用开始时。参数包括 `prompt` (字符串或字符串列表) 和 `run_manager` (用于管理回调运行的上下文)。
* `on_llm_new_token`: 触发于LLM生成流式响应中的每个新Token时。参数包括 `token` (新生成的Token字符串) 和 `run_manager`。
* `on_llm_end`: 触发于一个LLM调用结束时。参数包括 `response` (`LLMResult` 对象,包含生成的文本和可选的`token_usage`信息) 和 `run_manager`。
* `on_llm_error`: 触发于LLM调用过程中发生错误时。参数包括 `error` (异常对象) 和 `run_manager`。
* `on_chain_start`: 触发于一个Chain开始执行时。参数包括 `serialized` (Chain的序列化表示), `inputs` (Chain的输入), `run_manager`。
* `on_chain_end`: 触发于一个Chain执行结束时。参数包括 `outputs` (Chain的输出) 和 `run_manager`。
* `on_chain_error`: 触发于Chain执行过程中发生错误时。参数包括 `error` (异常对象) 和 `run_manager`。
* `on_tool_start`: 触发于Agent开始使用一个工具时。参数包括 `serialized` (工具的序列化表示), `input_str` (工具的输入字符串), `run_manager`。
* `on_tool_end`: 触发于Agent使用一个工具结束时。参数包括 `output` (工具的输出字符串) 和 `run_manager`。
* `on_tool_error`: 触发于工具执行过程中发生错误时。参数包括 `error` (异常对象) 和 `run_manager`。
* `on_agent_action`: 触发于Agent采取行动时(例如决定使用哪个工具)。参数包括 `action` (AgentAction对象) 和 `run_manager`。
* `on_agent_finish`: 触发于Agent完成其任务时。参数包括 `finish` (AgentFinish对象) 和 `run_manager`。
* `on_retriever_start`, `on_retriever_end`, `on_retriever_error`: 针对检索器操作。
* `on_text`: 触发于任何组件生成文本时。参数包括 `text` 和 `run_manager`。这个方法比较通用,常用于简单的文本日志。
run_manager 是一个 CallbackManagerFor<Component>Run 对象,它提供了当前运行的上下文信息,比如 run_id(唯一标识符)和 parent_run_id(如果存在嵌套调用)。你可以通过它来记录或关联特定运行的事件。
二、实时Token消耗统计的必要性与挑战
在开发基于LLM的应用时,Token消耗是直接与成本挂钩的关键指标。无论是OpenAI、Anthropic还是其他模型提供商,Token数量都是计费的基础。实时统计Token消耗不仅可以帮助我们:
- 预算控制: 避免意外的高额账单。
- 用户透明: 向用户展示其请求消耗的资源。
- 性能分析: 了解不同提示和模型配置对Token使用的影响。
- 优化策略: 引导我们编写更简洁、高效的提示。
然而,实时Token统计并非总是直截了当,主要挑战包括:
- 输入Token的统计: LLM在处理提示(prompt)时就会消耗Token。我们需要在请求发送前统计这些Token。
- 输出Token的统计: LLM生成响应时会消耗Token。
- 非流式模式: 一次性返回完整响应,Token统计通常在
on_llm_end的LLMResult中提供。 - 流式模式: 逐个Token返回,我们需要在
on_llm_new_token中累加。
- 非流式模式: 一次性返回完整响应,Token统计通常在
- 不同模型的Tokenization: 不同的LLM使用不同的Tokenization算法(如
tiktokenfor OpenAI,SentencePiecefor Llama等)。准确统计需要使用模型对应的Tokenization工具。LangChain 提供了一个get_num_tokens辅助函数,可以利用LLM实例的_llm_type来选择合适的Token计算器。 - 嵌套调用: 当Chain或Agent包含多个LLM调用时,需要能够区分和汇总不同层级的Token消耗。
- 异步操作: 在异步应用中,回调处理器也需要支持异步。
三、构建 TokenCountCallbackHandler
现在,让我们一步步构建一个名为 TokenCountCallbackHandler 的自定义回调处理器,它将能够:
- 记录每个LLM调用的输入Token和输出Token。
- 支持流式和非流式LLM调用。
- 汇总总的Token消耗。
- 支持同步和异步版本。
我们将先实现同步版本,再展示如何修改为异步版本。
3.1 核心思路
__init__: 初始化统计变量,包括当前运行的输入/输出Token和总的输入/输出Token。on_llm_start: 在LLM调用开始时触发。这里我们可以获取到原始的prompt,并使用LangChain提供的get_num_tokens或tiktoken库来估算输入Token数量。on_llm_new_token: 在流式LLM调用中,每生成一个新Token时触发。我们可以直接累加这些Token到输出Token计数器。on_llm_end: 在LLM调用结束时触发。- 如果模型(如OpenAI系列)在
LLMResult中提供了token_usage信息,我们应该优先使用这个精确的统计值。 - 如果
token_usage不可用(例如,某些本地模型或不提供此信息的API),则依靠on_llm_new_token中累加的输出Token数,并使用on_llm_start中估算的输入Token数。
- 如果模型(如OpenAI系列)在
on_chain_start/on_chain_end: 我们可以利用这些钩子来跟踪特定链的Token使用,或者在链结束时打印汇总信息。- 辅助方法:
reset()用于清零计数器,get_stats()用于获取当前统计数据。
3.2 同步 TokenCountCallbackHandler 的实现
首先,确保你安装了必要的库:
pip install langchain openai tiktoken
import tiktoken
import uuid
from typing import Any, Dict, List, Union, Optional
from langchain_core.callbacks import BaseCallbackHandler, CallbackManagerForLLMRun, CallbackManagerForChainRun, CallbackManagerForToolRun
from langchain_core.outputs import LLMResult, Generation
from langchain_core.messages import BaseMessage
from langchain_core.language_models import BaseLLM
from langchain_openai import ChatOpenAI # 示例LLM
# 辅助函数:根据模型名称获取token编码器
def get_tokenizer_for_model(model_name: str):
"""
根据模型名称获取tiktoken编码器。
这里仅为示例,实际应更全面覆盖各种模型。
"""
try:
# 尝试匹配常见的OpenAI模型
if "gpt-4" in model_name or "gpt-3.5" in model_name:
return tiktoken.encoding_for_model(model_name)
# 默认使用cl100k_base,适用于大多数GPT系列模型
return tiktoken.get_encoding("cl100k_base")
except KeyError:
# 如果模型名称找不到对应的编码器,则回退到cl100k_base
return tiktoken.get_encoding("cl100k_base")
class TokenCountCallbackHandler(BaseCallbackHandler):
"""
一个自定义的Callback Handler,用于实时统计LangChain LLM调用的Token消耗。
支持流式和非流式调用,并优先使用LLMResult中提供的精确Token使用量。
"""
def __init__(self, llm: Optional[BaseLLM] = None, verbose: bool = False):
super().__init__()
self.llm = llm # 引用LLM实例,用于获取其_get_num_tokens方法
self.verbose = verbose
# 用于跟踪当前LLM调用的Token
self.current_run_input_tokens: int = 0
self.current_run_output_tokens: int = 0
self.current_run_total_tokens: int = 0
# 用于跟踪总的Token
self.total_input_tokens: int = 0
self.total_output_tokens: int = 0
self.total_total_tokens: int = 0
# 用于跟踪每个run_id的Token,支持嵌套和并行调用
self.run_token_data: Dict[str, Dict[str, Any]] = {}
# 缓存tokenizer,避免重复创建
self._tokenizer = None
if self.llm and hasattr(self.llm, "model_name"):
self._tokenizer = get_tokenizer_for_model(self.llm.model_name)
elif self.llm and hasattr(self.llm, "model"): # For some models, 'model' might be the name
self._tokenizer = get_tokenizer_for_model(self.llm.model)
def _get_num_tokens_from_string(self, text: str) -> int:
"""
使用tiktoken计算字符串中的Token数量。
尝试使用实例化的LLM的tokenizer,如果LLM未提供或无tokenizer,则回退到默认。
"""
if self.llm and hasattr(self.llm, "get_num_tokens"):
# 优先使用LangChain LLM实例自带的token计算方法
return self.llm.get_num_tokens(text)
elif self._tokenizer:
# 使用预设的tiktoken tokenizer
return len(self._tokenizer.encode(text))
else:
# 否则,尝试使用默认的tiktoken编码器 (cl100k_base)
# 或者可以抛出错误,取决于需求
default_tokenizer = tiktoken.get_encoding("cl100k_base")
return len(default_tokenizer.encode(text))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""
在LLM调用开始时触发。
估算输入Token并记录。
"""
if self.verbose:
print(f"n--- LLM Run {run_id} Started ---")
print(f"Prompts: {prompts}")
input_tokens = 0
for prompt in prompts:
input_tokens += self._get_num_tokens_from_string(prompt)
self.current_run_input_tokens = input_tokens
self.current_run_output_tokens = 0 # 重置当前run的输出Token
self.current_run_total_tokens = input_tokens # 初始总Token为输入Token
self.run_token_data[str(run_id)] = {
"input_tokens": input_tokens,
"output_tokens": 0,
"total_tokens": input_tokens,
"type": "llm",
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"prompts": prompts,
"final_response": None,
}
if self.verbose:
print(f" Estimated Input Tokens (Run {run_id}): {input_tokens}")
def on_llm_new_token(
self, token: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""
在LLM流式生成时,每生成一个新Token时触发。
累加输出Token。
"""
if self.verbose:
print(f" Streaming Token (Run {run_id}): '{token}'")
# 累加当前run的输出Token
# 注意:这里我们使用默认tokenizer来近似计算,因为LLM实例的get_num_tokens可能不适合单字符。
# 更准确的做法是累积字符串,然后在on_llm_end时一次性计算。
# 但为了实时性,我们在此处进行近似。
token_len = self._get_num_tokens_from_string(token)
self.current_run_output_tokens += token_len
self.current_run_total_tokens += token_len
# 更新run_id对应的统计数据
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["output_tokens"] += token_len
self.run_token_data[str(run_id)]["total_tokens"] += token_len
else:
# 理论上on_llm_start会先触发,但以防万一
self.run_token_data[str(run_id)] = {
"input_tokens": 0, "output_tokens": token_len, "total_tokens": token_len,
"type": "llm", "parent_run_id": str(parent_run_id) if parent_run_id else None
}
def on_llm_end(
self, response: LLMResult, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""
在LLM调用结束时触发。
根据LLMResult中的token_usage信息或累加值更新Token统计。
"""
if self.verbose:
print(f"--- LLM Run {run_id} Ended ---")
print(f"Response: {response.generations[0][0].text}")
# 优先使用LLMResult中提供的精确Token使用量
# 适用于OpenAI等提供了Usage信息的模型
token_usage = response.llm_output.get("token_usage") if response.llm_output else None
input_tokens = 0
output_tokens = 0
total_tokens = 0
if token_usage:
# 如果LLM提供了精确的token_usage
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", 0)
if self.verbose:
print(f" Actual Token Usage from LLMResult (Run {run_id}): Input={input_tokens}, Output={output_tokens}, Total={total_tokens}")
else:
# 如果LLM未提供token_usage,则使用我们在on_llm_start和on_llm_new_token中累加的值
input_tokens = self.current_run_input_tokens
output_tokens = self.current_run_output_tokens
total_tokens = self.current_run_total_tokens
if self.verbose:
print(f" Estimated Token Usage from Callback Accumulation (Run {run_id}): Input={input_tokens}, Output={output_tokens}, Total={total_tokens}")
# 再次验证输出tokens,以防万一on_llm_new_token没有完全捕获,或者是非流式但没有提供token_usage的模型
if output_tokens == 0 and response.generations:
final_output_text = "".join([gen.text for gen in response.generations[0]])
recalculated_output_tokens = self._get_num_tokens_from_string(final_output_text)
if recalculated_output_tokens > output_tokens: # 如果重新计算的值更准确
output_tokens = recalculated_output_tokens
total_tokens = input_tokens + output_tokens
if self.verbose:
print(f" Recalculated output tokens (Run {run_id}): {output_tokens}")
# 更新当前run的统计值
self.current_run_input_tokens = input_tokens
self.current_run_output_tokens = output_tokens
self.current_run_total_tokens = total_tokens
# 更新总的Token统计
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.total_total_tokens += total_tokens
# 更新run_id对应的统计数据
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)].update({
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens,
"final_response": response.generations[0][0].text if response.generations else None,
})
else:
# 理论上on_llm_start会先触发,但以防万一
self.run_token_data[str(run_id)] = {
"input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": total_tokens,
"type": "llm", "final_response": response.generations[0][0].text if response.generations else None
}
if self.verbose:
print(f" Current LLM Total (Run {run_id}): Input={input_tokens}, Output={output_tokens}, Total={total_tokens}")
print(f" Overall Total: Input={self.total_input_tokens}, Output={self.total_output_tokens}, Total={self.total_total_tokens}")
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""在Chain开始时触发,记录Chain的开始。"""
if self.verbose:
print(f"n--- Chain Run {run_id} Started ({serialized.get('lc_kwargs', {}).get('name', serialized.get('name'))}) ---")
print(f" Inputs: {inputs}")
self.run_token_data[str(run_id)] = {
"type": "chain",
"name": serialized.get('lc_kwargs', {}).get('name', serialized.get('name', serialized.get('lc_id', ['UnknownChain'])[-1])),
"input": inputs,
"output": None,
"input_tokens": 0, # Chain本身不消耗Token,其内部LLM调用才消耗
"output_tokens": 0,
"total_tokens": 0,
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"children_runs": []
}
if parent_run_id and str(parent_run_id) in self.run_token_data:
self.run_token_data[str(parent_run_id)].setdefault("children_runs", []).append(str(run_id))
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""在Chain结束时触发,记录Chain的结束和输出。"""
if self.verbose:
print(f"--- Chain Run {run_id} Ended ---")
print(f" Outputs: {outputs}")
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["output"] = outputs
# 汇总子LLM调用的Token到父Chain
chain_total_input = 0
chain_total_output = 0
chain_total_total = 0
# 遍历子运行(LLM、Tool等)并汇总其Token
for child_run_id in self.run_token_data[str(run_id)].get("children_runs", []):
child_data = self.run_token_data.get(child_run_id)
if child_data and child_data["type"] == "llm": # 仅汇总LLM的Token
chain_total_input += child_data.get("input_tokens", 0)
chain_total_output += child_data.get("output_tokens", 0)
chain_total_total += child_data.get("total_tokens", 0)
self.run_token_data[str(run_id)]["input_tokens"] = chain_total_input
self.run_token_data[str(run_id)]["output_tokens"] = chain_total_output
self.run_token_data[str(run_id)]["total_tokens"] = chain_total_total
if self.verbose:
print(f" Chain Total (Run {run_id}): Input={chain_total_input}, Output={chain_total_output}, Total={chain_total_total}")
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any) -> None:
"""在Tool开始时触发。"""
if self.verbose:
print(f"n--- Tool Run {run_id} Started ({serialized.get('name')}) ---")
print(f" Input: {input_str}")
self.run_token_data[str(run_id)] = {
"type": "tool",
"name": serialized.get('name', 'UnknownTool'),
"input": input_str,
"output": None,
"input_tokens": 0, # 工具本身不消耗Token,其内部LLM调用才消耗
"output_tokens": 0,
"total_tokens": 0,
"parent_run_id": str(parent_run_id) if parent_run_id else None,
"children_runs": []
}
if parent_run_id and str(parent_run_id) in self.run_token_data:
self.run_token_data[str(parent_run_id)].setdefault("children_runs", []).append(str(run_id))
def on_tool_end(self, output: str, *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any) -> None:
"""在Tool结束时触发。"""
if self.verbose:
print(f"--- Tool Run {run_id} Ended ---")
print(f" Output: {output}")
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["output"] = output
# 汇总子LLM调用的Token到父Tool
tool_total_input = 0
tool_total_output = 0
tool_total_total = 0
for child_run_id in self.run_token_data[str(run_id)].get("children_runs", []):
child_data = self.run_token_data.get(child_run_id)
if child_data and child_data["type"] == "llm":
tool_total_input += child_data.get("input_tokens", 0)
tool_total_output += child_data.get("output_tokens", 0)
tool_total_total += child_data.get("total_tokens", 0)
self.run_token_data[str(run_id)]["input_tokens"] = tool_total_input
self.run_token_data[str(run_id)]["output_tokens"] = tool_total_output
self.run_token_data[str(run_id)]["total_tokens"] = tool_total_total
if self.verbose:
print(f" Tool Total (Run {run_id}): Input={tool_total_input}, Output={tool_total_output}, Total={tool_total_total}")
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any
) -> None:
"""在LLM调用出错时触发。"""
if self.verbose:
print(f"n--- LLM Run {run_id} Error ---")
print(f" Error: {error}")
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["error"] = str(error)
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any) -> None:
"""在Chain调用出错时触发。"""
if self.verbose:
print(f"n--- Chain Run {run_id} Error ---")
print(f" Error: {error}")
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["error"] = str(error)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], *, run_id: uuid.UUID, parent_run_id: Optional[uuid.UUID] = None, **kwargs: Any) -> None:
"""在Tool调用出错时触发。"""
if self.verbose:
print(f"n--- Tool Run {run_id} Error ---")
print(f" Error: {error}")
if str(run_id) in self.run_token_data:
self.run_token_data[str(run_id)]["error"] = str(error)
def reset(self) -> None:
"""重置所有Token计数器和运行数据。"""
self.current_run_input_tokens = 0
self.current_run_output_tokens = 0
self.current_run_total_tokens = 0
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_total_tokens = 0
self.run_token_data = {}
if self.verbose:
print("n--- Token Count Reset ---")
def get_stats(self) -> Dict[str, Any]:
"""返回当前的Token统计数据。"""
return {
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_total_tokens": self.total_total_tokens,
"runs": self.run_token_data,
}
def print_summary(self) -> None:
"""打印一个Token消耗的总结。"""
print("n=== Token Consumption Summary ===")
print(f"Overall Input Tokens: {self.total_input_tokens}")
print(f"Overall Output Tokens: {self.total_output_tokens}")
print(f"Overall Total Tokens: {self.total_total_tokens}")
print("nDetailed Runs:")
for run_id, data in self.run_token_data.items():
run_type = data.get("type", "unknown")
run_name = data.get("name", "LLM Call") if run_type != "llm" else ""
print(f" - {run_type.upper()} Run {run_id} ({run_name}):")
print(f" Input: {data.get('input_tokens', 0)} tokens")
print(f" Output: {data.get('output_tokens', 0)} tokens")
print(f" Total: {data.get('total_tokens', 0)} tokens")
if data.get("parent_run_id"):
print(f" Parent Run: {data['parent_run_id']}")
if data.get("error"):
print(f" Error: {data['error']}")
if data.get("children_runs"):
print(f" Children Runs: {', '.join(data['children_runs'])}")
print("==============================")
代码解析:
get_tokenizer_for_model: 这是一个辅助函数,用于根据模型名称获取合适的tiktoken编码器