解析 ‘Model-agnostic Graph Compiling’:如何编写一套逻辑,使其能无缝在不同供应商(OpenAI/Anthropic)间切换?

引言:构建弹性AI系统的必要性

随着大型语言模型(LLM)技术的飞速发展,它们已经从简单的文本生成工具演变为复杂智能应用的核心。今天,我们构建的AI系统往往不再是单一的LLM调用,而是涉及多个步骤、多轮交互、工具调用(Tool Calling)、知识检索增强生成(RAG)乃至多代理协作的复杂工作流。这些系统通常可以被清晰地建模为有向无环图(DAG),其中每个节点代表一个操作,每条边代表数据流或控制流。

然而,在构建这类复杂系统时,一个核心挑战日益凸显:供应商锁定(Vendor Lock-in)。目前市场上存在多家领先的LLM供应商,如OpenAI、Anthropic、Google、Mistral等。它们各自提供了强大的模型和独特的API接口。一旦我们的应用深度绑定了某一特定供应商的API,便会面临以下问题:

  1. 缺乏灵活性: 难以快速切换到性能更好、成本更低或功能更适合新需求的模型。
  2. 风险集中: 单一供应商的服务中断、政策变更或价格上涨可能直接影响整个应用。
  3. 创新受限: 无法轻易利用其他供应商的独特优势,例如Anthropic在长上下文处理上的表现,或OpenAI在工具调用上的成熟度。
  4. 成本优化困难: 无法根据实时负载或模型成本动态选择最经济的模型。

为了解决这些问题,我们需要一种机制,能够让我们在设计和实现AI工作流时,将底层的LLM供应商细节进行抽象和解耦。这正是“模型无关图编译”(Model-agnostic Graph Compiling)这一概念的核心目标。它旨在提供一套逻辑,使我们定义的AI工作流(图)可以在不同的LLM供应商之间无缝切换,从而实现真正的弹性、可移植性和成本效益。

理解“图编译”:将复杂逻辑抽象化

在深入探讨模型无关性之前,我们首先要理解为什么“图”是描述复杂AI工作流的理想范式,以及“编译”在这里的含义。

为什么是“图”?

在软件工程中,图结构是一种强大的数据模型,能够直观地表示实体之间的关系。对于AI工作流而言:

  • 节点(Nodes): 代表工作流中的原子操作或步骤。这可以是一个LLM调用、一个外部工具的执行(例如数据库查询、API调用)、一个数据转换、一个条件判断、甚至是一个子图的执行。
  • 边(Edges): 代表数据流或控制流。一条边从一个节点指向另一个节点,表示前一个节点的输出是后一个节点的输入,或者前一个节点的完成触发了后一个节点的执行。

示例:一个RAG(检索增强生成)工作流的图表示

节点类型 描述 输入 输出
UserQueryNode 用户输入原始问题 (无) query
RewriteQueryNode 使用LLM重写查询,使其更适合检索 query rewritten_query
RetrieveDocumentsNode 使用rewritten_query从向量数据库检索相关文档 rewritten_query documents
SynthesizeAnswerNode 使用LLM结合documentsquery生成最终答案 query, documents answer
OutputNode 将最终答案呈现给用户 answer (无)

这个例子清晰地展示了如何将一个线性或分支的工作流分解为相互依赖的节点。图的优势在于它能够自然地处理并行执行、条件分支、循环(通过反馈边或迭代子图)以及更复杂的依赖关系,远比简单的顺序执行链条更具表现力。

“编译”的含义

在传统编程中,“编译”是将高级语言代码转换为机器可执行指令的过程。在“模型无关图编译”的上下文中,“编译”的含义有所扩展:

  1. 高层抽象到可执行计划的转换: 将我们用图结构定义的高层AI工作流(例如,一个包含LLMNodeToolNode的抽象图),转换为一个具体的、可执行的步骤序列或并发任务集合。
  2. 验证与优化: 在执行前,验证图的拓扑结构(例如,确保没有死循环,所有输入都有来源),并可能进行一些优化,如合并不必要的节点、并行化独立任务等。
  3. 运行时绑定: 在运行时,根据配置或策略,将抽象的LLMNode绑定到具体的OpenAIProviderAnthropicProvider实例。这正是实现“模型无关性”的关键一步。
  4. 上下文与状态管理: 编译器或执行器负责在节点之间传递和管理状态(即数据流),确保每个节点都能接收到其所需的输入,并将输出正确地传递给下游节点。

简而言之,“模型无关图编译”就是设计一种机制,允许我们以一种抽象、供应商无关的方式定义复杂的AI工作流图,然后由一个“编译器”或“执行引擎”负责将这个抽象图转换为具体的、可执行的操作序列,并在运行时根据需要动态地选择和切换底层的LLM供应商。

模型无关性核心挑战

要实现真正的模型无关性,我们必须直面不同LLM供应商之间的深层差异。这些差异不仅限于API端点,还涉及到模型能力、行为模式乃至定价策略。

1. API 差异

这是最显而易见的挑战。OpenAI和Anthropic的API在请求和响应结构上存在显著差异。

示例对比:OpenAI Chat Completion vs. Anthropic Messages API

特性/参数 OpenAI (Chat Completions) Anthropic (Messages API) 备注
端点 /v1/chat/completions /v1/messages
请求体 JSON JSON
消息格式 [{'role': 'user', 'content': '...'}] [{'role': 'user', 'content': '...'}] 角色名称稍有不同 (assistant vs assistant)
系统消息 {'role': 'system', 'content': '...'} system: "..." (作为顶级参数,而非messages列表中的元素) Anthropic的system消息独立于messages列表
模型名称 model="gpt-4o" model="claude-3-opus-20240229" 命名约定不同
工具调用 tools=[{'type': 'function', 'function': {...}}] tools=[{'name': '...', 'description': '...', 'input_schema': {...}}] 结构和功能名称不同,OpenAI是function,Anthropic是tool_use
流式输出 stream=True stream=True 响应事件结构不同
响应体 choices[0].message.content, choices[0].message.tool_calls content[0].text, content[0].tool_use 访问生成的文本和工具调用的路径不同
错误处理 openai.APIError, openai.RateLimitError anthropic.APIError, anthropic.RateLimitError 错误类型和结构不同

2. 模型命名与可用性

每个供应商都有自己的模型命名体系。例如,OpenAI的gpt-4ogpt-3.5-turbo;Anthropic的claude-3-opus-20240229claude-3-sonnet-20240229。我们的系统需要一种机制来将抽象的模型名称(例如“高质量模型”、“快速模型”)映射到具体供应商的特定模型。

3. 功能差异

  • 工具调用 (Tool Calling/Function Calling): 这是构建复杂代理的关键。尽管OpenAI的function_calling和Anthropic的tool_use概念相似,但其API结构和底层机制存在显著差异。OpenAI的工具定义在tools参数中,调用结果在tool_calls中;Anthropic的工具定义也在tools参数中,但调用结果在content数组中以tool_use类型出现。如何统一这种差异是核心。
  • 流式输出 (Streaming): 虽然两者都支持流式输出,但每次流式响应的数据结构不同,需要不同的解析逻辑。
  • 上下文窗口 (Context Window): 各模型支持的最大输入令牌数不同。
  • 令牌计算 (Token Counting): 不同的模型使用不同的分词器(tokenizer)。OpenAI使用tiktoken,Anthropic有自己的内部tokenizer。准确地计算输入和输出的令牌数对于成本估算和避免超出上下文窗口至关重要。

4. 行为模式与性能

即使是“等效”的模型,在创意、事实准确性、安全性审查和响应速度等方面也可能存在细微差异。这通常需要通过实验和评估来发现,并在选择提供商时加以考虑。

5. 成本管理

不同模型的定价模式和费率差异巨大。在多供应商环境中,能够动态选择最经济的模型对于控制运营成本至关重要。

面对这些挑战,我们需要设计一个强大的抽象层和执行引擎,将这些供应商特定的细节封装起来,向上层的工作流提供一个统一、简洁的接口。

构建模型无关抽象层:LLMProvider 接口

实现模型无关性的核心在于构建一个统一的抽象层,将所有LLM供应商的API封装在一个通用的接口之下。这个接口定义了我们对任何LLM提供商的期望行为,而具体的实现则负责将这些通用行为转换为特定供应商的API调用。

核心设计原则

  1. 统一接口: 定义一套最小但完整的操作集(例如,生成文本、计算令牌、获取可用模型)。
  2. 隔离实现: 每个供应商的具体实现都独立于其他供应商,互不影响。
  3. 可扩展性: 方便地添加新的LLM供应商,只需实现这个接口。
  4. 解耦: 工作流的定义不依赖于任何特定的供应商。

LLMProvider 抽象基类

我们首先定义一个抽象基类 (abc.ABC),它将作为所有具体LLM提供商的蓝图。

import abc
import json
from typing import List, Dict, Any, Union, Optional, Iterator

# 定义一个通用的消息格式,兼容OpenAI和Anthropic
# Anthropic的system消息是独立参数,这里简化为在messages中包含role='system'
# 实际生产中可能需要更精细的抽象,或在Provider实现中处理
ChatMessage = Dict[str, Any]

class ToolDefinition:
    """
    抽象的工具定义,用于描述工具的名称、描述和输入schema。
    这个类旨在统一不同LLM供应商的工具定义格式。
    """
    def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
        self.name = name
        self.description = description
        self.input_schema = input_schema

    def to_openai_format(self) -> Dict[str, Any]:
        """将抽象工具定义转换为OpenAI的工具格式。"""
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.input_schema
            }
        }

    def to_anthropic_format(self) -> Dict[str, Any]:
        """将抽象工具定义转换为Anthropic的工具格式。"""
        # Anthropic的工具定义与OpenAI非常相似,只是参数名略有不同
        return {
            "name": self.name,
            "description": self.description,
            "input_schema": self.input_schema
        }

class LLMProvider(abc.ABC):
    """
    LLM供应商的抽象基类。
    定义了所有LLM供应商必须实现的核心方法。
    """
    def __init__(self, model_name: str, api_key: str):
        self._model_name = model_name
        self._api_key = api_key

    @property
    def model_name(self) -> str:
        return self._model_name

    @abc.abstractmethod
    async def generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> ChatMessage:
        """
        生成单个LLM响应。
        :param messages: 聊天消息列表。
        :param tools: 可选的工具定义列表。
        :param tool_choice: 控制工具使用的策略 (e.g., "auto", "none", {"type": "tool", "function": {"name": "..."}})
        :param temperature: 生成的随机性。
        :param max_tokens: 最大生成令牌数。
        :param kwargs: 额外参数,可能用于特定供应商的高级配置。
        :return: LLM生成的响应消息。
        """
        pass

    @abc.abstractmethod
    async def stream_generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> Iterator[ChatMessage]:
        """
        以流式方式生成LLM响应。
        :return: 包含聊天消息片段的迭代器。
        """
        pass

    @abc.abstractmethod
    def count_tokens(self, text: str) -> int:
        """
        计算给定文本的令牌数。
        这对于预估成本和管理上下文窗口至关重要。
        """
        pass

    @abc.abstractmethod
    def get_supported_model_names(self) -> List[str]:
        """
        返回此供应商支持的模型名称列表。
        """
        pass

    @abc.abstractmethod
    def _parse_streaming_chunk(self, chunk: Any) -> Optional[ChatMessage]:
        """
        内部方法:解析流式响应的单个块,提取有效消息部分。
        """
        pass

    @abc.abstractmethod
    def _process_response(self, response: Any) -> ChatMessage:
        """
        内部方法:处理非流式响应对象,提取统一格式的消息。
        """
        pass

OpenAIProvider 实现

import openai
import tiktoken
import os # 通常从环境变量获取API Key

class OpenAIProvider(LLMProvider):
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        super().__init__(model_name, api_key or os.getenv("OPENAI_API_KEY"))
        if not self._api_key:
            raise ValueError("OpenAI API Key is not provided or not found in environment variables.")
        self.client = openai.AsyncOpenAI(api_key=self._api_key)
        self._tokenizer = tiktoken.encoding_for_model(model_name) # 可能会根据model_name调整

    async def generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> ChatMessage:
        openai_tools = [tool.to_openai_format() for tool in tools] if tools else None

        # 统一处理系统消息
        system_message = next((msg for msg in messages if msg.get("role") == "system"), None)
        user_messages = [msg for msg in messages if msg.get("role") != "system"]

        try:
            response = await self.client.chat.completions.create(
                model=self._model_name,
                messages=user_messages, # OpenAI的系统消息在messages列表中
                temperature=temperature,
                max_tokens=max_tokens,
                tools=openai_tools,
                tool_choice=tool_choice, # "auto", "none", or specific tool_choice
                **kwargs
            )
            return self._process_response(response)
        except openai.APIError as e:
            print(f"OpenAI API Error: {e}")
            raise

    async def stream_generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> Iterator[ChatMessage]:
        openai_tools = [tool.to_openai_format() for tool in tools] if tools else None
        user_messages = [msg for msg in messages if msg.get("role") != "system"]

        try:
            stream = await self.client.chat.completions.create(
                model=self._model_name,
                messages=user_messages,
                temperature=temperature,
                max_tokens=max_tokens,
                tools=openai_tools,
                tool_choice=tool_choice,
                stream=True,
                **kwargs
            )
            async for chunk in stream:
                parsed_chunk = self._parse_streaming_chunk(chunk)
                if parsed_chunk:
                    yield parsed_chunk
        except openai.APIError as e:
            print(f"OpenAI API Error during streaming: {e}")
            raise

    def count_tokens(self, text: str) -> int:
        return len(self._tokenizer.encode(text))

    def get_supported_model_names(self) -> List[str]:
        # 实际生产中可能需要调用API获取最新列表或维护一个配置
        return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"]

    def _parse_streaming_chunk(self, chunk: Any) -> Optional[ChatMessage]:
        """解析OpenAI流式响应的单个块。"""
        if not chunk.choices:
            return None

        delta = chunk.choices[0].delta
        parsed_message: ChatMessage = {"role": "assistant"} # 默认角色

        if delta.content:
            parsed_message["content"] = delta.content

        if delta.tool_calls:
            # OpenAI的tool_calls在流式输出中可能分段
            # 这里简单地将所有tool_calls归集,实际需要更复杂的逻辑来拼接
            parsed_message["tool_calls"] = []
            for tc in delta.tool_calls:
                tool_call = {"id": tc.id, "type": tc.type, "function": {"name": tc.function.name}}
                if tc.function.arguments:
                    tool_call["function"]["arguments"] = tc.function.arguments
                parsed_message["tool_calls"].append(tool_call)

        # 只有当有内容或工具调用时才返回
        if "content" in parsed_message or "tool_calls" in parsed_message:
            return parsed_message
        return None

    def _process_response(self, response: Any) -> ChatMessage:
        """处理OpenAI非流式响应对象。"""
        choice = response.choices[0]
        message = choice.message

        parsed_message: ChatMessage = {"role": message.role}
        if message.content:
            parsed_message["content"] = message.content

        if message.tool_calls:
            parsed_message["tool_calls"] = []
            for tc in message.tool_calls:
                parsed_message["tool_calls"].append({
                    "id": tc.id,
                    "type": tc.type,
                    "function": {
                        "name": tc.function.name,
                        "arguments": tc.function.arguments # arguments是字符串,需要调用方解析
                    }
                })
        return parsed_message

AnthropicProvider 实现

import anthropic
import os
from anthropic import Anthropic
from anthropic.types import MessageParam, ToolParam

class AnthropicProvider(LLMProvider):
    def __init__(self, model_name: str, api_key: Optional[str] = None):
        super().__init__(model_name, api_key or os.getenv("ANTHROPIC_API_KEY"))
        if not self._api_key:
            raise ValueError("Anthropic API Key is not provided or not found in environment variables.")
        self.client = Anthropic(api_key=self._api_key)

    async def generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> ChatMessage:
        anthropic_tools: Optional[List[ToolParam]] = None
        if tools:
            anthropic_tools = [tool.to_anthropic_format() for tool in tools]

        # Anthropic的系统消息是独立的参数
        system_message_content = next((msg["content"] for msg in messages if msg.get("role") == "system"), None)
        user_messages_for_anthropic: List[MessageParam] = []
        for msg in messages:
            if msg.get("role") == "user":
                user_messages_for_anthropic.append({"role": "user", "content": msg["content"]})
            elif msg.get("role") == "assistant" and "content" in msg:
                user_messages_for_anthropic.append({"role": "assistant", "content": msg["content"]})
            elif msg.get("role") == "assistant" and "tool_calls" in msg:
                # Anthropic的tool_use是assistant role的content的一部分
                tool_uses_content = []
                for tc in msg["tool_calls"]:
                    tool_uses_content.append({
                        "type": "tool_use",
                        "id": tc["id"],
                        "name": tc["function"]["name"],
                        "input": json.loads(tc["function"]["arguments"]) # Anthropic需要input是对象
                    })
                user_messages_for_anthropic.append({"role": "assistant", "content": tool_uses_content})
            elif msg.get("role") == "tool":
                # OpenAI的tool_message对应Anthropic的user role中的tool_result
                user_messages_for_anthropic.append({
                    "role": "user",
                    "content": [
                        {
                            "type": "tool_result",
                            "tool_use_id": msg["tool_call_id"],
                            "content": msg["content"]
                        }
                    ]
                })

        try:
            response = await self.client.messages.create(
                model=self._model_name,
                messages=user_messages_for_anthropic,
                system=system_message_content, # 系统消息作为独立参数
                temperature=temperature,
                max_tokens=max_tokens,
                tools=anthropic_tools,
                # Anthropic的tool_choice处理方式与OpenAI略有不同,需要映射
                # "auto" 对应 "auto", "none" 对应 "none"
                # 特定工具需要 {"type": "tool", "name": "tool_name"}
                tool_choice=tool_choice, # 假设tool_choice格式已适配
                **kwargs
            )
            return self._process_response(response)
        except anthropic.APIError as e:
            print(f"Anthropic API Error: {e}")
            raise

    async def stream_generate(
        self,
        messages: List[ChatMessage],
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        **kwargs
    ) -> Iterator[ChatMessage]:
        anthropic_tools: Optional[List[ToolParam]] = None
        if tools:
            anthropic_tools = [tool.to_anthropic_format() for tool in tools]

        system_message_content = next((msg["content"] for msg in messages if msg.get("role") == "system"), None)
        user_messages_for_anthropic: List[MessageParam] = []
        for msg in messages:
            if msg.get("role") == "user":
                user_messages_for_anthropic.append({"role": "user", "content": msg["content"]})
            elif msg.get("role") == "assistant" and "content" in msg:
                user_messages_for_anthropic.append({"role": "assistant", "content": msg["content"]})
            elif msg.get("role") == "assistant" and "tool_calls" in msg:
                tool_uses_content = []
                for tc in msg["tool_calls"]:
                    tool_uses_content.append({
                        "type": "tool_use",
                        "id": tc["id"],
                        "name": tc["function"]["name"],
                        "input": json.loads(tc["function"]["arguments"])
                    })
                user_messages_for_anthropic.append({"role": "assistant", "content": tool_uses_content})
            elif msg.get("role") == "tool":
                user_messages_for_anthropic.append({
                    "role": "user",
                    "content": [
                        {
                            "type": "tool_result",
                            "tool_use_id": msg["tool_call_id"],
                            "content": msg["content"]
                        }
                    ]
                })

        try:
            async with self.client.messages.stream(
                model=self._model_name,
                messages=user_messages_for_anthropic,
                system=system_message_content,
                temperature=temperature,
                max_tokens=max_tokens,
                tools=anthropic_tools,
                tool_choice=tool_choice,
                **kwargs
            ) as stream:
                async for chunk in stream:
                    parsed_chunk = self._parse_streaming_chunk(chunk)
                    if parsed_chunk:
                        yield parsed_chunk
        except anthropic.APIError as e:
            print(f"Anthropic API Error during streaming: {e}")
            raise

    def count_tokens(self, text: str) -> int:
        # Anthropic没有公开的Python tokenizer库,通常通过API调用或估算
        # 这里为简化示例,返回一个近似值,实际生产需要更精确的实现
        return self.client.count_tokens(text)

    def get_supported_model_names(self) -> List[str]:
        return ["claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]

    def _parse_streaming_chunk(self, chunk: Any) -> Optional[ChatMessage]:
        """解析Anthropic流式响应的单个块。"""
        # Anthropic的流式事件类型很多,需要合并
        # message_start, content_block_start, content_block_delta, content_block_stop, message_delta, message_stop

        parsed_message: ChatMessage = {"role": "assistant"}

        if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
            parsed_message["content"] = chunk.delta.text
        elif chunk.type == "content_block_start" and chunk.content_block.type == "tool_use":
            # 这是一个工具调用开始的信号
            tool_use_block = chunk.content_block
            parsed_message["tool_calls"] = [{
                "id": tool_use_block.id,
                "type": "function", # 统一为function类型
                "function": {
                    "name": tool_use_block.name,
                    "arguments": json.dumps(tool_use_block.input) # input是字典,需要转成字符串
                }
            }]
        elif chunk.type == "content_block_delta" and chunk.delta.type == "input_json_delta":
            # 工具调用的参数可能分段传输,这里需要更复杂的拼接逻辑
            # 为简化,假设一次性传输,或者在GraphExecutor中处理拼接
            pass # 实际需要累积json片段
        elif chunk.type == "message_delta" and chunk.delta.stop_reason:
            # 消息结束,可能包含停止原因
            parsed_message["stop_reason"] = chunk.delta.stop_reason

        if "content" in parsed_message or "tool_calls" in parsed_message or "stop_reason" in parsed_message:
            return parsed_message
        return None

    def _process_response(self, response: Any) -> ChatMessage:
        """处理Anthropic非流式响应对象。"""
        parsed_message: ChatMessage = {"role": response.role}

        full_content = []
        tool_calls = []

        for content_block in response.content:
            if content_block.type == "text":
                full_content.append(content_block.text)
            elif content_block.type == "tool_use":
                tool_calls.append({
                    "id": content_block.id,
                    "type": "function", # 统一为function类型
                    "function": {
                        "name": content_block.name,
                        "arguments": json.dumps(content_block.input) # input是字典,需要转成字符串
                    }
                })

        if full_content:
            parsed_message["content"] = "".join(full_content)
        if tool_calls:
            parsed_message["tool_calls"] = tool_calls

        return parsed_message

通过LLMProvider抽象,我们成功地将不同供应商的API细节封装起来。上层的工作流代码现在只需与这个统一的接口交互,而无需关心底层是OpenAI还是Anthropic。

图结构定义:节点与边

现在,我们有了模型无关的LLM调用能力。接下来,需要定义如何构建我们的AI工作流图。我们将定义抽象的GraphNode,以及具体的LLMNodeToolNode

GraphNode 抽象基类

所有图中的操作都将继承自这个基类。

import uuid

class GraphNode(abc.ABC):
    """
    图节点的抽象基类。
    """
    def __init__(self, node_id: Optional[str] = None):
        self.node_id = node_id if node_id else str(uuid.uuid4())
        self.inputs: List[str] = []  # 节点所需输入数据的key
        self.output_key: Optional[str] = None # 节点产生输出数据存储的key

    @abc.abstractmethod
    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        """
        执行节点逻辑。
        :param context: 包含所有当前可用数据的字典。
        :param provider_selector: 用于选择LLMProvider的机制。
        :return: 节点执行后更新的context。
        """
        pass

    def __repr__(self):
        return f"<{self.__class__.__name__} id='{self.node_id}'>"

LLMNode:模型调用节点

LLMNode封装了一个LLM调用。它不关心是哪个具体的LLM提供商,只定义了它需要哪些输入来构建messages,以及它将把LLM的响应存储在哪里。

class LLMNode(GraphNode):
    """
    代表一个LLM调用的图节点。
    它使用一个提示模板和输入数据来构建消息,并调用LLM。
    """
    def __init__(
        self,
        node_id: Optional[str] = None,
        prompt_template: str = "{query}",
        input_keys: List[str] = ["query"],
        output_key: str = "llm_output",
        model_preference: Optional[str] = None, # 例如 "gpt-4o", "claude-3-opus", 或者 "high_quality"
        temperature: float = 0.7,
        max_tokens: int = 1024,
        tools: Optional[List[ToolDefinition]] = None,
        tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
        stream: bool = False
    ):
        super().__init__(node_id)
        self.prompt_template = prompt_template
        self.inputs = input_keys
        self.output_key = output_key
        self.model_preference = model_preference
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.tools = tools
        self.tool_choice = tool_choice
        self.stream = stream

    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        # 从context中获取输入数据
        input_data = {key: context.get(key) for key in self.inputs}

        # 填充提示模板
        formatted_prompt = self.prompt_template.format(**input_data)

        # 构建消息列表 (这里简化为单个用户消息,实际可能更复杂)
        messages: List[ChatMessage] = [{"role": "user", "content": formatted_prompt}]

        # 选择LLM提供商
        llm_provider: LLMProvider = provider_selector.select_provider(self.model_preference)

        print(f"Executing LLMNode {self.node_id} with model: {llm_provider.model_name} from {llm_provider.__class__.__name__}")

        if self.stream:
            full_response_content = []
            full_tool_calls = []
            async for chunk in llm_provider.stream_generate(
                messages=messages,
                tools=self.tools,
                tool_choice=self.tool_choice,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            ):
                if chunk.get("content"):
                    full_response_content.append(chunk["content"])
                    # 可以在这里实时处理流式输出,例如打印到控制台
                    # print(chunk["content"], end="", flush=True)
                if chunk.get("tool_calls"):
                    # 流式工具调用需要复杂的合并逻辑,这里简化
                    for tc in chunk["tool_calls"]:
                        # 检查是否是新的工具调用或现有调用的参数更新
                        existing_tc = next((extc for extc in full_tool_calls if extc["id"] == tc["id"]), None)
                        if existing_tc:
                            # 假设参数是字符串,进行拼接
                            if "arguments" in tc["function"] and "arguments" in existing_tc["function"]:
                                existing_tc["function"]["arguments"] += tc["function"]["arguments"]
                            else:
                                existing_tc["function"]["arguments"] = tc["function"]["arguments"]
                        else:
                            full_tool_calls.append(tc)

            final_response: ChatMessage = {"role": "assistant"}
            if full_response_content:
                final_response["content"] = "".join(full_response_content)
            if full_tool_calls:
                final_response["tool_calls"] = full_tool_calls

        else:
            final_response = await llm_provider.generate(
                messages=messages,
                tools=self.tools,
                tool_choice=self.tool_choice,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )

        # 将LLM响应存储到context
        updated_context = context.copy()
        updated_context[self.output_key] = final_response
        return updated_context

ToolNode:工具调用节点

ToolNode代表一个外部工具的执行。它接收来自LLM的工具调用指令,执行实际的工具逻辑,并将结果返回。

class BaseTool(abc.ABC):
    """所有工具的抽象基类。"""
    def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
        self.name = name
        self.description = description
        self.input_schema = input_schema
        self.tool_definition = ToolDefinition(name, description, input_schema)

    @abc.abstractmethod
    async def run(self, **kwargs) -> Any:
        """执行工具的核心逻辑。"""
        pass

# 示例工具:天气查询
class WeatherTool(BaseTool):
    def __init__(self):
        super().__init__(
            name="get_current_weather",
            description="获取指定城市当前的天气信息。",
            input_schema={
                "type": "object",
                "properties": {
                    "location": {"type": "string", "description": "城市名称,例如:北京"}
                },
                "required": ["location"]
            }
        )

    async def run(self, location: str) -> str:
        # 实际这里会调用外部天气API
        print(f"Calling Weather API for {location}...")
        if location == "北京":
            return "北京当前晴朗,25°C,微风。"
        elif location == "上海":
            return "上海当前多云,28°C,有阵雨风险。"
        else:
            return f"无法获取 {location} 的天气信息。"

class ToolNode(GraphNode):
    """
    代表一个外部工具执行的图节点。
    它从上下文中获取工具调用信息,执行工具,并将结果存储回上下文。
    """
    def __init__(
        self,
        node_id: Optional[str] = None,
        tool_instance: BaseTool = None, # 具体工具实例
        tool_input_key: str = "tool_call", # 预期从context中获取工具调用的key
        output_key: str = "tool_output",
        tool_name_filter: Optional[str] = None # 如果一个节点只处理特定工具
    ):
        super().__init__(node_id)
        if not tool_instance:
            raise ValueError("ToolNode requires a tool_instance.")
        self.tool_instance = tool_instance
        self.inputs = [tool_input_key] # 接收LLM生成的工具调用信息
        self.output_key = output_key
        self.tool_name_filter = tool_name_filter or tool_instance.name

    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        tool_call_info = context.get(self.inputs[0])

        if not tool_call_info:
            print(f"ToolNode {self.node_id}: No tool call info found in context['{self.inputs[0]}']. Skipping.")
            return context.copy()

        # 假设 tool_call_info 是一个列表,可能包含多个tool_calls
        # 或者是一个单个工具调用对象
        if not isinstance(tool_call_info, list):
            tool_call_info = [tool_call_info]

        tool_results = []
        for call in tool_call_info:
            # 兼容OpenAI和Anthropic的工具调用格式
            # OpenAI: {"id": "...", "type": "function", "function": {"name": "...", "arguments": "{...}"}}
            # Anthropic: {"id": "...", "type": "function", "function": {"name": "...", "arguments": "{...}"}} (已统一)

            tool_name = call["function"]["name"]
            tool_args_str = call["function"]["arguments"]
            tool_call_id = call["id"] # 用于将结果关联回LLM

            if self.tool_name_filter and tool_name != self.tool_name_filter:
                print(f"ToolNode {self.node_id}: Ignoring tool call for '{tool_name}', expecting '{self.tool_name_filter}'.")
                continue

            try:
                # 解析参数,Anthropic的arguments在_process_response中已转为JSON字符串
                tool_args = json.loads(tool_args_str)
                print(f"ToolNode {self.node_id}: Executing tool '{tool_name}' with args: {tool_args}")
                result = await self.tool_instance.run(**tool_args)
                tool_results.append({
                    "tool_call_id": tool_call_id,
                    "result": result
                })
            except Exception as e:
                print(f"ToolNode {self.node_id}: Error executing tool '{tool_name}': {e}")
                tool_results.append({
                    "tool_call_id": tool_call_id,
                    "error": str(e)
                })

        updated_context = context.copy()
        updated_context[self.output_key] = tool_results
        return updated_context

Graph

Graph类负责存储节点和它们之间的连接(边),并提供拓扑排序等功能,确保节点按正确的顺序执行。

class Graph:
    """
    表示一个AI工作流的图结构。
    """
    def __init__(self):
        self.nodes: Dict[str, GraphNode] = {}
        self.edges: Dict[str, List[str]] = {} # {source_node_id: [target_node_id, ...]}

    def add_node(self, node: GraphNode):
        if node.node_id in self.nodes:
            raise ValueError(f"Node with ID {node.node_id} already exists.")
        self.nodes[node.node_id] = node
        self.edges[node.node_id] = []

    def add_edge(self, source_node_id: str, target_node_id: str):
        if source_node_id not in self.nodes:
            raise ValueError(f"Source node {source_node_id} not found.")
        if target_node_id not in self.nodes:
            raise ValueError(f"Target node {target_node_id} not found.")
        self.edges[source_node_id].append(target_node_id)

    def topological_sort(self) -> List[GraphNode]:
        """
        对图进行拓扑排序,返回一个节点的执行顺序列表。
        这对于确定执行顺序至关重要。
        """
        in_degree: Dict[str, int] = {node_id: 0 for node_id in self.nodes}
        for source_node_id in self.edges:
            for target_node_id in self.edges[source_node_id]:
                in_degree[target_node_id] += 1

        queue: List[str] = [node_id for node_id, degree in in_degree.items() if degree == 0]
        sorted_nodes: List[GraphNode] = []

        while queue:
            node_id = queue.pop(0)
            sorted_nodes.append(self.nodes[node_id])

            for neighbor_id in self.edges[node_id]:
                in_degree[neighbor_id] -= 1
                if in_degree[neighbor_id] == 0:
                    queue.append(neighbor_id)

        if len(sorted_nodes) != len(self.nodes):
            raise ValueError("Graph contains a cycle!")
        return sorted_nodes

图编译与执行引擎

有了抽象的LLM提供商和图节点定义,现在我们需要一个“编译器”来将这些组件组合起来,并一个“执行引擎”来实际运行图。

ProviderSelector:动态选择LLM供应商

这是实现模型无关性的关键组件之一。它负责根据配置或运行时策略,为LLMNode选择一个合适的LLMProvider实例。

class ProviderSelector:
    """
    负责根据模型偏好选择合适的LLMProvider实例。
    """
    def __init__(self, providers: Dict[str, LLMProvider], default_provider_key: str):
        self._providers = providers
        if default_provider_key not in self._providers:
            raise ValueError(f"Default provider key '{default_provider_key}' not found in provided providers.")
        self._default_provider_key = default_provider_key

        # 维护一个映射:抽象模型名 -> 实际LLMProvider实例
        self._model_map: Dict[str, LLMProvider] = {}
        for key, provider in providers.items():
            for model_name in provider.get_supported_model_names():
                self._model_map[model_name] = provider
            # 也可以为每个provider定义一个“别名”
            self._model_map[key] = provider # 允许直接通过provider key访问

    def select_provider(self, model_preference: Optional[str] = None) -> LLMProvider:
        """
        根据模型偏好选择一个LLMProvider。
        :param model_preference: 首选的模型名称(例如 "gpt-4o")或供应商别名(例如 "openai")。
                                 如果为None,则使用默认供应商。
        """
        if model_preference:
            # 尝试直接匹配模型名称或供应商别名
            if model_preference in self._model_map:
                return self._model_map[model_preference]

            # 进一步的逻辑:例如,如果偏好是"high_quality",则可以遍历所有提供商,
            # 找到成本最高/性能最好的模型。这里简化为直接匹配。
            print(f"Warning: Model preference '{model_preference}' not directly matched. Falling back to default.")

        return self._providers[self._default_provider_key]

    def get_provider(self, provider_key: str) -> LLMProvider:
        """根据提供商键获取具体的LLMProvider实例。"""
        if provider_key not in self._providers:
            raise ValueError(f"Provider with key '{provider_key}' not found.")
        return self._providers[provider_key]

GraphExecutor:执行图

GraphExecutor是我们的“运行时”,它负责按照拓扑排序的顺序遍历图中的节点,管理上下文数据,并调度节点的执行。

class GraphExecutor:
    """
    负责执行AI工作流图。
    """
    def __init__(self, graph: Graph, provider_selector: ProviderSelector):
        self.graph = graph
        self.provider_selector = provider_selector

    async def execute(self, initial_context: Dict[str, Any]) -> Dict[str, Any]:
        """
        执行整个图工作流。
        :param initial_context: 初始输入数据。
        :return: 包含所有节点输出的最终上下文。
        """
        current_context = initial_context.copy()

        try:
            sorted_nodes = self.graph.topological_sort()
        except ValueError as e:
            print(f"Graph execution failed: {e}")
            return current_context

        print(f"Executing graph with {len(sorted_nodes)} nodes in order: {[node.node_id for node in sorted_nodes]}")

        for node in sorted_nodes:
            print(f"--> Executing node: {node}")
            try:
                # 检查节点所需输入是否都已在context中
                missing_inputs = [key for key in node.inputs if key not in current_context]
                if missing_inputs:
                    raise ValueError(f"Node {node.node_id} missing required inputs: {missing_inputs}")

                # 执行节点,并更新上下文
                current_context = await node.execute(current_context, self.provider_selector)
                print(f"<-- Node {node.node_id} executed. Output key: {node.output_key}")
                # print(f"Current Context Keys: {list(current_context.keys())}") # 调试用
            except Exception as e:
                print(f"Error executing node {node.node_id}: {e}")
                # 可以在这里实现错误处理、重试逻辑
                raise # 向上抛出异常,或者根据策略进行恢复

        return current_context

代码示例:一个简单的RAG流程

现在,我们来构建一个简单的RAG流程,并演示如何通过配置切换LLM提供商。这个RAG流程将:

  1. 接收用户查询。
  2. 使用LLM(QueryRewriteNode)重写查询以提高检索质量。
  3. 使用一个假想的SearchTool检索文档。
  4. 使用LLM(AnswerGenerationNode)结合原始查询、重写查询和检索到的文档生成最终答案。
import asyncio
import os

# 假设环境变量已设置
# os.environ["OPENAI_API_KEY"] = "sk-..."
# os.environ["ANTHROPIC_API_KEY"] = "sk-ant-..."

# 1. 定义工具 (SearchTool)
class SearchTool(BaseTool):
    def __init__(self):
        super().__init__(
            name="search_knowledge_base",
            description="在知识库中搜索相关信息。",
            input_schema={
                "type": "object",
                "properties": {
                    "query": {"type": "string", "description": "用于搜索的查询语句"}
                },
                "required": ["query"]
            }
        )

    async def run(self, query: str) -> str:
        print(f"Searching knowledge base for: '{query}'...")
        # 模拟搜索结果
        if "人工智能" in query:
            return "检索结果:人工智能(AI)是计算机科学的一个分支,旨在创建能够执行通常需要人类智能的任务的机器。"
        elif "量子计算" in query:
            return "检索结果:量子计算是一种利用量子力学现象(如叠加和纠缠)来执行计算的新型计算范式。"
        else:
            return "检索结果:未找到相关信息。"

# 2. 构建RAG图的节点
class UserQueryNode(GraphNode):
    def __init__(self, query: str, node_id: str = "user_query_node", output_key: str = "user_query"):
        super().__init__(node_id)
        self._query = query
        self.output_key = output_key

    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        updated_context = context.copy()
        updated_context[self.output_key] = self._query
        return updated_context

class QueryRewriteNode(LLMNode):
    def __init__(self, node_id: str = "query_rewrite_node", model_preference: Optional[str] = None):
        super().__init__(
            node_id=node_id,
            prompt_template="用户问题:{user_query}n请将上述用户问题重写为一个更适合在搜索引擎或知识库中检索的查询。只返回重写后的查询。",
            input_keys=["user_query"],
            output_key="rewritten_query_llm_output",
            model_preference=model_preference,
            max_tokens=100,
            temperature=0.0
        )

    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        updated_context = await super().execute(context, provider_selector)
        # 从LLM输出中提取实际的重写查询文本
        llm_response = updated_context[self.output_key]
        if llm_response.get("content"):
            updated_context["rewritten_query"] = llm_response["content"].strip()
        else:
            updated_context["rewritten_query"] = context["user_query"] # 降级为原始查询
        return updated_context

class SearchToolCallingLLMNode(LLMNode):
    def __init__(self, search_tool: BaseTool, node_id: str = "search_llm_node", model_preference: Optional[str] = None):
        super().__init__(
            node_id=node_id,
            prompt_template="根据用户问题:{user_query},你需要调用`search_knowledge_base`工具来获取信息吗?如果需要,请调用。如果没有,请直接回答。",
            input_keys=["user_query"],
            output_key="llm_tool_call_output", # 这里会存储LLM的工具调用响应
            model_preference=model_preference,
            tools=[search_tool.tool_definition],
            tool_choice="auto", # 允许LLM自动决定是否调用工具
            temperature=0.0,
            max_tokens=500
        )

class AnswerGenerationNode(LLMNode):
    def __init__(self, node_id: str = "answer_generation_node", model_preference: Optional[str] = None):
        super().__init__(
            node_id=node_id,
            prompt_template=(
                "用户问题:{user_query}n"
                "重写查询:{rewritten_query}n"
                "检索到的信息:{search_result}nn"
                "请根据以上信息,简洁地回答用户问题。如果检索信息不足以回答,请说明。"
            ),
            input_keys=["user_query", "rewritten_query", "search_result"],
            output_key="final_answer_llm_output",
            model_preference=model_preference,
            temperature=0.2,
            max_tokens=500,
            stream=True
        )

    async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
        updated_context = await super().execute(context, provider_selector)
        # 从LLM输出中提取最终答案文本
        llm_response = updated_context[self.output_key]
        if llm_response.get("content"):
            updated_context["final_answer"] = llm_response["content"].strip()
        else:
            updated_context["final_answer"] = "未能生成答案。"
        return updated_context

# 3. 组装图
async def run_rag_pipeline(user_input: str, llm_provider_preference: str = "openai") -> Dict[str, Any]:
    # 初始化LLM提供商
    openai_provider = OpenAIProvider(model_name="gpt-4o")
    anthropic_provider = AnthropicProvider(model_name="claude-3-opus-20240229")

    provider_selector = ProviderSelector(
        providers={
            "openai": openai_provider,
            "anthropic": anthropic_provider
        },
        default_provider_key=llm_provider_preference # 动态切换默认供应商
    )

    # 初始化工具
    search_tool_instance = SearchTool()

    # 创建节点
    user_node = UserQueryNode(query=user_input)
    # query_rewrite_node = QueryRewriteNode(model_preference=llm_provider_preference)
    # 模拟llm直接进行工具调用
    search_llm_node = SearchToolCallingLLMNode(search_tool=search_tool_instance, model_preference=llm_provider_preference)
    search_tool_node = ToolNode(
        tool_instance=search_tool_instance,
        tool_input_key="llm_tool_call_output", # 接收来自LLM的工具调用指令
        output_key="search_tool_raw_output",
        tool_name_filter=search_tool_instance.name
    )

    # 转换工具输出为LLM可用的字符串
    class FormatSearchResultNode(GraphNode):
        def __init__(self, node_id: str = "format_search_result_node", input_key: str = "search_tool_raw_output", output_key: str = "search_result"):
            super().__init__(node_id)
            self.inputs = [input_key]
            self.output_key = output_key

        async def execute(self, context: Dict[str, Any], provider_selector: Any) -> Dict[str, Any]:
            raw_output = context.get(self.inputs[0])
            formatted_output = "无检索结果。"
            if raw_output and isinstance(raw_output, list):
                results = []
                for item in raw_output:
                    if "result" in item:
                        results.append(item["result"])
                    elif "error" in item:
                        results.append(f"工具执行错误: {item['error']}")
                formatted_output = "n".join(results)

            updated_context = context.copy()
            updated_context[self.output_key] = formatted_output
            return updated_context

    format_search_node = FormatSearchResultNode()

    answer_node = AnswerGenerationNode(model_preference=llm_provider_preference)

    # 构建图
    rag_graph = Graph()
    rag_graph.add_node(user_node)
    # rag_graph.add_node(query_rewrite_node) # 移除重写节点,直接让LLM进行工具调用
    rag_graph.add_node(search_llm_node)
    rag_graph.add_node(search_tool_node)
    rag_graph.add_node(format_search_node)
    rag_graph.add_node(answer_node)

    # 定义边 (数据流)
    rag_graph.add_edge(user_node.node_id, search_llm_node.node_id) # 用户查询作为LLM的输入
    rag_graph.add_edge(search_llm_node.node_id, search_tool_node.node_id) # LLM的工具调用作为ToolNode的输入
    rag_graph.add_edge(search_tool_node.node_id, format_search_node.node_id) # 工具原始输出作为格式化节点的输入
    rag_graph.add_edge(format_search_node.node_id, answer_node.node_id) # 格式化后的搜索结果作为答案生成LLM的输入
    rag_graph.add_edge(user_node.node_id, answer_node.node_id) # 原始查询也作为答案生成LLM的输入
    # 为了简化,这里AnswerGenerationNode的prompt_template中直接使用了`rewritten_query`,
    # 但由于我们跳过了QueryRewriteNode,`rewritten_query`在context中可能不存在。
    # 实际应调整AnswerGenerationNode的input_keys或prompt_template来反映跳过重写。
    # 这里我们为演示Provider切换,简单地将user_query映射到rewritten_query
    # 或者修改AnswerGenerationNode的prompt_template
    # 暂时把AnswerGenerationNode的prompt_template的rewritten_query替换成user_query
    answer_node.prompt_template = (
                "用户问题:{user_query}n"
                "检索到的信息:{search_result}nn"
                "请根据以上信息,简洁地回答用户问题。如果检索信息不足以回答,请说明。"
            )
    answer_node.inputs = ["user_query", "search_result"]

    # 执行图
    executor = GraphExecutor(rag_graph, provider_selector)
    final_context = await executor.execute(initial_context={})

    return final_context

# 运行示例
async def main():
    print("--- Running with OpenAI Provider ---")
    openai_result = await run_rag_pipeline(
        user_input="什么是人工智能?",
        llm_provider_preference="openai"
    )
    print("nFinal Answer (OpenAI):")
    print(openai_result.get("final_answer"))
    print("-" * 50)

    print("n--- Running with Anthropic Provider ---")
    anthropic_result = await run_rag_pipeline(
        user_input="什么是量子计算?",
        llm_provider_preference="anthropic"
    )
    print("nFinal Answer (Anthropic):")
    print(anthropic_result.get("final_answer"))
    print("-" * 50)

    print("n--- Running with OpenAI Provider (again, different query) ---")
    openai_result_2 = await run_rag_pipeline(
        user_input="告诉我关于宇宙大爆炸的理论。",
        llm_provider_preference="openai"
    )
    print("nFinal Answer (OpenAI, again):")
    print(openai_result_2.get("final_answer"))
    print("-" * 50)

if __name__ == "__main__":
    asyncio.run(main())

代码解释:

  1. LLMProviderToolDefinition 提供了与OpenAI和Anthropic API交互的统一接口,并抽象了工具的定义。
  2. GraphNode 及其子类:
    • UserQueryNode:简单地将用户输入放入上下文。
    • SearchToolCallingLLMNode:这是一个LLM节点,它的主要任务是决定是否调用search_knowledge_base工具,并将工具调用指令输出。
    • ToolNode:负责接收并执行具体的SearchTool实例。
    • FormatSearchResultNode:将工具的原始输出格式化成LLM更容易理解的字符串。
    • AnswerGenerationNode:结合所有可用信息,生成最终答案。
  3. ProviderSelectorrun_rag_pipeline函数中初始化,并通过llm_provider_preference参数动态选择使用OpenAI或Anthropic。
  4. GraphGraphExecutor run_rag_pipeline函数构建了图并使用GraphExecutor来执行它。拓扑排序确保了节点按正确的依赖顺序执行。

通过这个例子,我们看到,在Graph的定义和GraphExecutor的执行逻辑中,我们从未直接调用openai.ChatCompletion.createanthropic.messages.create。所有LLM交互都通过LLMProvider接口进行,而具体使用哪个提供商则由ProviderSelector在运行时根据llm_provider_preference参数动态决定。这正是“模型无关图编译”的核心体现。

高级考量与最佳实践

构建一个健壮、可扩展的模型无关图编译系统,还需要考虑以下高级方面:

1. 可观测性 (Observability)

  • 统一日志: 所有LLM调用、工具执行和节点状态变化都应记录下来,使用统一的日志格式,便于追溯问题。
  • 分布式跟踪 (Distributed Tracing): 使用OpenTelemetry等标准,为每个图执行生成一个全局追踪ID,并将LLM调用、工具调用等子操作作为Span。这对于理解复杂工作流的性能瓶颈和失败原因至关重要。
  • 指标 (Metrics): 收集关键指标,如每个节点的执行时间、LLM的令牌使用量、API调用成功率、缓存命中率等,用于性能监控和成本分析。

2. 缓存策略

  • LLM调用缓存: 对于重复的或确定性强的LLM调用(例如,重写查询、少量信息提取),可以缓存其结果,避免不必要的API调用,降低成本并加快响应速度。
  • 工具调用缓存: 外部工具(如数据库查询、API调用)的结果也可以缓存。
  • 缓存失效: 考虑缓存的生命周期和失效策略。

3. 成本管理与优化

  • 动态模型选择: ProviderSelector可以变得更智能。例如,对于简单的任务使用便宜的模型(如gpt-4o-mini, claude-3-haiku),对于复杂任务使用高质量模型。可以基于输入长度、任务类型或历史性能数据来动态决策。
  • 并行化: 如果图中存在独立的节点,GraphExecutor可以并行执行它们,缩短总执行时间。
  • 令牌使用量监控: 精确计算每个请求的令牌数,并设置预算或告警。

4. 安全性

  • API Key管理: 永远不要将API Key硬编码在代码中,应通过环境变量、秘密管理服务(如Vault, AWS Secrets Manager)安全地注入。
  • 输入/输出过滤: 对用户输入进行敏感信息过滤和安全检查,防止注入攻击(Prompt Injection)。对LLM输出进行过滤,避免返回不当内容。
  • 权限控制: 限制LLM访问外部工具的权限,只提供必要的工具和功能。

5. 版本控制

  • 图定义版本: 随着业务逻辑的变化,图的结构会不断演进。需要对图的定义进行版本控制,确保部署和回滚的可靠性。
  • 工具版本: 外部工具的API和行为也可能发生变化,需要管理工具的版本。

6. 扩展性

  • 新LLM供应商: 我们的LLMProvider接口设计允许轻松添加新的LLM供应商,只需实现抽象方法即可。
  • 新节点类型: 可以轻松添加新的GraphNode子类,例如ConditionalNode(根据条件分支)、LoopNode(循环执行子图)或HumanInLoopNode(需要人工确认的节点)。
  • 声明式图定义: 考虑使用YAML或JSON等声明式语言来定义图结构,而不是纯Python代码。这可以提高可读性、可维护性,并允许非开发人员(如AI工程师、产品经理)更容易地理解和修改工作流。然后,一个“图加载器”可以将这些声明式定义解析为我们的Graph对象。

展望:未来趋势

模型无关图编译代表了构建弹性、高效AI系统的未来方向。随着LLM生态系统的不断成熟和标准化,我们可以期待:

  1. 更标准化的LLM API: 供应商之间可能会趋向于采用更统一的API规范,减少我们抽象层的复杂性。
  2. 更强大的图编排工具: 出现更多开箱即用的框架和平台,提供可视化的图构建界面、内置的可观测性、以及更高级的调度和优化功能。
  3. AI与传统编程范式的融合: AI工作流将更紧密地融入到现有的软件开发生命周期中,成为应用程序不可或缺的一部分,并利用传统的工程实践(如测试、CI/CD、版本控制)。

通过拥抱模型无关的图编译范式,我们能够构建出更具适应性、更具创新潜力的AI应用,真正解锁大型语言模型的全部力量。这是一个挑战与机遇并存的领域,值得每一位编程专家深入探索和实践。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注