解析 ‘Model-agnostic Graph Compiling’:如何编写一套逻辑,使其能无缝在 GPT-4o、Claude 3.5 和 Gemini 1.5 间切换?

解析 ‘Model-agnostic Graph Compiling’:如何编写一套逻辑,使其能无缝在 GPT-4o、Claude 3.5 和 Gemini 1.5 间切换?

各位技术同仁,下午好!

今天,我们齐聚一堂,探讨一个在当前AI浪潮中至关重要的话题:如何在大型语言模型(LLM)的异构生态中,构建一套灵活、健壮且高效的系统。具体来说,我们将深入剖析“模型无关图编译”(Model-agnostic Graph Compiling)这一理念,并着重讲解如何设计并实现一套逻辑,使其能够在这三大主流模型——GPT-4o、Claude 3.5 和 Gemini 1.5 之间进行无缝切换。

随着LLM技术的飞速发展,我们正面临一个既充满机遇又充满挑战的局面。一方面,各类模型在能力、成本、延迟、甚至偏好上都展现出独特的优势;另一方面,这种多样性也给开发者带来了巨大的集成和管理负担。我们的目标,正是要跨越这些模型间的藩篱,构建一个统一的、智能的LLM应用层。

I. 引言:大型语言模型与“模型无关图编译”的时代机遇

过去几年,大型语言模型(LLM)从研究实验室的深处一跃成为改变世界的通用技术。从内容创作、代码生成到复杂的问题解决,LLM的能力边界正不断拓展。然而,这种能力的爆发也伴随着一个显著的挑战:异构性

我们现在拥有OpenAI的GPT系列(如最新的GPT-4o)、Anthropic的Claude系列(如Claude 3.5 Sonnet)、Google的Gemini系列(如Gemini 1.5 Pro),以及更多开源模型如Llama 3等。这些模型在以下方面存在显著差异:

  • API 接口与参数: 不同的模型提供商有各自的API终端、请求参数命名和结构。
  • 性能与质量: 在特定任务上,不同模型的表现可能天差地别,例如代码生成、创意写作或逻辑推理。
  • 成本: token 价格因模型、输入/输出比例而异,这直接影响应用的运行成本。
  • 延迟与吞吐量: API 响应时间和服务稳定性各有不同。
  • 功能集: 如多模态支持、工具调用(Function Calling)的实现方式、上下文窗口大小等。
  • 内容策略与安全: 各自的模型对特定内容(如敏感信息、偏见)的过滤策略不同。

这种异构性使得开发者在构建LLM应用时,往往需要针对特定模型进行深度绑定。一旦想切换模型,或者同时利用多个模型的优势,就需要进行大量的代码修改和适配工作。这无疑降低了开发效率,增加了维护成本,并限制了应用的灵活性和韧性。

“模型无关图编译” 正是为了解决这一痛点而提出的一种架构思想。它旨在提供一个抽象层,将复杂的LLM交互、任务编排和数据流转化为一个可执行的“图”。这个图的节点代表各种操作(例如LLM调用、数据转换、条件判断、工具执行),边则表示数据和控制流。而“模型无关”则意味着,图中的LLM调用节点,能够在不修改图结构和核心业务逻辑的前提下,动态地切换底层使用的LLM模型。

其核心价值在于:

  1. 灵活性与韧性: 轻松更换或组合模型,以应对模型更新、价格波动或性能变化。
  2. 成本优化与性能均衡: 根据实时需求(如成本敏感性、延迟要求),智能选择最合适的模型。
  3. 创新与实验: 降低尝试新模型或新组合的门槛,加速产品迭代。
  4. 可维护性: 将业务逻辑与底层模型细节解耦,提升代码的可读性和可维护性。

接下来,我们将逐步深入,从抽象层设计到具体实现,探讨如何将这一愿景变为现实。

II. 核心概念解析:模型无关图编译

要理解“模型无关图编译”,我们首先需要拆解这两个核心词汇。

什么是“图编译”?

在计算机科学领域,“图编译”的概念并不新鲜。传统的编译器会将源代码转换为抽象语法树(AST),然后进行各种优化,例如控制流图(CFG)和数据流图(DFG)的构建,最终生成机器码。这里的“图”是程序结构和执行逻辑的抽象表示。

在LLM的工作流中,“图”同样是一种强大的抽象工具。它将一个复杂的端到端任务分解为一系列相互依赖的离散步骤。

  • 节点(Nodes): 代表工作流中的一个具体操作或任务。
    • LLM 调用节点: 向 LLM 发送请求并接收响应。
    • 工具调用节点: 执行外部工具(如数据库查询、API 调用、代码执行)。
    • 数据转换节点: 对数据进行解析、格式化、筛选等操作。
    • 条件判断节点: 根据某个条件决定后续执行路径。
    • 并行执行节点: 多个任务可以同时进行。
    • 人工审核节点: 需要人工介入进行决策或验证。
  • 边(Edges): 代表节点之间的关系,通常是数据流或控制流。
    • 数据流: 前一个节点的输出作为后一个节点的输入。
    • 控制流: 决定了节点的执行顺序或条件分支。

示例:一个简单的内容生成与审核工作流

假设我们要构建一个系统,用于根据用户需求生成文章大纲,然后根据大纲生成文章初稿,最后进行人工审核。

这个工作流可以被表示为一个图:

  1. 节点 1: GenerateOutline (LLM 调用)
    • 输入:用户需求 (Prompt)
    • 输出:文章大纲 (JSON 格式)
    • 依赖:无
  2. 节点 2: GenerateDraft (LLM 调用)
    • 输入:文章大纲 (来自节点 1 的输出)
    • 输出:文章初稿 (Markdown 格式)
    • 依赖:GenerateOutline
  3. 节点 3: HumanReview (人工操作)
    • 输入:文章初稿 (来自节点 2 的输出)
    • 输出:审核结果 (通过/拒绝/修改意见)
    • 依赖:GenerateDraft
  4. 节点 4: PublishOrRevise (条件判断/分支)
    • 输入:审核结果 (来自节点 3 的输出)
    • 如果“通过”,则转到 Publish 节点。
    • 如果“拒绝/修改”,则转到 GenerateDraft 节点(带修改意见作为输入,形成一个循环)。
    • 依赖:HumanReview

这种图的表示方式,使得我们可以清晰地定义复杂的工作流,并将其与具体的实现细节解耦。

什么是“模型无关”?

“模型无关”是指我们的系统设计不应该与任何特定的LLM模型(如GPT-4o、Claude 3.5、Gemini 1.5)紧密耦合。这意味着:

  • API 兼容性: 我们的代码不直接调用特定模型的 SDK,而是通过一个统一的抽象接口进行交互。
  • 参数一致性: 即使底层模型有不同的参数命名或取值范围,上层应用也应使用一套标准化的参数。
  • 响应标准化: 无论哪个模型响应,其输出都应被解析为统一的内部数据结构。
  • 行为抽象: 即使模型在处理工具调用或多模态输入时有不同的约定,我们的抽象层也应提供一致的接口。

举例来说,GPT系列使用 messages 数组来构建对话历史,其中包含 rolecontent。Claude 3.5 同样使用 messages 数组,但其 role 可能略有不同(例如 user, assistant,而没有 system 角色,或者 system 角色作为独立的参数)。Gemini 1.5 也使用 parts 数组来表示内容,并有 role。工具调用方面,OpenAI 采用 function_call,而 Gemini 采用 tool_code。这些差异都需要被抽象和抹平。

为什么需要“模型无关图编译”?

  1. 应对模型快速迭代: LLM领域发展极快,新模型层出不穷。模型无关的设计允许我们快速集成新模型,而无需重写大量业务逻辑。
  2. 实现最佳性能与成本: 在特定任务上,某个模型可能更优或更经济。例如,内容总结可能Claude更擅长,代码生成可能GPT-4o更强,而低成本的文本提取可能Gemini 1.5更划算。通过动态切换,我们可以为每个任务选择“最佳模型”。
  3. 提高系统韧性: 当某个模型的API出现故障、达到速率限制或暂时不可用时,系统可以自动回退到其他可用模型,保证服务的连续性。
  4. 促进A/B测试与实验: 轻松比较不同模型在真实业务场景中的表现,收集数据,指导优化。
  5. 简化复杂工作流: 图表示法提供了一种直观、模块化的方式来设计、可视化和调试复杂的LLM应用。

III. 构建模型无关层的基石:抽象与标准化

要实现模型无关,核心在于建立一个健壮的抽象层。这个抽象层将充当我们应用程序与具体LLM提供商之间的“翻译官”和“协调员”。

1. API 抽象层设计

我们将定义一个统一的接口,所有具体的LLM提供商实现都必须遵循这个接口。这确保了上层业务逻辑始终以相同的方式与LLM交互,无论底层是哪个模型。

统一接口定义

我们可以定义一个抽象基类或协议,包含核心的LLM操作,例如:

  • generate_text(prompt: str, **kwargs) -> str: 简单的文本生成。
  • chat_completion(messages: List[Dict], **kwargs) -> Dict: 对话式生成,支持消息历史。
  • embed(texts: List[str], **kwargs) -> List[List[float]]: 文本嵌入。
  • stream_chat_completion(...): 流式对话生成。
  • supports_tool_calling() -> bool: 查询是否支持工具调用。
  • supports_multimodal() -> bool: 查询是否支持多模态输入。

请求参数标准化

不同的模型有不同的参数,例如 temperature (温度), top_p, max_tokens (最大生成token数), stop_sequences (停止序列)。我们的抽象层应该定义一套统一的参数名称,并负责将其映射到具体模型的参数。

统一参数名 GPT-4o 参数名 Claude 3.5 参数名 Gemini 1.5 参数名 描述
temperature temperature temperature temperature 控制生成文本的随机性
max_tokens max_tokens max_tokens max_output_tokens 生成文本的最大token数
top_p top_p top_p top_p 控制采样核
stop_sequences stop stop_sequences stop_sequences 遇到这些字符串时停止生成
model_name model model model 具体的模型名称
tools tools tools tools 工具定义(Function Calling)
tool_choice tool_choice tool_choice tool_config 工具选择策略

响应格式标准化

无论哪个模型返回结果,我们都应该将其解析成一个统一的内部数据结构。这对于后续的图节点处理至关重要。

例如,对于对话完成:

# 统一的响应结构
class LLMResponse:
    def __init__(self, content: str, role: str = "assistant", tool_calls: Optional[List[Dict]] = None,
                 finish_reason: Optional[str] = None, usage: Optional[Dict] = None, raw_response: Optional[Any] = None):
        self.content = content
        self.role = role
        self.tool_calls = tool_calls if tool_calls is not None else []
        self.finish_reason = finish_reason # e.g., "stop", "length", "tool_calls"
        self.usage = usage # e.g., {"input_tokens": 10, "output_tokens": 20}
        self.raw_response = raw_response # 原始API响应,用于调试

    def __str__(self):
        return f"Role: {self.role}nContent: {self.content}nTool Calls: {self.tool_calls}"

错误处理统一

不同API的错误码和异常类型不同。抽象层应捕获这些原始异常,并抛出我们自己定义的一套统一异常,例如 LLMServiceUnavailableError, LLMRateLimitExceededError, LLMInvalidRequestError 等。

class LLMServiceError(Exception):
    """Base exception for LLM service errors."""
    pass

class LLMRateLimitExceededError(LLMServiceError):
    """Raised when rate limit is exceeded."""
    pass

class LLMInvalidRequestError(LLMServiceError):
    """Raised when the LLM request is invalid."""
    pass

# ... 其他错误类型

2. 输入/输出格式标准化

Prompt 工程的挑战

尽管我们追求模型无关,但不同的模型在理解Prompt方面可能存在细微差异。例如,有些模型对 system 角色有更好的支持,有些则更倾向于将系统指令融入 user 消息。在实践中,可能需要为不同模型维护略有不同的Prompt模板,或者使用更通用的Prompt策略。

消息结构标准化

一个通用的消息结构对于对话式LLM至关重要。我们可以采用 OpenAI 风格的消息格式作为内部标准,因为它已经被广泛接受:

# 标准化消息结构
class Message:
    def __init__(self, role: str, content: Union[str, List[Dict]], name: Optional[str] = None, tool_call_id: Optional[str] = None):
        """
        Args:
            role: "system", "user", "assistant", "tool"
            content: str for text, List[Dict] for multimodal (e.g., image_url)
            name: Optional name for tool calls
            tool_call_id: Optional ID for tool responses
        """
        self.role = role
        self.content = content
        self.name = name
        self.tool_call_id = tool_call_id

    def to_openai_format(self) -> Dict:
        msg = {"role": self.role, "content": self.content}
        if self.name:
            msg["name"] = self.name
        if self.tool_call_id: # For tool responses
            msg["tool_call_id"] = self.tool_call_id
        return msg

    def to_claude_format(self) -> Dict:
        # Claude 3.5 Sonnet supports system message
        # For older Claude models, system message might need to be prepended to user message
        if self.role == "system":
            return {"role": "user", "content": f"<system_instruction>{self.content}</system_instruction>"} # Example adaptation
        return {"role": self.role, "content": self.content}

    def to_gemini_format(self) -> Dict:
        parts = []
        if isinstance(self.content, str):
            parts.append({"text": self.content})
        elif isinstance(self.content, list):
            for item in self.content:
                if item["type"] == "text":
                    parts.append({"text": item["text"]})
                elif item["type"] == "image_url":
                    parts.append({"image_url": {"url": item["image_url"]}}) # Gemini might need base64 data for image

        # Gemini's tool_code vs OpenAI's tool_calls
        # This part requires more complex mapping, see 'tool_calling' section

        return {"role": self.role, "parts": parts}

工具调用(Function Calling)的统一处理

工具调用是现代LLM的关键功能之一。然而,不同模型实现这一功能的方式有所差异。

  • OpenAI (GPT-4o): 使用 tools 参数定义函数 schema,模型返回 tool_calls 结构。
  • Anthropic (Claude 3.5): 也使用 tools 参数定义函数 schema,模型返回 tool_use 结构。
  • Google (Gemini 1.5): 使用 tools 参数定义函数 schema,模型返回 tool_code 结构,其中包含函数名和参数。

我们的抽象层需要将这些不同格式统一起来。

# 标准化工具定义
class Tool:
    def __init__(self, name: str, description: str, parameters: Dict):
        self.name = name
        self.description = description
        self.parameters = parameters # JSON Schema format

    def to_openai_format(self) -> Dict:
        return {"type": "function", "function": {"name": self.name, "description": self.description, "parameters": self.parameters}}

    def to_claude_format(self) -> Dict:
        return {"input_schema": self.parameters, "name": self.name, "description": self.description}

    def to_gemini_format(self) -> Dict:
        return {"function_declarations": [{"name": self.name, "description": self.description, "parameters": self.parameters}]}

# 标准化工具调用结果
class ToolCall:
    def __init__(self, id: str, name: str, arguments: Dict):
        self.id = id
        self.name = name
        self.arguments = arguments

当模型响应包含工具调用时,无论原始API返回的是 tool_callstool_use 还是 tool_code,我们都应该将其解析为 List[ToolCall] 这样的统一结构。

输出解析与验证

LLM的输出有时会偏离预期格式(例如,期望JSON但返回了纯文本)。我们可以利用 JSON SchemaPydantic 等工具来定义期望的输出结构,并在抽象层进行验证和解析。如果输出不符合预期,可以触发重试或错误处理流程。

3. 配置管理与凭证安全

  • 凭证管理: API 密钥不应硬编码,而应通过环境变量、秘密管理服务(如Vault, AWS Secrets Manager)或配置文件安全地管理。
  • 模型参数配置: 允许动态配置每个模型的参数(温度、top_p等),甚至可以为不同任务设置不同的默认值。
  • 环境区分: 生产、预发布、开发环境应有不同的配置。
import os
from typing import Dict, Any

class ConfigManager:
    _instance = None
    _config: Dict[str, Any] = {}

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ConfigManager, cls).__new__(cls)
            cls._instance._load_config()
        return cls._instance

    def _load_config(self):
        # Load API keys from environment variables
        self._config["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
        self._config["CLAUDE_API_KEY"] = os.getenv("CLAUDE_API_KEY")
        self._config["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY")
        self._config["GEMINI_PROJECT_ID"] = os.getenv("GEMINI_PROJECT_ID") # For Google Cloud
        self._config["GEMINI_LOCATION"] = os.getenv("GEMINI_LOCATION", "us-central1")

        # Default model parameters (can be overridden per request)
        self._config["default_llm_params"] = {
            "temperature": 0.7,
            "max_tokens": 1024,
            "top_p": 1.0,
        }

        # Model specific parameters or overrides
        self._config["model_specific_params"] = {
            "gpt-4o": {"max_tokens": 2048},
            "claude-3-5-sonnet-20240620": {"max_tokens": 1500},
            "gemini-1.5-pro-latest": {"max_output_tokens": 1500}
        }

    def get(self, key: str, default: Any = None) -> Any:
        return self._config.get(key, default)

    def set(self, key: str, value: Any):
        self._config[key] = value

# Example usage:
config = ConfigManager()
openai_key = config.get("OPENAI_API_KEY")
default_temp = config.get("default_llm_params")["temperature"]

IV. 图编译的实现:任务分解与工作流编排

有了模型无关的抽象层,我们现在可以将注意力转向如何构建和执行工作流图。

1. 节点与边:图的构建

我们将定义不同的节点类型,它们都继承自一个抽象的 BaseNode 类。每个节点负责执行一个具体的任务,并处理输入、产生输出。

from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Callable, Union

class ExecutionContext:
    """Holds shared state and results during graph execution."""
    def __init__(self):
        self.results: Dict[str, Any] = {}
        self.errors: Dict[str, Exception] = {}
        self.shared_data: Dict[str, Any] = {}

class BaseNode(ABC):
    def __init__(self, node_id: str, depends_on: Optional[List[str]] = None):
        self.node_id = node_id
        self.depends_on = depends_on if depends_on is not None else []
        self.output: Any = None
        self.error: Optional[Exception] = None

    @abstractmethod
    async def execute(self, context: ExecutionContext) -> Any:
        """Execute the node's logic."""
        pass

    def get_output(self) -> Any:
        return self.output

    def get_error(self) -> Optional[Exception]:
        return self.error

# 示例:一个LLM调用节点
class LLMNode(BaseNode):
    def __init__(self, node_id: str, llm_orchestrator: Any, prompt_template: str,
                 model_selector: Optional[Callable[[ExecutionContext], str]] = None,
                 output_parser: Optional[Callable[[str], Any]] = None,
                 depends_on: Optional[List[str]] = None,
                 llm_params: Optional[Dict] = None):
        super().__init__(node_id, depends_on)
        self.llm_orchestrator = llm_orchestrator # Reference to our LLM switching logic
        self.prompt_template = prompt_template
        self.model_selector = model_selector # Function to dynamically select model
        self.output_parser = output_parser
        self.llm_params = llm_params if llm_params is not None else {}

    async def execute(self, context: ExecutionContext) -> Any:
        try:
            # Gather inputs from dependencies
            inputs = {dep: context.results[dep] for dep in self.depends_on if dep in context.results}

            # Render prompt using template and inputs
            prompt = self.prompt_template.format(**inputs)
            messages = [Message(role="user", content=prompt)] # Simple case, more complex for chat history

            # Dynamically select model
            model_name = self.model_selector(context) if self.model_selector else None

            # Call LLM via orchestrator
            response: LLMResponse = await self.llm_orchestrator.chat_completion(
                messages=messages,
                model_name=model_name, # Orchestrator will pick default if None
                **self.llm_params
            )

            raw_content = response.content
            if self.output_parser:
                self.output = self.output_parser(raw_content)
            else:
                self.output = raw_content
            context.results[self.node_id] = self.output
            return self.output
        except Exception as e:
            self.error = e
            context.errors[self.node_id] = e
            raise

# 示例:一个工具调用节点
class ToolCallNode(BaseNode):
    def __init__(self, node_id: str, tool_func: Callable,
                 input_mapper: Callable[[ExecutionContext], Dict],
                 depends_on: Optional[List[str]] = None):
        super().__init__(node_id, depends_on)
        self.tool_func = tool_func
        self.input_mapper = input_mapper

    async def execute(self, context: ExecutionContext) -> Any:
        try:
            # Map context data to tool function arguments
            args = self.input_mapper(context)
            self.output = await self.tool_func(**args) if inspect.iscoroutinefunction(self.tool_func) else self.tool_func(**args)
            context.results[self.node_id] = self.output
            return self.output
        except Exception as e:
            self.error = e
            context.errors[self.node_id] = e
            raise

# 示例:一个条件判断节点
class ConditionalNode(BaseNode):
    def __init__(self, node_id: str, condition_func: Callable[[ExecutionContext], bool],
                 true_path_node_id: str, false_path_node_id: str,
                 depends_on: Optional[List[str]] = None):
        super().__init__(node_id, depends_on)
        self.condition_func = condition_func
        self.true_path_node_id = true_path_node_id
        self.false_path_node_id = false_path_node_id

    async def execute(self, context: ExecutionContext) -> Any:
        try:
            if self.condition_func(context):
                self.output = self.true_path_node_id
            else:
                self.output = self.false_path_node_id
            context.results[self.node_id] = self.output
            return self.output
        except Exception as e:
            self.error = e
            context.errors[self.node_id] = e
            raise

2. DAG (有向无环图) 引擎

图引擎负责管理节点的生命周期、执行顺序、并发控制和状态传递。由于我们的工作流通常是异步的,使用 asyncio 是一个自然的选择。

核心步骤:

  1. 构建图: 将所有节点添加到图中,并建立它们的依赖关系。
  2. 拓扑排序: 确定节点的执行顺序,确保所有依赖项在节点执行前完成。对于有向无环图(DAG),这可以通过 Kahn’s 算法或深度优先搜索(DFS)实现。
  3. 执行: 按照拓扑排序的顺序执行节点。支持并行执行不相互依赖的节点。
  4. 状态管理: ExecutionContext 在节点间传递结果和共享数据。
import asyncio
import collections
from typing import Dict, List, Set, Tuple

class DAGExecutor:
    def __init__(self, nodes: List[BaseNode]):
        self.nodes_map: Dict[str, BaseNode] = {node.node_id: node for node in nodes}
        self.graph: Dict[str, List[str]] = collections.defaultdict(list)
        self.in_degree: Dict[str, int] = collections.defaultdict(int)
        self._build_graph()

    def _build_graph(self):
        for node_id, node in self.nodes_map.items():
            for dep_id in node.depends_on:
                if dep_id not in self.nodes_map:
                    raise ValueError(f"Node {node_id} depends on non-existent node {dep_id}")
                self.graph[dep_id].append(node_id) # dep_id -> node_id (dep_id is a prerequisite for node_id)
                self.in_degree[node_id] += 1

        # Initialize in-degree for nodes with no dependencies
        for node_id in self.nodes_map:
            if node_id not in self.in_degree:
                self.in_degree[node_id] = 0

    async def execute(self, initial_context: Optional[ExecutionContext] = None) -> ExecutionContext:
        context = initial_context if initial_context else ExecutionContext()

        # Kahn's algorithm for topological sort and execution
        queue = collections.deque([node_id for node_id, degree in self.in_degree.items() if degree == 0])

        executed_nodes_count = 0
        total_nodes = len(self.nodes_map)

        # Keep track of currently running tasks
        running_tasks: Dict[str, asyncio.Task] = {}

        while executed_nodes_count < total_nodes:
            # Identify nodes that are ready to run (dependencies met, not yet running)
            ready_to_run_now = []
            while queue:
                node_id = queue.popleft()
                if node_id not in running_tasks:
                    ready_to_run_now.append(node_id)

            if not ready_to_run_now and not running_tasks:
                # This indicates a cycle or an issue where no nodes can run
                unexecuted_nodes = [n_id for n_id in self.nodes_map if n_id not in context.results and n_id not in context.errors and n_id not in running_tasks]
                if unexecuted_nodes:
                    raise RuntimeError(f"Deadlock detected or cycle in graph. Unexecuted nodes: {unexecuted_nodes}")
                break # All nodes executed or no more ready nodes

            # Start new tasks for ready nodes
            for node_id in ready_to_run_now:
                node = self.nodes_map[node_id]
                task = asyncio.create_task(node.execute(context), name=node_id)
                running_tasks[node_id] = task
                print(f"Starting node: {node_id}")

            if not running_tasks:
                break # Should not happen if graph is valid and not empty

            # Wait for any of the running tasks to complete
            done, pending = await asyncio.wait(
                running_tasks.values(),
                return_when=asyncio.FIRST_COMPLETED
            )

            # Process completed tasks
            for task in done:
                node_id = task.get_name()
                executed_nodes_count += 1
                del running_tasks[node_id]

                try:
                    await task # Re-raise any exceptions from the task
                    print(f"Node {node_id} completed successfully.")
                except Exception as e:
                    print(f"Node {node_id} failed with error: {e}")
                    # Decide how to handle errors: propagate, stop, or allow partial completion
                    # For simplicity, we'll continue executing other nodes but mark this one as failed.
                    self.nodes_map[node_id].error = e
                    context.errors[node_id] = e
                    # If a critical node fails, you might want to stop the whole graph.
                    # For now, we allow downstream nodes to potentially fail or handle missing deps.

                # Update in-degrees for dependent nodes
                for neighbor_id in self.graph[node_id]:
                    self.in_degree[neighbor_id] -= 1
                    if self.in_degree[neighbor_id] == 0:
                        queue.append(neighbor_id)

        if executed_nodes_count < total_nodes:
            raise RuntimeError(f"Graph execution incomplete. {total_nodes - executed_nodes_count} nodes did not execute.")

        return context

3. 示例:一个简单的内容生成与审核工作流(图编译版)

让我们将前面提到的内容生成与审核工作流用我们定义的节点和DAG引擎实现。

import json
import inspect
from pydantic import BaseModel, Field

# Mock LLM Orchestrator (will be detailed in next section)
class MockLLMOrchestrator:
    async def chat_completion(self, messages: List[Message], model_name: Optional[str] = None, **kwargs) -> LLMResponse:
        print(f"--- Mock LLM Call ({model_name or 'default'}) ---")
        print(f"Prompt: {messages[-1].content}")

        # Simulate different model behaviors
        if model_name == "claude-3-5-sonnet-20240620" and "大纲" in messages[-1].content:
            content = json.dumps({
                "title": "人工智能发展与未来",
                "sections": [
                    {"heading": "引言:AI浪潮的兴起", "points": ["定义AI", "历史回顾"]},
                    {"heading": "当前AI技术突破", "points": ["深度学习", "大模型", "多模态"]},
                    {"heading": "AI对社会的影响", "points": ["经济", "就业", "伦理"]},
                    {"heading": "AI的未来展望", "points": ["通用人工智能", "人机协作"]},
                ]
            }, ensure_ascii=False)
            return LLMResponse(content=content, finish_reason="stop", usage={"input_tokens": 50, "output_tokens": 100})
        elif "大纲" in messages[-1].content:
            content = json.dumps({
                "title": "通用LLM应用架构",
                "sections": [
                    {"heading": "引言", "points": ["LLM异构性", "模型无关需求"]},
                    {"heading": "抽象层设计", "points": ["API标准化", "参数统一"]},
                    {"heading": "图编译引擎", "points": ["节点", "边", "DAG执行"]},
                ]
            }, ensure_ascii=False)
            return LLMResponse(content=content, finish_reason="stop", usage={"input_tokens": 60, "output_tokens": 120})
        elif "文章初稿" in messages[-1].content:
            return LLMResponse(content="这是一篇关于AI发展初稿,基于提供的大纲。n...", finish_reason="stop")

        return LLMResponse(content="Mock LLM Response", finish_reason="stop")

# Define expected output schema for outline
class ArticleOutline(BaseModel):
    title: str = Field(description="The title of the article.")
    sections: List[Dict[str, Union[str, List[str]]]] = Field(description="A list of sections, each with a heading and key points.")

# Define mock tools
async def mock_human_review_tool(draft_content: str) -> Dict:
    print(f"n--- Human Review Needed ---")
    print(f"Draft Content Preview: {draft_content[:200]}...")

    # In a real system, this would involve a web UI or message queue
    # For demo, simulate user input
    review_input = input("Review result (approve/reject/revise, [comment]): ").lower().strip()

    if review_input.startswith("approve"):
        return {"status": "approved", "comment": "Looks good!"}
    elif review_input.startswith("reject"):
        comment = review_input.split(":", 1)[1].strip() if ":" in review_input else "Content is not suitable."
        return {"status": "rejected", "comment": comment}
    else: # Default to revise
        comment = review_input.split(":", 1)[1].strip() if ":" in review_input else "Please elaborate on section 2."
        return {"status": "revise", "comment": comment}

# Graph Construction
async def run_article_workflow(llm_orchestrator: Any, user_request: str):
    nodes: List[BaseNode] = []

    # Node 1: Generate Outline
    def outline_parser(json_string: str) -> ArticleOutline:
        return ArticleOutline.model_validate_json(json_string)

    nodes.append(LLMNode(
        node_id="generate_outline",
        llm_orchestrator=llm_orchestrator,
        prompt_template="请根据以下需求,生成一篇结构清晰的文章大纲(JSON格式),包含标题和多个带有要点的章节:n需求:{user_request}",
        model_selector=lambda ctx: "claude-3-5-sonnet-20240620", # Force Claude for outline
        output_parser=outline_parser,
        llm_params={"response_format": {"type": "json_object"}}, # Hint for LLM
        depends_on=[],
    ))

    # Node 2: Generate Draft
    nodes.append(LLMNode(
        node_id="generate_draft",
        llm_orchestrator=llm_orchestrator,
        prompt_template="请根据以下文章大纲,撰写一篇详细的文章初稿:n大纲:{generate_outline}n{revision_comment}",
        model_selector=lambda ctx: "gpt-4o", # Force GPT-4o for draft
        llm_params={"temperature": 0.8},
        depends_on=["generate_outline"],
    ))

    # Node 3: Human Review
    nodes.append(ToolCallNode(
        node_id="human_review",
        tool_func=mock_human_review_tool,
        input_mapper=lambda ctx: {"draft_content": ctx.results["generate_draft"]},
        depends_on=["generate_draft"],
    ))

    # Node 4: Conditional Re-draft or Finalize
    class ReviewDecisionNode(BaseNode):
        def __init__(self, node_id: str, depends_on: Optional[List[str]] = None):
            super().__init__(node_id, depends_on)

        async def execute(self, context: ExecutionContext) -> Any:
            review_result = context.results["human_review"]
            if review_result["status"] == "approved":
                self.output = "finalize" # This would trigger a 'publish' node in a real system
                context.shared_data["revision_comment"] = ""
                print(f"Article Approved!")
            elif review_result["status"] == "revise":
                self.output = "redraft"
                context.shared_data["revision_comment"] = f"根据以下修改意见重新撰写:{review_result['comment']}"
                print(f"Article needs revision: {review_result['comment']}")
            else: # Rejected
                self.output = "rejected"
                context.shared_data["revision_comment"] = ""
                print(f"Article Rejected: {review_result['comment']}")
            context.results[self.node_id] = self.output
            return self.output

    nodes.append(ReviewDecisionNode(node_id="review_decision", depends_on=["human_review"]))

    # Add a pseudo-node for revision loop (handled by re-running part of the graph or a separate loop)
    # For a simple DAG, we might re-trigger the draft node with revision comments.
    # In a full-fledged workflow engine, this would be a feedback loop.
    # For demonstration, we'll manually inject revision comment into context and rerun if needed.

    executor = DAGExecutor(nodes)

    # Initial context with user request
    initial_context = ExecutionContext()
    initial_context.shared_data["user_request"] = user_request
    initial_context.shared_data["revision_comment"] = ""

    # Execute the graph
    print("n--- Starting Article Workflow ---")
    try:
        final_context = await executor.execute(initial_context)
        print("n--- Workflow Completed ---")
        if "generate_draft" in final_context.results:
            print(f"Final Draft (partial): {final_context.results['generate_draft'][:500]}...")
        if "generate_outline" in final_context.results:
            print(f"Outline: {final_context.results['generate_outline'].model_dump_json(indent=2)}")

    except Exception as e:
        print(f"Workflow failed: {e}")
        for node_id, error in final_context.errors.items():
            print(f"Error in node {node_id}: {error}")

# To run this:
# asyncio.run(run_article_workflow(MockLLMOrchestrator(), "一篇关于AI在医疗领域应用的文章"))

(由于DAGExecutor的循环处理和条件判断节点的复杂性,上述DAGExecutor示例未能完全演示条件循环,通常需要更高级的DAG框架或在外部循环中根据review_decision的结果重新触发部分图的执行。)

V. 无缝切换的策略与实践

现在我们已经有了模型无关的抽象接口和图编译引擎,核心问题是如何在GPT-4o、Claude 3.5和Gemini 1.5之间实现无缝切换。这需要一个智能的LLM编排器(Orchestrator)

1. 动态模型选择器

LLM编排器内部应该包含一个策略引擎,用于根据多种因素动态选择最佳模型。

选择策略:

  • 成本优先: 选择当前token价格最低的模型。
  • 延迟优先: 选择响应速度最快的模型。
  • 功能优先: 例如,如果任务需要图像识别,则选择支持多模态输入的模型(如GPT-4o、Gemini 1.5 Pro)。如果需要强大的代码生成能力,可能优先选择GPT-4o或Gemini 1.5 Pro。
  • 可靠性优先: 选择历史SLA表现最好的模型。
  • Token 限制: 如果输入上下文很长,选择上下文窗口大的模型(如Gemini 1.5 Pro的百万token)。
  • A/B 测试: 按一定比例随机分配流量到不同模型,进行效果评估。
  • 用户偏好/配置: 允许最终用户或管理员指定偏好模型。

健康检查与回退机制 (Circuit Breaker)

当某个模型提供商的API出现故障、响应时间过长或频繁返回错误时,编排器应能自动将其从可用列表中移除一段时间(熔断),并切换到备用模型。一段时间后,可以尝试“半开”状态,少量请求尝试恢复服务。

负载均衡与限流

如果同时使用多个模型,可以根据每个模型的配额和性能进行负载均衡。同时,编排器也应该实现全局的速率限制,以避免超出任何一个模型提供商的API配额。

2. 模型特定的优化与适配

尽管我们追求模型无关,但在某些情况下,为特定模型进行微调可以显著提升性能或利用其独特功能。

  • Prompt 模板差异化: 虽然有通用的消息结构,但某些模型可能对 system 消息的放置、指令的措辞有不同偏好。编排器可以根据选定的模型加载不同的Prompt模板。
  • 特定功能映射:
    • 多模态: GPT-4o和Gemini 1.5支持图像输入。当检测到 Message 中包含图像时,编排器需要将统一的图像表示(如URL或Base64)转换为对应API的格式。
    • 工具调用: 如前所述,需要将统一的 Tool 定义和 ToolCall 解析逻辑映射到不同模型的 tools 参数和响应结构。
  • 响应解析的微调: 虽然我们有标准化的 LLMResponse,但某些模型在特定请求下的响应可能需要更细致的解析逻辑。

3. 代码实践:Python 实现一个简化的模型无关层

我们将构建 AbstractLLMProvider 接口和其具体实现,以及 LLMOrchestrator 来实现动态切换。

import os
import asyncio
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union, Callable
import httpx # For async HTTP requests
import json
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type

# Assume Message, LLMResponse, Tool, ToolCall, LLMServiceError, etc. are defined as above
# For brevity, let's redefine Message and LLMResponse simpler here if not already imported

class Message:
    def __init__(self, role: str, content: Union[str, List[Dict]], name: Optional[str] = None, tool_call_id: Optional[str] = None):
        self.role = role
        self.content = content
        self.name = name
        self.tool_call_id = tool_call_id

    def to_dict(self):
        d = {"role": self.role, "content": self.content}
        if self.name: d["name"] = self.name
        if self.tool_call_id: d["tool_call_id"] = self.tool_call_id
        return d

class LLMResponse:
    def __init__(self, content: str, role: str = "assistant", tool_calls: Optional[List[Dict]] = None,
                 finish_reason: Optional[str] = None, usage: Optional[Dict] = None, raw_response: Optional[Any] = None):
        self.content = content
        self.role = role
        self.tool_calls = tool_calls if tool_calls is not None else []
        self.finish_reason = finish_reason
        self.usage = usage
        self.raw_response = raw_response

    def __str__(self):
        return f"Role: {self.role}nContent: {self.content}nTool Calls: {self.tool_calls}"

# --- Abstract LLM Provider ---
class AbstractLLMProvider(ABC):
    def __init__(self, model_name: str, api_key: str, base_url: str):
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url
        self.client = httpx.AsyncClient()

    @abstractmethod
    async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
                              tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
        pass

    @abstractmethod
    def supports_tool_calling(self) -> bool:
        pass

    @abstractmethod
    def supports_multimodal(self) -> bool:
        pass

    @abstractmethod
    def get_cost_per_token(self) -> Dict[str, float]:
        """Returns input/output token cost."""
        pass

    def _map_messages_to_provider_format(self, messages: List[Message]) -> List[Dict]:
        """Generic mapping, specific providers might override."""
        return [msg.to_dict() for msg in messages]

    def _map_tools_to_provider_format(self, tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
        """Generic mapping, specific providers might override."""
        return tools # Assuming tools are already in a somewhat standardized format

    def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
        """Generic parsing, specific providers must override."""
        raise NotImplementedError

# --- OpenAI Provider ---
class OpenAIProvider(AbstractLLMProvider):
    def __init__(self, model_name: str):
        super().__init__(model_name, os.getenv("OPENAI_API_KEY"), "https://api.openai.com/v1")
        self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

    @retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
           retry=retry_if_exception_type(httpx.HTTPStatusError))
    async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
                              tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
        payload = {
            "model": self.model_name,
            "messages": self._map_messages_to_provider_format(messages),
            **kwargs
        }
        if tools:
            payload["tools"] = tools # Tools are already in OpenAI format (from Tool.to_openai_format)
        if tool_choice:
            payload["tool_choice"] = tool_choice

        try:
            response = await self.client.post(f"{self.base_url}/chat/completions", headers=self.headers, json=payload, timeout=60.0)
            response.raise_for_status()
            return self._parse_provider_response(response.json())
        except httpx.HTTPStatusError as e:
            if e.response.status_code == 429:
                raise LLMRateLimitExceededError(f"OpenAI Rate Limit Exceeded: {e.response.text}") from e
            if e.response.status_code >= 500:
                raise LLMServiceUnavailableError(f"OpenAI Service Unavailable: {e.response.text}") from e
            raise LLMInvalidRequestError(f"OpenAI API Error: {e.response.status_code} - {e.response.text}") from e
        except Exception as e:
            raise LLMServiceError(f"OpenAI Unknown Error: {e}") from e

    def _map_messages_to_provider_format(self, messages: List[Message]) -> List[Dict]:
        openai_messages = []
        for msg in messages:
            content_val = msg.content
            if isinstance(content_val, list): # Multimodal content
                openai_content_parts = []
                for part in content_val:
                    if part["type"] == "text":
                        openai_content_parts.append({"type": "text", "text": part["text"]})
                    elif part["type"] == "image_url":
                        openai_content_parts.append({"type": "image_url", "image_url": {"url": part["image_url"]}})
                content_val = openai_content_parts

            if msg.role == "tool": # OpenAI uses 'tool' role for function call results
                openai_messages.append({"role": msg.role, "content": content_val, "tool_call_id": msg.tool_call_id})
            elif msg.role == "assistant" and msg.tool_call_id: # Assistant message potentially with tool calls
                 openai_messages.append({"role": "assistant", "content": content_val, "tool_calls": msg.tool_call_id}) # This needs careful mapping
            else:
                openai_messages.append({"role": msg.role, "content": content_val})
        return openai_messages

    def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
        choice = raw_response["choices"][0]
        message = choice["message"]
        content = message.get("content", "")
        tool_calls = message.get("tool_calls", [])
        finish_reason = choice["finish_reason"]
        usage = raw_response.get("usage", {})

        # Standardize tool calls for our internal ToolCall class
        standardized_tool_calls = []
        for tc in tool_calls:
            if tc["type"] == "function":
                standardized_tool_calls.append({
                    "id": tc["id"],
                    "name": tc["function"]["name"],
                    "arguments": json.loads(tc["function"]["arguments"])
                })

        return LLMResponse(
            content=content,
            role="assistant",
            tool_calls=standardized_tool_calls,
            finish_reason=finish_reason,
            usage={"input_tokens": usage.get("prompt_tokens"), "output_tokens": usage.get("completion_tokens")},
            raw_response=raw_response
        )

    def supports_tool_calling(self) -> bool:
        return True # GPT-4o supports tool calling

    def supports_multimodal(self) -> bool:
        return True # GPT-4o supports multimodal

    def get_cost_per_token(self) -> Dict[str, float]:
        # Placeholder values, replace with actual current prices
        if "gpt-4o" in self.model_name:
            return {"input": 5.0 / 1_000_000, "output": 15.0 / 1_000_000}
        return {"input": 0.0, "output": 0.0}

# --- Claude Provider ---
class ClaudeProvider(AbstractLLMProvider):
    def __init__(self, model_name: str):
        super().__init__(model_name, os.getenv("CLAUDE_API_KEY"), "https://api.anthropic.com/v1")
        self.headers = {
            "x-api-key": self.api_key,
            "anthropic-version": "2023-06-01", # Or latest "2023-06-01" / "2023-06-01-preview"
            "content-type": "application/json"
        }

    @retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
           retry=retry_if_exception_type(httpx.HTTPStatusError))
    async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
                              tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:

        system_message = ""
        claude_messages = []
        for msg in messages:
            if msg.role == "system":
                system_message = msg.content
            else:
                claude_messages.append({"role": msg.role, "content": msg.content}) # Claude 3.5 Sonnet accepts 'tool_use' and 'tool_result'

        payload = {
            "model": self.model_name,
            "messages": claude_messages,
            "max_tokens": kwargs.pop("max_tokens", 1024),
            **kwargs
        }
        if system_message:
            payload["system"] = system_message
        if tools:
            # Claude tool format uses "input_schema" instead of "parameters" directly
            claude_tools = []
            for tool_def in tools:
                claude_tools.append({
                    "name": tool_def["function"]["name"],
                    "description": tool_def["function"]["description"],
                    "input_schema": tool_def["function"]["parameters"]
                })
            payload["tools"] = claude_tools
        if tool_choice:
            # Claude tool_choice mapping might be different. E.g., for specific tool:
            # {"type": "tool", "name": "my_tool_name"}
            payload["tool_choice"] = tool_choice # Assuming it's already mapped

        try:
            response = await self.client.post(f"{self.base_url}/messages", headers=self.headers, json=payload, timeout=60.0)
            response.raise_for_status()
            return self._parse_provider_response(response.json())
        except httpx.HTTPStatusError as e:
            if e.response.status_code == 429:
                raise LLMRateLimitExceededError(f"Claude Rate Limit Exceeded: {e.response.text}") from e
            if e.response.status_code >= 500:
                raise LLMServiceUnavailableError(f"Claude Service Unavailable: {e.response.text}") from e
            raise LLMInvalidRequestError(f"Claude API Error: {e.response.status_code} - {e.response.text}") from e
        except Exception as e:
            raise LLMServiceError(f"Claude Unknown Error: {e}") from e

    def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
        content_parts = raw_response["content"]
        text_content = ""
        tool_calls = []

        for part in content_parts:
            if part["type"] == "text":
                text_content += part["text"]
            elif part["type"] == "tool_use":
                tool_calls.append({
                    "id": part["id"],
                    "name": part["name"],
                    "arguments": part["input"] # Claude uses 'input' for arguments
                })

        usage = raw_response.get("usage", {})

        return LLMResponse(
            content=text_content,
            role="assistant",
            tool_calls=tool_calls,
            finish_reason=raw_response.get("stop_reason"),
            usage={"input_tokens": usage.get("input_tokens"), "output_tokens": usage.get("output_tokens")},
            raw_response=raw_response
        )

    def supports_tool_calling(self) -> bool:
        return True # Claude 3.5 Sonnet supports tool calling

    def supports_multimodal(self) -> bool:
        return True # Claude 3.5 Sonnet supports multimodal (text and images)

    def get_cost_per_token(self) -> Dict[str, float]:
        # Placeholder values, replace with actual current prices
        if "claude-3-5-sonnet" in self.model_name:
            return {"input": 3.0 / 1_000_000, "output": 15.0 / 1_000_000}
        return {"input": 0.0, "output": 0.0}

# --- Gemini Provider ---
# Requires google-generativeai library
import google.generativeai as genai
from google.generativeai.types import Tool as GeminiTool, FunctionDeclaration
from google.api_core.exceptions import ResourceExhausted, GoogleAPIError

class GeminiProvider(AbstractLLMProvider):
    def __init__(self, model_name: str):
        # Gemini API doesn't use a single base_url for client like OpenAI/Claude
        # It's configured via genai.configure
        super().__init__(model_name, os.getenv("GEMINI_API_KEY"), "https://generativelanguage.googleapis.com/")
        genai.configure(api_key=self.api_key)
        self.model = genai.GenerativeModel(self.model_name)

    @retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
           retry=retry_if_exception_type(GoogleAPIError))
    async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
                              tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
        gemini_contents = []
        for msg in messages:
            parts = []
            if isinstance(msg.content, str):
                parts.append(msg.content)
            elif isinstance(msg.content, list): # Multimodal
                for part in msg.content:
                    if part["type"] == "text":
                        parts.append(part["text"])
                    elif part["type"] == "image_url":
                        # Gemini requires base64 encoded image data directly, not URL
                        # This is a significant difference and needs careful handling.
                        # For simplicity here, we'll just pass a placeholder or raise an error.
                        # In a real system, you'd fetch and encode the image.
                        print("Warning: Gemini requires base64 image data, not URLs. Placeholder used.")
                        parts.append(genai.upload_file(path="/path/to/placeholder.jpg").set_display_name("placeholder_image"))

            # Map roles: 'system' is typically not a distinct role in Gemini conversation history
            # For system-like instructions, it's often part of the first user message or model config.
            # We'll map 'system' to 'user' for simplicity if no specific system_instruction param exists.
            role = "user" if msg.role == "system" else msg.role
            gemini_contents.append({"role": role, "parts": parts})

        gemini_tools = None
        if tools:
            gemini_tools = [GeminiTool(function_declarations=[
                FunctionDeclaration(
                    name=tool_def["function"]["name"],
                    description=tool_def["function"]["description"],
                    parameters=tool_def["function"]["parameters"]
                )
            ]) for tool_def in tools]

        generation_config = genai.GenerationConfig(
            temperature=kwargs.pop("temperature", 0.7),
            max_output_tokens=kwargs.pop("max_tokens", 1024),
            top_p=kwargs.pop("top_p", 1.0),
            stop_sequences=kwargs.pop("stop_sequences", None)
        )

        try:
            # Gemini's chat history is managed by its 'start_chat' method
            # For a single turn, we use generate_content
            response = await self.model.generate_content_async(
                contents=gemini_contents,
                tools=gemini_tools,
                tool_config=tool_choice, # Gemini tool_config has a specific format
                generation_config=generation_config,
                # For streaming, use stream=True and iterate
            )
            return self._parse_provider_response(response)
        except ResourceExhausted as e:
            raise LLMRateLimitExceededError(f"Gemini Rate Limit Exceeded: {e}") from e
        except GoogleAPIError as e:
            if "quota" in str(e).lower():
                raise LLMRateLimitExceededError(f"Gemini Quota Exceeded: {e}") from e
            raise LLMInvalidRequestError(f"Gemini API Error: {e}") from e
        except Exception as e:
            raise LLMServiceError(f"Gemini Unknown Error: {e}") from e

    def _parse_provider_response(self, raw_response: Any) -> LLMResponse:
        candidates = raw_response.candidates
        if not candidates:
            return LLMResponse(content="", finish_reason="no_candidates", raw_response=raw_response)

        first_candidate = candidates[0]
        text_content = ""
        tool_calls = []

        for part in first_candidate.content.parts:
            if hasattr(part, "text"):
                text_content += part.text
            elif hasattr(part, "function_call"):
                fc = part.function_call
                tool_calls.append({
                    "id": f"gemini_tool_call_{fc.name}", # Gemini doesn't always provide an ID, create one
                    "name": fc.name,
                    "arguments": {key: value for key, value in fc.args.items()}
                })

        usage_metadata = raw_response.usage_metadata

        return LLMResponse(
            content=text_content,
            role="assistant",
            tool_calls=tool_calls,
            finish_reason=raw_response.prompt_feedback.block_reason if raw_response.prompt_feedback else "stop", # Simplified
            usage={"input_tokens": usage_metadata.prompt_token_count, "output_tokens": usage_metadata.candidates_token_count} if usage_metadata else None,
            raw_response=raw_response
        )

    def supports_tool_calling(self) -> bool:
        return True # Gemini 1.5 Pro supports tool calling

    def supports_multimodal(self) -> bool:
        return True # Gemini 1.5 Pro supports multimodal (text and images)

    def get_cost_per_token(self) -> Dict[str, float]:
        # Placeholder values, replace with actual current prices
        if "gemini-1.5-pro" in self.model_name:
            return {"input": 7.0 / 1_000_000, "output": 21.0 / 1_000_000}
        return {"input": 0.0, "output": 0.0}

# --- LLM Orchestrator ---
class LLMOrchestrator:
    def __init__(self, default_model: str = "gpt-4o", config: Optional[ConfigManager] = None):
        self.config = config if config else ConfigManager()
        self.providers: Dict[str, AbstractLLMProvider] = {
            "gpt-4o": OpenAIProvider("gpt-4o"),
            "claude-3-5-sonnet-20240620": ClaudeProvider("claude-3-5-sonnet-20240620"),
            "gemini-1.5-pro-latest": GeminiProvider("gemini-1.5-pro-latest"),
            # Add other models as needed
        }
        self.available_models = list(self.providers.keys())
        self.default_model = default_model
        self.circuit_breakers: Dict[str, bool] = {model_name: False for model_name in self.available_models}
        self.circuit_breaker_reset_time: Dict[str, float] = {model_name: 0.0 for model_name in self.available_models}

    async def _select_model(self, selection_strategy: Optional[Callable[[List[str], Dict[str, AbstractLLMProvider]], str]] = None,
                            preferred_model: Optional[str] = None) -> str:

        if preferred_model and preferred_model in self.available_models and not self.circuit_breakers[preferred_model]:
            return preferred_model

        active_models = [
            model_name for model_name in self.available_models 
            if not self.circuit_breakers[model_name] or self.circuit_breaker_reset_time[model_name] < asyncio.get_event_loop().time()
        ]

        # Reset circuit breaker if time elapsed
        for model_name in list(active_models):
            if self.circuit_breakers[model_name] and self.circuit_breaker_reset_time[model_name] < asyncio.get_event_loop().time():
                print(f"Attempting to reset circuit breaker for {model_name}")
                self.circuit_breakers[model_name] = False # Try to re-enable

        active_models = [
            model_name for model_name in self.available_models 
            if not self.circuit_breakers[model_name]
        ]

        if not active_models:
            raise LLMServiceUnavailableError("No LLM models are currently available.")

        if selection_strategy:
            return selection_strategy(active_models, self.providers)

        # Default strategy: try preferred, then default, then first available
        if self.default_model in active_models:
            return self.default_model
        return active_models[0] # Fallback to first available

    def _trigger_circuit_breaker(self, model_name: str, duration_seconds: int = 60):
        self.circuit_breakers[model_name] = True
        self.circuit_breaker_reset_time[model_name] = asyncio.get_event_loop().time() + duration_seconds
        print(f"Circuit breaker tripped for {model_name}. Will retry in {duration_seconds} seconds.")

    async def chat_completion(self, messages: List[Message], model_name: Optional[str] = None,
                              tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None,
                              selection_strategy: Optional[Callable[[List[str], Dict[str, AbstractLLMProvider]], str]] = None,
                              **kwargs) -> LLMResponse:

        selected_model_name = await self._select_model(selection_strategy, model_name)
        provider = self.providers[selected_model_name]

        # Apply model-specific parameter overrides from config
        llm_params = self.config.get("default_llm_params", {}).copy()
        llm_params.update(self.config.get("model_specific_params", {}).get(selected_model_name, {}))
        llm_params.update(kwargs) # Request-specific params override all

        # Map tools to provider-specific format
        provider_tools = None
        if tools:
            # Assume tools are passed in OpenAI's function schema format, then converted by provider
            if isinstance(provider, OpenAIProvider):
                provider_tools = tools # Already in OpenAI format
            elif isinstance(provider, ClaudeProvider):
                provider_tools = []
                for tool_def in tools:
                    provider_tools.append({
                        "name": tool_def["function"]["name"],
                        "description": tool_def["function"]["description"],
                        "input_schema": tool_def["function"]["parameters"]
                    })
            elif isinstance(provider, GeminiProvider):
                provider_tools = [GeminiTool(function_declarations=[
                    FunctionDeclaration(
                        name=tool_def["function"]["name"],
                        description=tool_def["function"]["description"],
                        parameters=tool_def["function"]["parameters"]
                    )
                ]) for tool_def in tools]

        try:
            response = await provider.chat_completion(messages=messages, tools=provider_tools, tool_choice=tool_choice, **llm_params)

            # Log cost for this call
            cost_info = provider.get_cost_per_token()
            input_cost = response.usage["input_tokens"] * cost_info["input"]
            output_cost = response.usage["output_tokens"] * cost_info["output"]
            print(f"LLM Call: Model={selected_model_name}, Input Tokens={response.usage['input_tokens']}, Output Tokens={response.usage['output_tokens']}, Cost=${input_cost+output_cost:.4f}")

            return response
        except LLMRateLimitExceededError as e:
            print(f"Rate limit for {selected_model_name}. Trying fallback if available.")
            self._trigger_circuit_breaker(selected_model_name)
            # Recursively try again with a different model if possible
            return await self.chat_completion(messages, model_name=None, tools=tools, tool_choice=tool_choice, selection_strategy=selection_strategy, **kwargs)
        except LLMServiceUnavailableError as e:
            print(f"Service unavailable for {selected_model_name}. Trying fallback if available.")
            self._trigger_circuit_breaker(selected_model_name, duration_seconds=300) # Longer timeout for service unavailability
            return await self.chat_completion(messages, model_name=None, tools=tools, tool_choice=tool_choice, selection_strategy=selection_strategy, **kwargs)
        except Exception as e:
            print(f"Error with {selected_model_name}: {e}")
            raise # Re-raise if no fallback or persistent error

# Example usage (needs actual API keys in env vars)
# async def main():
#     # Mock API keys for demonstration
#     os.environ["OPENAI_API_KEY"] = "sk-..."
#     os.environ["CLAUDE_API_KEY"] = "sk-ant-..."
#     os.environ["GEMINI_API_KEY"] = "AIza..."

#     orchestrator = LLMOrchestrator(default_model="gpt-4o")

#     # Simple text generation with default model
#     print("--- Default Model Call (GPT-4o) ---")
#     response1 = await orchestrator.chat_completion(messages=[Message(role="user", content="Hello, what is the capital of France?")])
#     print(f"Response 1: {response1.content}")

#     # Switch to Claude
#     print("n--- Claude Call ---")
#     response2 = await orchestrator.chat_completion(messages=[Message(role="user", content="Tell me a short story about a brave knight.")], model_name="claude-3-5-sonnet-20240620")
#     print(f"Response 2: {response2.content}")

#     # Switch to Gemini
#     print("n--- Gemini Call ---")
#     response3 = await orchestrator.chat_completion(messages=[Message(role="user", content="What's a good recipe for chocolate chip cookies?")], model_name="gemini-1.5-pro-latest")
#     print(f"Response 3: {response3.content}")

#     # Example with a custom selection strategy (e.g., cost-optimized)
#     def cost_optimized_selector(available_models: List[str], providers_map: Dict[str, AbstractLLMProvider]) -> str:
#         min_cost_model = None
#         min_total_cost = float('inf')
#         # Simulate a simple prompt to estimate output tokens for cost calculation
#         simulated_input_tokens = 10
#         simulated_output_tokens = 50 
#         for model_name in available_models:
#             provider = providers_map[model_name]
#             costs = provider.get_cost_per_token()
#             current_total_cost = (costs.get("input", 0) * simulated_input_tokens) + 
#                                  (costs.get("output", 0) * simulated_output_tokens)
#             if current_total_cost < min_total_cost:
#                 min_total_cost = current_total_cost
#                 min_cost_model = model_name
#         print(f"Cost-optimized selector chose: {min_cost_model}")
#         return min_cost_model or orchestrator.default_model

#     print("n--- Cost-Optimized Call ---")
#     response4 = await orchestrator.chat_completion(
#         messages=[Message(role="user", content="Summarize the key events of World War II.")],
#         selection_strategy=cost_optimized_selector
#     )
#     print(f"Response 4: {response4.content}")

# if __name__ == "__main__":
#     asyncio.run(main())

VI. 错误处理、监控与韧性设计

一个健壮的模型无关系统不仅仅是实现功能,更要能优雅地处理异常和故障。

1. 统一错误码与异常处理

如前所述,将底层LLM提供商的各种错误映射到我们自己的统一异常体系中。这使得上层业务逻辑可以用一致的方式捕获和处理错误,而无需关心具体是哪个模型出了问题。

2. 重试机制与指数退避

网络瞬时故障、API限流是常见问题。对于这些可重试的错误(如HTTP 429, 5xx),我们应该实现自动重试,并采用指数退避策略(即每次重试间隔时间逐渐增长),避免对服务造成更大压力。tenacity 库在Python中是实现这一功能的利器,我在上面的代码示例中已经集成。

3. 日志与可观测性

  • 请求/响应日志: 记录每次LLM调用的完整请求(脱敏后)和响应,对于调试和审计至关重要。
  • 性能指标: 跟踪每个模型的延迟、吞吐量、成功率。这有助于我们评估模型选择策略的有效性。
  • 成本追踪: 记录每次调用的token使用量和预估成本,提供细粒度的成本分析,并帮助优化模型选择。
  • 错误率: 监控每个模型的错误率,作为触发熔断的依据。
import logging
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# In LLMOrchestrator's chat_completion method:
# ... after getting response:
# logging.info(f"LLM Call Success - Model: {selected_model_name}, Input: {response.usage['input_tokens']} tokens, Output: {response.usage['output_tokens']} tokens, Cost: ${input_cost+output_cost:.4f}")
# ... in exception handlers:
# logging.error(f"LLM Call Failed - Model: {selected_model_name}, Error: {e}", exc_info=True)

4. 缓存策略

对于重复的或确定性高的LLM请求,可以使用缓存来:

  • 减少成本: 避免重复调用付费API。
  • 提升性能: 立即返回结果,无需等待LLM响应。
  • 降低负载: 减少对LLM提供商的请求量。

缓存可以是简单的内存缓存、Redis或数据库。需要注意的是,LLM的输出有时具有随机性(受 temperature 等参数影响),因此缓存策略需要根据具体业务场景进行设计。例如,只缓存 temperature=0 的确定性请求。

VII. 进阶考量与未来展望

1. 多模态模型支持

随着GPT

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注