解析 ‘Model-agnostic Graph Compiling’:如何编写一套逻辑,使其能无缝在 GPT-4o、Claude 3.5 和 Gemini 1.5 间切换?
各位技术同仁,下午好!
今天,我们齐聚一堂,探讨一个在当前AI浪潮中至关重要的话题:如何在大型语言模型(LLM)的异构生态中,构建一套灵活、健壮且高效的系统。具体来说,我们将深入剖析“模型无关图编译”(Model-agnostic Graph Compiling)这一理念,并着重讲解如何设计并实现一套逻辑,使其能够在这三大主流模型——GPT-4o、Claude 3.5 和 Gemini 1.5 之间进行无缝切换。
随着LLM技术的飞速发展,我们正面临一个既充满机遇又充满挑战的局面。一方面,各类模型在能力、成本、延迟、甚至偏好上都展现出独特的优势;另一方面,这种多样性也给开发者带来了巨大的集成和管理负担。我们的目标,正是要跨越这些模型间的藩篱,构建一个统一的、智能的LLM应用层。
I. 引言:大型语言模型与“模型无关图编译”的时代机遇
过去几年,大型语言模型(LLM)从研究实验室的深处一跃成为改变世界的通用技术。从内容创作、代码生成到复杂的问题解决,LLM的能力边界正不断拓展。然而,这种能力的爆发也伴随着一个显著的挑战:异构性。
我们现在拥有OpenAI的GPT系列(如最新的GPT-4o)、Anthropic的Claude系列(如Claude 3.5 Sonnet)、Google的Gemini系列(如Gemini 1.5 Pro),以及更多开源模型如Llama 3等。这些模型在以下方面存在显著差异:
- API 接口与参数: 不同的模型提供商有各自的API终端、请求参数命名和结构。
- 性能与质量: 在特定任务上,不同模型的表现可能天差地别,例如代码生成、创意写作或逻辑推理。
- 成本: token 价格因模型、输入/输出比例而异,这直接影响应用的运行成本。
- 延迟与吞吐量: API 响应时间和服务稳定性各有不同。
- 功能集: 如多模态支持、工具调用(Function Calling)的实现方式、上下文窗口大小等。
- 内容策略与安全: 各自的模型对特定内容(如敏感信息、偏见)的过滤策略不同。
这种异构性使得开发者在构建LLM应用时,往往需要针对特定模型进行深度绑定。一旦想切换模型,或者同时利用多个模型的优势,就需要进行大量的代码修改和适配工作。这无疑降低了开发效率,增加了维护成本,并限制了应用的灵活性和韧性。
“模型无关图编译” 正是为了解决这一痛点而提出的一种架构思想。它旨在提供一个抽象层,将复杂的LLM交互、任务编排和数据流转化为一个可执行的“图”。这个图的节点代表各种操作(例如LLM调用、数据转换、条件判断、工具执行),边则表示数据和控制流。而“模型无关”则意味着,图中的LLM调用节点,能够在不修改图结构和核心业务逻辑的前提下,动态地切换底层使用的LLM模型。
其核心价值在于:
- 灵活性与韧性: 轻松更换或组合模型,以应对模型更新、价格波动或性能变化。
- 成本优化与性能均衡: 根据实时需求(如成本敏感性、延迟要求),智能选择最合适的模型。
- 创新与实验: 降低尝试新模型或新组合的门槛,加速产品迭代。
- 可维护性: 将业务逻辑与底层模型细节解耦,提升代码的可读性和可维护性。
接下来,我们将逐步深入,从抽象层设计到具体实现,探讨如何将这一愿景变为现实。
II. 核心概念解析:模型无关图编译
要理解“模型无关图编译”,我们首先需要拆解这两个核心词汇。
什么是“图编译”?
在计算机科学领域,“图编译”的概念并不新鲜。传统的编译器会将源代码转换为抽象语法树(AST),然后进行各种优化,例如控制流图(CFG)和数据流图(DFG)的构建,最终生成机器码。这里的“图”是程序结构和执行逻辑的抽象表示。
在LLM的工作流中,“图”同样是一种强大的抽象工具。它将一个复杂的端到端任务分解为一系列相互依赖的离散步骤。
- 节点(Nodes): 代表工作流中的一个具体操作或任务。
- LLM 调用节点: 向 LLM 发送请求并接收响应。
- 工具调用节点: 执行外部工具(如数据库查询、API 调用、代码执行)。
- 数据转换节点: 对数据进行解析、格式化、筛选等操作。
- 条件判断节点: 根据某个条件决定后续执行路径。
- 并行执行节点: 多个任务可以同时进行。
- 人工审核节点: 需要人工介入进行决策或验证。
- 边(Edges): 代表节点之间的关系,通常是数据流或控制流。
- 数据流: 前一个节点的输出作为后一个节点的输入。
- 控制流: 决定了节点的执行顺序或条件分支。
示例:一个简单的内容生成与审核工作流
假设我们要构建一个系统,用于根据用户需求生成文章大纲,然后根据大纲生成文章初稿,最后进行人工审核。
这个工作流可以被表示为一个图:
- 节点 1:
GenerateOutline(LLM 调用)- 输入:用户需求 (Prompt)
- 输出:文章大纲 (JSON 格式)
- 依赖:无
- 节点 2:
GenerateDraft(LLM 调用)- 输入:文章大纲 (来自节点 1 的输出)
- 输出:文章初稿 (Markdown 格式)
- 依赖:
GenerateOutline
- 节点 3:
HumanReview(人工操作)- 输入:文章初稿 (来自节点 2 的输出)
- 输出:审核结果 (通过/拒绝/修改意见)
- 依赖:
GenerateDraft
- 节点 4:
PublishOrRevise(条件判断/分支)- 输入:审核结果 (来自节点 3 的输出)
- 如果“通过”,则转到
Publish节点。 - 如果“拒绝/修改”,则转到
GenerateDraft节点(带修改意见作为输入,形成一个循环)。 - 依赖:
HumanReview
这种图的表示方式,使得我们可以清晰地定义复杂的工作流,并将其与具体的实现细节解耦。
什么是“模型无关”?
“模型无关”是指我们的系统设计不应该与任何特定的LLM模型(如GPT-4o、Claude 3.5、Gemini 1.5)紧密耦合。这意味着:
- API 兼容性: 我们的代码不直接调用特定模型的 SDK,而是通过一个统一的抽象接口进行交互。
- 参数一致性: 即使底层模型有不同的参数命名或取值范围,上层应用也应使用一套标准化的参数。
- 响应标准化: 无论哪个模型响应,其输出都应被解析为统一的内部数据结构。
- 行为抽象: 即使模型在处理工具调用或多模态输入时有不同的约定,我们的抽象层也应提供一致的接口。
举例来说,GPT系列使用 messages 数组来构建对话历史,其中包含 role 和 content。Claude 3.5 同样使用 messages 数组,但其 role 可能略有不同(例如 user, assistant,而没有 system 角色,或者 system 角色作为独立的参数)。Gemini 1.5 也使用 parts 数组来表示内容,并有 role。工具调用方面,OpenAI 采用 function_call,而 Gemini 采用 tool_code。这些差异都需要被抽象和抹平。
为什么需要“模型无关图编译”?
- 应对模型快速迭代: LLM领域发展极快,新模型层出不穷。模型无关的设计允许我们快速集成新模型,而无需重写大量业务逻辑。
- 实现最佳性能与成本: 在特定任务上,某个模型可能更优或更经济。例如,内容总结可能Claude更擅长,代码生成可能GPT-4o更强,而低成本的文本提取可能Gemini 1.5更划算。通过动态切换,我们可以为每个任务选择“最佳模型”。
- 提高系统韧性: 当某个模型的API出现故障、达到速率限制或暂时不可用时,系统可以自动回退到其他可用模型,保证服务的连续性。
- 促进A/B测试与实验: 轻松比较不同模型在真实业务场景中的表现,收集数据,指导优化。
- 简化复杂工作流: 图表示法提供了一种直观、模块化的方式来设计、可视化和调试复杂的LLM应用。
III. 构建模型无关层的基石:抽象与标准化
要实现模型无关,核心在于建立一个健壮的抽象层。这个抽象层将充当我们应用程序与具体LLM提供商之间的“翻译官”和“协调员”。
1. API 抽象层设计
我们将定义一个统一的接口,所有具体的LLM提供商实现都必须遵循这个接口。这确保了上层业务逻辑始终以相同的方式与LLM交互,无论底层是哪个模型。
统一接口定义
我们可以定义一个抽象基类或协议,包含核心的LLM操作,例如:
generate_text(prompt: str, **kwargs) -> str: 简单的文本生成。chat_completion(messages: List[Dict], **kwargs) -> Dict: 对话式生成,支持消息历史。embed(texts: List[str], **kwargs) -> List[List[float]]: 文本嵌入。stream_chat_completion(...): 流式对话生成。supports_tool_calling() -> bool: 查询是否支持工具调用。supports_multimodal() -> bool: 查询是否支持多模态输入。
请求参数标准化
不同的模型有不同的参数,例如 temperature (温度), top_p, max_tokens (最大生成token数), stop_sequences (停止序列)。我们的抽象层应该定义一套统一的参数名称,并负责将其映射到具体模型的参数。
| 统一参数名 | GPT-4o 参数名 | Claude 3.5 参数名 | Gemini 1.5 参数名 | 描述 |
|---|---|---|---|---|
temperature |
temperature |
temperature |
temperature |
控制生成文本的随机性 |
max_tokens |
max_tokens |
max_tokens |
max_output_tokens |
生成文本的最大token数 |
top_p |
top_p |
top_p |
top_p |
控制采样核 |
stop_sequences |
stop |
stop_sequences |
stop_sequences |
遇到这些字符串时停止生成 |
model_name |
model |
model |
model |
具体的模型名称 |
tools |
tools |
tools |
tools |
工具定义(Function Calling) |
tool_choice |
tool_choice |
tool_choice |
tool_config |
工具选择策略 |
响应格式标准化
无论哪个模型返回结果,我们都应该将其解析成一个统一的内部数据结构。这对于后续的图节点处理至关重要。
例如,对于对话完成:
# 统一的响应结构
class LLMResponse:
def __init__(self, content: str, role: str = "assistant", tool_calls: Optional[List[Dict]] = None,
finish_reason: Optional[str] = None, usage: Optional[Dict] = None, raw_response: Optional[Any] = None):
self.content = content
self.role = role
self.tool_calls = tool_calls if tool_calls is not None else []
self.finish_reason = finish_reason # e.g., "stop", "length", "tool_calls"
self.usage = usage # e.g., {"input_tokens": 10, "output_tokens": 20}
self.raw_response = raw_response # 原始API响应,用于调试
def __str__(self):
return f"Role: {self.role}nContent: {self.content}nTool Calls: {self.tool_calls}"
错误处理统一
不同API的错误码和异常类型不同。抽象层应捕获这些原始异常,并抛出我们自己定义的一套统一异常,例如 LLMServiceUnavailableError, LLMRateLimitExceededError, LLMInvalidRequestError 等。
class LLMServiceError(Exception):
"""Base exception for LLM service errors."""
pass
class LLMRateLimitExceededError(LLMServiceError):
"""Raised when rate limit is exceeded."""
pass
class LLMInvalidRequestError(LLMServiceError):
"""Raised when the LLM request is invalid."""
pass
# ... 其他错误类型
2. 输入/输出格式标准化
Prompt 工程的挑战
尽管我们追求模型无关,但不同的模型在理解Prompt方面可能存在细微差异。例如,有些模型对 system 角色有更好的支持,有些则更倾向于将系统指令融入 user 消息。在实践中,可能需要为不同模型维护略有不同的Prompt模板,或者使用更通用的Prompt策略。
消息结构标准化
一个通用的消息结构对于对话式LLM至关重要。我们可以采用 OpenAI 风格的消息格式作为内部标准,因为它已经被广泛接受:
# 标准化消息结构
class Message:
def __init__(self, role: str, content: Union[str, List[Dict]], name: Optional[str] = None, tool_call_id: Optional[str] = None):
"""
Args:
role: "system", "user", "assistant", "tool"
content: str for text, List[Dict] for multimodal (e.g., image_url)
name: Optional name for tool calls
tool_call_id: Optional ID for tool responses
"""
self.role = role
self.content = content
self.name = name
self.tool_call_id = tool_call_id
def to_openai_format(self) -> Dict:
msg = {"role": self.role, "content": self.content}
if self.name:
msg["name"] = self.name
if self.tool_call_id: # For tool responses
msg["tool_call_id"] = self.tool_call_id
return msg
def to_claude_format(self) -> Dict:
# Claude 3.5 Sonnet supports system message
# For older Claude models, system message might need to be prepended to user message
if self.role == "system":
return {"role": "user", "content": f"<system_instruction>{self.content}</system_instruction>"} # Example adaptation
return {"role": self.role, "content": self.content}
def to_gemini_format(self) -> Dict:
parts = []
if isinstance(self.content, str):
parts.append({"text": self.content})
elif isinstance(self.content, list):
for item in self.content:
if item["type"] == "text":
parts.append({"text": item["text"]})
elif item["type"] == "image_url":
parts.append({"image_url": {"url": item["image_url"]}}) # Gemini might need base64 data for image
# Gemini's tool_code vs OpenAI's tool_calls
# This part requires more complex mapping, see 'tool_calling' section
return {"role": self.role, "parts": parts}
工具调用(Function Calling)的统一处理
工具调用是现代LLM的关键功能之一。然而,不同模型实现这一功能的方式有所差异。
- OpenAI (GPT-4o): 使用
tools参数定义函数 schema,模型返回tool_calls结构。 - Anthropic (Claude 3.5): 也使用
tools参数定义函数 schema,模型返回tool_use结构。 - Google (Gemini 1.5): 使用
tools参数定义函数 schema,模型返回tool_code结构,其中包含函数名和参数。
我们的抽象层需要将这些不同格式统一起来。
# 标准化工具定义
class Tool:
def __init__(self, name: str, description: str, parameters: Dict):
self.name = name
self.description = description
self.parameters = parameters # JSON Schema format
def to_openai_format(self) -> Dict:
return {"type": "function", "function": {"name": self.name, "description": self.description, "parameters": self.parameters}}
def to_claude_format(self) -> Dict:
return {"input_schema": self.parameters, "name": self.name, "description": self.description}
def to_gemini_format(self) -> Dict:
return {"function_declarations": [{"name": self.name, "description": self.description, "parameters": self.parameters}]}
# 标准化工具调用结果
class ToolCall:
def __init__(self, id: str, name: str, arguments: Dict):
self.id = id
self.name = name
self.arguments = arguments
当模型响应包含工具调用时,无论原始API返回的是 tool_calls、tool_use 还是 tool_code,我们都应该将其解析为 List[ToolCall] 这样的统一结构。
输出解析与验证
LLM的输出有时会偏离预期格式(例如,期望JSON但返回了纯文本)。我们可以利用 JSON Schema 或 Pydantic 等工具来定义期望的输出结构,并在抽象层进行验证和解析。如果输出不符合预期,可以触发重试或错误处理流程。
3. 配置管理与凭证安全
- 凭证管理: API 密钥不应硬编码,而应通过环境变量、秘密管理服务(如Vault, AWS Secrets Manager)或配置文件安全地管理。
- 模型参数配置: 允许动态配置每个模型的参数(温度、top_p等),甚至可以为不同任务设置不同的默认值。
- 环境区分: 生产、预发布、开发环境应有不同的配置。
import os
from typing import Dict, Any
class ConfigManager:
_instance = None
_config: Dict[str, Any] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(ConfigManager, cls).__new__(cls)
cls._instance._load_config()
return cls._instance
def _load_config(self):
# Load API keys from environment variables
self._config["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
self._config["CLAUDE_API_KEY"] = os.getenv("CLAUDE_API_KEY")
self._config["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY")
self._config["GEMINI_PROJECT_ID"] = os.getenv("GEMINI_PROJECT_ID") # For Google Cloud
self._config["GEMINI_LOCATION"] = os.getenv("GEMINI_LOCATION", "us-central1")
# Default model parameters (can be overridden per request)
self._config["default_llm_params"] = {
"temperature": 0.7,
"max_tokens": 1024,
"top_p": 1.0,
}
# Model specific parameters or overrides
self._config["model_specific_params"] = {
"gpt-4o": {"max_tokens": 2048},
"claude-3-5-sonnet-20240620": {"max_tokens": 1500},
"gemini-1.5-pro-latest": {"max_output_tokens": 1500}
}
def get(self, key: str, default: Any = None) -> Any:
return self._config.get(key, default)
def set(self, key: str, value: Any):
self._config[key] = value
# Example usage:
config = ConfigManager()
openai_key = config.get("OPENAI_API_KEY")
default_temp = config.get("default_llm_params")["temperature"]
IV. 图编译的实现:任务分解与工作流编排
有了模型无关的抽象层,我们现在可以将注意力转向如何构建和执行工作流图。
1. 节点与边:图的构建
我们将定义不同的节点类型,它们都继承自一个抽象的 BaseNode 类。每个节点负责执行一个具体的任务,并处理输入、产生输出。
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Callable, Union
class ExecutionContext:
"""Holds shared state and results during graph execution."""
def __init__(self):
self.results: Dict[str, Any] = {}
self.errors: Dict[str, Exception] = {}
self.shared_data: Dict[str, Any] = {}
class BaseNode(ABC):
def __init__(self, node_id: str, depends_on: Optional[List[str]] = None):
self.node_id = node_id
self.depends_on = depends_on if depends_on is not None else []
self.output: Any = None
self.error: Optional[Exception] = None
@abstractmethod
async def execute(self, context: ExecutionContext) -> Any:
"""Execute the node's logic."""
pass
def get_output(self) -> Any:
return self.output
def get_error(self) -> Optional[Exception]:
return self.error
# 示例:一个LLM调用节点
class LLMNode(BaseNode):
def __init__(self, node_id: str, llm_orchestrator: Any, prompt_template: str,
model_selector: Optional[Callable[[ExecutionContext], str]] = None,
output_parser: Optional[Callable[[str], Any]] = None,
depends_on: Optional[List[str]] = None,
llm_params: Optional[Dict] = None):
super().__init__(node_id, depends_on)
self.llm_orchestrator = llm_orchestrator # Reference to our LLM switching logic
self.prompt_template = prompt_template
self.model_selector = model_selector # Function to dynamically select model
self.output_parser = output_parser
self.llm_params = llm_params if llm_params is not None else {}
async def execute(self, context: ExecutionContext) -> Any:
try:
# Gather inputs from dependencies
inputs = {dep: context.results[dep] for dep in self.depends_on if dep in context.results}
# Render prompt using template and inputs
prompt = self.prompt_template.format(**inputs)
messages = [Message(role="user", content=prompt)] # Simple case, more complex for chat history
# Dynamically select model
model_name = self.model_selector(context) if self.model_selector else None
# Call LLM via orchestrator
response: LLMResponse = await self.llm_orchestrator.chat_completion(
messages=messages,
model_name=model_name, # Orchestrator will pick default if None
**self.llm_params
)
raw_content = response.content
if self.output_parser:
self.output = self.output_parser(raw_content)
else:
self.output = raw_content
context.results[self.node_id] = self.output
return self.output
except Exception as e:
self.error = e
context.errors[self.node_id] = e
raise
# 示例:一个工具调用节点
class ToolCallNode(BaseNode):
def __init__(self, node_id: str, tool_func: Callable,
input_mapper: Callable[[ExecutionContext], Dict],
depends_on: Optional[List[str]] = None):
super().__init__(node_id, depends_on)
self.tool_func = tool_func
self.input_mapper = input_mapper
async def execute(self, context: ExecutionContext) -> Any:
try:
# Map context data to tool function arguments
args = self.input_mapper(context)
self.output = await self.tool_func(**args) if inspect.iscoroutinefunction(self.tool_func) else self.tool_func(**args)
context.results[self.node_id] = self.output
return self.output
except Exception as e:
self.error = e
context.errors[self.node_id] = e
raise
# 示例:一个条件判断节点
class ConditionalNode(BaseNode):
def __init__(self, node_id: str, condition_func: Callable[[ExecutionContext], bool],
true_path_node_id: str, false_path_node_id: str,
depends_on: Optional[List[str]] = None):
super().__init__(node_id, depends_on)
self.condition_func = condition_func
self.true_path_node_id = true_path_node_id
self.false_path_node_id = false_path_node_id
async def execute(self, context: ExecutionContext) -> Any:
try:
if self.condition_func(context):
self.output = self.true_path_node_id
else:
self.output = self.false_path_node_id
context.results[self.node_id] = self.output
return self.output
except Exception as e:
self.error = e
context.errors[self.node_id] = e
raise
2. DAG (有向无环图) 引擎
图引擎负责管理节点的生命周期、执行顺序、并发控制和状态传递。由于我们的工作流通常是异步的,使用 asyncio 是一个自然的选择。
核心步骤:
- 构建图: 将所有节点添加到图中,并建立它们的依赖关系。
- 拓扑排序: 确定节点的执行顺序,确保所有依赖项在节点执行前完成。对于有向无环图(DAG),这可以通过 Kahn’s 算法或深度优先搜索(DFS)实现。
- 执行: 按照拓扑排序的顺序执行节点。支持并行执行不相互依赖的节点。
- 状态管理:
ExecutionContext在节点间传递结果和共享数据。
import asyncio
import collections
from typing import Dict, List, Set, Tuple
class DAGExecutor:
def __init__(self, nodes: List[BaseNode]):
self.nodes_map: Dict[str, BaseNode] = {node.node_id: node for node in nodes}
self.graph: Dict[str, List[str]] = collections.defaultdict(list)
self.in_degree: Dict[str, int] = collections.defaultdict(int)
self._build_graph()
def _build_graph(self):
for node_id, node in self.nodes_map.items():
for dep_id in node.depends_on:
if dep_id not in self.nodes_map:
raise ValueError(f"Node {node_id} depends on non-existent node {dep_id}")
self.graph[dep_id].append(node_id) # dep_id -> node_id (dep_id is a prerequisite for node_id)
self.in_degree[node_id] += 1
# Initialize in-degree for nodes with no dependencies
for node_id in self.nodes_map:
if node_id not in self.in_degree:
self.in_degree[node_id] = 0
async def execute(self, initial_context: Optional[ExecutionContext] = None) -> ExecutionContext:
context = initial_context if initial_context else ExecutionContext()
# Kahn's algorithm for topological sort and execution
queue = collections.deque([node_id for node_id, degree in self.in_degree.items() if degree == 0])
executed_nodes_count = 0
total_nodes = len(self.nodes_map)
# Keep track of currently running tasks
running_tasks: Dict[str, asyncio.Task] = {}
while executed_nodes_count < total_nodes:
# Identify nodes that are ready to run (dependencies met, not yet running)
ready_to_run_now = []
while queue:
node_id = queue.popleft()
if node_id not in running_tasks:
ready_to_run_now.append(node_id)
if not ready_to_run_now and not running_tasks:
# This indicates a cycle or an issue where no nodes can run
unexecuted_nodes = [n_id for n_id in self.nodes_map if n_id not in context.results and n_id not in context.errors and n_id not in running_tasks]
if unexecuted_nodes:
raise RuntimeError(f"Deadlock detected or cycle in graph. Unexecuted nodes: {unexecuted_nodes}")
break # All nodes executed or no more ready nodes
# Start new tasks for ready nodes
for node_id in ready_to_run_now:
node = self.nodes_map[node_id]
task = asyncio.create_task(node.execute(context), name=node_id)
running_tasks[node_id] = task
print(f"Starting node: {node_id}")
if not running_tasks:
break # Should not happen if graph is valid and not empty
# Wait for any of the running tasks to complete
done, pending = await asyncio.wait(
running_tasks.values(),
return_when=asyncio.FIRST_COMPLETED
)
# Process completed tasks
for task in done:
node_id = task.get_name()
executed_nodes_count += 1
del running_tasks[node_id]
try:
await task # Re-raise any exceptions from the task
print(f"Node {node_id} completed successfully.")
except Exception as e:
print(f"Node {node_id} failed with error: {e}")
# Decide how to handle errors: propagate, stop, or allow partial completion
# For simplicity, we'll continue executing other nodes but mark this one as failed.
self.nodes_map[node_id].error = e
context.errors[node_id] = e
# If a critical node fails, you might want to stop the whole graph.
# For now, we allow downstream nodes to potentially fail or handle missing deps.
# Update in-degrees for dependent nodes
for neighbor_id in self.graph[node_id]:
self.in_degree[neighbor_id] -= 1
if self.in_degree[neighbor_id] == 0:
queue.append(neighbor_id)
if executed_nodes_count < total_nodes:
raise RuntimeError(f"Graph execution incomplete. {total_nodes - executed_nodes_count} nodes did not execute.")
return context
3. 示例:一个简单的内容生成与审核工作流(图编译版)
让我们将前面提到的内容生成与审核工作流用我们定义的节点和DAG引擎实现。
import json
import inspect
from pydantic import BaseModel, Field
# Mock LLM Orchestrator (will be detailed in next section)
class MockLLMOrchestrator:
async def chat_completion(self, messages: List[Message], model_name: Optional[str] = None, **kwargs) -> LLMResponse:
print(f"--- Mock LLM Call ({model_name or 'default'}) ---")
print(f"Prompt: {messages[-1].content}")
# Simulate different model behaviors
if model_name == "claude-3-5-sonnet-20240620" and "大纲" in messages[-1].content:
content = json.dumps({
"title": "人工智能发展与未来",
"sections": [
{"heading": "引言:AI浪潮的兴起", "points": ["定义AI", "历史回顾"]},
{"heading": "当前AI技术突破", "points": ["深度学习", "大模型", "多模态"]},
{"heading": "AI对社会的影响", "points": ["经济", "就业", "伦理"]},
{"heading": "AI的未来展望", "points": ["通用人工智能", "人机协作"]},
]
}, ensure_ascii=False)
return LLMResponse(content=content, finish_reason="stop", usage={"input_tokens": 50, "output_tokens": 100})
elif "大纲" in messages[-1].content:
content = json.dumps({
"title": "通用LLM应用架构",
"sections": [
{"heading": "引言", "points": ["LLM异构性", "模型无关需求"]},
{"heading": "抽象层设计", "points": ["API标准化", "参数统一"]},
{"heading": "图编译引擎", "points": ["节点", "边", "DAG执行"]},
]
}, ensure_ascii=False)
return LLMResponse(content=content, finish_reason="stop", usage={"input_tokens": 60, "output_tokens": 120})
elif "文章初稿" in messages[-1].content:
return LLMResponse(content="这是一篇关于AI发展初稿,基于提供的大纲。n...", finish_reason="stop")
return LLMResponse(content="Mock LLM Response", finish_reason="stop")
# Define expected output schema for outline
class ArticleOutline(BaseModel):
title: str = Field(description="The title of the article.")
sections: List[Dict[str, Union[str, List[str]]]] = Field(description="A list of sections, each with a heading and key points.")
# Define mock tools
async def mock_human_review_tool(draft_content: str) -> Dict:
print(f"n--- Human Review Needed ---")
print(f"Draft Content Preview: {draft_content[:200]}...")
# In a real system, this would involve a web UI or message queue
# For demo, simulate user input
review_input = input("Review result (approve/reject/revise, [comment]): ").lower().strip()
if review_input.startswith("approve"):
return {"status": "approved", "comment": "Looks good!"}
elif review_input.startswith("reject"):
comment = review_input.split(":", 1)[1].strip() if ":" in review_input else "Content is not suitable."
return {"status": "rejected", "comment": comment}
else: # Default to revise
comment = review_input.split(":", 1)[1].strip() if ":" in review_input else "Please elaborate on section 2."
return {"status": "revise", "comment": comment}
# Graph Construction
async def run_article_workflow(llm_orchestrator: Any, user_request: str):
nodes: List[BaseNode] = []
# Node 1: Generate Outline
def outline_parser(json_string: str) -> ArticleOutline:
return ArticleOutline.model_validate_json(json_string)
nodes.append(LLMNode(
node_id="generate_outline",
llm_orchestrator=llm_orchestrator,
prompt_template="请根据以下需求,生成一篇结构清晰的文章大纲(JSON格式),包含标题和多个带有要点的章节:n需求:{user_request}",
model_selector=lambda ctx: "claude-3-5-sonnet-20240620", # Force Claude for outline
output_parser=outline_parser,
llm_params={"response_format": {"type": "json_object"}}, # Hint for LLM
depends_on=[],
))
# Node 2: Generate Draft
nodes.append(LLMNode(
node_id="generate_draft",
llm_orchestrator=llm_orchestrator,
prompt_template="请根据以下文章大纲,撰写一篇详细的文章初稿:n大纲:{generate_outline}n{revision_comment}",
model_selector=lambda ctx: "gpt-4o", # Force GPT-4o for draft
llm_params={"temperature": 0.8},
depends_on=["generate_outline"],
))
# Node 3: Human Review
nodes.append(ToolCallNode(
node_id="human_review",
tool_func=mock_human_review_tool,
input_mapper=lambda ctx: {"draft_content": ctx.results["generate_draft"]},
depends_on=["generate_draft"],
))
# Node 4: Conditional Re-draft or Finalize
class ReviewDecisionNode(BaseNode):
def __init__(self, node_id: str, depends_on: Optional[List[str]] = None):
super().__init__(node_id, depends_on)
async def execute(self, context: ExecutionContext) -> Any:
review_result = context.results["human_review"]
if review_result["status"] == "approved":
self.output = "finalize" # This would trigger a 'publish' node in a real system
context.shared_data["revision_comment"] = ""
print(f"Article Approved!")
elif review_result["status"] == "revise":
self.output = "redraft"
context.shared_data["revision_comment"] = f"根据以下修改意见重新撰写:{review_result['comment']}"
print(f"Article needs revision: {review_result['comment']}")
else: # Rejected
self.output = "rejected"
context.shared_data["revision_comment"] = ""
print(f"Article Rejected: {review_result['comment']}")
context.results[self.node_id] = self.output
return self.output
nodes.append(ReviewDecisionNode(node_id="review_decision", depends_on=["human_review"]))
# Add a pseudo-node for revision loop (handled by re-running part of the graph or a separate loop)
# For a simple DAG, we might re-trigger the draft node with revision comments.
# In a full-fledged workflow engine, this would be a feedback loop.
# For demonstration, we'll manually inject revision comment into context and rerun if needed.
executor = DAGExecutor(nodes)
# Initial context with user request
initial_context = ExecutionContext()
initial_context.shared_data["user_request"] = user_request
initial_context.shared_data["revision_comment"] = ""
# Execute the graph
print("n--- Starting Article Workflow ---")
try:
final_context = await executor.execute(initial_context)
print("n--- Workflow Completed ---")
if "generate_draft" in final_context.results:
print(f"Final Draft (partial): {final_context.results['generate_draft'][:500]}...")
if "generate_outline" in final_context.results:
print(f"Outline: {final_context.results['generate_outline'].model_dump_json(indent=2)}")
except Exception as e:
print(f"Workflow failed: {e}")
for node_id, error in final_context.errors.items():
print(f"Error in node {node_id}: {error}")
# To run this:
# asyncio.run(run_article_workflow(MockLLMOrchestrator(), "一篇关于AI在医疗领域应用的文章"))
(由于DAGExecutor的循环处理和条件判断节点的复杂性,上述DAGExecutor示例未能完全演示条件循环,通常需要更高级的DAG框架或在外部循环中根据review_decision的结果重新触发部分图的执行。)
V. 无缝切换的策略与实践
现在我们已经有了模型无关的抽象接口和图编译引擎,核心问题是如何在GPT-4o、Claude 3.5和Gemini 1.5之间实现无缝切换。这需要一个智能的LLM编排器(Orchestrator)。
1. 动态模型选择器
LLM编排器内部应该包含一个策略引擎,用于根据多种因素动态选择最佳模型。
选择策略:
- 成本优先: 选择当前token价格最低的模型。
- 延迟优先: 选择响应速度最快的模型。
- 功能优先: 例如,如果任务需要图像识别,则选择支持多模态输入的模型(如GPT-4o、Gemini 1.5 Pro)。如果需要强大的代码生成能力,可能优先选择GPT-4o或Gemini 1.5 Pro。
- 可靠性优先: 选择历史SLA表现最好的模型。
- Token 限制: 如果输入上下文很长,选择上下文窗口大的模型(如Gemini 1.5 Pro的百万token)。
- A/B 测试: 按一定比例随机分配流量到不同模型,进行效果评估。
- 用户偏好/配置: 允许最终用户或管理员指定偏好模型。
健康检查与回退机制 (Circuit Breaker)
当某个模型提供商的API出现故障、响应时间过长或频繁返回错误时,编排器应能自动将其从可用列表中移除一段时间(熔断),并切换到备用模型。一段时间后,可以尝试“半开”状态,少量请求尝试恢复服务。
负载均衡与限流
如果同时使用多个模型,可以根据每个模型的配额和性能进行负载均衡。同时,编排器也应该实现全局的速率限制,以避免超出任何一个模型提供商的API配额。
2. 模型特定的优化与适配
尽管我们追求模型无关,但在某些情况下,为特定模型进行微调可以显著提升性能或利用其独特功能。
- Prompt 模板差异化: 虽然有通用的消息结构,但某些模型可能对
system消息的放置、指令的措辞有不同偏好。编排器可以根据选定的模型加载不同的Prompt模板。 - 特定功能映射:
- 多模态: GPT-4o和Gemini 1.5支持图像输入。当检测到
Message中包含图像时,编排器需要将统一的图像表示(如URL或Base64)转换为对应API的格式。 - 工具调用: 如前所述,需要将统一的
Tool定义和ToolCall解析逻辑映射到不同模型的tools参数和响应结构。
- 多模态: GPT-4o和Gemini 1.5支持图像输入。当检测到
- 响应解析的微调: 虽然我们有标准化的
LLMResponse,但某些模型在特定请求下的响应可能需要更细致的解析逻辑。
3. 代码实践:Python 实现一个简化的模型无关层
我们将构建 AbstractLLMProvider 接口和其具体实现,以及 LLMOrchestrator 来实现动态切换。
import os
import asyncio
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union, Callable
import httpx # For async HTTP requests
import json
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type
# Assume Message, LLMResponse, Tool, ToolCall, LLMServiceError, etc. are defined as above
# For brevity, let's redefine Message and LLMResponse simpler here if not already imported
class Message:
def __init__(self, role: str, content: Union[str, List[Dict]], name: Optional[str] = None, tool_call_id: Optional[str] = None):
self.role = role
self.content = content
self.name = name
self.tool_call_id = tool_call_id
def to_dict(self):
d = {"role": self.role, "content": self.content}
if self.name: d["name"] = self.name
if self.tool_call_id: d["tool_call_id"] = self.tool_call_id
return d
class LLMResponse:
def __init__(self, content: str, role: str = "assistant", tool_calls: Optional[List[Dict]] = None,
finish_reason: Optional[str] = None, usage: Optional[Dict] = None, raw_response: Optional[Any] = None):
self.content = content
self.role = role
self.tool_calls = tool_calls if tool_calls is not None else []
self.finish_reason = finish_reason
self.usage = usage
self.raw_response = raw_response
def __str__(self):
return f"Role: {self.role}nContent: {self.content}nTool Calls: {self.tool_calls}"
# --- Abstract LLM Provider ---
class AbstractLLMProvider(ABC):
def __init__(self, model_name: str, api_key: str, base_url: str):
self.model_name = model_name
self.api_key = api_key
self.base_url = base_url
self.client = httpx.AsyncClient()
@abstractmethod
async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
pass
@abstractmethod
def supports_tool_calling(self) -> bool:
pass
@abstractmethod
def supports_multimodal(self) -> bool:
pass
@abstractmethod
def get_cost_per_token(self) -> Dict[str, float]:
"""Returns input/output token cost."""
pass
def _map_messages_to_provider_format(self, messages: List[Message]) -> List[Dict]:
"""Generic mapping, specific providers might override."""
return [msg.to_dict() for msg in messages]
def _map_tools_to_provider_format(self, tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
"""Generic mapping, specific providers might override."""
return tools # Assuming tools are already in a somewhat standardized format
def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
"""Generic parsing, specific providers must override."""
raise NotImplementedError
# --- OpenAI Provider ---
class OpenAIProvider(AbstractLLMProvider):
def __init__(self, model_name: str):
super().__init__(model_name, os.getenv("OPENAI_API_KEY"), "https://api.openai.com/v1")
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
@retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
retry=retry_if_exception_type(httpx.HTTPStatusError))
async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
payload = {
"model": self.model_name,
"messages": self._map_messages_to_provider_format(messages),
**kwargs
}
if tools:
payload["tools"] = tools # Tools are already in OpenAI format (from Tool.to_openai_format)
if tool_choice:
payload["tool_choice"] = tool_choice
try:
response = await self.client.post(f"{self.base_url}/chat/completions", headers=self.headers, json=payload, timeout=60.0)
response.raise_for_status()
return self._parse_provider_response(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise LLMRateLimitExceededError(f"OpenAI Rate Limit Exceeded: {e.response.text}") from e
if e.response.status_code >= 500:
raise LLMServiceUnavailableError(f"OpenAI Service Unavailable: {e.response.text}") from e
raise LLMInvalidRequestError(f"OpenAI API Error: {e.response.status_code} - {e.response.text}") from e
except Exception as e:
raise LLMServiceError(f"OpenAI Unknown Error: {e}") from e
def _map_messages_to_provider_format(self, messages: List[Message]) -> List[Dict]:
openai_messages = []
for msg in messages:
content_val = msg.content
if isinstance(content_val, list): # Multimodal content
openai_content_parts = []
for part in content_val:
if part["type"] == "text":
openai_content_parts.append({"type": "text", "text": part["text"]})
elif part["type"] == "image_url":
openai_content_parts.append({"type": "image_url", "image_url": {"url": part["image_url"]}})
content_val = openai_content_parts
if msg.role == "tool": # OpenAI uses 'tool' role for function call results
openai_messages.append({"role": msg.role, "content": content_val, "tool_call_id": msg.tool_call_id})
elif msg.role == "assistant" and msg.tool_call_id: # Assistant message potentially with tool calls
openai_messages.append({"role": "assistant", "content": content_val, "tool_calls": msg.tool_call_id}) # This needs careful mapping
else:
openai_messages.append({"role": msg.role, "content": content_val})
return openai_messages
def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
choice = raw_response["choices"][0]
message = choice["message"]
content = message.get("content", "")
tool_calls = message.get("tool_calls", [])
finish_reason = choice["finish_reason"]
usage = raw_response.get("usage", {})
# Standardize tool calls for our internal ToolCall class
standardized_tool_calls = []
for tc in tool_calls:
if tc["type"] == "function":
standardized_tool_calls.append({
"id": tc["id"],
"name": tc["function"]["name"],
"arguments": json.loads(tc["function"]["arguments"])
})
return LLMResponse(
content=content,
role="assistant",
tool_calls=standardized_tool_calls,
finish_reason=finish_reason,
usage={"input_tokens": usage.get("prompt_tokens"), "output_tokens": usage.get("completion_tokens")},
raw_response=raw_response
)
def supports_tool_calling(self) -> bool:
return True # GPT-4o supports tool calling
def supports_multimodal(self) -> bool:
return True # GPT-4o supports multimodal
def get_cost_per_token(self) -> Dict[str, float]:
# Placeholder values, replace with actual current prices
if "gpt-4o" in self.model_name:
return {"input": 5.0 / 1_000_000, "output": 15.0 / 1_000_000}
return {"input": 0.0, "output": 0.0}
# --- Claude Provider ---
class ClaudeProvider(AbstractLLMProvider):
def __init__(self, model_name: str):
super().__init__(model_name, os.getenv("CLAUDE_API_KEY"), "https://api.anthropic.com/v1")
self.headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01", # Or latest "2023-06-01" / "2023-06-01-preview"
"content-type": "application/json"
}
@retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
retry=retry_if_exception_type(httpx.HTTPStatusError))
async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
system_message = ""
claude_messages = []
for msg in messages:
if msg.role == "system":
system_message = msg.content
else:
claude_messages.append({"role": msg.role, "content": msg.content}) # Claude 3.5 Sonnet accepts 'tool_use' and 'tool_result'
payload = {
"model": self.model_name,
"messages": claude_messages,
"max_tokens": kwargs.pop("max_tokens", 1024),
**kwargs
}
if system_message:
payload["system"] = system_message
if tools:
# Claude tool format uses "input_schema" instead of "parameters" directly
claude_tools = []
for tool_def in tools:
claude_tools.append({
"name": tool_def["function"]["name"],
"description": tool_def["function"]["description"],
"input_schema": tool_def["function"]["parameters"]
})
payload["tools"] = claude_tools
if tool_choice:
# Claude tool_choice mapping might be different. E.g., for specific tool:
# {"type": "tool", "name": "my_tool_name"}
payload["tool_choice"] = tool_choice # Assuming it's already mapped
try:
response = await self.client.post(f"{self.base_url}/messages", headers=self.headers, json=payload, timeout=60.0)
response.raise_for_status()
return self._parse_provider_response(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == 429:
raise LLMRateLimitExceededError(f"Claude Rate Limit Exceeded: {e.response.text}") from e
if e.response.status_code >= 500:
raise LLMServiceUnavailableError(f"Claude Service Unavailable: {e.response.text}") from e
raise LLMInvalidRequestError(f"Claude API Error: {e.response.status_code} - {e.response.text}") from e
except Exception as e:
raise LLMServiceError(f"Claude Unknown Error: {e}") from e
def _parse_provider_response(self, raw_response: Dict) -> LLMResponse:
content_parts = raw_response["content"]
text_content = ""
tool_calls = []
for part in content_parts:
if part["type"] == "text":
text_content += part["text"]
elif part["type"] == "tool_use":
tool_calls.append({
"id": part["id"],
"name": part["name"],
"arguments": part["input"] # Claude uses 'input' for arguments
})
usage = raw_response.get("usage", {})
return LLMResponse(
content=text_content,
role="assistant",
tool_calls=tool_calls,
finish_reason=raw_response.get("stop_reason"),
usage={"input_tokens": usage.get("input_tokens"), "output_tokens": usage.get("output_tokens")},
raw_response=raw_response
)
def supports_tool_calling(self) -> bool:
return True # Claude 3.5 Sonnet supports tool calling
def supports_multimodal(self) -> bool:
return True # Claude 3.5 Sonnet supports multimodal (text and images)
def get_cost_per_token(self) -> Dict[str, float]:
# Placeholder values, replace with actual current prices
if "claude-3-5-sonnet" in self.model_name:
return {"input": 3.0 / 1_000_000, "output": 15.0 / 1_000_000}
return {"input": 0.0, "output": 0.0}
# --- Gemini Provider ---
# Requires google-generativeai library
import google.generativeai as genai
from google.generativeai.types import Tool as GeminiTool, FunctionDeclaration
from google.api_core.exceptions import ResourceExhausted, GoogleAPIError
class GeminiProvider(AbstractLLMProvider):
def __init__(self, model_name: str):
# Gemini API doesn't use a single base_url for client like OpenAI/Claude
# It's configured via genai.configure
super().__init__(model_name, os.getenv("GEMINI_API_KEY"), "https://generativelanguage.googleapis.com/")
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)
@retry(wait=wait_random_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3),
retry=retry_if_exception_type(GoogleAPIError))
async def chat_completion(self, messages: List[Message], tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = None, **kwargs) -> LLMResponse:
gemini_contents = []
for msg in messages:
parts = []
if isinstance(msg.content, str):
parts.append(msg.content)
elif isinstance(msg.content, list): # Multimodal
for part in msg.content:
if part["type"] == "text":
parts.append(part["text"])
elif part["type"] == "image_url":
# Gemini requires base64 encoded image data directly, not URL
# This is a significant difference and needs careful handling.
# For simplicity here, we'll just pass a placeholder or raise an error.
# In a real system, you'd fetch and encode the image.
print("Warning: Gemini requires base64 image data, not URLs. Placeholder used.")
parts.append(genai.upload_file(path="/path/to/placeholder.jpg").set_display_name("placeholder_image"))
# Map roles: 'system' is typically not a distinct role in Gemini conversation history
# For system-like instructions, it's often part of the first user message or model config.
# We'll map 'system' to 'user' for simplicity if no specific system_instruction param exists.
role = "user" if msg.role == "system" else msg.role
gemini_contents.append({"role": role, "parts": parts})
gemini_tools = None
if tools:
gemini_tools = [GeminiTool(function_declarations=[
FunctionDeclaration(
name=tool_def["function"]["name"],
description=tool_def["function"]["description"],
parameters=tool_def["function"]["parameters"]
)
]) for tool_def in tools]
generation_config = genai.GenerationConfig(
temperature=kwargs.pop("temperature", 0.7),
max_output_tokens=kwargs.pop("max_tokens", 1024),
top_p=kwargs.pop("top_p", 1.0),
stop_sequences=kwargs.pop("stop_sequences", None)
)
try:
# Gemini's chat history is managed by its 'start_chat' method
# For a single turn, we use generate_content
response = await self.model.generate_content_async(
contents=gemini_contents,
tools=gemini_tools,
tool_config=tool_choice, # Gemini tool_config has a specific format
generation_config=generation_config,
# For streaming, use stream=True and iterate
)
return self._parse_provider_response(response)
except ResourceExhausted as e:
raise LLMRateLimitExceededError(f"Gemini Rate Limit Exceeded: {e}") from e
except GoogleAPIError as e:
if "quota" in str(e).lower():
raise LLMRateLimitExceededError(f"Gemini Quota Exceeded: {e}") from e
raise LLMInvalidRequestError(f"Gemini API Error: {e}") from e
except Exception as e:
raise LLMServiceError(f"Gemini Unknown Error: {e}") from e
def _parse_provider_response(self, raw_response: Any) -> LLMResponse:
candidates = raw_response.candidates
if not candidates:
return LLMResponse(content="", finish_reason="no_candidates", raw_response=raw_response)
first_candidate = candidates[0]
text_content = ""
tool_calls = []
for part in first_candidate.content.parts:
if hasattr(part, "text"):
text_content += part.text
elif hasattr(part, "function_call"):
fc = part.function_call
tool_calls.append({
"id": f"gemini_tool_call_{fc.name}", # Gemini doesn't always provide an ID, create one
"name": fc.name,
"arguments": {key: value for key, value in fc.args.items()}
})
usage_metadata = raw_response.usage_metadata
return LLMResponse(
content=text_content,
role="assistant",
tool_calls=tool_calls,
finish_reason=raw_response.prompt_feedback.block_reason if raw_response.prompt_feedback else "stop", # Simplified
usage={"input_tokens": usage_metadata.prompt_token_count, "output_tokens": usage_metadata.candidates_token_count} if usage_metadata else None,
raw_response=raw_response
)
def supports_tool_calling(self) -> bool:
return True # Gemini 1.5 Pro supports tool calling
def supports_multimodal(self) -> bool:
return True # Gemini 1.5 Pro supports multimodal (text and images)
def get_cost_per_token(self) -> Dict[str, float]:
# Placeholder values, replace with actual current prices
if "gemini-1.5-pro" in self.model_name:
return {"input": 7.0 / 1_000_000, "output": 21.0 / 1_000_000}
return {"input": 0.0, "output": 0.0}
# --- LLM Orchestrator ---
class LLMOrchestrator:
def __init__(self, default_model: str = "gpt-4o", config: Optional[ConfigManager] = None):
self.config = config if config else ConfigManager()
self.providers: Dict[str, AbstractLLMProvider] = {
"gpt-4o": OpenAIProvider("gpt-4o"),
"claude-3-5-sonnet-20240620": ClaudeProvider("claude-3-5-sonnet-20240620"),
"gemini-1.5-pro-latest": GeminiProvider("gemini-1.5-pro-latest"),
# Add other models as needed
}
self.available_models = list(self.providers.keys())
self.default_model = default_model
self.circuit_breakers: Dict[str, bool] = {model_name: False for model_name in self.available_models}
self.circuit_breaker_reset_time: Dict[str, float] = {model_name: 0.0 for model_name in self.available_models}
async def _select_model(self, selection_strategy: Optional[Callable[[List[str], Dict[str, AbstractLLMProvider]], str]] = None,
preferred_model: Optional[str] = None) -> str:
if preferred_model and preferred_model in self.available_models and not self.circuit_breakers[preferred_model]:
return preferred_model
active_models = [
model_name for model_name in self.available_models
if not self.circuit_breakers[model_name] or self.circuit_breaker_reset_time[model_name] < asyncio.get_event_loop().time()
]
# Reset circuit breaker if time elapsed
for model_name in list(active_models):
if self.circuit_breakers[model_name] and self.circuit_breaker_reset_time[model_name] < asyncio.get_event_loop().time():
print(f"Attempting to reset circuit breaker for {model_name}")
self.circuit_breakers[model_name] = False # Try to re-enable
active_models = [
model_name for model_name in self.available_models
if not self.circuit_breakers[model_name]
]
if not active_models:
raise LLMServiceUnavailableError("No LLM models are currently available.")
if selection_strategy:
return selection_strategy(active_models, self.providers)
# Default strategy: try preferred, then default, then first available
if self.default_model in active_models:
return self.default_model
return active_models[0] # Fallback to first available
def _trigger_circuit_breaker(self, model_name: str, duration_seconds: int = 60):
self.circuit_breakers[model_name] = True
self.circuit_breaker_reset_time[model_name] = asyncio.get_event_loop().time() + duration_seconds
print(f"Circuit breaker tripped for {model_name}. Will retry in {duration_seconds} seconds.")
async def chat_completion(self, messages: List[Message], model_name: Optional[str] = None,
tools: Optional[List[Dict]] = None, tool_choice: Optional[Union[str, Dict]] = None,
selection_strategy: Optional[Callable[[List[str], Dict[str, AbstractLLMProvider]], str]] = None,
**kwargs) -> LLMResponse:
selected_model_name = await self._select_model(selection_strategy, model_name)
provider = self.providers[selected_model_name]
# Apply model-specific parameter overrides from config
llm_params = self.config.get("default_llm_params", {}).copy()
llm_params.update(self.config.get("model_specific_params", {}).get(selected_model_name, {}))
llm_params.update(kwargs) # Request-specific params override all
# Map tools to provider-specific format
provider_tools = None
if tools:
# Assume tools are passed in OpenAI's function schema format, then converted by provider
if isinstance(provider, OpenAIProvider):
provider_tools = tools # Already in OpenAI format
elif isinstance(provider, ClaudeProvider):
provider_tools = []
for tool_def in tools:
provider_tools.append({
"name": tool_def["function"]["name"],
"description": tool_def["function"]["description"],
"input_schema": tool_def["function"]["parameters"]
})
elif isinstance(provider, GeminiProvider):
provider_tools = [GeminiTool(function_declarations=[
FunctionDeclaration(
name=tool_def["function"]["name"],
description=tool_def["function"]["description"],
parameters=tool_def["function"]["parameters"]
)
]) for tool_def in tools]
try:
response = await provider.chat_completion(messages=messages, tools=provider_tools, tool_choice=tool_choice, **llm_params)
# Log cost for this call
cost_info = provider.get_cost_per_token()
input_cost = response.usage["input_tokens"] * cost_info["input"]
output_cost = response.usage["output_tokens"] * cost_info["output"]
print(f"LLM Call: Model={selected_model_name}, Input Tokens={response.usage['input_tokens']}, Output Tokens={response.usage['output_tokens']}, Cost=${input_cost+output_cost:.4f}")
return response
except LLMRateLimitExceededError as e:
print(f"Rate limit for {selected_model_name}. Trying fallback if available.")
self._trigger_circuit_breaker(selected_model_name)
# Recursively try again with a different model if possible
return await self.chat_completion(messages, model_name=None, tools=tools, tool_choice=tool_choice, selection_strategy=selection_strategy, **kwargs)
except LLMServiceUnavailableError as e:
print(f"Service unavailable for {selected_model_name}. Trying fallback if available.")
self._trigger_circuit_breaker(selected_model_name, duration_seconds=300) # Longer timeout for service unavailability
return await self.chat_completion(messages, model_name=None, tools=tools, tool_choice=tool_choice, selection_strategy=selection_strategy, **kwargs)
except Exception as e:
print(f"Error with {selected_model_name}: {e}")
raise # Re-raise if no fallback or persistent error
# Example usage (needs actual API keys in env vars)
# async def main():
# # Mock API keys for demonstration
# os.environ["OPENAI_API_KEY"] = "sk-..."
# os.environ["CLAUDE_API_KEY"] = "sk-ant-..."
# os.environ["GEMINI_API_KEY"] = "AIza..."
# orchestrator = LLMOrchestrator(default_model="gpt-4o")
# # Simple text generation with default model
# print("--- Default Model Call (GPT-4o) ---")
# response1 = await orchestrator.chat_completion(messages=[Message(role="user", content="Hello, what is the capital of France?")])
# print(f"Response 1: {response1.content}")
# # Switch to Claude
# print("n--- Claude Call ---")
# response2 = await orchestrator.chat_completion(messages=[Message(role="user", content="Tell me a short story about a brave knight.")], model_name="claude-3-5-sonnet-20240620")
# print(f"Response 2: {response2.content}")
# # Switch to Gemini
# print("n--- Gemini Call ---")
# response3 = await orchestrator.chat_completion(messages=[Message(role="user", content="What's a good recipe for chocolate chip cookies?")], model_name="gemini-1.5-pro-latest")
# print(f"Response 3: {response3.content}")
# # Example with a custom selection strategy (e.g., cost-optimized)
# def cost_optimized_selector(available_models: List[str], providers_map: Dict[str, AbstractLLMProvider]) -> str:
# min_cost_model = None
# min_total_cost = float('inf')
# # Simulate a simple prompt to estimate output tokens for cost calculation
# simulated_input_tokens = 10
# simulated_output_tokens = 50
# for model_name in available_models:
# provider = providers_map[model_name]
# costs = provider.get_cost_per_token()
# current_total_cost = (costs.get("input", 0) * simulated_input_tokens) +
# (costs.get("output", 0) * simulated_output_tokens)
# if current_total_cost < min_total_cost:
# min_total_cost = current_total_cost
# min_cost_model = model_name
# print(f"Cost-optimized selector chose: {min_cost_model}")
# return min_cost_model or orchestrator.default_model
# print("n--- Cost-Optimized Call ---")
# response4 = await orchestrator.chat_completion(
# messages=[Message(role="user", content="Summarize the key events of World War II.")],
# selection_strategy=cost_optimized_selector
# )
# print(f"Response 4: {response4.content}")
# if __name__ == "__main__":
# asyncio.run(main())
VI. 错误处理、监控与韧性设计
一个健壮的模型无关系统不仅仅是实现功能,更要能优雅地处理异常和故障。
1. 统一错误码与异常处理
如前所述,将底层LLM提供商的各种错误映射到我们自己的统一异常体系中。这使得上层业务逻辑可以用一致的方式捕获和处理错误,而无需关心具体是哪个模型出了问题。
2. 重试机制与指数退避
网络瞬时故障、API限流是常见问题。对于这些可重试的错误(如HTTP 429, 5xx),我们应该实现自动重试,并采用指数退避策略(即每次重试间隔时间逐渐增长),避免对服务造成更大压力。tenacity 库在Python中是实现这一功能的利器,我在上面的代码示例中已经集成。
3. 日志与可观测性
- 请求/响应日志: 记录每次LLM调用的完整请求(脱敏后)和响应,对于调试和审计至关重要。
- 性能指标: 跟踪每个模型的延迟、吞吐量、成功率。这有助于我们评估模型选择策略的有效性。
- 成本追踪: 记录每次调用的token使用量和预估成本,提供细粒度的成本分析,并帮助优化模型选择。
- 错误率: 监控每个模型的错误率,作为触发熔断的依据。
import logging
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# In LLMOrchestrator's chat_completion method:
# ... after getting response:
# logging.info(f"LLM Call Success - Model: {selected_model_name}, Input: {response.usage['input_tokens']} tokens, Output: {response.usage['output_tokens']} tokens, Cost: ${input_cost+output_cost:.4f}")
# ... in exception handlers:
# logging.error(f"LLM Call Failed - Model: {selected_model_name}, Error: {e}", exc_info=True)
4. 缓存策略
对于重复的或确定性高的LLM请求,可以使用缓存来:
- 减少成本: 避免重复调用付费API。
- 提升性能: 立即返回结果,无需等待LLM响应。
- 降低负载: 减少对LLM提供商的请求量。
缓存可以是简单的内存缓存、Redis或数据库。需要注意的是,LLM的输出有时具有随机性(受 temperature 等参数影响),因此缓存策略需要根据具体业务场景进行设计。例如,只缓存 temperature=0 的确定性请求。
VII. 进阶考量与未来展望
1. 多模态模型支持
随着GPT