各位编程专家、AI爱好者,大家好!
今天,我们将深入探讨一个前沿且极具挑战性的话题:如何实现一个“跨模型迁移”的图执行系统。想象一下,你的复杂逻辑流程在GPT-4o上运行了一半,但出于成本、性能、甚至模型特性偏好等原因,你需要它无缝地迁移到Claude 3.5上继续执行,而无需从头再来。这听起来像科幻小说,但在我们构建更灵活、更具韧性的AI系统时,它正成为一个迫切的需求。
这不仅仅是简单地切换一个API端点,更是一场关于状态、上下文、语义鸿沟的深度挑战。今天,我将作为一名编程专家,为大家揭示实现这一目标背后的原理、架构与代码实践。
1. 引言:跨模型迁移的挑战与机遇
在大型语言模型(LLM)飞速发展的今天,我们面临着前所未有的选择。从OpenAI的GPT系列到Anthropic的Claude系列,再到Google、Meta以及开源社区的众多模型,每个模型都有其独特的优势、定价策略、性能曲线和偏好。这种多样性既是福音,也带来了新的工程挑战:如何充分利用它们,而不是被特定模型绑定?
设想一个复杂的AI应用场景:
- 阶段一:高精度、复杂推理。 用户提出一个需要深入理解和多步骤逻辑推理的问题。你可能倾向于使用像GPT-4o这样在复杂推理方面表现卓越的模型。
- 阶段二:创意生成、长文本创作。 基于第一阶段的分析结果,你需要生成一份详细的报告、创意文案或代码。此时,你可能更倾向于使用像Claude 3.5 Sonnet这样在长上下文处理和连贯性方面表现出色的模型,同时可能也为了成本效益考虑。
问题来了:如果这两个阶段是紧密耦合的,后一阶段需要前一阶段的所有中间结果和上下文,我们如何才能实现这种模型间的“接力跑”,而不是每次都从零开始?这就是我们今天探讨的“跨模型迁移”问题。
为什么是“深度挑战”?
核心难点在于LLM本身是无状态的API调用。每次调用都是一次独立的请求/响应,虽然可以通过传递messages数组来模拟对话状态,但这仅仅是上下文的重建,而非底层计算状态的共享。不同LLM的API接口、消息格式、系统提示词处理方式、甚至对指令的语义理解都存在细微差异。要实现无缝迁移,我们需要:
- 状态的全面捕获: 不仅是对话历史,还包括任务执行进度、中间结果、工具调用记录等。
- 上下文的通用表示与重建: 能够将一个模型理解的上下文,转化为另一个模型能够理解并继续的上下文。
- 模型间的语义对齐: 确保迁移后,新模型能够准确地“接管”旧模型的思维,保持逻辑连贯性。
解决这些挑战,将为我们构建更具弹性、成本效益和智能的AI系统打开大门。
2. 核心概念与技术基石
为了实现跨模型迁移,我们需要建立一套坚实的技术基石。
2.1 图执行范式 (Graph-based Execution)
将复杂的业务逻辑或AI工作流建模为一个有向无环图(DAG),是解决这一问题的核心思路。
- 节点 (Node): 代表一个独立的任务或操作,例如一次LLM推理、一次工具调用、一个数据处理步骤等。
- 边 (Edge): 代表节点之间的依赖关系和数据流。一个节点的输出可以作为另一个节点的输入。
为什么要用图?
- 模块化: 将大问题分解为小任务,每个任务封装在节点中。
- 可视化与管理: 任务流程清晰可见,易于追踪和调试。
- 状态追踪: 每个节点的执行状态(待执行、运行中、已完成、失败)可以被独立管理和持久化。
- 并行化: 无依赖关系的节点可以并行执行。
- 可恢复性: 当系统崩溃或需要迁移时,可以从图的任意一个已知状态恢复执行。
2.2 状态管理 (State Management)
迁移的本质是状态的转移。我们需要捕获并持久化以下几种状态:
- 执行状态 (Execution State): 图中每个节点的当前状态(如
PENDING,RUNNING,COMPLETED,FAILED)。 - 数据状态 (Data State): 节点执行产生的中间结果。这些结果需要是可序列化的,并且能够被后续节点引用。
- 上下文状态 (Context State): 这是LLM特有的。它包括了到目前为止的完整对话历史(
messages数组),以及可能影响后续推理的任何关键变量或指令。
2.3 中间表示 (Intermediate Representation – IR)
为了实现跨模型的兼容性,我们需要一个与具体LLM无关的、标准化的中间表示。JSON或YAML是理想的序列化格式。我们的图结构、节点配置、节点结果都将以这种通用格式进行存储和交换。
2.4 代理模式 (Agentic Workflow)
一个外部的“任务编排器”(Orchestrator)将扮演核心代理的角色。它不直接执行LLM的推理,而是负责:
- 解析图定义。
- 调度节点执行。
- 管理状态的持久化与加载。
- 决定何时进行模型迁移。
- 将上下文从一个模型转换为另一个模型。
3. 架构设计:实现跨模型迁移的蓝图
为了实现上述目标,我们设计了一个分层架构,确保职责分离和高内聚低耦合。
系统组件概述:
| 组件名称 | 职责 | 关键能力 |
|---|---|---|
| 任务编排器 (Orchestrator) | 核心控制器,管理图的整个生命周期。负责调度、状态更新、错误处理和迁移决策。 | 解析图定义、任务调度、状态管理、依赖解决、迁移触发。 |
| 模型适配器层 (LLM Adapter Layer) | 统一不同LLM提供商的API接口和消息格式。 | 将内部通用请求转换为特定LLM的API请求,将响应转换为通用格式。 |
| 共享状态存储 (Shared State Store) | 持久化图的执行状态、节点结果和全局上下文。 | 高效的读写、数据持久化、支持并发访问。 |
| 上下文管理器 (Context Manager) | 负责构建、维护和重构LLM的对话上下文。 | 从存储中加载历史上下文、根据当前任务和目标模型调整上下文格式。 |
| 任务执行器 (Task Executor) | 负责执行特定类型的任务节点(如LLM推理、工具调用、数据处理)。 | 接收节点配置和上下文,执行具体操作,返回结果。 |
| 迁移策略引擎 (Migration Policy Engine) | (可选但推荐) 根据预设规则或实时监控,决定何时何地进行模型迁移。 | 成本阈值、性能指标、错误率、模型能力匹配。 |
数据流与控制流:
- 初始化: 用户提交一个任务图的定义给Orchestrator。Orchestrator将图结构和初始状态保存到Shared State Store。
- 执行循环: Orchestrator不断从Shared State Store加载图状态,识别可运行的节点。
- 任务执行:
- 对于每个可运行节点,Orchestrator通过Context Manager构建该节点所需的执行上下文(包括历史对话、前置节点结果)。
- Orchestrator选择对应的Task Executor(例如
LLMTaskExecutor或ToolUseTaskExecutor)。 - 如果任务是LLM推理,
LLMTaskExecutor会使用当前节点指定的LLM Adapter(例如GPT4oAdapter)。 - LLM Adapter将通用请求转换为特定LLM的API调用,并处理响应。
- 状态更新: 节点执行完成后,其状态和结果会更新到Shared State Store。Context Manager也会更新全局上下文。
- 迁移触发:
- 在任务执行前,或者在某个节点执行完成后,Orchestrator可以调用Migration Policy Engine来评估是否需要迁移。
- 如果决定迁移,Orchestrator会更新图中的相关节点,将其
assigned_model字段设置为目标模型,并重新保存图状态。
- 无缝接力: 当Orchestrator再次调度到这些被迁移的节点时,它们将使用新的
assigned_model,通过对应的LLM Adapter,利用Shared State Store中保存的最新上下文和数据,无缝地继续执行。
4. 图的定义与执行
首先,我们定义图中的核心元素:节点(Node)和图(Graph)。
# task_graph.py
import uuid
from typing import Dict, Any, List, Optional
class Node:
"""
图中的一个节点,代表一个独立的任务。
"""
def __init__(self, node_id: str, task_type: str, config: Dict[str, Any], depends_on: Optional[List[str]] = None):
self.node_id = node_id # 节点的唯一标识符
self.task_type = task_type # 任务类型,如 'llm_inference', 'tool_use', 'data_processing'
self.config = config # 任务的具体配置,如prompt_template, temperature, tool_name等
self.depends_on = depends_on if depends_on is not None else [] # 依赖的前置节点ID列表
self.status: str = "PENDING" # 节点状态: PENDING, RUNNING, COMPLETED, FAILED, SKIPPED
self.result: Optional[Any] = None # 节点执行结果
self.error: Optional[str] = None # 节点失败时的错误信息
self.assigned_model: Optional[str] = None # 分配给该节点的模型ID,如"gpt4o", "claude3_5"
def to_dict(self) -> Dict[str, Any]:
"""将节点对象序列化为字典,便于存储和传输。"""
return {
"node_id": self.node_id,
"task_type": self.task_type,
"config": self.config,
"depends_on": self.depends_on,
"status": self.status,
"result": self.result, # 注意:result需要是可JSON序列化的
"error": self.error,
"assigned_model": self.assigned_model,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Node':
"""从字典反序列化为节点对象。"""
node = cls(
node_id=data["node_id"],
task_type=data["task_type"],
config=data["config"],
depends_on=data.get("depends_on", [])
)
node.status = data.get("status", "PENDING")
node.result = data.get("result")
node.error = data.get("error")
node.assigned_model = data.get("assigned_model")
return node
class Graph:
"""
表示一个有向无环图 (DAG) 任务流程。
"""
def __init__(self, graph_id: str, nodes: List[Node]):
self.graph_id = graph_id
self.nodes: Dict[str, Node] = {node.node_id: node for node in nodes}
self.adjacency_list: Dict[str, List[str]] = self._build_adjacency_list()
def _build_adjacency_list(self) -> Dict[str, List[str]]:
"""构建邻接表,表示节点间的依赖关系(谁依赖谁)。"""
adj = {node_id: [] for node_id in self.nodes}
for node in self.nodes.values():
for dep_id in node.depends_on:
if dep_id in self.nodes:
adj[dep_id].append(node.node_id)
return adj
def get_runnable_nodes(self) -> List[Node]:
"""
获取当前可运行的节点列表。
一个节点可运行,当且仅当其状态为PENDING,且所有依赖的前置节点都已COMPLETED。
"""
runnable = []
for node in self.nodes.values():
if node.status == "PENDING":
all_deps_completed = True
for dep_id in node.depends_on:
if dep_id not in self.nodes or self.nodes[dep_id].status != "COMPLETED":
all_deps_completed = False
break
if all_deps_completed:
runnable.append(node)
return runnable
def get_graph_state(self) -> Dict[str, Any]:
"""获取当前图的完整状态,用于序列化存储。"""
return {
"graph_id": self.graph_id,
"nodes": [node.to_dict() for node in self.nodes.values()]
}
def load_graph_state(self, state: Dict[str, Any]):
"""从存储的状态字典加载图状态,更新节点信息。"""
if self.graph_id != state["graph_id"]:
raise ValueError(f"Graph ID mismatch: expected {self.graph_id}, got {state['graph_id']}")
for node_state_dict in state["nodes"]:
node_id = node_state_dict["node_id"]
if node_id in self.nodes:
node = self.nodes[node_id]
node.status = node_state_dict["status"]
node.result = node_state_dict["result"]
node.error = node_state_dict["error"]
node.assigned_model = node_state_dict["assigned_model"]
else:
# 理论上,图结构在执行中不应改变。如果发生,可能需要更复杂的处理。
# 此处为简化,假设结构固定。
pass
接下来是不同类型的任务执行器,它们负责根据节点配置执行具体的任务。
# executors.py
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import json
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 导入Node类型,用于类型提示
from task_graph import Node
# 前向引用LLMAdapter,因为这里可能需要注入
if TYPE_CHECKING:
from llm_adapters import LLMAdapter
class TaskExecutor(ABC):
"""所有任务执行器的抽象基类。"""
@abstractmethod
async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
"""
执行一个任务节点。
:param node: 要执行的Node对象。
:param context: 包含执行该节点所需的所有上下文信息(如前置节点结果、聊天历史等)。
:return: 任务执行结果。
"""
pass
class LLMTaskExecutor(TaskExecutor):
"""
负责执行LLM推理任务的执行器。
它会根据节点配置,通过LLMAdapter与模型交互。
"""
def __init__(self, llm_adapter: 'LLMAdapter'):
self.llm_adapter = llm_adapter # 注入LLM适配器实例
async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
prompt_template = node.config.get("prompt_template")
if not prompt_template:
raise ValueError(f"Node {node.node_id}: 'prompt_template' not found in config.")
# 动态填充prompt_template中的占位符,例如 {node_result_prev_node_id}
# 这里需要一个更健壮的模板引擎,此处简化处理
try:
filled_prompt = prompt_template.format(**context)
except KeyError as e:
logging.error(f"Node {node.node_id}: Missing context variable for prompt templating: {e}")
raise ValueError(f"Missing context variable: {e}")
# 获取或构建当前LLM的对话历史
messages = context.get("messages", [])
messages.append({"role": "user", "content": filled_prompt})
model_name = node.assigned_model or self.llm_adapter.default_model # 使用节点指定的模型或适配器默认模型
try:
logging.info(f"Node {node.node_id} sending request to model {model_name} with adapter {type(self.llm_adapter).__name__}")
response = await self.llm_adapter.chat_completion(
model=model_name,
messages=messages,
temperature=node.config.get("temperature", 0.7),
max_tokens=node.config.get("max_tokens", 2048)
)
node.result = response
node.status = "COMPLETED"
# 将最新的对话历史也更新到context中,以便ContextManager持久化
context["messages"].append({"role": "assistant", "content": response})
return response
except Exception as e:
node.status = "FAILED"
node.error = str(e)
logging.error(f"LLMTaskExecutor failed for node {node.node_id}: {e}")
raise
class ToolUseTaskExecutor(TaskExecutor):
"""
负责执行工具调用任务的执行器。
它会根据节点配置,从工具注册中心调用相应的工具函数。
"""
def __init__(self, tools_registry: Dict[str, Any]):
self.tools_registry = tools_registry # 工具函数注册中心
async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
tool_name = node.config.get("tool_name")
tool_args = node.config.get("tool_args", {})
if not tool_name:
raise ValueError(f"Node {node.node_id}: 'tool_name' not found in config.")
# 解析工具参数,支持从context中获取值
resolved_args = {}
for k, v in tool_args.items():
if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
# 简单解析,例如 {node_result_prev_node}
key_in_context = v[1:-1]
if key_in_context in context:
resolved_args[k] = context[key_in_context]
else:
logging.warning(f"Node {node.node_id}: Context variable '{key_in_context}' not found for tool arg '{k}'. Using raw value.")
resolved_args[k] = v
else:
resolved_args[k] = v
if tool_name not in self.tools_registry:
node.status = "FAILED"
node.error = f"Tool '{tool_name}' not found in registry."
raise ValueError(node.error)
try:
tool_func = self.tools_registry[tool_name]
logging.info(f"Node {node.node_id} executing tool '{tool_name}' with args: {resolved_args}")
result = await tool_func(**resolved_args)
node.result = result
node.status = "COMPLETED"
return result
except Exception as e:
node.status = "FAILED"
node.error = str(e)
logging.error(f"ToolUseTaskExecutor failed for node {node.node_id}: {e}")
raise
5. 状态的序列化、存储与重构
实现跨模型迁移的关键在于如何有效地捕获、持久化和重构执行状态与上下文。我们将使用Redis作为共享状态存储,并实现一个专门的上下文管理器。
# state_store.py
import redis
import json
from typing import Dict, Any, Optional
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class StateStore:
"""
负责持久化和加载图的执行状态和上下文数据。
使用Redis作为后端存储,因为其读写速度快,适合存储JSON数据。
"""
def __init__(self, host='localhost', port=6379, db=0):
self.redis = redis.Redis(host=host, port=port, db=db, decode_responses=True)
logging.info(f"Initialized Redis StateStore at {host}:{port}/{db}")
def save_graph_state(self, graph_id: str, state: Dict[str, Any]):
"""保存图的当前执行状态。"""
try:
self.redis.set(f"graph:{graph_id}:state", json.dumps(state))
logging.debug(f"Graph {graph_id} state saved.")
except Exception as e:
logging.error(f"Failed to save graph {graph_id} state: {e}")
raise
def load_graph_state(self, graph_id: str) -> Optional[Dict[str, Any]]:
"""加载图的执行状态。"""
try:
state_json = self.redis.get(f"graph:{graph_id}:state")
if state_json:
logging.debug(f"Graph {graph_id} state loaded.")
return json.loads(state_json)
return None
except Exception as e:
logging.error(f"Failed to load graph {graph_id} state: {e}")
raise
def save_context(self, graph_id: str, context: Dict[str, Any]):
"""保存图的全局上下文,包括聊天历史、中间变量等。"""
try:
self.redis.set(f"graph:{graph_id}:context", json.dumps(context))
logging.debug(f"Graph {graph_id} context saved.")
except Exception as e:
logging.error(f"Failed to save graph {graph_id} context: {e}")
raise
def load_context(self, graph_id: str) -> Optional[Dict[str, Any]]:
"""加载图的全局上下文。"""
try:
context_json = self.redis.get(f"graph:{graph_id}:context")
if context_json:
logging.debug(f"Graph {graph_id} context loaded.")
return json.loads(context_json)
return None
except Exception as e:
logging.error(f"Failed to load graph {graph_id} context: {e}")
raise
def delete_graph_data(self, graph_id: str):
"""删除某个图的所有相关数据。"""
try:
self.redis.delete(f"graph:{graph_id}:state", f"graph:{graph_id}:context")
logging.info(f"Graph {graph_id} data deleted from store.")
except Exception as e:
logging.error(f"Failed to delete graph {graph_id} data: {e}")
raise
上下文管理器是实现无缝迁移的核心。它需要能够根据当前任务和目标LLM的特性,动态构建和调整上下文。
# context_manager.py
from typing import Dict, Any, List
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 导入StateStore类型,用于类型提示
if TYPE_CHECKING:
from state_store import StateStore
class ContextManager:
"""
负责管理和构建LLM的对话上下文。
它从StateStore加载历史数据,并根据当前任务和目标LLM的特点进行上下文重构。
"""
def __init__(self, state_store: 'StateStore'):
self.state_store = state_store
logging.info("Initialized ContextManager.")
def build_llm_context(self, graph_id: str, current_node_id: str, model_type: str) -> Dict[str, Any]:
"""
为即将执行的LLM节点构建完整的上下文。
:param graph_id: 当前图的ID。
:param current_node_id: 当前正在执行的节点ID。
:param model_type: 目标模型的类型(如"gpt", "claude"),用于模型特定的上下文调整。
:return: 包含messages列表和其它变量的字典,供LLMTaskExecutor使用。
"""
# 1. 从状态存储加载所有历史数据
graph_state = self.state_store.load_graph_state(graph_id)
session_context = self.state_store.load_context(graph_id) or {}
# 2. 提取并合并关键信息
messages = session_context.get("messages", []) # 历史对话消息
# 将已完成节点的结果添加到上下文,以便后续节点引用
if graph_state:
for node_state in graph_state["nodes"]:
if node_state["status"] == "COMPLETED" and node_state["node_id"] != current_node_id:
# 将结果以特定格式注入到session_context中,供prompt templating使用
session_context[f"node_result_{node_state['node_id']}"] = node_state["result"]
elif node_state["node_id"] == current_node_id and node_state["result"]:
# 如果当前节点已有部分结果(例如从上次失败恢复),也可以考虑注入
session_context[f"node_result_{node_state['node_id']}"] = node_state["result"]
# 3. 进行模型特定的上下文调整
# 这是实现“无缝迁移”最关键的一步。不同模型对系统提示词、消息格式有不同偏好。
final_messages = []
system_message_content = ""
# 提取或构建系统消息
existing_system_messages = [m for m in messages if m["role"] == "system"]
if existing_system_messages:
system_message_content = existing_system_messages[0]["content"]
# 移除已有的系统消息,因为Claude可能需要通过'system'参数传递
messages = [m for m in messages if m["role"] != "system"]
# 针对不同模型类型构建消息列表
if model_type == "gpt":
# GPT模型通常接受一个显式的"system"角色消息
if not system_message_content:
system_message_content = "你是一个专业且乐于助人的AI助手,请严格按照指令和历史对话进行推理和回复。"
final_messages.append({"role": "system", "content": system_message_content})
final_messages.extend(messages) # 其他用户/助手消息直接追加
elif model_type == "claude":
# Claude模型通常通过`system`参数传递系统提示词,而不是在`messages`列表中包含"system"角色
# 并且其`messages`列表不能以助手消息开头
if not system_message_content:
system_message_content = "你是一个专业且乐于助人的AI助手,请严格按照指令和历史对话进行推理和回复。"
# Claude Messages API要求消息列表必须是用户-助手交替的,且不能以助手消息开头
# 这里需要对历史消息进行清理和校验
cleaned_messages = []
for i, msg in enumerate(messages):
if i == 0 and msg["role"] == "assistant":
# 如果第一条是助手消息,说明上下文可能不完整或格式不符,需要特殊处理
# 实际场景中,可能需要一个更智能的策略,例如忽略或尝试修复
logging.warning(f"Claude context: First message is assistant role. Potentially invalid for Claude API.")
# 我们可以选择跳过这条,或者将其内容合并到后续的用户消息中
continue
cleaned_messages.append(msg)
# 如果清理后消息列表仍以助手开头,或为空,需要插入一个占位用户消息
if not cleaned_messages or cleaned_messages[0]["role"] == "assistant":
# 这种情况下,可能需要一个默认的用户开始语
logging.warning("Claude context: No valid user message to start the conversation. Inserting a default.")
final_messages.append({"role": "user", "content": "请继续我们之前的讨论。"})
else:
final_messages.extend(cleaned_messages)
# 将系统消息单独存储,以便LLMAdapter处理
session_context["system_message"] = system_message_content
else:
# 对于未知模型类型,直接使用原始消息列表
final_messages.extend(messages)
if system_message_content:
final_messages.insert(0, {"role": "system", "content": system_message_content})
# 4. 返回构建好的上下文
# 这里的session_context包含了除messages之外的所有变量,供prompt templating使用
return {
"messages": final_messages,
**{k: v for k, v in session_context.items() if k != "messages"} # 排除messages,因为它已经处理过了
}
def update_llm_context(self, graph_id: str, new_messages: List[Dict[str, Any]], additional_vars: Dict[str, Any]):
"""
更新图的全局上下文。
:param graph_id: 当前图的ID。
:param new_messages: 最新的对话消息列表。
:param additional_vars: 需要添加到上下文中的额外变量(如新节点的结果)。
"""
session_context = self.state_store.load_context(graph_id) or {}
session_context["messages"] = new_messages # 用最新消息覆盖
session_context.update(additional_vars) # 合并其他变量
self.state_store.save_context(graph_id, session_context)
logging.debug(f"Graph {graph_id} context updated.")
6. 模型适配器层:统一接口与差异处理
不同LLM提供商的API接口和消息格式存在差异。模型适配器层的作用就是将这些差异封装起来,为上层提供统一的接口。
# llm_adapters.py
from abc import ABC, abstractmethod
from typing import Dict, Any, List
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class LLMAdapter(ABC):
"""所有LLM适配器的抽象基类,定义统一的LLM交互接口。"""
def __init__(self, api_key: str, default_model: str):
self.api_key = api_key
self.default_model = default_model
logging.info(f"Initialized LLMAdapter for {type(self).__name__} with default model {default_model}")
@abstractmethod
async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
"""
执行聊天补全请求。
:param model: 要使用的模型名称。
:param messages: 对话历史消息列表。
:param kwargs: 其他模型特定的参数(如temperature, max_tokens)。
:return: LLM生成的文本回复。
"""
pass
class GPT4oAdapter(LLMAdapter):
"""OpenAI GPT-4o 模型的适配器。"""
def __init__(self, api_key: str):
super().__init__(api_key, "gpt-4o")
try:
from openai import AsyncOpenAI
self.client = AsyncOpenAI(api_key=api_key)
except ImportError:
logging.error("OpenAI library not found. Please install it with `pip install openai`.")
raise
async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
try:
response = await self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs # 传递温度、max_tokens等参数
)
return response.choices[0].message.content
except Exception as e:
logging.error(f"GPT-4o API call failed: {e}")
raise
class Claude3_5_Adapter(LLMAdapter):
"""Anthropic Claude 3.5 Sonnet 模型的适配器。"""
def __init__(self, api_key: str):
super().__init__(api_key, "claude-3-5-sonnet-20240620")
try:
from anthropic import Anthropic
self.client = Anthropic(api_key=api_key)
except ImportError:
logging.error("Anthropic library not found. Please install it with `pip install anthropic`.")
raise
async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
# Anthropic Messages API 的消息格式与OpenAI略有不同
# 它通过一个独立的 `system` 参数来传递系统提示词
# 并且 `messages` 列表不能包含 `system` 角色,也不能以 `assistant` 角色开始
system_message = None
anthropic_messages = []
for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
anthropic_messages.append(msg)
# 确保messages列表不为空,且不以assistant角色开头
if not anthropic_messages or anthropic_messages[0]["role"] == "assistant":
# 这是一个需要上下文管理器在构建时就处理好的问题
# 如果走到这里,说明上下文管理器没有正确处理或LLMTaskExecutor直接传了不合规的消息
# 简单处理:如果为空或以助手开头,插入一个默认用户消息
if not anthropic_messages:
anthropic_messages.append({"role": "user", "content": "请继续执行任务。"})
elif anthropic_messages[0]["role"] == "assistant":
# 如果以助手开头,插入一个用户消息作为承接
anthropic_messages.insert(0, {"role": "user", "content": "好的,我理解了,请基于此继续。"})
try:
response = await self.client.messages.create(
model=model,
max_tokens=kwargs.get("max_tokens", 1024), # Claude需要明确的max_tokens
messages=anthropic_messages,
system=system_message, # 将系统提示词通过system参数传递
temperature=kwargs.get("temperature", 0.7)
)
# Claude的响应内容在content列表里,可能包含多个text块
return "".join(block.text for block in response.content if block.type == "text")
except Exception as e:
logging.error(f"Claude 3.5 API call failed: {e}")
raise
7. 迁移策略与执行流程
迁移策略决定了何时、为何进行模型切换。而执行流程则是由Orchestrator来驱动。
何时迁移?
- 成本优化: 复杂、高推理的任务由昂贵但强大的模型(如GPT-4o)完成,后续的生成、润色任务迁移到更经济的模型(如Claude 3.5 Sonnet)。
- 性能/能力匹配: 某个模型在特定任务类型(如代码生成、创意写作、数学推理)上表现更佳。
- 负载均衡/高可用: 当一个模型的API出现延迟或故障时,自动切换到另一个可用模型。
- 用户偏好/策略: 根据用户或业务规则,显式指定某些任务由特定模型执行。
Orchestrator的核心逻辑:
# orchestrator.py
import asyncio
import uuid
import logging
from typing import Dict, Any, List, Optional, TYPE_CHECKING
# 导入所有模块,用于类型提示和实例化
from task_graph import Graph, Node
from executors import LLMTaskExecutor, ToolUseTaskExecutor, TaskExecutor
from state_store import StateStore
from context_manager import ContextManager
from llm_adapters import LLMAdapter, GPT4oAdapter, Claude3_5_Adapter
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class Orchestrator:
"""
任务编排器,负责管理图的初始化、执行、状态持久化和模型迁移。
"""
def __init__(self, state_store: StateStore, context_manager: ContextManager,
llm_adapters: Dict[str, LLMAdapter], tools_registry: Dict[str, Any]):
self.state_store = state_store
self.context_manager = context_manager
self.llm_adapters = llm_adapters # 存储所有可用的LLM适配器实例
# 初始化任务执行器,根据任务类型注册
self.task_executors: Dict[str, TaskExecutor] = {
"llm_inference": LLMTaskExecutor(llm_adapters["gpt4o"]), # 默认LLM执行器使用gpt4o适配器
"tool_use": ToolUseTaskExecutor(tools_registry),
# 可以添加更多任务类型,如 "data_processing": DataProcessingExecutor(...)
}
self.current_graph: Optional[Graph] = None # 当前活跃的图实例
logging.info("Orchestrator initialized.")
async def initialize_graph(self, graph_definition: Dict[str, Any]) -> str:
"""
初始化一个新的任务图。
:param graph_definition: 图的定义字典。
:return: 初始化后的图ID。
"""
graph_id = graph_definition.get("graph_id", str(uuid.uuid4()))
nodes = [Node.from_dict(node_data) for node_data in graph_definition["nodes"]]
self.current_graph = Graph(graph_id, nodes)
self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
self.context_manager.update_llm_context(graph_id, [], {}) # 初始化空上下文
logging.info(f"Graph {graph_id} initialized.")
return graph_id
async def resume_graph(self, graph_id: str) -> bool:
"""
从持久化状态恢复一个任务图的执行。
:param graph_id: 要恢复的图ID。
:return: True如果恢复成功,否则False。
"""
state = self.state_store.load_graph_state(graph_id)
if not state:
logging.error(f"Graph state for {graph_id} not found. Cannot resume.")
return False
# 假设图的结构定义在别处(例如数据库或配置文件)
# 这里为了简化,我们假设 Orchestrator 启动时能够获取到所有图的初始定义
# 如果current_graph为空,需要从初始定义重建图结构,再加载状态
if not self.current_graph or self.current_graph.graph_id != graph_id:
# 实际场景中,这里需要从某个地方加载graph_id对应的初始图结构定义
# 然后用 state 更新其内部节点状态
logging.warning(f"Graph {graph_id} not active. Attempting to load initial definition and state.")
# Dummy: For this example, let's just assume initial definition is available
# In a real system, you'd fetch the original graph_definition from a persistent store
# For now, we'll error if current_graph is not already set for that ID.
logging.error(f"Cannot resume graph {graph_id} without its initial definition loaded into Orchestrator.")
return False
self.current_graph.load_graph_state(state)
logging.info(f"Graph {graph_id} resumed from state.")
return True
async def execute_graph(self, graph_id: str, initial_input: Dict[str, Any] = None):
"""
执行整个任务图。
:param graph_id: 要执行的图ID。
:param initial_input: 初始输入数据,会添加到图的全局上下文。
"""
if not self.current_graph or self.current_graph.graph_id != graph_id:
if not await self.resume_graph(graph_id):
logging.error(f"Could not load or resume graph {graph_id}. Aborting execution.")
return
if initial_input:
self.context_manager.update_llm_context(graph_id, [], initial_input)
logging.info(f"Initial input added to context for graph {graph_id}.")
while True:
runnable_nodes = self.current_graph.get_runnable_nodes()
if not runnable_nodes:
# 检查所有节点是否都已完成
if all(node.status in ["COMPLETED", "SKIPPED", "FAILED"] for node in self.current_graph.nodes.values()):
logging.info(f"Graph {graph_id} execution completed (or all pending nodes failed).")
break
else:
logging.warning(f"Graph {graph_id}: No runnable nodes, but some nodes are still PENDING. Possible deadlock or unhandled dependencies. Exiting.")
break # 避免无限循环
tasks = []
for node in runnable_nodes:
node.status = "RUNNING" # 标记节点为运行中
logging.info(f"Scheduling node {node.node_id} ({node.task_type}) for execution.")
# 获取节点上下文,这会涉及从存储加载和根据模型类型调整
node_context = self.context_manager.build_llm_context(
graph_id, node.node_id, self._get_model_type(node.assigned_model)
)
# 如果是LLM任务,动态设置LLMTaskExecutor使用的适配器
if node.task_type == "llm_inference":
model_id = node.assigned_model or self.llm_adapters["gpt4o"].default_model # 如果未指定,默认使用gpt4o
if model_id not in self.llm_adapters:
logging.error(f"LLM adapter for model ID '{model_id}' not found. Node {node.node_id} will fail.")
node.status = "FAILED"
node.error = f"LLM adapter '{model_id}' missing."
self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
continue # 跳过当前节点,因为它肯定会失败
# 动态切换LLMTaskExecutor内部的LLMAdapter实例
self.task_executors["llm_inference"].llm_adapter = self.llm_adapters[model_id]
logging.info(f"Node {node.node_id} will use LLM adapter: {model_id}")
tasks.append(self._execute_node_task(node, node_context))
# 并行执行所有可运行的节点
await asyncio.gather(*tasks, return_exceptions=True) # return_exceptions=True 确保一个任务失败不会中断其他任务
# 每次批次执行后,保存图的最新状态
self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
logging.info(f"Graph {graph_id} state saved after batch execution.")
# 为了避免忙等,可以加一个短暂的延迟
await asyncio.sleep(0.1)
async def _execute_node_task(self, node: Node, context: Dict[str, Any]):
"""
内部方法:执行单个任务节点。
:param node: 要执行的Node对象。
:param context: 节点执行所需的上下文。
"""
try:
executor = self.task_executors.get(node.task_type)
if not executor:
raise ValueError(f"No executor found for task type: {node.task_type}")
result = await executor.execute(node, context)
logging.info(f"Node {node.node_id} ({node.task_type}) completed.")
# 更新全局上下文:将当前节点的结果和最新的消息历史(如果LLMTaskExecutor更新了)保存
# 注意:LLMTaskExecutor应该在执行时更新了context['messages']
self.context_manager.update_llm_context(
self.current_graph.graph_id,
context.get("messages", []), # 传入LLMTaskExecutor可能修改过的消息列表
{f"node_result_{node.node_id}": result} # 将节点结果作为变量保存
)
except Exception as e:
logging.error(f"Node {node.node_id} ({node.task_type}) failed: {e}")
node.status = "FAILED"
node.error = str(e)
finally:
# 无论成功失败,都确保节点状态被持久化
self.state_store.save_graph_state(self.current_graph.graph_id, self.current_graph.get_graph_state())
def _get_model_type(self, model_id: Optional[str]) -> str:
"""根据模型ID判断模型类型,以便上下文管理器做模型特定处理。"""
if model_id:
if "gpt" in model_id.lower():
return "gpt"
if "claude" in model_id.lower():
return "claude"
return "unknown" # 默认类型
async def migrate_node_to_model(self, graph_id: str, node_id: str, target_model_id: str) -> bool:
"""
将指定节点的目标LLM模型更改为另一个。
这可以在节点执行前调用,实现动态迁移。
:param graph_id: 图ID。
:param node_id: 要迁移的节点ID。
:param target_model_id: 目标模型ID(如"claude3_5")。
:return: True如果迁移成功,否则False。
"""
if not self.current_graph or self.current_graph.graph_id != graph_id:
logging.error(f"Graph {graph_id} not active. Cannot migrate node {node_id}.")
return False
node = self.current_graph.nodes.get(node_id)
if not node