各位来宾,各位同事,大家好!
今天,我们齐聚一堂,探讨一个在当前AI时代极具前瞻性和实践意义的话题——“思考的优化”(Optimization of Thought)。随着大型语言模型(LLM)驱动的Agent日益普及,它们在执行复杂任务时展现出的强大能力令人惊叹。然而,这种能力并非没有代价。Agent的每一次“思考”、每一次工具调用、每一次与LLM的交互,都伴随着计算资源的消耗、API调用的延迟,以及最直观的——Token的开销。
在Agent的世界里,一次推理过程可能涉及多个步骤、多条路径。它像是在一个迷宫中寻找出路,有些路宽敞平坦,直达目标;有些路则蜿蜒曲折,耗时耗力。我们的目标,就是利用工程化的手段,找到Agent推理路径中的“黄金路径”——那条最短、最省Token,同时又能高效达成目标的路径。
而LangGraph,作为LangChain家族中的一员,为我们构建这种复杂、有状态的Agent提供了强大的框架。它将Agent的行为建模为状态机,让复杂的决策流变得可管理、可观测。今天,我将向大家展示如何将LangGraph与经典的图算法结合,系统性地实现Agent思考路径的优化。
1. Agent推理的本质:一个动态图
要理解如何优化Agent的思考,我们首先需要将其“思考过程”抽象化。想象一个Agent从接收到任务开始,到最终给出答案的整个过程。这个过程可以被分解为一系列离散的步骤:分析用户输入、调用工具搜索信息、生成中间思考、决定下一步行动等等。
这些步骤,我们可以将其视为图论中的“节点”(Nodes)。而Agent从一个步骤转移到另一个步骤的决策或执行,则构成了图中的“边”(Edges)。由于Agent的决策往往是动态的,依赖于LLM的输出或工具调用的结果,因此这个图是一个动态的、可能包含条件分支和循环的图。
LangGraph正是这样一种将Agent行为建模为状态机的强大工具。它允许我们定义:
- 状态(State):Agent在某一时刻的所有上下文信息,例如用户输入、历史对话、工具返回结果等。
- 节点(Nodes):执行特定操作的单元,可以是LLM调用、工具调用、或者自定义的Python函数。
- 边(Edges):定义了Agent如何从一个节点转移到另一个节点。可以是固定的,也可以是基于当前状态的条件判断。
让我们通过一个简单的LangGraph Agent示例来直观感受一下:
# 确保安装了必要的库
# pip install langchain langchain_community langgraph
import os
from typing import TypedDict, Annotated, List, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END
# 假设已经设置了OPENAI_API_KEY环境变量
# os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
# 定义Agent的状态
class AgentState(TypedDict):
"""
Agent的状态定义。
messages: 存储对话历史和中间思考。
tool_input: 存储工具调用的输入。
"""
messages: Annotated[List[BaseMessage], lambda x: x]
tool_input: str
intermediate_steps: Annotated[List[AgentAction], lambda x: x] # 存储AgentAction用于跟踪
# 定义一个简单的工具
@tool
def search_web(query: str) -> str:
"""在网络上搜索信息。"""
print(f"n--- Calling search_web with query: {query} ---")
# 模拟网络搜索结果
if "Python LangGraph" in query:
return "LangGraph是一个用于构建健壮、有状态的Agent的库,它将Agent的决策流建模为状态机。"
elif "北京天气" in query:
return "北京今天晴,气温20-30摄氏度。"
else:
return "未找到相关信息。"
# 初始化LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0)
# 绑定工具到LLM
tools = [search_web]
llm_with_tools = llm.bind_tools(tools)
# 定义Agent的节点函数
def call_llm(state: AgentState):
"""调用LLM进行思考或生成最终回复。"""
print(f"n--- Calling LLM for thought/response ---")
messages = state['messages']
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
def call_tool(state: AgentState):
"""执行AgentAction中指定的工具调用。"""
print(f"n--- Calling Tool ---")
last_message = state['messages'][-1]
if isinstance(last_message, AgentAction):
action = last_message
tool_name = action.tool
tool_input = action.tool_input
print(f"Executing tool: {tool_name} with input: {tool_input}")
# 查找并执行工具
for t in tools:
if t.name == tool_name:
tool_output = t.invoke(tool_input)
# 将工具输出添加回消息历史
return {"messages": [tool_output]}
raise ValueError(f"Tool {tool_name} not found.")
else:
raise ValueError("Last message is not an AgentAction, cannot call tool.")
# 定义条件路由函数
def should_continue(state: AgentState) -> str:
"""根据LLM的输出决定下一步是调用工具还是结束。"""
print(f"n--- Deciding next step ---")
last_message = state['messages'][-1]
if isinstance(last_message, AgentAction):
# 如果LLM决定调用工具,则下一步是执行工具
print("LLM decided to call a tool.")
return "call_tool"
else:
# 如果LLM生成了最终回复,则结束Agent
print("LLM generated a final response.")
return "end"
# 构建LangGraph工作流
workflow = StateGraph(AgentState)
# 添加节点
workflow.add_node("llm", call_llm)
workflow.add_node("call_tool", call_tool)
# 设置入口点
workflow.set_entry_point("llm")
# 添加边
workflow.add_conditional_edges(
"llm", # 从LLM节点出发
should_continue, # 根据should_continue函数的返回值决定走向
{
"call_tool": "call_tool", # 如果返回"call_tool",则走向call_tool节点
"end": END # 如果返回"end",则结束
}
)
workflow.add_edge("call_tool", "llm") # 工具执行完后,重新回到LLM节点进行下一步思考或总结
# 编译工作流
app = workflow.compile()
# 运行一个示例
# print("--- Running Agent for '北京天气' ---")
# initial_state = {"messages": [("user", "告诉我北京今天的天气如何?")], "tool_input": "", "intermediate_steps": []}
# for s in app.stream(initial_state):
# print(s)
# print("n--- Running Agent for '什么是LangGraph?' ---")
# initial_state = {"messages": [("user", "什么是LangGraph?")], "tool_input": "", "intermediate_steps": []}
# for s in app.stream(initial_state):
# print(s)
在这个示例中,我们定义了两个核心节点:llm(负责思考和生成回复)和call_tool(负责执行工具)。should_continue函数则是一个条件路由器,它根据LLM的输出决定下一步是继续调用工具还是结束整个流程。这是一个非常经典的ReAct(Reasoning and Acting)模式的简化版。
Agent在执行过程中,会从llm节点开始,如果需要工具,则转到call_tool,工具执行完毕后再回到llm节点,直到最终生成一个不需要工具的回复,从而结束整个推理过程。这个过程就是Agent在图中动态遍历的过程。
2. 定义“成本”:如何量化思考的代价?
要优化Agent的思考路径,我们首先需要量化“思考的代价”。在Agent的实际运行中,最直接、最易于量化的成本指标是:
- Token数量:这是与LLM交互最主要的成本来源。包括发送给LLM的输入Token和从LLM接收到的输出Token。不同的模型、不同的API提供商,Token的价格各不相同,但Token数量始终是计算成本的基础。
- 延迟(Latency):每次LLM调用或工具执行都需要时间。优化延迟可以显著提升用户体验。
- API调用次数:一些API可能有调用频率限制或按调用次数计费。
- 计算资源:对于本地部署的LLM或复杂工具,计算资源的消耗也是成本的一部分。
本次讲座,我们将主要聚焦于Token数量作为我们的核心优化指标,因为它是最普遍且直接影响成本的因素。“最短”路径,在这里特指“Token消耗最少”的路径。
LangChain提供了一个强大的回调系统(Callbacks),允许我们监控Agent的执行过程,包括LLM调用、工具调用等。我们可以利用这个系统来精确地追踪Token使用情况。
自定义Token计数器
为了捕获每个LLM调用产生的Token数,我们可以创建一个自定义的回调处理器。
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Dict, Any
class TokenCostCallbackHandler(BaseCallbackHandler):
"""
一个自定义回调处理器,用于跟踪LLM调用的Token使用情况。
"""
def __init__(self):
self.total_input_tokens = 0
self.total_output_tokens = 0
self.total_cost = 0.0 # 假设一个简单的成本模型
self.llm_call_count = 0
self.current_node_tokens = {} # 存储当前节点(或操作)的token
self.node_costs_log = [] # 记录每次节点执行的成本
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""在LLM调用结束时记录Token使用。"""
if response.llm_output is not None:
token_usage = response.llm_output.get("token_usage")
if token_usage:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.llm_call_count += 1
# 假设GPT-4o的输入和输出Token成本
# 实际应用中应查询最新的API价格
# https://openai.com/pricing
input_cost_per_token = 5.00 / 1_000_000 # $5 / 1M tokens
output_cost_per_token = 15.00 / 1_000_000 # $15 / 1M tokens
current_call_cost = (input_tokens * input_cost_per_token) +
(output_tokens * output_cost_per_token)
self.total_cost += current_call_cost
# 记录到当前节点的token信息中
self.current_node_tokens['input_tokens'] = self.current_node_tokens.get('input_tokens', 0) + input_tokens
self.current_node_tokens['output_tokens'] = self.current_node_tokens.get('output_tokens', 0) + output_tokens
self.current_node_tokens['cost'] = self.current_node_tokens.get('cost', 0.0) + current_call_cost
print(f" [Callback] LLM Call Ended: Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, Cost: ${current_call_cost:.6f}")
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
"""在工具调用结束时,我们可以模拟一个固定成本或基于复杂度的成本。"""
# 工具调用本身可能不直接消耗Token,但可能消耗时间或外部服务费用
# 这里我们假设工具调用有一个固定的“Token等效成本”或者直接成本
tool_cost = 0.01 # 假设每次工具调用固定成本0.01美元
tool_equivalent_tokens = 20 # 假设每次工具调用相当于20个Token的成本
self.total_cost += tool_cost
self.current_node_tokens['cost'] = self.current_node_tokens.get('cost', 0.0) + tool_cost
self.current_node_tokens['tool_equivalent_tokens'] = self.current_node_tokens.get('tool_equivalent_tokens', 0) + tool_equivalent_tokens
print(f" [Callback] Tool Call Ended: Output: {output[:50]}..., Assumed Cost: ${tool_cost:.4f}")
def reset_current_node_tokens(self, node_name: str):
"""重置并记录当前节点的Token信息。"""
if self.current_node_tokens:
self.node_costs_log.append({
"node": node_name,
"input_tokens": self.current_node_tokens.get('input_tokens', 0),
"output_tokens": self.current_node_tokens.get('output_tokens', 0),
"tool_equivalent_tokens": self.current_node_tokens.get('tool_equivalent_tokens', 0),
"cost": self.current_node_tokens.get('cost', 0.0)
})
self.current_node_tokens = {}
def get_total_usage(self) -> Dict[str, Any]:
return {
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_llm_calls": self.llm_call_count,
"total_cost": self.total_cost,
"node_costs_log": self.node_costs_log
}
# 为了让AgentState能够记录路径和成本,我们对其进行扩展
class OptimizedAgentState(AgentState):
current_node: str = "" # 记录当前执行的节点名称
path_costs: Annotated[List[Dict[str, Any]], lambda x: x] # 记录每一步的成本日志
# 修改LLM和Tool节点以使用回调并更新状态
def call_llm_optimized(state: OptimizedAgentState, config: Dict[str, Any]):
"""调用LLM,并更新状态中的当前节点和成本信息。"""
current_node_name = "llm"
state['current_node'] = current_node_name
callback_handler: TokenCostCallbackHandler = config.get("callbacks")[0]
callback_handler.reset_current_node_tokens(current_node_name) # 记录上一个节点的成本并准备记录当前节点的
print(f"n--- Calling LLM for thought/response (Node: {current_node_name}) ---")
messages = state['messages']
response = llm_with_tools.invoke(messages, config={"callbacks": [callback_handler]})
# 强制记录当前LLM节点的成本
callback_handler.reset_current_node_tokens(current_node_name)
# 更新Agent状态
new_messages = state['messages'] + [response]
return {"messages": new_messages, "current_node": current_node_name}
def call_tool_optimized(state: OptimizedAgentState, config: Dict[str, Any]):
"""执行工具调用,并更新状态中的当前节点和成本信息。"""
current_node_name = "call_tool"
state['current_node'] = current_node_name
callback_handler: TokenCostCallbackHandler = config.get("callbacks")[0]
callback_handler.reset_current_node_tokens(current_node_name) # 记录上一个节点的成本并准备记录当前节点的
print(f"n--- Calling Tool (Node: {current_node_name}) ---")
last_message = state['messages'][-1]
if isinstance(last_message, AgentAction):
action = last_message
tool_name = action.tool
tool_input = action.tool_input
print(f"Executing tool: {tool_name} with input: {tool_input}")
for t in tools:
if t.name == tool_name:
tool_output = t.invoke(tool_input, config={"callbacks": [callback_handler]})
# 强制记录当前Tool节点的成本
callback_handler.reset_current_node_tokens(current_node_name)
# 将工具输出添加回消息历史
new_messages = state['messages'] + [tool_output]
return {"messages": new_messages, "current_node": current_node_name}
raise ValueError(f"Tool {tool_name} not found.")
else:
raise ValueError("Last message is not an AgentAction, cannot call tool.")
# 重新定义should_continue,主要为了打印日志
def should_continue_optimized(state: OptimizedAgentState) -> str:
"""根据LLM的输出决定下一步是调用工具还是结束。"""
current_node_name = "should_continue" # 路由节点本身没有成本,但我们可以在这里记录上一个节点的成本
# 路由节点本身不产生直接的LLM或工具成本,但它是状态转移的关键点
# 它的成本可以忽略,或者在它前面的节点中累加
print(f"n--- Deciding next step (Node: {current_node_name}) ---")
last_message = state['messages'][-1]
if isinstance(last_message, AgentAction):
print("LLM decided to call a tool.")
return "call_tool"
else:
print("LLM generated a final response.")
return "end"
# 重新构建工作流以使用优化后的节点
optimized_workflow = StateGraph(OptimizedAgentState)
optimized_workflow.add_node("llm", call_llm_optimized)
optimized_workflow.add_node("call_tool", call_tool_optimized)
optimized_workflow.set_entry_point("llm")
optimized_workflow.add_conditional_edges(
"llm",
should_continue_optimized,
{
"call_tool": "call_tool",
"end": END
}
)
optimized_workflow.add_edge("call_tool", "llm")
optimized_app = optimized_workflow.compile()
# 在运行Agent时传入回调处理器
# callback_handler = TokenCostCallbackHandler()
# initial_state_optimized = {"messages": [("user", "什么是LangGraph?")], "tool_input": "", "intermediate_steps": [], "current_node": "", "path_costs": []}
# for s in optimized_app.stream(initial_state_optimized, config={"callbacks": [callback_handler]}):
# pass # 只是运行,不打印中间状态,因为回调函数已打印
# total_usage = callback_handler.get_total_usage()
# print("n=== Total Usage for LangGraph Query ===")
# print(f"Total Input Tokens: {total_usage['total_input_tokens']}")
# print(f"Total Output Tokens: {total_usage['total_output_tokens']}")
# print(f"Total LLM Calls: {total_usage['total_llm_calls']}")
# print(f"Estimated Total Cost: ${total_usage['total_cost']:.6f}")
# print("nNode Costs Log:")
# for entry in total_usage['node_costs_log']:
# print(f" Node: {entry['node']}, Input: {entry['input_tokens']}, Output: {entry['output_tokens']}, Tool Eq. Tokens: {entry['tool_equivalent_tokens']}, Cost: ${entry['cost']:.6f}")
在TokenCostCallbackHandler中,我们不仅记录了LLM的Token使用,还为工具调用设置了一个假设的“等效Token”成本或直接成本。在实际应用中,工具的成本可能需要根据其具体的API价格或执行时间来精确计算。关键在于,我们现在有了一个量化每个操作(LLM调用或工具调用)成本的机制。
3. 图算法的武器库:寻找最短路径
有了量化成本的能力,我们就可以将Agent的推理过程视为一个加权图,其中边上的权重就是执行该步骤所产生的成本。我们的目标,就是在这样一个加权图中找到从起点到终点的最短路径。
对于这类问题,图算法提供了强大的解决方案:
- 广度优先搜索 (BFS):用于寻找无权图中的最短路径(即最少跳数)。它不考虑边的权重,因此不适用于我们的加权成本场景。
- 深度优先搜索 (DFS):倾向于深入探索一条路径,不保证找到最短路径。
- Dijkstra (迪杰斯特拉) 算法:这是解决单源最短路径问题的经典算法,适用于边的权重非负的加权图。它能够找到从指定起点到所有其他节点的最短路径。这正是我们需要的!
- A* 搜索算法:Dijkstra算法的扩展,引入了启发式函数来指导搜索方向,可以更快地找到目标节点,尤其是在大型图中。如果能设计一个好的启发式函数,A*会更高效。
对于Agent的推理路径优化,Dijkstra算法是一个非常合适的选择。
Dijkstra算法核心思想
Dijkstra算法通过维护一个所有节点到起点的最短距离估计值集合,并不断更新这些估计值,直到找到真正的最短路径。它使用一个优先队列来高效地选择下一个要处理的节点。
| 步骤序号 | 描述 | |
|---|---|---|
| 1. | 初始化:设置起点到自身的距离为0,到所有其他节点的距离为无穷大。创建一个优先队列,将起点加入队列,优先级为0(距离)。同时,维护一个集合来记录已经访问过的节点。 | |
| 2. | 循环:只要优先队列不为空: | |
| a. | 取出最小距离节点:从优先队列中取出距离最小的节点 u。如果 u 已经被访问过,则跳过。 |
|
| b. | 标记访问:将 u 标记为已访问。 |
|
| c. | 更新邻居距离:遍历 u 的所有未访问邻居 v。计算从起点经过 u 到 v 的距离:dist[u] + weight(u, v)。如果这个距离小于 dist[v](当前记录的起点到 v 的最短距离),则更新 dist[v],并将 v 及新距离加入优先队列。同时,记录 u 是 v 的前驱节点,以便重构路径。 |
|
| 3. | 结束:当优先队列为空,或者目标节点被取出并标记为已访问时,算法结束。此时,dist 字典中存储的就是从起点到所有节点的最短距离,通过前驱节点可以重构出路径。 |
Python实现Dijkstra算法
import heapq
def dijkstra_shortest_path(graph: Dict[str, List[tuple[str, float]]], start_node: str, end_node: str) -> tuple[float, List[str]]:
"""
使用Dijkstra算法查找从start_node到end_node的最短路径及其成本。
Args:
graph: 表示加权图的字典。键是节点名称(字符串),值是一个列表,
列表中每个元素是一个元组 (neighbor_node, weight),表示一条边。
start_node: 起始节点名称。
end_node: 目标节点名称。
Returns:
一个元组 (shortest_distance, shortest_path_nodes)。
如果无法到达end_node,则返回 (float('inf'), [])。
"""
# distances: 存储从start_node到每个节点的最短距离
# 初始化为无穷大,start_node到自身为0
distances = {node: float('inf') for node in graph}
distances[start_node] = 0
# predecessors: 存储每个节点在最短路径中的前一个节点,用于重构路径
predecessors = {node: None for node in graph}
# priority_queue: 优先队列,存储 (distance, node) 元组
# 按照distance从小到大排序
priority_queue = [(0, start_node)] # (distance, node)
while priority_queue:
current_distance, current_node = heapq.heappop(priority_queue)
# 如果当前距离已经大于记录的最短距离,说明我们找到了更短的路径,跳过
if current_distance > distances[current_node]:
continue
# 如果我们已经到达了目标节点,可以提前结束
if current_node == end_node:
break
# 遍历当前节点的所有邻居
for neighbor, weight in graph.get(current_node, []):
distance = current_distance + weight
# 如果找到了更短的路径
if distance < distances[neighbor]:
distances[neighbor] = distance
predecessors[neighbor] = current_node
heapq.heappush(priority_queue, (distance, neighbor))
# 重构最短路径
path = []
current = end_node
while current is not None and current in predecessors:
path.insert(0, current)
current = predecessors[current]
if current == start_node: # 添加起点
path.insert(0, current)
break
# 检查路径是否完整,即是否从start_node开始
if path and path[0] == start_node:
return distances[end_node], path
else:
return float('inf'), [] # 无法到达目标节点
4. LangGraph与Dijkstra的结合:构建“黄金路径”搜索系统
现在我们有了LangGraph来构建Agent,有了Token计数器来量化成本,也有了Dijkstra算法来寻找最短路径。接下来,就是如何将它们有机地结合起来,构建一个“黄金路径”搜索系统。
挑战
- 动态图的表示:LangGraph的执行路径是动态的,依赖于LLM的实时决策。我们无法预先画出所有可能的路径。
- 状态依赖的成本:同一个节点,在不同的输入状态下,其执行成本(例如LLM调用生成的内容长度)可能不同。
策略
我们的策略是通过“模拟运行”来构建“经验性加权图”,然后在这个经验图上应用Dijkstra算法。
- Agent设计:首先,我们设计一个LangGraph Agent,使其能够完成既定任务,并包含多种可能的推理路径。
- 路径探索与成本记录:我们运行Agent多次,使用不同的输入或在关键决策点上模拟不同的LLM输出,以覆盖尽可能多的潜在路径。在每次运行中,我们利用
TokenCostCallbackHandler精确记录每个节点(LLM调用或工具调用)的Token成本。 - 构建经验性加权图:从Agent的执行日志中,我们提取出节点序列和每一步的成本。我们将LangGraph的每个“节点”视为Dijkstra图中的一个“状态”,而从一个LangGraph节点到下一个LangGraph节点的“转移”以及执行下一个节点所产生的成本,视为Dijkstra图中的“边”及其“权重”。
- 重要说明:Dijkstra图中的节点,实际上是LangGraph中的“状态点”或“处理单元”。例如,从
llm节点到call_tool节点的转移,其成本是call_tool节点执行的成本。
- 重要说明:Dijkstra图中的节点,实际上是LangGraph中的“状态点”或“处理单元”。例如,从
- 执行Dijkstra算法:在构建好的经验加权图上,我们运行Dijkstra算法,找出从Agent的入口点到某个“目标状态”(例如,生成最终回复)的最短(最低成本)路径。
- 优化应用:根据Dijkstra算法发现的“黄金路径”,我们可以采取多种措施来优化Agent的行为。
5. 实战演练:一个客户支持Agent的优化之旅
让我们通过一个具体的客户支持Agent的例子来实践这个过程。
场景设定
假设我们有一个客户支持Agent,它的任务是:
- 接收用户查询。
- 分类查询:判断查询类型(产品信息、技术支持、订单查询、通用问候)。
- 搜索知识库:如果需要,搜索内部知识库获取答案。
- 生成回复:基于分类和搜索结果生成友好的回复。
- 升级人工:如果无法解决或查询复杂,将问题升级给人工客服。
这个Agent的推理路径可能非常多样:
- 路径A (高效):
用户输入 -> 分类(通用问候) -> 生成回复(问候语) - 路径B (中等):
用户输入 -> 分类(产品信息) -> 搜索知识库 -> 生成回复(产品信息) - 路径C (复杂):
用户输入 -> 分类(技术支持) -> 搜索知识库 -> 无法解决 -> 升级人工
我们的目标是,对于常见的查询类型,找到并优化其最低成本的路径。
Agent架构(LangGraph代码)
首先,我们扩展OptimizedAgentState以更好地记录路径。
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
class CustomerSupportAgentState(TypedDict):
"""
客户支持Agent的状态定义。
messages: 存储对话历史和中间思考。
query_type: 分类后的查询类型。
search_results: 知识库搜索结果。
escalated: 布尔值,表示是否已升级人工。
# 用于路径追踪和成本分析
current_node: str
path_log: Annotated[List[Dict[str, Any]], lambda x: x] # 记录每次节点转移的详细信息
"""
messages: Annotated[List[BaseMessage], lambda x: x]
query_type: str
search_results: str
escalated: bool
current_node: str
path_log: Annotated[List[Dict[str, Any]], lambda x: x]
# 定义工具
@tool
def knowledge_base_search(query: str) -> str:
"""在内部知识库中搜索关于产品或技术支持的信息。"""
print(f"n--- Calling knowledge_base_search with query: '{query}' ---")
if "产品A功能" in query:
return "产品A具有实时数据分析、报表生成和用户权限管理功能。"
elif "登录问题" in query:
return "登录问题通常可通过重置密码或检查网络连接解决。若仍无法解决,请联系技术支持。"
elif "退货政策" in query:
return "我们的退货政策允许在购买后30天内退货,商品需保持原状并附带购买凭证。"
else:
return "知识库中未找到关于“" + query + "”的相关信息。"
tools_cs = [knowledge_base_search]
llm_cs = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools_cs = llm_cs.bind_tools(tools_cs)
# 定义回调处理器
class CS_TokenCostCallbackHandler(TokenCostCallbackHandler):
"""为客户支持Agent定制的TokenCostCallbackHandler。"""
def __init__(self):
super().__init__()
self.current_agent_step_info = {} # 临时存储当前LangGraph步骤的信息
def on_chain_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
"""在LangGraph节点(链)开始时记录节点名称。"""
node_name = serialized.get("name", "Unknown Node")
self.current_agent_step_info = {"node": node_name, "start_time": time.time()}
print(f"n[Callback] Starting Node: {node_name}")
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""在LangGraph节点(链)结束时记录成本并添加到path_log。"""
end_time = time.time()
node_name = self.current_agent_step_info.get("node", "Unknown Node")
# 确保在on_llm_end或on_tool_end中记录的current_node_tokens已经累加了当前节点的所有成本
# 这里我们收集并清空 current_node_tokens
cost_details = self.current_node_tokens.copy()
self.current_node_tokens = {} # 清空,准备下一个节点
if node_name != "Unknown Node": # 避免记录外部Chain的成本
self.node_costs_log.append({
"node": node_name,
"duration": end_time - self.current_agent_step_info.get("start_time", end_time),
**cost_details # 合并LLM/Tool回调中累积的成本信息
})
print(f"[Callback] Ended Node: {node_name}, Cost: ${cost_details.get('cost', 0.0):.6f}")
# 重写on_llm_end和on_tool_end,确保它们累加到self.current_node_tokens
# 并在on_chain_end时统一处理
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if response.llm_output is not None:
token_usage = response.llm_output.get("token_usage")
if token_usage:
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
self.total_input_tokens += input_tokens
self.total_output_tokens += output_tokens
self.llm_call_count += 1
input_cost_per_token = 5.00 / 1_000_000 # $5 / 1M tokens
output_cost_per_token = 15.00 / 1_000_000 # $15 / 1M tokens
current_call_cost = (input_tokens * input_cost_per_token) +
(output_tokens * output_cost_per_token)
self.total_cost += current_call_cost
self.current_agent_step_info['input_tokens'] = self.current_agent_step_info.get('input_tokens', 0) + input_tokens
self.current_agent_step_info['output_tokens'] = self.current_agent_step_info.get('output_tokens', 0) + output_tokens
self.current_agent_step_info['cost'] = self.current_agent_step_info.get('cost', 0.0) + current_call_cost
print(f" [Callback] LLM Sub-call Ended: Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, Cost: ${current_call_cost:.6f}")
def on_tool_end(self, output: Any, **kwargs: Any) -> None:
tool_cost = 0.01 # 假设每次工具调用固定成本0.01美元
tool_equivalent_tokens = 20
self.total_cost += tool_cost
self.current_agent_step_info['cost'] = self.current_agent_step_info.get('cost', 0.0) + tool_cost
self.current_agent_step_info['tool_equivalent_tokens'] = self.current_agent_step_info.get('tool_equivalent_tokens', 0) + tool_equivalent_tokens
print(f" [Callback] Tool Sub-call Ended: Output: {str(output)[:50]}..., Assumed Cost: ${tool_cost:.4f}")
# Agent的节点函数
import time
def classify_query_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
"""使用LLM分类用户查询。"""
print("n--- Node: classify_query ---")
callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]
messages = [HumanMessage(content=f"请将以下用户查询分类为 'product_info', 'technical_support', 'order_query', 'greeting' 或 'unresolved'。n查询: {state['messages'][-1].content}n分类:")]
response = llm_cs.invoke(messages, config={"callbacks": [callback_handler]})
query_type = response.content.strip().lower()
print(f"Query classified as: {query_type}")
return {"query_type": query_type, "messages": state['messages'] + [AIMessage(content=f"分类结果: {query_type}")]}
def search_knowledge_base_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
"""调用知识库搜索工具。"""
print("n--- Node: search_knowledge_base ---")
callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]
user_query_content = state['messages'][0].content # 假设原始用户查询在第一个消息中
search_query = f"{state['query_type']}相关信息: {user_query_content}"
tool_output = knowledge_base_search.invoke(search_query, config={"callbacks": [callback_handler]})
print(f"Knowledge base search results: {tool_output[:100]}...")
return {"search_results": tool_output, "messages": state['messages'] + [ToolMessage(content=tool_output, tool_name="knowledge_base_search")]}
def generate_response_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
"""根据分类和搜索结果生成最终回复。"""
print("n--- Node: generate_response ---")
callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]
user_query = state['messages'][0].content
query_type = state['query_type']
search_results = state['search_results']
prompt_template = f"""
你是一个友好的客户支持Agent。
用户查询: {user_query}
查询类型: {query_type}
{f"知识库搜索结果: {search_results}" if search_results else ""}
请根据以上信息,生成一个简洁、专业的回复。如果搜索结果不相关或不足以解决问题,请告知用户。
"""
messages = [HumanMessage(content=prompt_template)]
response = llm_cs.invoke(messages, config={"callbacks": [callback_handler]})
print(f"Generated response: {response.content[:100]}...")
return {"messages": state['messages'] + [response]}
def escalate_human_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
"""将问题升级给人工客服。"""
print("n--- Node: escalate_human ---")
# 这个节点通常不涉及LLM调用或工具,但可能有一个固定的成本
callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]
# 模拟一个非常小的成本,因为LLM没有直接参与
callback_handler.current_agent_step_info['cost'] = callback_handler.current_agent_step_info.get('cost', 0.0) + 0.001
callback_handler.current_agent_step_info['tool_equivalent_tokens'] = callback_handler.current_agent_step_info.get('tool_equivalent_tokens', 0) + 5
escalation_message = "您好,您的问题已记录,我们将尽快安排人工客服与您联系。请耐心等待。"
print(escalation_message)
return {"escalated": True, "messages": state['messages'] + [AIMessage(content=escalation_message)]}
# 定义路由函数
def route_next_step(state: CustomerSupportAgentState) -> str:
"""根据查询类型和搜索结果决定下一步。"""
print("n--- Routing: route_next_step ---")
query_type = state['query_type']
search_results = state['search_results']
if query_type == "greeting":
return "generate_response" # 问候语直接回复
elif query_type in ["product_info", "technical_support", "order_query"]:
if not search_results: # 第一次进入,需要搜索
return "search_knowledge_base"
elif "未找到相关信息" in search_results or "无法解决" in search_results: # 搜索无果
return "escalate_human"
else: # 搜索到结果,可以生成回复
return "generate_response"
elif query_type == "unresolved":
return "escalate_human" # 无法分类的问题直接升级
else:
return "escalate_human" # 默认升级
# 构建客户支持Agent工作流
cs_workflow = StateGraph(CustomerSupportAgentState)
# 添加节点
cs_workflow.add_node("classify_query", classify_query_node)
cs_workflow.add_node("search_knowledge_base", search_knowledge_base_node)
cs_workflow.add_node("generate_response", generate_response_node)
cs_workflow.add_node("escalate_human", escalate_human_node)
# 设置入口点
cs_workflow.set_entry_point("classify_query")
# 添加边
cs_workflow.add_conditional_edges(
"classify_query",
route_next_step, # 从分类节点开始,根据分类结果路由
{
"search_knowledge_base": "search_knowledge_base",
"generate_response": "generate_response",
"escalate_human": "escalate_human"
}
)
cs_workflow.add_conditional_edges(
"search_knowledge_base",
route_next_step, # 搜索完知识库后,根据结果路由
{
"generate_response": "generate_response",
"escalate_human": "escalate_human"
}
)
# 最终回复或升级后结束
cs_workflow.add_edge("generate_response", END)
cs_workflow.add_edge("escalate_human", END)
# 编译工作流
cs_app = cs_workflow.compile()
成本追踪器集成和数据收集
我们准备几个不同的查询来模拟Agent的运行,并收集每条路径的成本数据。
import time
sample_queries = [
"你好!", # greeting
"产品A有哪些功能?", # product_info -> search -> generate
"我的账户登录不了怎么办?", # technical_support -> search -> generate
"我想查询我的订单状态。", # order_query -> search -> escalate (假设订单查询工具缺失)
"我很生气,我要投诉!", # unresolved -> escalate
"请问退货政策是什么?" # product_info -> search -> generate
]
# 用于存储所有运行日志的列表
all_agent_run_logs = []
for i, query_text in enumerate(sample_queries):
print(f"n==================== Running Agent for Query {i+1}: '{query_text}' ====================")
callback_handler = CS_TokenCostCallbackHandler()
initial_state = CustomerSupportAgentState(
messages=[HumanMessage(content=query_text)],
query_type="",
search_results="",
escalated=False,
current_node="",
path_log=[]
)
# 每次运行都是一个独立的路径
current_path_nodes = []
# LangGraph的stream方法会返回每个节点执行后的状态
for state_update in cs_app.stream(initial_state, config={"callbacks": [callback_handler]}):
# 提取当前节点名称
node_name = list(state_update.keys())[0] if state_update else "END"
# 更新最新状态
initial_state.update(state_update.get(node_name, {}))
if node_name != "END":
# 记录节点名称,用于构建Dijkstra图
current_path_nodes.append(node_name)
# 确保在运行结束后,将callback_handler中的node_costs_log追加到all_agent_run_logs
run_info = {
"query": query_text,
"final_state": initial_state,
"total_usage": callback_handler.get_total_usage()
}
all_agent_run_logs.append(run_info)
print(f"n--- Query {i+1} Summary ---")
print(f"Path: {' -> '.join(current_path_nodes)}")
print(f"Total Cost: ${run_info['total_usage']['total_cost']:.6f}")
print(f"Final Response: {run_info['final_state']['messages'][-1].content}")
构建加权图
现在,我们从all_agent_run_logs中提取信息,构建一个适用于Dijkstra算法的加权图。
图的节点将是LangGraph的各个节点名称(例如classify_query, search_knowledge_base等)。边上的权重将是执行目标节点所产生的成本。
from collections import defaultdict
def build_weighted_graph_from_logs(agent_run_logs: List[Dict[str, Any]]) -> Dict[str, List[tuple[str, float]]]:
"""
从Agent运行日志中构建一个经验性加权图。
Args:
agent_run_logs: 包含Agent每次运行详细信息的日志列表。
Returns:
一个表示加权图的字典。键是节点名称,值是一个列表,
列表中每个元素是一个元组 (neighbor_node, average_cost)。
"""
# 存储从一个节点到另一个节点的总成本和计数,用于计算平均成本
transition_costs_sum = defaultdict(lambda: defaultdict(float))
transition_counts = defaultdict(lambda: defaultdict(int))
# 提取所有唯一的节点,确保所有可能的节点都被包含在图中
all_nodes_in_graph = set()
for run_info in agent_run_logs:
# path_log包含了按顺序执行的节点及其成本
node_costs_log = run_info['total_usage']['node_costs_log']
for i in range(len(node_costs_log)):
current_node_info = node_costs_log[i]
current_node_name = current_node_info['node']
cost_of_current_node = current_node_info.get('cost', 0.0)
all_nodes_in_graph.add(current_node_name)
if i < len(node_costs_log) - 1:
next_node_info = node_costs_log[i+1]
next_node_name = next_node_info['node']
# 从 current_node_name 转移到 next_node_name 的成本,我们定义为 next_node_name 的执行成本
# 这种定义方式更符合LangGraph中节点执行的逻辑
transition_costs_sum[current_node_name][next_node_name] += cost_of_current_node
transition_counts[current_node_name][next_node_name] += 1
else:
# 最后一个节点到END的成本就是这个节点自身的成本
# 我们需要添加一个虚拟的'END'节点
all_nodes_in_graph.add("END")
transition_costs_sum[current_node_name]["END"] += cost_of_current_node
transition_counts[current_node_name]["END"] += 1
weighted_graph = defaultdict(list)
for source_node, targets in transition_costs_sum.items():
for target_node, total_cost in targets.items():
count = transition_counts[source_node][target_node]
average_cost = total_cost / count if count > 0 else float('inf')
weighted_graph[source_node].append((target_node, average_cost))
# 确保所有节点都在图中,即使它们没有出边
for node in all_nodes_in_graph:
if node not in weighted_graph:
weighted_graph[node] = []
return dict(weighted_graph)
# 构建加权图
empirical_graph = build_weighted_graph_from_logs(all_agent_run_logs)
print("n=== Empirical Weighted Graph ===")
for node, edges in empirical_graph.items():
print(f"Node '{node}':")
for neighbor, weight in edges:
print(f" -> '{neighbor}' (Cost: ${weight:.6f})")
Dijkstra算法实现与“黄金路径”发现
现在,我们有了加权图,可以运行Dijkstra算法来找到特定查询类型的“黄金路径”。
我们定义一个目标:对于任何一个用户查询,最终都应该到达generate_response节点(如果可以解决)或者escalate_human节点(如果需要人工介入)。所以,我们的目标节点可以是generate_response或escalate_human。为了简化,我们可以找从classify_query到END的最短路径。
# 假设我们想要找到从 'classify_query' 到 'END' 的最短路径
start_node = "classify_query"
end_node = "END" # 目标是Agent结束
shortest_cost, shortest_path = dijkstra_shortest_path(empirical_graph, start_node, end_node)
print(f"n=== Dijkstra Shortest Path Analysis ===")
if shortest_cost != float('inf'):
print(f"Shortest path from '{start_node}' to '{end_node}':")
print(f" Path: {' -> '.join(shortest_path)}")
print(f" Total Estimated Cost: ${shortest_cost:.6f}")
else:
print(f"No path found from '{start_node}' to '{end_node}'.")
# 我们可以为特定场景分析路径,例如“greeting”的路径
# 理论上,对于"greeting",路径应该是 classify_query -> generate_response -> END
# 让我们看看我们的经验图是否反映了这一点
# 假设我们从classify_query开始,并且它直接路由到generate_response
# 由于我们的图是基于实际执行的平均成本构建的,所以它会反映常见的、低成本的路径
# 进一步分析:哪些查询导致了哪些路径?
# 遍历原始日志,找出每种查询类型的典型路径和成本
query_path_analysis = defaultdict(list)
for run_info in all_agent_run_logs:
query = run_info['query']
node_sequence = [entry['node'] for entry in run_info['total_usage']['node_costs_log']]
total_cost = run_info['total_usage']['total_cost']
query_path_analysis[query].append({"path": " -> ".join(node_sequence) + " -> END", "cost": total_cost})
print("n=== Query Type Path Analysis ===")
for query, paths in query_path_analysis.items():
print(f"Query: '{query}'")
for p in paths:
print(f" Path: {p['path']}, Cost: ${p['cost']:.6f}")
# 假设我们识别出“greeting”查询的理想路径是 `classify_query -> generate_response -> END`
# 并且其成本是最低的。这可以作为我们的“黄金路径”目标。
优化策略应用
找到了“黄金路径”之后,我们如何利用它来优化Agent?
-
Prompt Engineering(提示词工程):
- 引导分类:如果发现
classify_query节点在某些情况下分类不准确,导致走上了错误的、高成本的路径,我们可以优化分类LLM的提示词。例如,为LLM提供更清晰的分类标准和更多示例。 - 引导回复:如果
generate_response节点在某些情况下生成了不必要的冗长回复(高Token),可以明确要求LLM“用简洁的语言回复”或“限定在50字以内”。
# 优化后的分类提示词示例 optimized_classify_prompt = """ 你是一个专业的客户支持查询分类器。请将用户查询严格分类为以下之一: 'product_info' (关于产品功能、价格、规格等) 'technical_support' (关于登录、错误、故障排除等技术问题) 'order_query' (关于订单状态、修改、配送等) 'greeting' (简单的问候,如“你好”、“谢谢”) 'unresolved' (无法明确分类或超出上述范围的复杂问题) 请直接输出分类结果,不要有任何额外文字或解释。 用户查询: {query} 分类: """ # 在 `classify_query_node` 中使用这个优化后的提示词 # messages = [HumanMessage(content=optimized_classify_prompt.format(query=state['messages'][-1].content))] - 引导分类:如果发现
-
条件逻辑微调:
- 路由优化:如果Dijkstra算法揭示了某个
route_next_step函数经常将Agent导向高成本路径,我们可以调整其逻辑。例如,对于order_query,如果我们的knowledge_base_search工具无法处理,但我们知道有一个外部的订单查询API,我们可以优先调用那个API而不是直接升级人工。 - 提前终止:如果在某个节点发现问题无法解决,且后续路径成本很高,可以考虑提前终止并直接升级人工,而不是进行无效的搜索。
# 优化后的路由逻辑示例:优先尝试专门的订单查询API(假设存在) def optimized_route_next_step(state: CustomerSupportAgentState) -> str: query_type = state['query_type'] search_results = state['search_results'] if query_type == "greeting": return "generate_response" elif query_type == "order_query" and not search_results: # 假设我们有一个专门的订单查询API,比知识库搜索更直接 # return "call_order_api" # 需要添加新的节点和边 # 如果没有,就走现有路径 return "search_knowledge_base" elif query_type in ["product_info", "technical_support"]: if not search_results: return "search_knowledge_base" elif "未找到相关信息" in search_results or "无法解决" in search_results: return "escalate_human" else: return "generate_response" elif query_type == "unresolved": return "escalate_human" else: return "escalate_human" - 路由优化:如果Dijkstra算法揭示了某个
-
工具选择与优先级:
- 如果存在多个工具可以完成类似任务(例如,一个通用搜索工具和一个专门的知识库搜索工具),但它们的成本或效果不同,Dijkstra路径可以帮助我们决定在特定情况下优先使用哪个工具。
- 如果发现某个工具成本很高但经常被调用,可以考虑优化该工具的内部逻辑,或者寻找更便宜的替代方案。
-
预计算与缓存:
- 对于一些非常常见且结果稳定的查询,如果其“黄金路径”包含昂贵的LLM调用或工具搜索,我们可以考虑将这些结果预计算并缓存起来。当Agent遇到相同或相似的查询时,直接返回缓存结果,跳过整个推理路径。
通过这些方法,我们可以将数据驱动的路径分析与LangGraph的灵活性相结合,持续迭代和优化Agent的“思考”过程,使其在满足功能需求的同时,实现更高的效率和更低的成本。
6. 高级思考与未来展望
我们今天探讨的方法是Agent思考优化领域的一个起点。未来,这个领域还有巨大的发展空间:
- 多目标优化:除了Token成本,我们还可以同时优化延迟、准确性、用户满意度等多个目标。这需要更复杂的优化算法,例如多目标A*搜索。
- 动态权重与环境适应:Agent的成本可能不是静态的。例如,在高峰期,外部API的延迟可能更高。Agent需要能够根据实时环境动态调整其路径选择的权重。
- 强化学习在路径优化中的潜力:我们可以将Agent的决策过程视为一个马尔可夫决策过程(MDP),通过强化学习让Agent在与环境交互中自主学习和发现最优策略,以最大化奖励(例如:低成本、高准确性)并最小化惩罚。
- 图剪枝与启发式搜索的结合:对于非常复杂的Agent,其潜在的推理路径图可能极其庞大。我们需要更智能的方法来剪枝不必要的路径,或者设计更有效的启发式函数来加速A*搜索。
- 鲁棒性与失败处理:即使找到了“黄金路径”,也需要考虑路径中的某个节点失败(例如,API调用失败、LLM生成无效输出)的情况。Agent需要有回退机制和错误处理策略,确保即使最优路径受阻,也能找到次优或安全的替代路径。
思考的艺术与工程的结合
通过将Agent的“思考”过程从模糊的黑盒转化为可量化、可优化的工程问题,我们为构建更高效、更经济的智能系统奠定了基础。这不仅仅是降低成本,更是提升Agent响应速度和稳定性的关键。当我们能够精确地理解Agent的每一步决策所带来的代价时,我们就能更有针对性地进行优化,从而释放Agent在现实世界中的更大潜力。
谢谢大家!