尊敬的各位技术同仁:
欢迎来到今天的技术讲座,我们将深入探讨一个在构建人工智能Agent时至关重要但又常常被忽视的问题:延迟剖析(Latency Profiling)。随着AI Agent在各种应用场景中扮演越来越重要的角色,其响应速度直接关系到用户体验、业务效率乃至系统稳定性。当您的Agent响应变慢时,您是否能迅速定位问题根源——是Embedding模型速度不济?还是LLM推理瓶颈?亦或是某个工具的执行拖了后腿?
今天的讲座,我将以编程专家的视角,带领大家系统性地理解Agent的内部工作机制,掌握一套行之有效的延迟剖析策略、工具和实战技巧,帮助您精准找出Agent响应慢的根源,并提供相应的优化思路。
一、Agent的崛起与延迟的挑战
近年来,以大型语言模型(LLM)为核心的AI Agent正以前所未有的速度渗透到软件开发的各个领域。它们不再仅仅是回答问题的模型,而是具备感知、规划、记忆和行动能力的智能实体。从自动化客服、代码助手到复杂的数据分析和决策支持系统,Agent的应用前景广阔。
然而,随之而来的挑战也日益凸显:性能,尤其是响应延迟。一个需要数秒甚至数十秒才能给出响应的Agent,在许多实时性要求高的场景中是不可接受的。用户可能会流失,业务流程可能会中断,甚至可能导致系统级故障。
为什么Agent的延迟如此难以捉摸?
不同于传统服务,AI Agent的运行流程通常涉及多个异构组件的协作:
- 外部API调用: 频繁与LLM提供商(如OpenAI, Anthropic)进行网络通信。
- 数据检索与处理: 可能涉及向量数据库查询、传统数据库操作、文件系统读写等。
- 工具执行: 调用各种外部服务(如天气API、CRM系统、内部微服务)或执行复杂计算逻辑。
- 内部框架开销: Agent框架(如LangChain, LlamaIndex)自身的编排逻辑。
这些环节中的任何一个都可能成为瓶颈,而传统的性能监控工具往往难以提供Agent级别、端到端的细粒度分析。因此,我们需要一套专门针对AI Agent的延迟剖析方法论。
二、理解Agent的内部机制与延迟来源
要剖析延迟,首先要清晰地理解Agent的典型工作流程及其每个环节可能引入的延迟。
2.1 Agent的典型工作流
一个典型的AI Agent从接收用户请求到生成响应,大致会经历以下核心阶段:
- 用户输入与预处理: 接收用户请求,进行初步的清洗、格式化或意图识别。
- 上下文构建/检索增强生成 (RAG):
- 根据用户查询,通过Embedding模型将查询转换为向量。
- 在向量数据库中进行相似性搜索,检索相关文档片段或知识。
- 将检索到的上下文与用户查询一起,构建发送给LLM的Prompt。
- LLM推理与规划:
- 将构建好的Prompt发送给大型语言模型(LLM)。
- LLM进行推理,理解用户意图,生成行动计划(例如,决定调用哪个工具,或者直接生成回复)。
- 工具选择与执行:
- 如果LLM决定调用工具,Agent会解析LLM的输出,选择合适的工具(Tool)。
- 执行选定的工具,可能涉及调用外部API、数据库查询、本地计算等。
- 获取工具执行结果。
- LLM推理与响应生成:
- 将工具执行结果(如果存在)与之前的对话历史、上下文等再次喂给LLM。
- LLM根据所有信息生成最终的用户响应。
- 后处理与用户输出: 对LLM生成的响应进行格式化、过滤或存储,最终呈现给用户。
![Agent Work Flow Diagram – Internal thought process]
2.2 详细分解延迟来源
基于上述工作流,我们可以将延迟来源细分为以下几类:
| 延迟来源类别 | 涉及组件/操作 | 典型表现 |
|---|---|---|
| Embedding Generation | Embedding模型API调用、本地Embedding模型推理 | Embedding模型调用耗时,尤其是大批量处理时。 |
| Vector Database Operations | 向量数据库查询、索引构建、数据同步 | 相似性搜索耗时,特别是在大数据集或复杂查询下。 |
| LLM Inference | LLM API调用(请求发送、响应接收)、LLM模型本身的推理时间、Token生成速度 | LLM API响应慢,特别是对于长Prompt或长回复。 |
| Tool Execution | 外部API调用、数据库查询、本地复杂计算、网络I/O | 某个工具的外部依赖服务响应慢,或工具内部计算复杂。 |
| Orchestration Overhead | Agent框架内部的逻辑处理、状态管理、数据序列化/反序列化 | 通常占比不大,但在极端复杂流程或低效框架下可能累积。 |
| Network Latency | 所有远程API调用(LLM、Embedding、外部工具)的网络传输时间 | 受网络带宽、区域、服务提供商SLA影响。 |
明确了这些潜在的延迟点,我们就可以有针对性地进行剖析。
三、延迟剖析的核心策略与工具
为了有效地找出Agent的延迟瓶颈,我们需要结合多种技术和工具。
3.1 日志记录 (Logging):最基础但强大的方法
日志是性能剖析最直接、最易于实现的方法。通过在关键代码路径上记录时间戳,我们可以计算出每个操作的持续时间。
实现方式:
在Agent的每个核心阶段(如Embedding调用前、LLM调用前、工具执行前等)记录开始时间,并在操作完成后记录结束时间,计算差值。
示例代码:Python time 模块与 logging 模块
import time
import logging
import json
from typing import Dict, Any
# 配置日志,便于结构化输出
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def log_event(event_name: str, duration_ms: float = None, details: Dict[str, Any] = None):
"""
统一的日志记录函数,支持结构化数据。
"""
log_data = {"event": event_name}
if duration_ms is not None:
log_data["duration_ms"] = f"{duration_ms:.2f}"
if details:
log_data.update(details)
logging.info(json.dumps(log_data))
class AgentProfiler:
def __init__(self):
self.start_times = {}
def start_phase(self, phase_name: str):
self.start_times[phase_name] = time.perf_counter()
log_event(f"{phase_name}_start")
def end_phase(self, phase_name: str, details: Dict[str, Any] = None):
if phase_name in self.start_times:
end_time = time.perf_counter()
duration_ms = (end_time - self.start_times[phase_name]) * 1000
log_event(f"{phase_name}_end", duration_ms, details)
del self.start_times[phase_name]
else:
logging.warning(f"Phase '{phase_name}' was ended without being started.")
# 模拟Agent的核心操作
class MockEmbeddingModel:
def embed_query(self, query: str) -> list:
time.sleep(0.1 + len(query) * 0.005) # 模拟Embedding耗时
return [0.1] * 1536
class MockVectorStore:
def similarity_search(self, query_vector: list, k: int = 4) -> list:
time.sleep(0.05 + len(query_vector) * 0.0001) # 模拟向量搜索耗时
return [{"id": f"doc_{i}", "content": f"Relevant document {i} for query."} for i in range(k)]
class MockLLM:
def generate(self, prompt: str) -> str:
tokens_in = len(prompt.split())
tokens_out = max(50, len(prompt) // 2) # 模拟输出tokens
time.sleep(0.5 + tokens_in * 0.002 + tokens_out * 0.005) # 模拟LLM推理和生成耗时
return f"LLM's response to: {prompt[:50]}..."
class MockTool:
def __init__(self, name: str, min_latency: float, max_latency: float):
self.name = name
self.min_latency = min_latency
self.max_latency = max_latency
def execute(self, input_data: str) -> str:
latency = self.min_latency + (self.max_latency - self.min_latency) * (hash(input_data) % 100 / 100.0)
time.sleep(latency)
return f"Tool '{self.name}' executed with input '{input_data}', result: data from external service."
# 模拟一个Agent运行流程
def run_agent_with_profiling(query: str, profiler: AgentProfiler):
profiler.start_phase("total_agent_run")
# 1. Embedding Generation & Vector Store Retrieval
profiler.start_phase("embedding_retrieval")
embedding_model = MockEmbeddingModel()
vector_store = MockVectorStore()
query_embedding = embedding_model.embed_query(query)
log_event("embedding_generated", details={"query_len": len(query)})
retrieved_docs = vector_store.similarity_search(query_embedding, k=2)
log_event("vector_search_completed", details={"num_docs": len(retrieved_docs)})
context = "n".join([doc["content"] for doc in retrieved_docs])
profiler.end_phase("embedding_retrieval", details={"retrieved_chars": len(context)})
# 2. LLM Planning & Tool Selection
profiler.start_phase("llm_planning")
llm = MockLLM()
planning_prompt = f"Based on context: {context}nAnd query: {query}nDecide if a tool is needed or respond directly."
llm_plan = llm.generate(planning_prompt)
profiler.end_phase("llm_planning", details={"plan_output_len": len(llm_plan)})
# 3. Tool Execution (Conditional)
tool_executed = False
if "tool" in llm_plan.lower(): # 简单判断是否需要工具
profiler.start_phase("tool_execution")
weather_tool = MockTool("WeatherAPI", 0.3, 1.5) # 模拟一个耗时0.3到1.5秒的工具
tool_input = "New York"
tool_result = weather_tool.execute(tool_input)
profiler.end_phase("tool_execution", details={"tool_name": "WeatherAPI", "tool_result_len": len(tool_result)})
tool_executed = True
# 4. LLM Final Response Generation
profiler.start_phase("llm_response_generation")
final_prompt = f"Query: {query}nContext: {context}nPlan: {llm_plan}n{'Tool result: ' + tool_result if tool_executed else ''}nGenerate final response."
final_response = llm.generate(final_prompt)
profiler.end_phase("llm_response_generation", details={"final_response_len": len(final_response)})
profiler.end_phase("total_agent_run")
return final_response
if __name__ == "__main__":
profiler = AgentProfiler()
print("--- Running Agent with Profiling ---")
response = run_agent_with_profiling("What's the weather like in New York and how does it relate to recent news?", profiler)
print(f"nAgent Final Response: {response}")
# 运行几次,观察不同情况
print("n--- Running Agent again with a simpler query ---")
response = run_agent_with_profiling("Tell me a joke.", profiler)
print(f"nAgent Final Response: {response}")
运行上述代码,您将看到结构化的日志输出,清晰地显示每个阶段的耗时。通过聚合这些日志,您可以计算平均耗时、最大耗时、P95/P99延迟等指标。
3.2 回调函数/观察者模式 (Callbacks/Observer Pattern):框架级支持
现代Agent框架(如LangChain, LlamaIndex)通常提供Callbacks或Hooks机制,允许开发者在Agent生命周期的关键事件点插入自定义逻辑,而无需修改框架核心代码。这是进行无侵入式性能剖析的强大工具。
LangChain CallbackHandler 示例
import time
import json
import logging
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult, ChatGenerationChunk, GenerationChunk
# 再次配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class CustomLatencyProfilerCallback(BaseCallbackHandler):
def __init__(self):
self.start_times: Dict[str, float] = {}
self.trace_data: Dict[str, List[Dict[str, Any]]] = {}
def _get_current_trace_id(self) -> str:
# 简单地使用时间戳作为trace_id,实际生产环境应使用UUID或OpenTelemetry Span ID
return str(int(time.time() * 1000))
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
self.start_times[run_id] = time.perf_counter()
logging.info(json.dumps({"event": "llm_start", "run_id": run_id, "prompts_len": len(prompts[0])}))
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
end_time = time.perf_counter()
duration_ms = (end_time - self.start_times.pop(run_id, end_time)) * 1000
total_tokens = response.llm_output.get("token_usage", {}).get("total_tokens") if response.llm_output else "N/A"
logging.info(json.dumps({
"event": "llm_end",
"run_id": run_id,
"duration_ms": f"{duration_ms:.2f}",
"total_tokens": total_tokens,
"generations": [gen.text[:50] for gen in response.generations[0]] # Log first 50 chars of first generation
}))
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
self.start_times[run_id] = time.perf_counter()
logging.info(json.dumps({"event": "chain_start", "run_id": run_id, "chain_type": serialized.get("lc_name", "Unknown"), "inputs_keys": list(inputs.keys())}))
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
end_time = time.perf_counter()
duration_ms = (end_time - self.start_times.pop(run_id, end_time)) * 1000
logging.info(json.dumps({
"event": "chain_end",
"run_id": run_id,
"duration_ms": f"{duration_ms:.2f}",
"outputs_keys": list(outputs.keys())
}))
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
self.start_times[run_id] = time.perf_counter()
logging.info(json.dumps({"event": "tool_start", "run_id": run_id, "tool_name": serialized.get("lc_kwargs", {}).get("name", "Unknown"), "input_str_len": len(input_str)}))
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
run_id = str(kwargs.get("run_id", UUID("00000000-0000-0000-0000-000000000000")))
end_time = time.perf_counter()
duration_ms = (end_time - self.start_times.pop(run_id, end_time)) * 1000
logging.info(json.dumps({
"event": "tool_end",
"run_id": run_id,
"duration_ms": f"{duration_ms:.2f}",
"output_len": len(output)
}))
# LangChain Agent 示例(使用 OpenAI 模拟)
from langchain_openai import ChatOpenAI
from langchain.agents import AgentExecutor, create_react_agent
from langchain import hub
from langchain_core.tools import Tool
# 模拟一个外部工具
def get_current_weather(location: str) -> str:
"""Get the current weather in a given location."""
time.sleep(0.8) # Simulate network call
if "london" in location.lower():
return "It's 15 degrees Celsius and cloudy in London."
elif "paris" in location.lower():
return "It's 20 degrees Celsius and sunny in Paris."
else:
return f"Weather data for {location} not available."
tools = [
Tool(
name="get_current_weather",
func=get_current_weather,
description="Useful for getting the current weather in a given location",
),
]
# 假设已经设置了 OPENAI_API_KEY 环境变量
llm = ChatOpenAI(temperature=0, callbacks=[CustomLatencyProfilerCallback()])
# 将 callback 实例传递给 LLM 和 AgentExecutor
# prompt = hub.pull("hwchase17/react") # ReAct prompt
prompt = hub.pull("hwchase17/react-chat") # Chat ReAct prompt
agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, callbacks=[CustomLatencyProfilerCallback()])
if __name__ == "__main__":
print("n--- Running LangChain Agent with Callback Profiling ---")
response = agent_executor.invoke({"input": "What's the weather in London?"})
print(f"nLangChain Agent Response: {response['output']}")
print("n--- Running LangChain Agent for a simple question (no tool) ---")
response = agent_executor.invoke({"input": "What is the capital of France?"})
print(f"nLangChain Agent Response: {response['output']}")
通过回调,我们可以更细致地追踪LangChain内部的LLM调用、Tool调用和整个Chain的执行情况。LlamaIndex也提供了类似的CallbackManager机制。
3.3 分布式追踪 (Distributed Tracing):应对复杂Agent和微服务架构
对于生产环境中的复杂Agent系统,特别是当Agent本身是微服务架构的一部分时,分布式追踪是不可或缺的。它能提供端到端的请求流可视化,帮助我们理解请求如何在不同服务、组件之间传递,以及每个环节的耗时。
核心概念:
- Trace (追踪): 表示一个完整的请求流,从开始到结束。
- Span (跨度): Trace中的一个独立操作单元,例如一次函数调用、一次数据库查询、一次外部API调用。Span之间可以有父子关系。
- Span Context: 包含Trace ID和Span ID,用于在服务间传递追踪信息。
工具: OpenTelemetry 是一个跨语言的开放标准,用于收集和导出遥测数据(追踪、指标、日志)。其他如Jaeger、Zipkin是流行的后端存储和可视化工具。LangSmith(LangChain的配套工具)也提供了强大的追踪和调试功能。
OpenTelemetry 集成示例:
# 假设已经安装了 opentelemetry-api, opentelemetry-sdk, opentelemetry-exporter-otlp
# 以及相关的 for_flask, for_requests 等集成库
import os
import time
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.resource import ResourceAttributes
# 配置OpenTelemetry TracerProvider
resource = Resource.create({
ResourceAttributes.SERVICE_NAME: "ai-agent-service",
ResourceAttributes.SERVICE_VERSION: "1.0.0",
})
provider = TracerProvider(resource=resource)
# 将Span导出到控制台,实际生产环境会导出到Jaeger/Zipkin/OTLP Collector
processor = BatchSpanProcessor(ConsoleSpanExporter())
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
tracer = trace.get_tracer(__name__)
# 模拟Embedding模型
class TracedEmbeddingModel:
def embed_query(self, query: str) -> list:
with tracer.start_as_current_span("EmbeddingModel.embed_query") as span:
span.set_attribute("query_length", len(query))
time.sleep(0.1 + len(query) * 0.005)
return [0.1] * 1536
# 模拟LLM调用
class TracedLLM:
def generate(self, prompt: str) -> str:
with tracer.start_as_current_span("LLM.generate") as span:
span.set_attribute("prompt_length", len(prompt))
tokens_in = len(prompt.split())
tokens_out = max(50, len(prompt) // 2)
span.set_attribute("tokens_input", tokens_in)
span.set_attribute("tokens_output", tokens_out)
time.sleep(0.5 + tokens_in * 0.002 + tokens_out * 0.005)
return f"LLM response to: {prompt[:50]}..."
# 模拟工具执行
class TracedTool:
def __init__(self, name: str):
self.name = name
def execute(self, input_data: str) -> str:
with tracer.start_as_current_span(f"Tool.{self.name}.execute") as span:
span.set_attribute("tool_name", self.name)
span.set_attribute("input_data_length", len(input_data))
time.sleep(0.8) # Simulate external API call
return f"Result from {self.name} for {input_data}"
# 模拟Agent的整体运行
def run_traced_agent(query: str):
with tracer.start_as_current_span("Agent.run_full_cycle") as parent_span:
embedding_model = TracedEmbeddingModel()
llm = TracedLLM()
weather_tool = TracedTool("WeatherAPI")
# Embedding
query_embedding = embedding_model.embed_query(query)
# LLM Planning
planning_prompt = f"Embeddings generated for query: {query}. Decide next step."
llm_plan = llm.generate(planning_prompt)
# Tool Execution (conditional)
final_result = ""
if "weather" in query.lower():
tool_result = weather_tool.execute("New York")
final_result = f"Weather info: {tool_result}"
else:
final_result = f"Direct response based on plan: {llm_plan}"
# LLM Final Response
final_response_prompt = f"Based on: {query}, {final_result}. Generate final user response."
final_response = llm.generate(final_response_prompt)
return final_response
if __name__ == "__main__":
print("n--- Running Traced Agent with OpenTelemetry ---")
response = run_traced_agent("What's the weather like in New York?")
print(f"nTraced Agent Final Response: {response}")
print("n--- Running Traced Agent for a different query ---")
response = run_traced_agent("Tell me a story about a dragon.")
print(f"nTraced Agent Final Response: {response}")
# 确保所有span都已导出
time.sleep(1)
运行这段代码,您会在控制台看到OpenTelemetry导出的Span信息,包括每个操作的名称、ID、父ID和持续时间。在实际部署中,这些Span会被发送到Jaeger等系统进行可视化,形成如下的甘特图:
| Trace ID | Span Name | Start Time | End Time | Duration (ms) | Parent Span ID | Attributes |
|---|---|---|---|---|---|---|
| abc-123 | Agent.run_full_cycle | T0 | T0+2500 | 2500 | ||
| abc-123 | └─ EmbeddingModel.embed_query | T0+10 | T0+150 | 140 | Span_A | query_length=32 |
| abc-123 | └─ LLM.generate | T0+160 | T0+960 | 800 | Span_A | prompt_length=50, tokens_input=10, tokens_output=60 |
| abc-123 | └─ Tool.WeatherAPI.execute | T0+970 | T0+1770 | 800 | Span_A | tool_name=WeatherAPI, input_data_length=8 |
| abc-123 | └─ LLM.generate | T0+1780 | T0+2480 | 700 | Span_A | prompt_length=80, tokens_input=15, tokens_output=50 |
通过这种可视化,我们可以一目了然地看到哪些操作占据了大部分时间,以及它们之间的依赖关系。
3.4 专用Profiling工具:深入代码细节
虽然上述方法主要关注宏观组件间的延迟,但有时瓶颈可能隐藏在某个工具的内部实现中,或是某个复杂数据处理函数的CPU密集型操作。这时,Python自带的cProfile或第三方库line_profiler就能派上用场。
cProfile: Python标准库,用于函数级别的CPU和时间开销分析。line_profiler: 更细粒度,可以分析到代码行级别的耗时。
示例:使用 cProfile 剖析一个复杂工具
import cProfile
import pstats
import time
import random
class ComplexDataProcessor:
def __init__(self, data_size: int):
self.data = self._generate_large_data(data_size)
def _generate_large_data(self, size: int) -> list:
return [random.randint(0, 1000) for _ in range(size)]
def process_data_step1(self, data: list) -> list:
# 模拟一个O(N log N)的操作
time.sleep(0.01) # 少量I/O
return sorted(data)
def process_data_step2(self, data: list) -> int:
# 模拟一个O(N)的操作
time.sleep(0.005) # 少量I/O
return sum(x * 2 for x in data if x % 2 == 0)
def perform_heavy_computation(self) -> int:
processed_data1 = self.process_data_step1(self.data)
result = self.process_data_step2(processed_data1)
return result
# 模拟Agent中调用此工具
def agent_tool_wrapper():
processor = ComplexDataProcessor(100000) # 10万个数据点
return processor.perform_heavy_computation()
if __name__ == "__main__":
print("n--- Running cProfile on a complex tool ---")
profiler = cProfile.Profile()
profiler.enable()
agent_tool_wrapper()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumtime') # 按累计时间排序
stats.print_stats(10) # 打印耗时最多的前10个函数
# 也可以保存到文件,用可视化工具查看
# stats.dump_stats("profile_results.prof")
# 然后用 `snakeviz profile_results.prof` 或 `gprof2dot` 等工具可视化
通过cProfile的输出,您可以看到agent_tool_wrapper内部各个函数的调用次数、总耗时和自身耗时,从而精确识别哪个函数是计算瓶颈。
四、实战演练:剖析一个Agent的延迟
现在,我们将结合前述策略,在一个更贴近实际的Agent场景中进行实战演练。
场景设定:
一个问答Agent,能够:
- 根据用户提问,从知识库中检索相关文档(RAG)。
- 利用LLM根据检索结果生成答案。
- 如果用户查询涉及实时信息(如天气),则调用外部天气API获取数据。
Agent核心组件:
- Embedding Model: 用于将查询和文档转换为向量。
- Vector Store: 存储文档向量并进行相似性搜索。
- LLM: OpenAI GPT-4(模拟)。
- Tool: 外部天气API调用。
我们将使用手动计时和回调相结合的方式进行剖析。
import time
import json
import logging
from typing import Dict, Any, List
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult, ChatGenerationChunk, GenerationChunk
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain.agents import create_tool_calling_agent, AgentExecutor
import os
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 1. 自定义CallbackHandler 用于延迟剖析 ---
class AgentProfilingCallback(BaseCallbackHandler):
def __init__(self):
self.start_times: Dict[str, float] = {}
self.events: List[Dict[str, Any]] = []
def _record_event(self, event_type: str, run_id: UUID, duration_ms: Optional[float] = None, **kwargs: Any):
event_data = {
"timestamp": time.time(),
"event_type": event_type,
"run_id": str(run_id)
}
if duration_ms is not None:
event_data["duration_ms"] = f"{duration_ms:.2f}"
event_data.update(kwargs)
self.events.append(event_data)
logging.info(json.dumps(event_data))
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
self.start_times[str(run_id)] = time.perf_counter()
self._record_event("llm_start", run_id, prompt_len=len(prompts[0]))
def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
end_time = time.perf_counter()
duration = (end_time - self.start_times.pop(str(run_id), end_time)) * 1000
total_tokens = response.llm_output.get("token_usage", {}).get("total_tokens") if response.llm_output else "N/A"
self._record_event("llm_end", run_id, duration_ms=duration, total_tokens=total_tokens)
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
self.start_times[str(run_id)] = time.perf_counter()
self._record_event("chain_start", run_id, chain_type=serialized.get("lc_name", "Unknown"), inputs_keys=list(inputs.keys()))
def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
end_time = time.perf_counter()
duration = (end_time - self.start_times.pop(str(run_id), end_time)) * 1000
self._record_event("chain_end", run_id, duration_ms=duration, outputs_keys=list(outputs.keys()))
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
self.start_times[str(run_id)] = time.perf_counter()
self._record_event("tool_start", run_id, tool_name=serialized.get("lc_kwargs", {}).get("name", "Unknown"), input_len=len(input_str))
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
end_time = time.perf_counter()
duration = (end_time - self.start_times.pop(str(run_id), end_time)) * 1000
self._record_event("tool_end", run_id, duration_ms=duration, output_len=len(output))
# For Embedding model calls (not directly covered by standard BaseCallbackHandler for all embedding models,
# but some integrations might provide it or we can wrap it manually)
def on_embedding_start(self, texts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
self.start_times[str(run_id)] = time.perf_counter()
self._record_event("embedding_start", run_id, num_texts=len(texts), first_text_len=len(texts[0]))
def on_embedding_end(self, embeddings: List[List[float]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> None:
end_time = time.perf_counter()
duration = (end_time - self.start_times.pop(str(run_id), end_time)) * 1000
self._record_event("embedding_end", run_id, duration_ms=duration, num_embeddings=len(embeddings))
# --- 2. 模拟外部工具 ---
@tool
def get_current_weather(location: str) -> str:
"""Get the current weather in a given location."""
start_time = time.perf_counter()
time.sleep(random.uniform(0.5, 1.5)) # 模拟外部API调用的不稳定性
end_time = time.perf_counter()
logging.info(json.dumps({
"event": "manual_tool_weather_api_call_duration",
"duration_ms": f"{(end_time - start_time) * 1000:.2f}",
"location": location
}))
if "london" in location.lower():
return "It's 15 degrees Celsius and cloudy in London."
elif "paris" in location.lower():
return "It's 20 degrees Celsius and sunny in Paris."
else:
return f"Weather data for {location} not available. Please try London or Paris."
tools = [get_current_weather]
# --- 3. 设置RAG组件 ---
# 模拟一些文档
docs = [
Document(page_content="The capital of France is Paris. It is known for its Eiffel Tower."),
Document(page_content="London is the capital of England and the United Kingdom."),
Document(page_content="Artificial intelligence (AI) is intelligence demonstrated by machines."),
Document(page_content="Generative AI models like GPT-4 can create human-like text."),
Document(page_content="The fastest land animal is the cheetah, reaching speeds of up to 120 km/h."),
]
# 手动包装Embedding模型以捕获延迟
class TracedOpenAIEmbeddings(OpenAIEmbeddings):
def __init__(self, profiling_callback: AgentProfilingCallback, **kwargs):
super().__init__(**kwargs)
self.profiling_callback = profiling_callback
def embed_documents(self, texts: List[str]) -> List[List[float]]:
run_id = UUID(int=random.getrandbits(128)) # Generate a unique run_id for this embedding call
self.profiling_callback.on_embedding_start(texts=texts, run_id=run_id)
embeddings = super().embed_documents(texts)
self.profiling_callback.on_embedding_end(embeddings=embeddings, run_id=run_id)
return embeddings
def embed_query(self, text: str) -> List[float]:
run_id = UUID(int=random.getrandbits(128))
self.profiling_callback.on_embedding_start(texts=[text], run_id=run_id)
embedding = super().embed_query(text)
self.profiling_callback.on_embedding_end(embeddings=[embedding], run_id=run_id)
return embedding
def setup_rag(profiling_callback: AgentProfilingCallback):
# Ensure OPENAI_API_KEY is set in environment variables
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable not set.")
embeddings_model = TracedOpenAIEmbeddings(profiling_callback=profiling_callback)
# 构建向量数据库
start_time = time.perf_counter()
vectorstore = FAISS.from_documents(docs, embeddings_model)
end_time = time.perf_counter()
logging.info(json.dumps({
"event": "vectorstore_creation_duration",
"duration_ms": f"{(end_time - start_time) * 1000:.2f}"
}))
retriever = vectorstore.as_retriever()
return retriever
# --- 4. 组装Agent ---
def create_agent(profiling_callback: AgentProfilingCallback):
llm = ChatOpenAI(model="gpt-4", temperature=0, callbacks=[profiling_callback])
retriever = setup_rag(profiling_callback)
# RAG chain for answering questions based on retrieved docs
rag_prompt = ChatPromptTemplate.from_messages([
("system", "Answer the user's questions based on the below context:nn{context}"),
("user", "{input}"),
])
document_chain = create_stuff_documents_chain(llm, rag_prompt)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
# Agent for tool use or direct response
agent_prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful assistant. Use the provided tools if necessary."),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
])
agent = create_tool_calling_agent(llm, tools, agent_prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=False, callbacks=[profiling_callback])
return retrieval_chain, agent_executor
# --- 5. 运行Agent并收集数据 ---
if __name__ == "__main__":
profiler = AgentProfilingCallback()
rag_chain, tool_agent_executor = create_agent(profiler)
print("n--- Running RAG Chain ---")
query_rag = "What are some facts about AI?"
start_total_rag = time.perf_counter()
rag_response = rag_chain.invoke({"input": query_rag}, config={"callbacks": [profiler]})
end_total_rag = time.perf_counter()
logging.info(json.dumps({
"event": "total_rag_chain_duration",
"duration_ms": f"{(end_total_rag - start_total_rag) * 1000:.2f}",
"query": query_rag
}))
print(f"RAG Chain Response: {rag_response['answer']}")
print("n--- Running Tool-Calling Agent (with tool) ---")
query_tool = "What's the weather like in London?"
start_total_tool = time.perf_counter()
tool_response = tool_agent_executor.invoke({"input": query_tool}, config={"callbacks": [profiler]})
end_total_tool = time.perf_counter()
logging.info(json.dumps({
"event": "total_tool_agent_duration",
"duration_ms": f"{(end_total_tool - start_total_tool) * 1000:.2f}",
"query": query_tool
}))
print(f"Tool Agent Response: {tool_response['output']}")
print("n--- Running Tool-Calling Agent (no tool) ---")
query_no_tool = "Tell me a fun fact about animals."
start_total_no_tool = time.perf_counter()
no_tool_response = tool_agent_executor.invoke({"input": query_no_tool}, config={"callbacks": [profiler]})
end_total_no_tool = time.perf_counter()
logging.info(json.dumps({
"event": "total_no_tool_agent_duration",
"duration_ms": f"{(end_total_no_tool - start_total_no_tool) * 1000:.2f}",
"query": query_no_tool
}))
print(f"No Tool Agent Response: {no_tool_response['output']}")
# 分析收集到的事件
print("n--- Collected Profiling Events Summary ---")
# 简单的聚合统计
event_durations = {}
for event in profiler.events:
event_type = event.get("event_type")
duration = event.get("duration_ms")
if event_type and duration:
try:
duration_float = float(duration)
if event_type not in event_durations:
event_durations[event_type] = []
event_durations[event_type].append(duration_float)
except ValueError:
pass
for event_type, durations in event_durations.items():
if durations:
print(f"Event Type: {event_type}")
print(f" Avg Duration: {sum(durations)/len(durations):.2f} ms")
print(f" Max Duration: {max(durations):.2f} ms")
print(f" Min Duration: {min(durations):.2f} ms")
print(f" Count: {len(durations)}")
print("-" * 20)
分析点与结果解读:
运行上述代码,您将看到详细的日志输出。通过这些日志,我们可以进行以下分析:
A. 剖析Embedding Generation的延迟
- 日志关注点:
embedding_start和embedding_end事件。 - 判断依据:
TracedOpenAIEmbeddings类中的计时日志。 - 典型瓶颈:
- API调用延迟: OpenAI
text-embedding-ada-002模型虽然高效,但网络延迟和API服务器负载仍会影响。 - 批量大小: 如果一次性处理大量文档,
embed_documents会产生显著延迟。 - 本地模型性能: 若使用本地Embedding模型,则CPU/GPU性能、模型大小是关键。
- API调用延迟: OpenAI
优化策略:
- 缓存: 对频繁查询的Embedding结果进行缓存。
- 异步Embedding: 利用
asyncio并行调用Embedding API。 - 批量处理: 合理设置批处理大小,平衡延迟和吞吐量。
- 模型选择: 考虑更轻量、更快的Embedding模型(如果精度要求允许)。
- 向量数据库优化: 确保向量数据库索引高效,减少查询时间(例如,使用HNSW索引)。
B. 剖析LLM Inference的延迟
- 日志关注点:
llm_start和llm_end事件。 - 判断依据:
ChatOpenAI内部的回调计时。 - 典型瓶颈:
- 模型选择: GPT-4通常比GPT-3.5-turbo慢很多,但更智能。选择合适的模型是平衡性能和效果的关键。
- Prompt长度: 输入Prompt越长,LLM处理时间越久。
- 生成Token数: LLM输出的Token越多,生成时间越长。
- 网络延迟: 与LLM提供商的API服务器之间的网络延迟。
- API速率限制: 达到API的每分钟Token数(TPM)或每分钟请求数(RPM)限制。
优化策略:
- 模型选择: 优先使用更快的模型,如
gpt-3.5-turbo或特定任务微调的模型。 - Prompt工程: 优化Prompt,使其更简洁,减少不必要的上下文,同时确保LLM能理解意图。
- 限制输出长度: 通过
max_tokens参数控制LLM的输出长度。 - 流式输出 (Streaming): 尽管不能减少总延迟,但能显著改善用户体验,让用户感觉Agent响应更快。
- LLM缓存: 对于重复的、确定性强的查询,缓存LLM的响应。
- 并行LLM调用: 如果Agent的逻辑允许,可以并行发起多个LLM请求。
- 本地化模型: 对于私有化部署场景,考虑使用本地部署的开源LLM,并优化硬件加速。
C. 剖析Tool Execution的延迟
- 日志关注点:
tool_start和tool_end事件,以及手动记录的manual_tool_weather_api_call_duration。 - 判断依据:
get_current_weather函数内部的计时和回调计时。 - 典型瓶颈:
- 外部API响应时间: 外部服务本身的性能、网络状况。
- 数据库查询: 工具内部可能执行数据库操作,索引缺失、复杂查询等都会导致慢。
- 复杂计算: 工具内部的业务逻辑如果包含CPU密集型计算。
- 并发限制: 某些外部服务可能对并发请求有限制。
优化策略:
- 优化工具内部逻辑: 对工具函数内部进行精细化Profiling(如使用
cProfile),优化算法或数据结构。 - 外部服务优化: 如果可以,优化被调用的外部服务的性能。
- 异步I/O: 工具如果涉及大量网络I/O或文件I/O,考虑使用
asyncio进行非阻塞操作。 - 缓存: 缓存工具的重复查询结果。
- 重试与熔断: 对于不稳定的外部服务,实现幂等的重试机制和熔断模式,提高健壮性。
- 工具选择优化: Agent在选择工具时,应优先选择更快、更可靠的工具(如果存在多个选项)。
五、案例分析与常见陷阱
5.1 案例分析
-
案例1: RAG Agent,文档量大,查询慢
- 问题: 用户查询RAG Agent时,响应时间过长。
- 诊断: Profiling显示
embedding_start->embedding_end和vectorstore_query阶段耗时占比超过50%。 - 解决方案:
- 优化Embedding检索: 实施混合检索(Hybrid Search),结合稀疏向量(BM25)和密集向量,有时能更快地筛选出相关文档。
- 优化向量数据库: 升级硬件、调整索引参数、进行数据分片或分区。
- 预计算和缓存: 预计算常见查询的Embedding并缓存检索结果。
- 异步加载: 异步加载检索结果,而不是阻塞等待。
-
案例2: 工具密集型Agent,外部API频繁超时
- 问题: Agent在执行某些复杂任务时,经常卡顿或失败。
- 诊断: Profiling显示
tool_start->tool_end阶段的延迟非常高且波动大,甚至出现超时。 - 解决方案:
- 实现重试机制: 对外部API调用添加指数退避(Exponential Backoff)的重试逻辑。
- 引入熔断机制: 当外部API持续失败时,暂时停止调用该API,避免雪崩效应。
- 异步调用与超时设置: 将工具调用改为异步,并设置合理的超时时间。
- 优化外部服务: 如果是内部服务,与服务提供方协作优化其性能。
-
案例3: 通用Agent,LLM推理慢,用户体验差
- 问题: 无论用户问什么,Agent的响应都需要较长时间。
- 诊断: Profiling显示
llm_start->llm_end阶段的耗时占比最高。 - 解决方案:
- 模型切换: 在非关键任务中,考虑使用更小、更快的模型(如GPT-3.5-turbo),或者专门针对任务优化的模型。
- Prompt优化: 精简Prompt,减少不必要的指导语和上下文,争取用更少的token表达清楚意图。
- 流式输出: 启用LLM的流式响应功能,让用户能看到逐字生成的过程,提升感知速度。
- 缓存: 对频繁出现的、确定性高的LLM响应进行缓存。
5.2 常见陷阱
- 只关注平均延迟: 平均值可能掩盖了尾部延迟(P95、P99),即少数非常慢的请求对用户体验的影响。务必关注这些高百分位延迟。
- 开发与生产环境差异: 开发环境的资源、网络、数据量都与生产环境不同,性能表现也可能大相径庭。务必在接近生产的环境进行性能测试。
- I/O绑定与CPU绑定混淆: 区分是等待网络/磁盘I/O造成的延迟,还是CPU密集型计算造成的延迟。I/O绑定问题通常通过异步、缓存解决;CPU绑定问题则需要算法优化、并行计算或更强的硬件。
- 过度优化非瓶颈环节: 在没有明确证据表明某个环节是瓶颈之前,投入大量精力去优化它,往往是浪费时间。始终遵循“二八原则”,优化最重要的20%。
- 忽略Agent框架自身的开销: 虽然通常不是主要瓶颈,但在某些情况下,Agent框架自身的复杂逻辑、数据结构转换、日志记录等也可能累积成可观的开销。
六、持续监控与自动化
性能剖析不应该是一次性任务。在Agent上线后,持续的性能监控至关重要。
- 集成APM(Application Performance Management)系统: 将Agent的性能指标(延迟、错误率、吞吐量)集成到Prometheus、Grafana、Datadog等APM系统中,构建实时仪表盘。
- 设置告警: 对关键指标设置阈值,例如P95延迟超过N秒、错误率超过M%时自动触发告警。
- A/B测试与灰度发布: 在部署新功能或优化时,通过A/B测试或灰度发布来验证性能改进,并监控可能引入的回归。
- 自动化性能测试: 将性能测试集成到CI/CD流程中,每次代码提交或部署前自动运行性能基准测试,防止性能倒退。
七、展望
AI Agent领域发展迅猛,延迟剖析技术也将随之演进。未来,我们可以期待:
- 更智能的Agent框架: 内置更强大的、开箱即用的Profiling和追踪能力,甚至能自动识别瓶颈并提出优化建议。
- 实时推理优化: 结合编译器优化、硬件加速(如GPU、TPU),实现更低的LLM推理延迟。
- 边缘AI Agent: 将部分Agent能力下沉到边缘设备,减少网络传输延迟,提升响应速度。
掌握延迟剖析是构建高性能、用户友好的AI Agent的关键技能。通过系统的方法、合适的工具和持续的监控,您将能够精准识别并解决Agent的性能瓶颈,从而交付卓越的智能体验。