探讨 ‘Optimization of Thought’:利用 LangGraph 寻找 Agent 推理路径中最短、最省 Token 的‘黄金路径’

各位来宾,各位同事,大家好!

今天,我们齐聚一堂,探讨一个在当前AI时代极具前瞻性和实践意义的话题——“思考的优化”(Optimization of Thought)。随着大型语言模型(LLM)驱动的Agent日益普及,它们在执行复杂任务时展现出的强大能力令人惊叹。然而,这种能力并非没有代价。Agent的每一次“思考”、每一次工具调用、每一次与LLM的交互,都伴随着计算资源的消耗、API调用的延迟,以及最直观的——Token的开销。

在Agent的世界里,一次推理过程可能涉及多个步骤、多条路径。它像是在一个迷宫中寻找出路,有些路宽敞平坦,直达目标;有些路则蜿蜒曲折,耗时耗力。我们的目标,就是利用工程化的手段,找到Agent推理路径中的“黄金路径”——那条最短、最省Token,同时又能高效达成目标的路径。

而LangGraph,作为LangChain家族中的一员,为我们构建这种复杂、有状态的Agent提供了强大的框架。它将Agent的行为建模为状态机,让复杂的决策流变得可管理、可观测。今天,我将向大家展示如何将LangGraph与经典的图算法结合,系统性地实现Agent思考路径的优化。

1. Agent推理的本质:一个动态图

要理解如何优化Agent的思考,我们首先需要将其“思考过程”抽象化。想象一个Agent从接收到任务开始,到最终给出答案的整个过程。这个过程可以被分解为一系列离散的步骤:分析用户输入、调用工具搜索信息、生成中间思考、决定下一步行动等等。

这些步骤,我们可以将其视为图论中的“节点”(Nodes)。而Agent从一个步骤转移到另一个步骤的决策或执行,则构成了图中的“边”(Edges)。由于Agent的决策往往是动态的,依赖于LLM的输出或工具调用的结果,因此这个图是一个动态的、可能包含条件分支和循环的图。

LangGraph正是这样一种将Agent行为建模为状态机的强大工具。它允许我们定义:

  • 状态(State):Agent在某一时刻的所有上下文信息,例如用户输入、历史对话、工具返回结果等。
  • 节点(Nodes):执行特定操作的单元,可以是LLM调用、工具调用、或者自定义的Python函数。
  • 边(Edges):定义了Agent如何从一个节点转移到另一个节点。可以是固定的,也可以是基于当前状态的条件判断。

让我们通过一个简单的LangGraph Agent示例来直观感受一下:

# 确保安装了必要的库
# pip install langchain langchain_community langgraph
import os
from typing import TypedDict, Annotated, List, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END

# 假设已经设置了OPENAI_API_KEY环境变量
# os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"

# 定义Agent的状态
class AgentState(TypedDict):
    """
    Agent的状态定义。
    messages: 存储对话历史和中间思考。
    tool_input: 存储工具调用的输入。
    """
    messages: Annotated[List[BaseMessage], lambda x: x]
    tool_input: str
    intermediate_steps: Annotated[List[AgentAction], lambda x: x] # 存储AgentAction用于跟踪

# 定义一个简单的工具
@tool
def search_web(query: str) -> str:
    """在网络上搜索信息。"""
    print(f"n--- Calling search_web with query: {query} ---")
    # 模拟网络搜索结果
    if "Python LangGraph" in query:
        return "LangGraph是一个用于构建健壮、有状态的Agent的库,它将Agent的决策流建模为状态机。"
    elif "北京天气" in query:
        return "北京今天晴,气温20-30摄氏度。"
    else:
        return "未找到相关信息。"

# 初始化LLM
llm = ChatOpenAI(model="gpt-4o", temperature=0)

# 绑定工具到LLM
tools = [search_web]
llm_with_tools = llm.bind_tools(tools)

# 定义Agent的节点函数
def call_llm(state: AgentState):
    """调用LLM进行思考或生成最终回复。"""
    print(f"n--- Calling LLM for thought/response ---")
    messages = state['messages']
    response = llm_with_tools.invoke(messages)
    return {"messages": [response]}

def call_tool(state: AgentState):
    """执行AgentAction中指定的工具调用。"""
    print(f"n--- Calling Tool ---")
    last_message = state['messages'][-1]
    if isinstance(last_message, AgentAction):
        action = last_message
        tool_name = action.tool
        tool_input = action.tool_input
        print(f"Executing tool: {tool_name} with input: {tool_input}")

        # 查找并执行工具
        for t in tools:
            if t.name == tool_name:
                tool_output = t.invoke(tool_input)
                # 将工具输出添加回消息历史
                return {"messages": [tool_output]}
        raise ValueError(f"Tool {tool_name} not found.")
    else:
        raise ValueError("Last message is not an AgentAction, cannot call tool.")

# 定义条件路由函数
def should_continue(state: AgentState) -> str:
    """根据LLM的输出决定下一步是调用工具还是结束。"""
    print(f"n--- Deciding next step ---")
    last_message = state['messages'][-1]
    if isinstance(last_message, AgentAction):
        # 如果LLM决定调用工具,则下一步是执行工具
        print("LLM decided to call a tool.")
        return "call_tool"
    else:
        # 如果LLM生成了最终回复,则结束Agent
        print("LLM generated a final response.")
        return "end"

# 构建LangGraph工作流
workflow = StateGraph(AgentState)

# 添加节点
workflow.add_node("llm", call_llm)
workflow.add_node("call_tool", call_tool)

# 设置入口点
workflow.set_entry_point("llm")

# 添加边
workflow.add_conditional_edges(
    "llm",        # 从LLM节点出发
    should_continue, # 根据should_continue函数的返回值决定走向
    {
        "call_tool": "call_tool", # 如果返回"call_tool",则走向call_tool节点
        "end": END               # 如果返回"end",则结束
    }
)
workflow.add_edge("call_tool", "llm") # 工具执行完后,重新回到LLM节点进行下一步思考或总结

# 编译工作流
app = workflow.compile()

# 运行一个示例
# print("--- Running Agent for '北京天气' ---")
# initial_state = {"messages": [("user", "告诉我北京今天的天气如何?")], "tool_input": "", "intermediate_steps": []}
# for s in app.stream(initial_state):
#     print(s)

# print("n--- Running Agent for '什么是LangGraph?' ---")
# initial_state = {"messages": [("user", "什么是LangGraph?")], "tool_input": "", "intermediate_steps": []}
# for s in app.stream(initial_state):
#     print(s)

在这个示例中,我们定义了两个核心节点:llm(负责思考和生成回复)和call_tool(负责执行工具)。should_continue函数则是一个条件路由器,它根据LLM的输出决定下一步是继续调用工具还是结束整个流程。这是一个非常经典的ReAct(Reasoning and Acting)模式的简化版。

Agent在执行过程中,会从llm节点开始,如果需要工具,则转到call_tool,工具执行完毕后再回到llm节点,直到最终生成一个不需要工具的回复,从而结束整个推理过程。这个过程就是Agent在图中动态遍历的过程。

2. 定义“成本”:如何量化思考的代价?

要优化Agent的思考路径,我们首先需要量化“思考的代价”。在Agent的实际运行中,最直接、最易于量化的成本指标是:

  1. Token数量:这是与LLM交互最主要的成本来源。包括发送给LLM的输入Token和从LLM接收到的输出Token。不同的模型、不同的API提供商,Token的价格各不相同,但Token数量始终是计算成本的基础。
  2. 延迟(Latency):每次LLM调用或工具执行都需要时间。优化延迟可以显著提升用户体验。
  3. API调用次数:一些API可能有调用频率限制或按调用次数计费。
  4. 计算资源:对于本地部署的LLM或复杂工具,计算资源的消耗也是成本的一部分。

本次讲座,我们将主要聚焦于Token数量作为我们的核心优化指标,因为它是最普遍且直接影响成本的因素。“最短”路径,在这里特指“Token消耗最少”的路径。

LangChain提供了一个强大的回调系统(Callbacks),允许我们监控Agent的执行过程,包括LLM调用、工具调用等。我们可以利用这个系统来精确地追踪Token使用情况。

自定义Token计数器

为了捕获每个LLM调用产生的Token数,我们可以创建一个自定义的回调处理器。

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Dict, Any

class TokenCostCallbackHandler(BaseCallbackHandler):
    """
    一个自定义回调处理器,用于跟踪LLM调用的Token使用情况。
    """
    def __init__(self):
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_cost = 0.0 # 假设一个简单的成本模型
        self.llm_call_count = 0
        self.current_node_tokens = {} # 存储当前节点(或操作)的token
        self.node_costs_log = [] # 记录每次节点执行的成本

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """在LLM调用结束时记录Token使用。"""
        if response.llm_output is not None:
            token_usage = response.llm_output.get("token_usage")
            if token_usage:
                input_tokens = token_usage.get("prompt_tokens", 0)
                output_tokens = token_usage.get("completion_tokens", 0)

                self.total_input_tokens += input_tokens
                self.total_output_tokens += output_tokens
                self.llm_call_count += 1

                # 假设GPT-4o的输入和输出Token成本
                # 实际应用中应查询最新的API价格
                # https://openai.com/pricing
                input_cost_per_token = 5.00 / 1_000_000 # $5 / 1M tokens
                output_cost_per_token = 15.00 / 1_000_000 # $15 / 1M tokens

                current_call_cost = (input_tokens * input_cost_per_token) + 
                                    (output_tokens * output_cost_per_token)
                self.total_cost += current_call_cost

                # 记录到当前节点的token信息中
                self.current_node_tokens['input_tokens'] = self.current_node_tokens.get('input_tokens', 0) + input_tokens
                self.current_node_tokens['output_tokens'] = self.current_node_tokens.get('output_tokens', 0) + output_tokens
                self.current_node_tokens['cost'] = self.current_node_tokens.get('cost', 0.0) + current_call_cost
                print(f"  [Callback] LLM Call Ended: Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, Cost: ${current_call_cost:.6f}")

    def on_tool_end(self, output: Any, **kwargs: Any) -> None:
        """在工具调用结束时,我们可以模拟一个固定成本或基于复杂度的成本。"""
        # 工具调用本身可能不直接消耗Token,但可能消耗时间或外部服务费用
        # 这里我们假设工具调用有一个固定的“Token等效成本”或者直接成本
        tool_cost = 0.01 # 假设每次工具调用固定成本0.01美元
        tool_equivalent_tokens = 20 # 假设每次工具调用相当于20个Token的成本

        self.total_cost += tool_cost
        self.current_node_tokens['cost'] = self.current_node_tokens.get('cost', 0.0) + tool_cost
        self.current_node_tokens['tool_equivalent_tokens'] = self.current_node_tokens.get('tool_equivalent_tokens', 0) + tool_equivalent_tokens
        print(f"  [Callback] Tool Call Ended: Output: {output[:50]}..., Assumed Cost: ${tool_cost:.4f}")

    def reset_current_node_tokens(self, node_name: str):
        """重置并记录当前节点的Token信息。"""
        if self.current_node_tokens:
            self.node_costs_log.append({
                "node": node_name,
                "input_tokens": self.current_node_tokens.get('input_tokens', 0),
                "output_tokens": self.current_node_tokens.get('output_tokens', 0),
                "tool_equivalent_tokens": self.current_node_tokens.get('tool_equivalent_tokens', 0),
                "cost": self.current_node_tokens.get('cost', 0.0)
            })
        self.current_node_tokens = {}

    def get_total_usage(self) -> Dict[str, Any]:
        return {
            "total_input_tokens": self.total_input_tokens,
            "total_output_tokens": self.total_output_tokens,
            "total_llm_calls": self.llm_call_count,
            "total_cost": self.total_cost,
            "node_costs_log": self.node_costs_log
        }

# 为了让AgentState能够记录路径和成本,我们对其进行扩展
class OptimizedAgentState(AgentState):
    current_node: str = "" # 记录当前执行的节点名称
    path_costs: Annotated[List[Dict[str, Any]], lambda x: x] # 记录每一步的成本日志

# 修改LLM和Tool节点以使用回调并更新状态
def call_llm_optimized(state: OptimizedAgentState, config: Dict[str, Any]):
    """调用LLM,并更新状态中的当前节点和成本信息。"""
    current_node_name = "llm"
    state['current_node'] = current_node_name

    callback_handler: TokenCostCallbackHandler = config.get("callbacks")[0]
    callback_handler.reset_current_node_tokens(current_node_name) # 记录上一个节点的成本并准备记录当前节点的

    print(f"n--- Calling LLM for thought/response (Node: {current_node_name}) ---")
    messages = state['messages']
    response = llm_with_tools.invoke(messages, config={"callbacks": [callback_handler]})

    # 强制记录当前LLM节点的成本
    callback_handler.reset_current_node_tokens(current_node_name) 

    # 更新Agent状态
    new_messages = state['messages'] + [response]
    return {"messages": new_messages, "current_node": current_node_name}

def call_tool_optimized(state: OptimizedAgentState, config: Dict[str, Any]):
    """执行工具调用,并更新状态中的当前节点和成本信息。"""
    current_node_name = "call_tool"
    state['current_node'] = current_node_name

    callback_handler: TokenCostCallbackHandler = config.get("callbacks")[0]
    callback_handler.reset_current_node_tokens(current_node_name) # 记录上一个节点的成本并准备记录当前节点的

    print(f"n--- Calling Tool (Node: {current_node_name}) ---")
    last_message = state['messages'][-1]
    if isinstance(last_message, AgentAction):
        action = last_message
        tool_name = action.tool
        tool_input = action.tool_input
        print(f"Executing tool: {tool_name} with input: {tool_input}")

        for t in tools:
            if t.name == tool_name:
                tool_output = t.invoke(tool_input, config={"callbacks": [callback_handler]})
                # 强制记录当前Tool节点的成本
                callback_handler.reset_current_node_tokens(current_node_name)

                # 将工具输出添加回消息历史
                new_messages = state['messages'] + [tool_output]
                return {"messages": new_messages, "current_node": current_node_name}
        raise ValueError(f"Tool {tool_name} not found.")
    else:
        raise ValueError("Last message is not an AgentAction, cannot call tool.")

# 重新定义should_continue,主要为了打印日志
def should_continue_optimized(state: OptimizedAgentState) -> str:
    """根据LLM的输出决定下一步是调用工具还是结束。"""
    current_node_name = "should_continue" # 路由节点本身没有成本,但我们可以在这里记录上一个节点的成本
    # 路由节点本身不产生直接的LLM或工具成本,但它是状态转移的关键点
    # 它的成本可以忽略,或者在它前面的节点中累加
    print(f"n--- Deciding next step (Node: {current_node_name}) ---")
    last_message = state['messages'][-1]
    if isinstance(last_message, AgentAction):
        print("LLM decided to call a tool.")
        return "call_tool"
    else:
        print("LLM generated a final response.")
        return "end"

# 重新构建工作流以使用优化后的节点
optimized_workflow = StateGraph(OptimizedAgentState)
optimized_workflow.add_node("llm", call_llm_optimized)
optimized_workflow.add_node("call_tool", call_tool_optimized)
optimized_workflow.set_entry_point("llm")
optimized_workflow.add_conditional_edges(
    "llm",
    should_continue_optimized,
    {
        "call_tool": "call_tool",
        "end": END
    }
)
optimized_workflow.add_edge("call_tool", "llm")
optimized_app = optimized_workflow.compile()

# 在运行Agent时传入回调处理器
# callback_handler = TokenCostCallbackHandler()
# initial_state_optimized = {"messages": [("user", "什么是LangGraph?")], "tool_input": "", "intermediate_steps": [], "current_node": "", "path_costs": []}
# for s in optimized_app.stream(initial_state_optimized, config={"callbacks": [callback_handler]}):
#     pass # 只是运行,不打印中间状态,因为回调函数已打印

# total_usage = callback_handler.get_total_usage()
# print("n=== Total Usage for LangGraph Query ===")
# print(f"Total Input Tokens: {total_usage['total_input_tokens']}")
# print(f"Total Output Tokens: {total_usage['total_output_tokens']}")
# print(f"Total LLM Calls: {total_usage['total_llm_calls']}")
# print(f"Estimated Total Cost: ${total_usage['total_cost']:.6f}")
# print("nNode Costs Log:")
# for entry in total_usage['node_costs_log']:
#     print(f"  Node: {entry['node']}, Input: {entry['input_tokens']}, Output: {entry['output_tokens']}, Tool Eq. Tokens: {entry['tool_equivalent_tokens']}, Cost: ${entry['cost']:.6f}")

TokenCostCallbackHandler中,我们不仅记录了LLM的Token使用,还为工具调用设置了一个假设的“等效Token”成本或直接成本。在实际应用中,工具的成本可能需要根据其具体的API价格或执行时间来精确计算。关键在于,我们现在有了一个量化每个操作(LLM调用或工具调用)成本的机制。

3. 图算法的武器库:寻找最短路径

有了量化成本的能力,我们就可以将Agent的推理过程视为一个加权图,其中边上的权重就是执行该步骤所产生的成本。我们的目标,就是在这样一个加权图中找到从起点到终点的最短路径。

对于这类问题,图算法提供了强大的解决方案:

  • 广度优先搜索 (BFS):用于寻找无权图中的最短路径(即最少跳数)。它不考虑边的权重,因此不适用于我们的加权成本场景。
  • 深度优先搜索 (DFS):倾向于深入探索一条路径,不保证找到最短路径。
  • Dijkstra (迪杰斯特拉) 算法:这是解决单源最短路径问题的经典算法,适用于边的权重非负的加权图。它能够找到从指定起点到所有其他节点的最短路径。这正是我们需要的!
  • A* 搜索算法:Dijkstra算法的扩展,引入了启发式函数来指导搜索方向,可以更快地找到目标节点,尤其是在大型图中。如果能设计一个好的启发式函数,A*会更高效。

对于Agent的推理路径优化,Dijkstra算法是一个非常合适的选择。

Dijkstra算法核心思想

Dijkstra算法通过维护一个所有节点到起点的最短距离估计值集合,并不断更新这些估计值,直到找到真正的最短路径。它使用一个优先队列来高效地选择下一个要处理的节点。

步骤序号 描述
1. 初始化:设置起点到自身的距离为0,到所有其他节点的距离为无穷大。创建一个优先队列,将起点加入队列,优先级为0(距离)。同时,维护一个集合来记录已经访问过的节点。
2. 循环:只要优先队列不为空:
a. 取出最小距离节点:从优先队列中取出距离最小的节点 u。如果 u 已经被访问过,则跳过。
b. 标记访问:将 u 标记为已访问。
c. 更新邻居距离:遍历 u 的所有未访问邻居 v。计算从起点经过 uv 的距离:dist[u] + weight(u, v)。如果这个距离小于 dist[v](当前记录的起点到 v 的最短距离),则更新 dist[v],并将 v 及新距离加入优先队列。同时,记录 uv 的前驱节点,以便重构路径。
3. 结束:当优先队列为空,或者目标节点被取出并标记为已访问时,算法结束。此时,dist 字典中存储的就是从起点到所有节点的最短距离,通过前驱节点可以重构出路径。

Python实现Dijkstra算法

import heapq

def dijkstra_shortest_path(graph: Dict[str, List[tuple[str, float]]], start_node: str, end_node: str) -> tuple[float, List[str]]:
    """
    使用Dijkstra算法查找从start_node到end_node的最短路径及其成本。

    Args:
        graph: 表示加权图的字典。键是节点名称(字符串),值是一个列表,
               列表中每个元素是一个元组 (neighbor_node, weight),表示一条边。
        start_node: 起始节点名称。
        end_node: 目标节点名称。

    Returns:
        一个元组 (shortest_distance, shortest_path_nodes)。
        如果无法到达end_node,则返回 (float('inf'), [])。
    """

    # distances: 存储从start_node到每个节点的最短距离
    # 初始化为无穷大,start_node到自身为0
    distances = {node: float('inf') for node in graph}
    distances[start_node] = 0

    # predecessors: 存储每个节点在最短路径中的前一个节点,用于重构路径
    predecessors = {node: None for node in graph}

    # priority_queue: 优先队列,存储 (distance, node) 元组
    # 按照distance从小到大排序
    priority_queue = [(0, start_node)] # (distance, node)

    while priority_queue:
        current_distance, current_node = heapq.heappop(priority_queue)

        # 如果当前距离已经大于记录的最短距离,说明我们找到了更短的路径,跳过
        if current_distance > distances[current_node]:
            continue

        # 如果我们已经到达了目标节点,可以提前结束
        if current_node == end_node:
            break

        # 遍历当前节点的所有邻居
        for neighbor, weight in graph.get(current_node, []):
            distance = current_distance + weight

            # 如果找到了更短的路径
            if distance < distances[neighbor]:
                distances[neighbor] = distance
                predecessors[neighbor] = current_node
                heapq.heappush(priority_queue, (distance, neighbor))

    # 重构最短路径
    path = []
    current = end_node
    while current is not None and current in predecessors:
        path.insert(0, current)
        current = predecessors[current]
        if current == start_node: # 添加起点
            path.insert(0, current)
            break

    # 检查路径是否完整,即是否从start_node开始
    if path and path[0] == start_node:
        return distances[end_node], path
    else:
        return float('inf'), [] # 无法到达目标节点

4. LangGraph与Dijkstra的结合:构建“黄金路径”搜索系统

现在我们有了LangGraph来构建Agent,有了Token计数器来量化成本,也有了Dijkstra算法来寻找最短路径。接下来,就是如何将它们有机地结合起来,构建一个“黄金路径”搜索系统。

挑战

  1. 动态图的表示:LangGraph的执行路径是动态的,依赖于LLM的实时决策。我们无法预先画出所有可能的路径。
  2. 状态依赖的成本:同一个节点,在不同的输入状态下,其执行成本(例如LLM调用生成的内容长度)可能不同。

策略

我们的策略是通过“模拟运行”来构建“经验性加权图”,然后在这个经验图上应用Dijkstra算法。

  1. Agent设计:首先,我们设计一个LangGraph Agent,使其能够完成既定任务,并包含多种可能的推理路径。
  2. 路径探索与成本记录:我们运行Agent多次,使用不同的输入或在关键决策点上模拟不同的LLM输出,以覆盖尽可能多的潜在路径。在每次运行中,我们利用TokenCostCallbackHandler精确记录每个节点(LLM调用或工具调用)的Token成本。
  3. 构建经验性加权图:从Agent的执行日志中,我们提取出节点序列和每一步的成本。我们将LangGraph的每个“节点”视为Dijkstra图中的一个“状态”,而从一个LangGraph节点到下一个LangGraph节点的“转移”以及执行下一个节点所产生的成本,视为Dijkstra图中的“边”及其“权重”。
    • 重要说明:Dijkstra图中的节点,实际上是LangGraph中的“状态点”或“处理单元”。例如,从llm节点到call_tool节点的转移,其成本是call_tool节点执行的成本。
  4. 执行Dijkstra算法:在构建好的经验加权图上,我们运行Dijkstra算法,找出从Agent的入口点到某个“目标状态”(例如,生成最终回复)的最短(最低成本)路径。
  5. 优化应用:根据Dijkstra算法发现的“黄金路径”,我们可以采取多种措施来优化Agent的行为。

5. 实战演练:一个客户支持Agent的优化之旅

让我们通过一个具体的客户支持Agent的例子来实践这个过程。

场景设定

假设我们有一个客户支持Agent,它的任务是:

  • 接收用户查询。
  • 分类查询:判断查询类型(产品信息、技术支持、订单查询、通用问候)。
  • 搜索知识库:如果需要,搜索内部知识库获取答案。
  • 生成回复:基于分类和搜索结果生成友好的回复。
  • 升级人工:如果无法解决或查询复杂,将问题升级给人工客服。

这个Agent的推理路径可能非常多样:

  • 路径A (高效)用户输入 -> 分类(通用问候) -> 生成回复(问候语)
  • 路径B (中等)用户输入 -> 分类(产品信息) -> 搜索知识库 -> 生成回复(产品信息)
  • 路径C (复杂)用户输入 -> 分类(技术支持) -> 搜索知识库 -> 无法解决 -> 升级人工

我们的目标是,对于常见的查询类型,找到并优化其最低成本的路径。

Agent架构(LangGraph代码)

首先,我们扩展OptimizedAgentState以更好地记录路径。

from langchain_core.messages import HumanMessage, AIMessage, ToolMessage

class CustomerSupportAgentState(TypedDict):
    """
    客户支持Agent的状态定义。
    messages: 存储对话历史和中间思考。
    query_type: 分类后的查询类型。
    search_results: 知识库搜索结果。
    escalated: 布尔值,表示是否已升级人工。

    # 用于路径追踪和成本分析
    current_node: str
    path_log: Annotated[List[Dict[str, Any]], lambda x: x] # 记录每次节点转移的详细信息
    """
    messages: Annotated[List[BaseMessage], lambda x: x]
    query_type: str
    search_results: str
    escalated: bool

    current_node: str
    path_log: Annotated[List[Dict[str, Any]], lambda x: x]

# 定义工具
@tool
def knowledge_base_search(query: str) -> str:
    """在内部知识库中搜索关于产品或技术支持的信息。"""
    print(f"n--- Calling knowledge_base_search with query: '{query}' ---")
    if "产品A功能" in query:
        return "产品A具有实时数据分析、报表生成和用户权限管理功能。"
    elif "登录问题" in query:
        return "登录问题通常可通过重置密码或检查网络连接解决。若仍无法解决,请联系技术支持。"
    elif "退货政策" in query:
        return "我们的退货政策允许在购买后30天内退货,商品需保持原状并附带购买凭证。"
    else:
        return "知识库中未找到关于“" + query + "”的相关信息。"

tools_cs = [knowledge_base_search]
llm_cs = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools_cs = llm_cs.bind_tools(tools_cs)

# 定义回调处理器
class CS_TokenCostCallbackHandler(TokenCostCallbackHandler):
    """为客户支持Agent定制的TokenCostCallbackHandler。"""
    def __init__(self):
        super().__init__()
        self.current_agent_step_info = {} # 临时存储当前LangGraph步骤的信息

    def on_chain_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
        """在LangGraph节点(链)开始时记录节点名称。"""
        node_name = serialized.get("name", "Unknown Node")
        self.current_agent_step_info = {"node": node_name, "start_time": time.time()}
        print(f"n[Callback] Starting Node: {node_name}")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        """在LangGraph节点(链)结束时记录成本并添加到path_log。"""
        end_time = time.time()
        node_name = self.current_agent_step_info.get("node", "Unknown Node")

        # 确保在on_llm_end或on_tool_end中记录的current_node_tokens已经累加了当前节点的所有成本
        # 这里我们收集并清空 current_node_tokens
        cost_details = self.current_node_tokens.copy()
        self.current_node_tokens = {} # 清空,准备下一个节点

        if node_name != "Unknown Node": # 避免记录外部Chain的成本
            self.node_costs_log.append({
                "node": node_name,
                "duration": end_time - self.current_agent_step_info.get("start_time", end_time),
                **cost_details # 合并LLM/Tool回调中累积的成本信息
            })
            print(f"[Callback] Ended Node: {node_name}, Cost: ${cost_details.get('cost', 0.0):.6f}")

    # 重写on_llm_end和on_tool_end,确保它们累加到self.current_node_tokens
    # 并在on_chain_end时统一处理
    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        if response.llm_output is not None:
            token_usage = response.llm_output.get("token_usage")
            if token_usage:
                input_tokens = token_usage.get("prompt_tokens", 0)
                output_tokens = token_usage.get("completion_tokens", 0)

                self.total_input_tokens += input_tokens
                self.total_output_tokens += output_tokens
                self.llm_call_count += 1

                input_cost_per_token = 5.00 / 1_000_000 # $5 / 1M tokens
                output_cost_per_token = 15.00 / 1_000_000 # $15 / 1M tokens

                current_call_cost = (input_tokens * input_cost_per_token) + 
                                    (output_tokens * output_cost_per_token)
                self.total_cost += current_call_cost

                self.current_agent_step_info['input_tokens'] = self.current_agent_step_info.get('input_tokens', 0) + input_tokens
                self.current_agent_step_info['output_tokens'] = self.current_agent_step_info.get('output_tokens', 0) + output_tokens
                self.current_agent_step_info['cost'] = self.current_agent_step_info.get('cost', 0.0) + current_call_cost
                print(f"  [Callback] LLM Sub-call Ended: Input Tokens: {input_tokens}, Output Tokens: {output_tokens}, Cost: ${current_call_cost:.6f}")

    def on_tool_end(self, output: Any, **kwargs: Any) -> None:
        tool_cost = 0.01 # 假设每次工具调用固定成本0.01美元
        tool_equivalent_tokens = 20

        self.total_cost += tool_cost
        self.current_agent_step_info['cost'] = self.current_agent_step_info.get('cost', 0.0) + tool_cost
        self.current_agent_step_info['tool_equivalent_tokens'] = self.current_agent_step_info.get('tool_equivalent_tokens', 0) + tool_equivalent_tokens
        print(f"  [Callback] Tool Sub-call Ended: Output: {str(output)[:50]}..., Assumed Cost: ${tool_cost:.4f}")

# Agent的节点函数
import time

def classify_query_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
    """使用LLM分类用户查询。"""
    print("n--- Node: classify_query ---")
    callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]

    messages = [HumanMessage(content=f"请将以下用户查询分类为 'product_info', 'technical_support', 'order_query', 'greeting' 或 'unresolved'。n查询: {state['messages'][-1].content}n分类:")]
    response = llm_cs.invoke(messages, config={"callbacks": [callback_handler]})
    query_type = response.content.strip().lower()

    print(f"Query classified as: {query_type}")
    return {"query_type": query_type, "messages": state['messages'] + [AIMessage(content=f"分类结果: {query_type}")]}

def search_knowledge_base_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
    """调用知识库搜索工具。"""
    print("n--- Node: search_knowledge_base ---")
    callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]

    user_query_content = state['messages'][0].content # 假设原始用户查询在第一个消息中
    search_query = f"{state['query_type']}相关信息: {user_query_content}"

    tool_output = knowledge_base_search.invoke(search_query, config={"callbacks": [callback_handler]})

    print(f"Knowledge base search results: {tool_output[:100]}...")
    return {"search_results": tool_output, "messages": state['messages'] + [ToolMessage(content=tool_output, tool_name="knowledge_base_search")]}

def generate_response_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
    """根据分类和搜索结果生成最终回复。"""
    print("n--- Node: generate_response ---")
    callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]

    user_query = state['messages'][0].content
    query_type = state['query_type']
    search_results = state['search_results']

    prompt_template = f"""
你是一个友好的客户支持Agent。
用户查询: {user_query}
查询类型: {query_type}
{f"知识库搜索结果: {search_results}" if search_results else ""}

请根据以上信息,生成一个简洁、专业的回复。如果搜索结果不相关或不足以解决问题,请告知用户。
"""
    messages = [HumanMessage(content=prompt_template)]
    response = llm_cs.invoke(messages, config={"callbacks": [callback_handler]})

    print(f"Generated response: {response.content[:100]}...")
    return {"messages": state['messages'] + [response]}

def escalate_human_node(state: CustomerSupportAgentState, config: Dict[str, Any]):
    """将问题升级给人工客服。"""
    print("n--- Node: escalate_human ---")
    # 这个节点通常不涉及LLM调用或工具,但可能有一个固定的成本
    callback_handler: CS_TokenCostCallbackHandler = config.get("callbacks")[0]
    # 模拟一个非常小的成本,因为LLM没有直接参与
    callback_handler.current_agent_step_info['cost'] = callback_handler.current_agent_step_info.get('cost', 0.0) + 0.001
    callback_handler.current_agent_step_info['tool_equivalent_tokens'] = callback_handler.current_agent_step_info.get('tool_equivalent_tokens', 0) + 5

    escalation_message = "您好,您的问题已记录,我们将尽快安排人工客服与您联系。请耐心等待。"
    print(escalation_message)
    return {"escalated": True, "messages": state['messages'] + [AIMessage(content=escalation_message)]}

# 定义路由函数
def route_next_step(state: CustomerSupportAgentState) -> str:
    """根据查询类型和搜索结果决定下一步。"""
    print("n--- Routing: route_next_step ---")
    query_type = state['query_type']
    search_results = state['search_results']

    if query_type == "greeting":
        return "generate_response" # 问候语直接回复
    elif query_type in ["product_info", "technical_support", "order_query"]:
        if not search_results: # 第一次进入,需要搜索
            return "search_knowledge_base"
        elif "未找到相关信息" in search_results or "无法解决" in search_results: # 搜索无果
            return "escalate_human"
        else: # 搜索到结果,可以生成回复
            return "generate_response"
    elif query_type == "unresolved":
        return "escalate_human" # 无法分类的问题直接升级
    else:
        return "escalate_human" # 默认升级

# 构建客户支持Agent工作流
cs_workflow = StateGraph(CustomerSupportAgentState)

# 添加节点
cs_workflow.add_node("classify_query", classify_query_node)
cs_workflow.add_node("search_knowledge_base", search_knowledge_base_node)
cs_workflow.add_node("generate_response", generate_response_node)
cs_workflow.add_node("escalate_human", escalate_human_node)

# 设置入口点
cs_workflow.set_entry_point("classify_query")

# 添加边
cs_workflow.add_conditional_edges(
    "classify_query",
    route_next_step, # 从分类节点开始,根据分类结果路由
    {
        "search_knowledge_base": "search_knowledge_base",
        "generate_response": "generate_response",
        "escalate_human": "escalate_human"
    }
)

cs_workflow.add_conditional_edges(
    "search_knowledge_base",
    route_next_step, # 搜索完知识库后,根据结果路由
    {
        "generate_response": "generate_response",
        "escalate_human": "escalate_human"
    }
)

# 最终回复或升级后结束
cs_workflow.add_edge("generate_response", END)
cs_workflow.add_edge("escalate_human", END)

# 编译工作流
cs_app = cs_workflow.compile()

成本追踪器集成和数据收集

我们准备几个不同的查询来模拟Agent的运行,并收集每条路径的成本数据。

import time

sample_queries = [
    "你好!", # greeting
    "产品A有哪些功能?", # product_info -> search -> generate
    "我的账户登录不了怎么办?", # technical_support -> search -> generate
    "我想查询我的订单状态。", # order_query -> search -> escalate (假设订单查询工具缺失)
    "我很生气,我要投诉!", # unresolved -> escalate
    "请问退货政策是什么?" # product_info -> search -> generate
]

# 用于存储所有运行日志的列表
all_agent_run_logs = []

for i, query_text in enumerate(sample_queries):
    print(f"n==================== Running Agent for Query {i+1}: '{query_text}' ====================")
    callback_handler = CS_TokenCostCallbackHandler()

    initial_state = CustomerSupportAgentState(
        messages=[HumanMessage(content=query_text)],
        query_type="",
        search_results="",
        escalated=False,
        current_node="",
        path_log=[]
    )

    # 每次运行都是一个独立的路径
    current_path_nodes = []

    # LangGraph的stream方法会返回每个节点执行后的状态
    for state_update in cs_app.stream(initial_state, config={"callbacks": [callback_handler]}):
        # 提取当前节点名称
        node_name = list(state_update.keys())[0] if state_update else "END"
        # 更新最新状态
        initial_state.update(state_update.get(node_name, {}))

        if node_name != "END":
            # 记录节点名称,用于构建Dijkstra图
            current_path_nodes.append(node_name)

    # 确保在运行结束后,将callback_handler中的node_costs_log追加到all_agent_run_logs
    run_info = {
        "query": query_text,
        "final_state": initial_state,
        "total_usage": callback_handler.get_total_usage()
    }
    all_agent_run_logs.append(run_info)

    print(f"n--- Query {i+1} Summary ---")
    print(f"Path: {' -> '.join(current_path_nodes)}")
    print(f"Total Cost: ${run_info['total_usage']['total_cost']:.6f}")
    print(f"Final Response: {run_info['final_state']['messages'][-1].content}")

构建加权图

现在,我们从all_agent_run_logs中提取信息,构建一个适用于Dijkstra算法的加权图。

图的节点将是LangGraph的各个节点名称(例如classify_query, search_knowledge_base等)。边上的权重将是执行目标节点所产生的成本。

from collections import defaultdict

def build_weighted_graph_from_logs(agent_run_logs: List[Dict[str, Any]]) -> Dict[str, List[tuple[str, float]]]:
    """
    从Agent运行日志中构建一个经验性加权图。

    Args:
        agent_run_logs: 包含Agent每次运行详细信息的日志列表。

    Returns:
        一个表示加权图的字典。键是节点名称,值是一个列表,
        列表中每个元素是一个元组 (neighbor_node, average_cost)。
    """

    # 存储从一个节点到另一个节点的总成本和计数,用于计算平均成本
    transition_costs_sum = defaultdict(lambda: defaultdict(float))
    transition_counts = defaultdict(lambda: defaultdict(int))

    # 提取所有唯一的节点,确保所有可能的节点都被包含在图中
    all_nodes_in_graph = set()

    for run_info in agent_run_logs:
        # path_log包含了按顺序执行的节点及其成本
        node_costs_log = run_info['total_usage']['node_costs_log']

        for i in range(len(node_costs_log)):
            current_node_info = node_costs_log[i]
            current_node_name = current_node_info['node']
            cost_of_current_node = current_node_info.get('cost', 0.0)

            all_nodes_in_graph.add(current_node_name)

            if i < len(node_costs_log) - 1:
                next_node_info = node_costs_log[i+1]
                next_node_name = next_node_info['node']

                # 从 current_node_name 转移到 next_node_name 的成本,我们定义为 next_node_name 的执行成本
                # 这种定义方式更符合LangGraph中节点执行的逻辑
                transition_costs_sum[current_node_name][next_node_name] += cost_of_current_node
                transition_counts[current_node_name][next_node_name] += 1
            else:
                # 最后一个节点到END的成本就是这个节点自身的成本
                # 我们需要添加一个虚拟的'END'节点
                all_nodes_in_graph.add("END")
                transition_costs_sum[current_node_name]["END"] += cost_of_current_node
                transition_counts[current_node_name]["END"] += 1

    weighted_graph = defaultdict(list)
    for source_node, targets in transition_costs_sum.items():
        for target_node, total_cost in targets.items():
            count = transition_counts[source_node][target_node]
            average_cost = total_cost / count if count > 0 else float('inf')
            weighted_graph[source_node].append((target_node, average_cost))

    # 确保所有节点都在图中,即使它们没有出边
    for node in all_nodes_in_graph:
        if node not in weighted_graph:
            weighted_graph[node] = []

    return dict(weighted_graph)

# 构建加权图
empirical_graph = build_weighted_graph_from_logs(all_agent_run_logs)

print("n=== Empirical Weighted Graph ===")
for node, edges in empirical_graph.items():
    print(f"Node '{node}':")
    for neighbor, weight in edges:
        print(f"  -> '{neighbor}' (Cost: ${weight:.6f})")

Dijkstra算法实现与“黄金路径”发现

现在,我们有了加权图,可以运行Dijkstra算法来找到特定查询类型的“黄金路径”。

我们定义一个目标:对于任何一个用户查询,最终都应该到达generate_response节点(如果可以解决)或者escalate_human节点(如果需要人工介入)。所以,我们的目标节点可以是generate_responseescalate_human。为了简化,我们可以找从classify_queryEND的最短路径。

# 假设我们想要找到从 'classify_query' 到 'END' 的最短路径
start_node = "classify_query"
end_node = "END" # 目标是Agent结束

shortest_cost, shortest_path = dijkstra_shortest_path(empirical_graph, start_node, end_node)

print(f"n=== Dijkstra Shortest Path Analysis ===")
if shortest_cost != float('inf'):
    print(f"Shortest path from '{start_node}' to '{end_node}':")
    print(f"  Path: {' -> '.join(shortest_path)}")
    print(f"  Total Estimated Cost: ${shortest_cost:.6f}")
else:
    print(f"No path found from '{start_node}' to '{end_node}'.")

# 我们可以为特定场景分析路径,例如“greeting”的路径
# 理论上,对于"greeting",路径应该是 classify_query -> generate_response -> END
# 让我们看看我们的经验图是否反映了这一点
# 假设我们从classify_query开始,并且它直接路由到generate_response
# 由于我们的图是基于实际执行的平均成本构建的,所以它会反映常见的、低成本的路径

# 进一步分析:哪些查询导致了哪些路径?
# 遍历原始日志,找出每种查询类型的典型路径和成本
query_path_analysis = defaultdict(list)
for run_info in all_agent_run_logs:
    query = run_info['query']
    node_sequence = [entry['node'] for entry in run_info['total_usage']['node_costs_log']]
    total_cost = run_info['total_usage']['total_cost']
    query_path_analysis[query].append({"path": " -> ".join(node_sequence) + " -> END", "cost": total_cost})

print("n=== Query Type Path Analysis ===")
for query, paths in query_path_analysis.items():
    print(f"Query: '{query}'")
    for p in paths:
        print(f"  Path: {p['path']}, Cost: ${p['cost']:.6f}")

# 假设我们识别出“greeting”查询的理想路径是 `classify_query -> generate_response -> END`
# 并且其成本是最低的。这可以作为我们的“黄金路径”目标。

优化策略应用

找到了“黄金路径”之后,我们如何利用它来优化Agent?

  1. Prompt Engineering(提示词工程)

    • 引导分类:如果发现classify_query节点在某些情况下分类不准确,导致走上了错误的、高成本的路径,我们可以优化分类LLM的提示词。例如,为LLM提供更清晰的分类标准和更多示例。
    • 引导回复:如果generate_response节点在某些情况下生成了不必要的冗长回复(高Token),可以明确要求LLM“用简洁的语言回复”或“限定在50字以内”。
    # 优化后的分类提示词示例
    optimized_classify_prompt = """
    你是一个专业的客户支持查询分类器。请将用户查询严格分类为以下之一:
    'product_info' (关于产品功能、价格、规格等)
    'technical_support' (关于登录、错误、故障排除等技术问题)
    'order_query' (关于订单状态、修改、配送等)
    'greeting' (简单的问候,如“你好”、“谢谢”)
    'unresolved' (无法明确分类或超出上述范围的复杂问题)
    
    请直接输出分类结果,不要有任何额外文字或解释。
    
    用户查询: {query}
    分类:
    """
    # 在 `classify_query_node` 中使用这个优化后的提示词
    # messages = [HumanMessage(content=optimized_classify_prompt.format(query=state['messages'][-1].content))]
  2. 条件逻辑微调

    • 路由优化:如果Dijkstra算法揭示了某个route_next_step函数经常将Agent导向高成本路径,我们可以调整其逻辑。例如,对于order_query,如果我们的knowledge_base_search工具无法处理,但我们知道有一个外部的订单查询API,我们可以优先调用那个API而不是直接升级人工。
    • 提前终止:如果在某个节点发现问题无法解决,且后续路径成本很高,可以考虑提前终止并直接升级人工,而不是进行无效的搜索。
    # 优化后的路由逻辑示例:优先尝试专门的订单查询API(假设存在)
    def optimized_route_next_step(state: CustomerSupportAgentState) -> str:
        query_type = state['query_type']
        search_results = state['search_results']
    
        if query_type == "greeting":
            return "generate_response"
        elif query_type == "order_query" and not search_results:
            # 假设我们有一个专门的订单查询API,比知识库搜索更直接
            # return "call_order_api" # 需要添加新的节点和边
            # 如果没有,就走现有路径
            return "search_knowledge_base" 
        elif query_type in ["product_info", "technical_support"]:
            if not search_results:
                return "search_knowledge_base"
            elif "未找到相关信息" in search_results or "无法解决" in search_results:
                return "escalate_human"
            else:
                return "generate_response"
        elif query_type == "unresolved":
            return "escalate_human"
        else:
            return "escalate_human"
  3. 工具选择与优先级

    • 如果存在多个工具可以完成类似任务(例如,一个通用搜索工具和一个专门的知识库搜索工具),但它们的成本或效果不同,Dijkstra路径可以帮助我们决定在特定情况下优先使用哪个工具。
    • 如果发现某个工具成本很高但经常被调用,可以考虑优化该工具的内部逻辑,或者寻找更便宜的替代方案。
  4. 预计算与缓存

    • 对于一些非常常见且结果稳定的查询,如果其“黄金路径”包含昂贵的LLM调用或工具搜索,我们可以考虑将这些结果预计算并缓存起来。当Agent遇到相同或相似的查询时,直接返回缓存结果,跳过整个推理路径。

通过这些方法,我们可以将数据驱动的路径分析与LangGraph的灵活性相结合,持续迭代和优化Agent的“思考”过程,使其在满足功能需求的同时,实现更高的效率和更低的成本。

6. 高级思考与未来展望

我们今天探讨的方法是Agent思考优化领域的一个起点。未来,这个领域还有巨大的发展空间:

  • 多目标优化:除了Token成本,我们还可以同时优化延迟、准确性、用户满意度等多个目标。这需要更复杂的优化算法,例如多目标A*搜索。
  • 动态权重与环境适应:Agent的成本可能不是静态的。例如,在高峰期,外部API的延迟可能更高。Agent需要能够根据实时环境动态调整其路径选择的权重。
  • 强化学习在路径优化中的潜力:我们可以将Agent的决策过程视为一个马尔可夫决策过程(MDP),通过强化学习让Agent在与环境交互中自主学习和发现最优策略,以最大化奖励(例如:低成本、高准确性)并最小化惩罚。
  • 图剪枝与启发式搜索的结合:对于非常复杂的Agent,其潜在的推理路径图可能极其庞大。我们需要更智能的方法来剪枝不必要的路径,或者设计更有效的启发式函数来加速A*搜索。
  • 鲁棒性与失败处理:即使找到了“黄金路径”,也需要考虑路径中的某个节点失败(例如,API调用失败、LLM生成无效输出)的情况。Agent需要有回退机制和错误处理策略,确保即使最优路径受阻,也能找到次优或安全的替代路径。

思考的艺术与工程的结合

通过将Agent的“思考”过程从模糊的黑盒转化为可量化、可优化的工程问题,我们为构建更高效、更经济的智能系统奠定了基础。这不仅仅是降低成本,更是提升Agent响应速度和稳定性的关键。当我们能够精确地理解Agent的每一步决策所带来的代价时,我们就能更有针对性地进行优化,从而释放Agent在现实世界中的更大潜力。

谢谢大家!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注