深度挑战:如何实现一个‘跨模型迁移’的图——让逻辑在 GPT-4o 上运行一半后,无缝迁移到 Claude 3.5 tiếp tục执行?

各位编程专家、AI爱好者,大家好!

今天,我们将深入探讨一个前沿且极具挑战性的话题:如何实现一个“跨模型迁移”的图执行系统。想象一下,你的复杂逻辑流程在GPT-4o上运行了一半,但出于成本、性能、甚至模型特性偏好等原因,你需要它无缝地迁移到Claude 3.5上继续执行,而无需从头再来。这听起来像科幻小说,但在我们构建更灵活、更具韧性的AI系统时,它正成为一个迫切的需求。

这不仅仅是简单地切换一个API端点,更是一场关于状态、上下文、语义鸿沟的深度挑战。今天,我将作为一名编程专家,为大家揭示实现这一目标背后的原理、架构与代码实践。

1. 引言:跨模型迁移的挑战与机遇

在大型语言模型(LLM)飞速发展的今天,我们面临着前所未有的选择。从OpenAI的GPT系列到Anthropic的Claude系列,再到Google、Meta以及开源社区的众多模型,每个模型都有其独特的优势、定价策略、性能曲线和偏好。这种多样性既是福音,也带来了新的工程挑战:如何充分利用它们,而不是被特定模型绑定?

设想一个复杂的AI应用场景:

  1. 阶段一:高精度、复杂推理。 用户提出一个需要深入理解和多步骤逻辑推理的问题。你可能倾向于使用像GPT-4o这样在复杂推理方面表现卓越的模型。
  2. 阶段二:创意生成、长文本创作。 基于第一阶段的分析结果,你需要生成一份详细的报告、创意文案或代码。此时,你可能更倾向于使用像Claude 3.5 Sonnet这样在长上下文处理和连贯性方面表现出色的模型,同时可能也为了成本效益考虑。

问题来了:如果这两个阶段是紧密耦合的,后一阶段需要前一阶段的所有中间结果和上下文,我们如何才能实现这种模型间的“接力跑”,而不是每次都从零开始?这就是我们今天探讨的“跨模型迁移”问题。

为什么是“深度挑战”?

核心难点在于LLM本身是无状态的API调用。每次调用都是一次独立的请求/响应,虽然可以通过传递messages数组来模拟对话状态,但这仅仅是上下文的重建,而非底层计算状态的共享。不同LLM的API接口、消息格式、系统提示词处理方式、甚至对指令的语义理解都存在细微差异。要实现无缝迁移,我们需要:

  • 状态的全面捕获: 不仅是对话历史,还包括任务执行进度、中间结果、工具调用记录等。
  • 上下文的通用表示与重建: 能够将一个模型理解的上下文,转化为另一个模型能够理解并继续的上下文。
  • 模型间的语义对齐: 确保迁移后,新模型能够准确地“接管”旧模型的思维,保持逻辑连贯性。

解决这些挑战,将为我们构建更具弹性、成本效益和智能的AI系统打开大门。

2. 核心概念与技术基石

为了实现跨模型迁移,我们需要建立一套坚实的技术基石。

2.1 图执行范式 (Graph-based Execution)

将复杂的业务逻辑或AI工作流建模为一个有向无环图(DAG),是解决这一问题的核心思路。

  • 节点 (Node): 代表一个独立的任务或操作,例如一次LLM推理、一次工具调用、一个数据处理步骤等。
  • 边 (Edge): 代表节点之间的依赖关系和数据流。一个节点的输出可以作为另一个节点的输入。

为什么要用图?

  • 模块化: 将大问题分解为小任务,每个任务封装在节点中。
  • 可视化与管理: 任务流程清晰可见,易于追踪和调试。
  • 状态追踪: 每个节点的执行状态(待执行、运行中、已完成、失败)可以被独立管理和持久化。
  • 并行化: 无依赖关系的节点可以并行执行。
  • 可恢复性: 当系统崩溃或需要迁移时,可以从图的任意一个已知状态恢复执行。

2.2 状态管理 (State Management)

迁移的本质是状态的转移。我们需要捕获并持久化以下几种状态:

  • 执行状态 (Execution State): 图中每个节点的当前状态(如PENDING, RUNNING, COMPLETED, FAILED)。
  • 数据状态 (Data State): 节点执行产生的中间结果。这些结果需要是可序列化的,并且能够被后续节点引用。
  • 上下文状态 (Context State): 这是LLM特有的。它包括了到目前为止的完整对话历史(messages数组),以及可能影响后续推理的任何关键变量或指令。

2.3 中间表示 (Intermediate Representation – IR)

为了实现跨模型的兼容性,我们需要一个与具体LLM无关的、标准化的中间表示。JSON或YAML是理想的序列化格式。我们的图结构、节点配置、节点结果都将以这种通用格式进行存储和交换。

2.4 代理模式 (Agentic Workflow)

一个外部的“任务编排器”(Orchestrator)将扮演核心代理的角色。它不直接执行LLM的推理,而是负责:

  • 解析图定义。
  • 调度节点执行。
  • 管理状态的持久化与加载。
  • 决定何时进行模型迁移。
  • 将上下文从一个模型转换为另一个模型。

3. 架构设计:实现跨模型迁移的蓝图

为了实现上述目标,我们设计了一个分层架构,确保职责分离和高内聚低耦合。

系统组件概述:

组件名称 职责 关键能力
任务编排器 (Orchestrator) 核心控制器,管理图的整个生命周期。负责调度、状态更新、错误处理和迁移决策。 解析图定义、任务调度、状态管理、依赖解决、迁移触发。
模型适配器层 (LLM Adapter Layer) 统一不同LLM提供商的API接口和消息格式。 将内部通用请求转换为特定LLM的API请求,将响应转换为通用格式。
共享状态存储 (Shared State Store) 持久化图的执行状态、节点结果和全局上下文。 高效的读写、数据持久化、支持并发访问。
上下文管理器 (Context Manager) 负责构建、维护和重构LLM的对话上下文。 从存储中加载历史上下文、根据当前任务和目标模型调整上下文格式。
任务执行器 (Task Executor) 负责执行特定类型的任务节点(如LLM推理、工具调用、数据处理)。 接收节点配置和上下文,执行具体操作,返回结果。
迁移策略引擎 (Migration Policy Engine) (可选但推荐) 根据预设规则或实时监控,决定何时何地进行模型迁移。 成本阈值、性能指标、错误率、模型能力匹配。

数据流与控制流:

  1. 初始化: 用户提交一个任务图的定义给Orchestrator。Orchestrator将图结构和初始状态保存到Shared State Store。
  2. 执行循环: Orchestrator不断从Shared State Store加载图状态,识别可运行的节点。
  3. 任务执行:
    • 对于每个可运行节点,Orchestrator通过Context Manager构建该节点所需的执行上下文(包括历史对话、前置节点结果)。
    • Orchestrator选择对应的Task Executor(例如LLMTaskExecutorToolUseTaskExecutor)。
    • 如果任务是LLM推理,LLMTaskExecutor会使用当前节点指定的LLM Adapter(例如GPT4oAdapter)。
    • LLM Adapter将通用请求转换为特定LLM的API调用,并处理响应。
  4. 状态更新: 节点执行完成后,其状态和结果会更新到Shared State Store。Context Manager也会更新全局上下文。
  5. 迁移触发:
    • 在任务执行前,或者在某个节点执行完成后,Orchestrator可以调用Migration Policy Engine来评估是否需要迁移。
    • 如果决定迁移,Orchestrator会更新图中的相关节点,将其assigned_model字段设置为目标模型,并重新保存图状态。
  6. 无缝接力: 当Orchestrator再次调度到这些被迁移的节点时,它们将使用新的assigned_model,通过对应的LLM Adapter,利用Shared State Store中保存的最新上下文和数据,无缝地继续执行。

4. 图的定义与执行

首先,我们定义图中的核心元素:节点(Node)和图(Graph)。

# task_graph.py
import uuid
from typing import Dict, Any, List, Optional

class Node:
    """
    图中的一个节点,代表一个独立的任务。
    """
    def __init__(self, node_id: str, task_type: str, config: Dict[str, Any], depends_on: Optional[List[str]] = None):
        self.node_id = node_id  # 节点的唯一标识符
        self.task_type = task_type  # 任务类型,如 'llm_inference', 'tool_use', 'data_processing'
        self.config = config  # 任务的具体配置,如prompt_template, temperature, tool_name等
        self.depends_on = depends_on if depends_on is not None else []  # 依赖的前置节点ID列表

        self.status: str = "PENDING"  # 节点状态: PENDING, RUNNING, COMPLETED, FAILED, SKIPPED
        self.result: Optional[Any] = None  # 节点执行结果
        self.error: Optional[str] = None  # 节点失败时的错误信息
        self.assigned_model: Optional[str] = None  # 分配给该节点的模型ID,如"gpt4o", "claude3_5"

    def to_dict(self) -> Dict[str, Any]:
        """将节点对象序列化为字典,便于存储和传输。"""
        return {
            "node_id": self.node_id,
            "task_type": self.task_type,
            "config": self.config,
            "depends_on": self.depends_on,
            "status": self.status,
            "result": self.result,  # 注意:result需要是可JSON序列化的
            "error": self.error,
            "assigned_model": self.assigned_model,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'Node':
        """从字典反序列化为节点对象。"""
        node = cls(
            node_id=data["node_id"],
            task_type=data["task_type"],
            config=data["config"],
            depends_on=data.get("depends_on", [])
        )
        node.status = data.get("status", "PENDING")
        node.result = data.get("result")
        node.error = data.get("error")
        node.assigned_model = data.get("assigned_model")
        return node

class Graph:
    """
    表示一个有向无环图 (DAG) 任务流程。
    """
    def __init__(self, graph_id: str, nodes: List[Node]):
        self.graph_id = graph_id
        self.nodes: Dict[str, Node] = {node.node_id: node for node in nodes}
        self.adjacency_list: Dict[str, List[str]] = self._build_adjacency_list()

    def _build_adjacency_list(self) -> Dict[str, List[str]]:
        """构建邻接表,表示节点间的依赖关系(谁依赖谁)。"""
        adj = {node_id: [] for node_id in self.nodes}
        for node in self.nodes.values():
            for dep_id in node.depends_on:
                if dep_id in self.nodes:
                    adj[dep_id].append(node.node_id)
        return adj

    def get_runnable_nodes(self) -> List[Node]:
        """
        获取当前可运行的节点列表。
        一个节点可运行,当且仅当其状态为PENDING,且所有依赖的前置节点都已COMPLETED。
        """
        runnable = []
        for node in self.nodes.values():
            if node.status == "PENDING":
                all_deps_completed = True
                for dep_id in node.depends_on:
                    if dep_id not in self.nodes or self.nodes[dep_id].status != "COMPLETED":
                        all_deps_completed = False
                        break
                if all_deps_completed:
                    runnable.append(node)
        return runnable

    def get_graph_state(self) -> Dict[str, Any]:
        """获取当前图的完整状态,用于序列化存储。"""
        return {
            "graph_id": self.graph_id,
            "nodes": [node.to_dict() for node in self.nodes.values()]
        }

    def load_graph_state(self, state: Dict[str, Any]):
        """从存储的状态字典加载图状态,更新节点信息。"""
        if self.graph_id != state["graph_id"]:
            raise ValueError(f"Graph ID mismatch: expected {self.graph_id}, got {state['graph_id']}")

        for node_state_dict in state["nodes"]:
            node_id = node_state_dict["node_id"]
            if node_id in self.nodes:
                node = self.nodes[node_id]
                node.status = node_state_dict["status"]
                node.result = node_state_dict["result"]
                node.error = node_state_dict["error"]
                node.assigned_model = node_state_dict["assigned_model"]
            else:
                # 理论上,图结构在执行中不应改变。如果发生,可能需要更复杂的处理。
                # 此处为简化,假设结构固定。
                pass

接下来是不同类型的任务执行器,它们负责根据节点配置执行具体的任务。

# executors.py
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
import json
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 导入Node类型,用于类型提示
from task_graph import Node 
# 前向引用LLMAdapter,因为这里可能需要注入
if TYPE_CHECKING:
    from llm_adapters import LLMAdapter 

class TaskExecutor(ABC):
    """所有任务执行器的抽象基类。"""
    @abstractmethod
    async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
        """
        执行一个任务节点。
        :param node: 要执行的Node对象。
        :param context: 包含执行该节点所需的所有上下文信息(如前置节点结果、聊天历史等)。
        :return: 任务执行结果。
        """
        pass

class LLMTaskExecutor(TaskExecutor):
    """
    负责执行LLM推理任务的执行器。
    它会根据节点配置,通过LLMAdapter与模型交互。
    """
    def __init__(self, llm_adapter: 'LLMAdapter'):
        self.llm_adapter = llm_adapter # 注入LLM适配器实例

    async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
        prompt_template = node.config.get("prompt_template")
        if not prompt_template:
            raise ValueError(f"Node {node.node_id}: 'prompt_template' not found in config.")

        # 动态填充prompt_template中的占位符,例如 {node_result_prev_node_id}
        # 这里需要一个更健壮的模板引擎,此处简化处理
        try:
            filled_prompt = prompt_template.format(**context) 
        except KeyError as e:
            logging.error(f"Node {node.node_id}: Missing context variable for prompt templating: {e}")
            raise ValueError(f"Missing context variable: {e}")

        # 获取或构建当前LLM的对话历史
        messages = context.get("messages", []) 
        messages.append({"role": "user", "content": filled_prompt})

        model_name = node.assigned_model or self.llm_adapter.default_model # 使用节点指定的模型或适配器默认模型

        try:
            logging.info(f"Node {node.node_id} sending request to model {model_name} with adapter {type(self.llm_adapter).__name__}")
            response = await self.llm_adapter.chat_completion(
                model=model_name,
                messages=messages,
                temperature=node.config.get("temperature", 0.7),
                max_tokens=node.config.get("max_tokens", 2048)
            )

            node.result = response 
            node.status = "COMPLETED"
            # 将最新的对话历史也更新到context中,以便ContextManager持久化
            context["messages"].append({"role": "assistant", "content": response})
            return response
        except Exception as e:
            node.status = "FAILED"
            node.error = str(e)
            logging.error(f"LLMTaskExecutor failed for node {node.node_id}: {e}")
            raise

class ToolUseTaskExecutor(TaskExecutor):
    """
    负责执行工具调用任务的执行器。
    它会根据节点配置,从工具注册中心调用相应的工具函数。
    """
    def __init__(self, tools_registry: Dict[str, Any]):
        self.tools_registry = tools_registry # 工具函数注册中心

    async def execute(self, node: Node, context: Dict[str, Any]) -> Any:
        tool_name = node.config.get("tool_name")
        tool_args = node.config.get("tool_args", {})

        if not tool_name:
            raise ValueError(f"Node {node.node_id}: 'tool_name' not found in config.")

        # 解析工具参数,支持从context中获取值
        resolved_args = {}
        for k, v in tool_args.items():
            if isinstance(v, str) and v.startswith("{") and v.endswith("}"):
                # 简单解析,例如 {node_result_prev_node}
                key_in_context = v[1:-1]
                if key_in_context in context:
                    resolved_args[k] = context[key_in_context]
                else:
                    logging.warning(f"Node {node.node_id}: Context variable '{key_in_context}' not found for tool arg '{k}'. Using raw value.")
                    resolved_args[k] = v
            else:
                resolved_args[k] = v

        if tool_name not in self.tools_registry:
            node.status = "FAILED"
            node.error = f"Tool '{tool_name}' not found in registry."
            raise ValueError(node.error)

        try:
            tool_func = self.tools_registry[tool_name]
            logging.info(f"Node {node.node_id} executing tool '{tool_name}' with args: {resolved_args}")
            result = await tool_func(**resolved_args)
            node.result = result
            node.status = "COMPLETED"
            return result
        except Exception as e:
            node.status = "FAILED"
            node.error = str(e)
            logging.error(f"ToolUseTaskExecutor failed for node {node.node_id}: {e}")
            raise

5. 状态的序列化、存储与重构

实现跨模型迁移的关键在于如何有效地捕获、持久化和重构执行状态与上下文。我们将使用Redis作为共享状态存储,并实现一个专门的上下文管理器。

# state_store.py
import redis
import json
from typing import Dict, Any, Optional
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class StateStore:
    """
    负责持久化和加载图的执行状态和上下文数据。
    使用Redis作为后端存储,因为其读写速度快,适合存储JSON数据。
    """
    def __init__(self, host='localhost', port=6379, db=0):
        self.redis = redis.Redis(host=host, port=port, db=db, decode_responses=True)
        logging.info(f"Initialized Redis StateStore at {host}:{port}/{db}")

    def save_graph_state(self, graph_id: str, state: Dict[str, Any]):
        """保存图的当前执行状态。"""
        try:
            self.redis.set(f"graph:{graph_id}:state", json.dumps(state))
            logging.debug(f"Graph {graph_id} state saved.")
        except Exception as e:
            logging.error(f"Failed to save graph {graph_id} state: {e}")
            raise

    def load_graph_state(self, graph_id: str) -> Optional[Dict[str, Any]]:
        """加载图的执行状态。"""
        try:
            state_json = self.redis.get(f"graph:{graph_id}:state")
            if state_json:
                logging.debug(f"Graph {graph_id} state loaded.")
                return json.loads(state_json)
            return None
        except Exception as e:
            logging.error(f"Failed to load graph {graph_id} state: {e}")
            raise

    def save_context(self, graph_id: str, context: Dict[str, Any]):
        """保存图的全局上下文,包括聊天历史、中间变量等。"""
        try:
            self.redis.set(f"graph:{graph_id}:context", json.dumps(context))
            logging.debug(f"Graph {graph_id} context saved.")
        except Exception as e:
            logging.error(f"Failed to save graph {graph_id} context: {e}")
            raise

    def load_context(self, graph_id: str) -> Optional[Dict[str, Any]]:
        """加载图的全局上下文。"""
        try:
            context_json = self.redis.get(f"graph:{graph_id}:context")
            if context_json:
                logging.debug(f"Graph {graph_id} context loaded.")
                return json.loads(context_json)
            return None
        except Exception as e:
            logging.error(f"Failed to load graph {graph_id} context: {e}")
            raise

    def delete_graph_data(self, graph_id: str):
        """删除某个图的所有相关数据。"""
        try:
            self.redis.delete(f"graph:{graph_id}:state", f"graph:{graph_id}:context")
            logging.info(f"Graph {graph_id} data deleted from store.")
        except Exception as e:
            logging.error(f"Failed to delete graph {graph_id} data: {e}")
            raise

上下文管理器是实现无缝迁移的核心。它需要能够根据当前任务和目标LLM的特性,动态构建和调整上下文。

# context_manager.py
from typing import Dict, Any, List
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 导入StateStore类型,用于类型提示
if TYPE_CHECKING:
    from state_store import StateStore

class ContextManager:
    """
    负责管理和构建LLM的对话上下文。
    它从StateStore加载历史数据,并根据当前任务和目标LLM的特点进行上下文重构。
    """
    def __init__(self, state_store: 'StateStore'):
        self.state_store = state_store
        logging.info("Initialized ContextManager.")

    def build_llm_context(self, graph_id: str, current_node_id: str, model_type: str) -> Dict[str, Any]:
        """
        为即将执行的LLM节点构建完整的上下文。
        :param graph_id: 当前图的ID。
        :param current_node_id: 当前正在执行的节点ID。
        :param model_type: 目标模型的类型(如"gpt", "claude"),用于模型特定的上下文调整。
        :return: 包含messages列表和其它变量的字典,供LLMTaskExecutor使用。
        """
        # 1. 从状态存储加载所有历史数据
        graph_state = self.state_store.load_graph_state(graph_id)
        session_context = self.state_store.load_context(graph_id) or {}

        # 2. 提取并合并关键信息
        messages = session_context.get("messages", []) # 历史对话消息

        # 将已完成节点的结果添加到上下文,以便后续节点引用
        if graph_state:
            for node_state in graph_state["nodes"]:
                if node_state["status"] == "COMPLETED" and node_state["node_id"] != current_node_id:
                    # 将结果以特定格式注入到session_context中,供prompt templating使用
                    session_context[f"node_result_{node_state['node_id']}"] = node_state["result"]
                elif node_state["node_id"] == current_node_id and node_state["result"]:
                    # 如果当前节点已有部分结果(例如从上次失败恢复),也可以考虑注入
                    session_context[f"node_result_{node_state['node_id']}"] = node_state["result"]

        # 3. 进行模型特定的上下文调整
        # 这是实现“无缝迁移”最关键的一步。不同模型对系统提示词、消息格式有不同偏好。
        final_messages = []
        system_message_content = ""

        # 提取或构建系统消息
        existing_system_messages = [m for m in messages if m["role"] == "system"]
        if existing_system_messages:
            system_message_content = existing_system_messages[0]["content"]
            # 移除已有的系统消息,因为Claude可能需要通过'system'参数传递
            messages = [m for m in messages if m["role"] != "system"]

        # 针对不同模型类型构建消息列表
        if model_type == "gpt":
            # GPT模型通常接受一个显式的"system"角色消息
            if not system_message_content:
                system_message_content = "你是一个专业且乐于助人的AI助手,请严格按照指令和历史对话进行推理和回复。"
            final_messages.append({"role": "system", "content": system_message_content})
            final_messages.extend(messages) # 其他用户/助手消息直接追加
        elif model_type == "claude":
            # Claude模型通常通过`system`参数传递系统提示词,而不是在`messages`列表中包含"system"角色
            # 并且其`messages`列表不能以助手消息开头
            if not system_message_content:
                system_message_content = "你是一个专业且乐于助人的AI助手,请严格按照指令和历史对话进行推理和回复。"

            # Claude Messages API要求消息列表必须是用户-助手交替的,且不能以助手消息开头
            # 这里需要对历史消息进行清理和校验
            cleaned_messages = []
            for i, msg in enumerate(messages):
                if i == 0 and msg["role"] == "assistant":
                    # 如果第一条是助手消息,说明上下文可能不完整或格式不符,需要特殊处理
                    # 实际场景中,可能需要一个更智能的策略,例如忽略或尝试修复
                    logging.warning(f"Claude context: First message is assistant role. Potentially invalid for Claude API.")
                    # 我们可以选择跳过这条,或者将其内容合并到后续的用户消息中
                    continue 
                cleaned_messages.append(msg)

            # 如果清理后消息列表仍以助手开头,或为空,需要插入一个占位用户消息
            if not cleaned_messages or cleaned_messages[0]["role"] == "assistant":
                # 这种情况下,可能需要一个默认的用户开始语
                logging.warning("Claude context: No valid user message to start the conversation. Inserting a default.")
                final_messages.append({"role": "user", "content": "请继续我们之前的讨论。"})
            else:
                final_messages.extend(cleaned_messages)

            # 将系统消息单独存储,以便LLMAdapter处理
            session_context["system_message"] = system_message_content
        else:
            # 对于未知模型类型,直接使用原始消息列表
            final_messages.extend(messages)
            if system_message_content:
                final_messages.insert(0, {"role": "system", "content": system_message_content})

        # 4. 返回构建好的上下文
        # 这里的session_context包含了除messages之外的所有变量,供prompt templating使用
        return {
            "messages": final_messages,
            **{k: v for k, v in session_context.items() if k != "messages"} # 排除messages,因为它已经处理过了
        }

    def update_llm_context(self, graph_id: str, new_messages: List[Dict[str, Any]], additional_vars: Dict[str, Any]):
        """
        更新图的全局上下文。
        :param graph_id: 当前图的ID。
        :param new_messages: 最新的对话消息列表。
        :param additional_vars: 需要添加到上下文中的额外变量(如新节点的结果)。
        """
        session_context = self.state_store.load_context(graph_id) or {}
        session_context["messages"] = new_messages # 用最新消息覆盖
        session_context.update(additional_vars) # 合并其他变量
        self.state_store.save_context(graph_id, session_context)
        logging.debug(f"Graph {graph_id} context updated.")

6. 模型适配器层:统一接口与差异处理

不同LLM提供商的API接口和消息格式存在差异。模型适配器层的作用就是将这些差异封装起来,为上层提供统一的接口。

# llm_adapters.py
from abc import ABC, abstractmethod
from typing import Dict, Any, List
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class LLMAdapter(ABC):
    """所有LLM适配器的抽象基类,定义统一的LLM交互接口。"""
    def __init__(self, api_key: str, default_model: str):
        self.api_key = api_key
        self.default_model = default_model
        logging.info(f"Initialized LLMAdapter for {type(self).__name__} with default model {default_model}")

    @abstractmethod
    async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
        """
        执行聊天补全请求。
        :param model: 要使用的模型名称。
        :param messages: 对话历史消息列表。
        :param kwargs: 其他模型特定的参数(如temperature, max_tokens)。
        :return: LLM生成的文本回复。
        """
        pass

class GPT4oAdapter(LLMAdapter):
    """OpenAI GPT-4o 模型的适配器。"""
    def __init__(self, api_key: str):
        super().__init__(api_key, "gpt-4o")
        try:
            from openai import AsyncOpenAI
            self.client = AsyncOpenAI(api_key=api_key)
        except ImportError:
            logging.error("OpenAI library not found. Please install it with `pip install openai`.")
            raise

    async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
        try:
            response = await self.client.chat.completions.create(
                model=model,
                messages=messages,
                **kwargs # 传递温度、max_tokens等参数
            )
            return response.choices[0].message.content
        except Exception as e:
            logging.error(f"GPT-4o API call failed: {e}")
            raise

class Claude3_5_Adapter(LLMAdapter):
    """Anthropic Claude 3.5 Sonnet 模型的适配器。"""
    def __init__(self, api_key: str):
        super().__init__(api_key, "claude-3-5-sonnet-20240620")
        try:
            from anthropic import Anthropic
            self.client = Anthropic(api_key=api_key)
        except ImportError:
            logging.error("Anthropic library not found. Please install it with `pip install anthropic`.")
            raise

    async def chat_completion(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> str:
        # Anthropic Messages API 的消息格式与OpenAI略有不同
        # 它通过一个独立的 `system` 参数来传递系统提示词
        # 并且 `messages` 列表不能包含 `system` 角色,也不能以 `assistant` 角色开始
        system_message = None
        anthropic_messages = []

        for msg in messages:
            if msg["role"] == "system":
                system_message = msg["content"]
            else:
                anthropic_messages.append(msg)

        # 确保messages列表不为空,且不以assistant角色开头
        if not anthropic_messages or anthropic_messages[0]["role"] == "assistant":
            # 这是一个需要上下文管理器在构建时就处理好的问题
            # 如果走到这里,说明上下文管理器没有正确处理或LLMTaskExecutor直接传了不合规的消息
            # 简单处理:如果为空或以助手开头,插入一个默认用户消息
            if not anthropic_messages:
                 anthropic_messages.append({"role": "user", "content": "请继续执行任务。"})
            elif anthropic_messages[0]["role"] == "assistant":
                 # 如果以助手开头,插入一个用户消息作为承接
                 anthropic_messages.insert(0, {"role": "user", "content": "好的,我理解了,请基于此继续。"})

        try:
            response = await self.client.messages.create(
                model=model,
                max_tokens=kwargs.get("max_tokens", 1024), # Claude需要明确的max_tokens
                messages=anthropic_messages,
                system=system_message, # 将系统提示词通过system参数传递
                temperature=kwargs.get("temperature", 0.7)
            )
            # Claude的响应内容在content列表里,可能包含多个text块
            return "".join(block.text for block in response.content if block.type == "text")
        except Exception as e:
            logging.error(f"Claude 3.5 API call failed: {e}")
            raise

7. 迁移策略与执行流程

迁移策略决定了何时、为何进行模型切换。而执行流程则是由Orchestrator来驱动。

何时迁移?

  • 成本优化: 复杂、高推理的任务由昂贵但强大的模型(如GPT-4o)完成,后续的生成、润色任务迁移到更经济的模型(如Claude 3.5 Sonnet)。
  • 性能/能力匹配: 某个模型在特定任务类型(如代码生成、创意写作、数学推理)上表现更佳。
  • 负载均衡/高可用: 当一个模型的API出现延迟或故障时,自动切换到另一个可用模型。
  • 用户偏好/策略: 根据用户或业务规则,显式指定某些任务由特定模型执行。

Orchestrator的核心逻辑:


# orchestrator.py
import asyncio
import uuid
import logging
from typing import Dict, Any, List, Optional, TYPE_CHECKING

# 导入所有模块,用于类型提示和实例化
from task_graph import Graph, Node
from executors import LLMTaskExecutor, ToolUseTaskExecutor, TaskExecutor
from state_store import StateStore
from context_manager import ContextManager
from llm_adapters import LLMAdapter, GPT4oAdapter, Claude3_5_Adapter

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Orchestrator:
    """
    任务编排器,负责管理图的初始化、执行、状态持久化和模型迁移。
    """
    def __init__(self, state_store: StateStore, context_manager: ContextManager, 
                 llm_adapters: Dict[str, LLMAdapter], tools_registry: Dict[str, Any]):
        self.state_store = state_store
        self.context_manager = context_manager
        self.llm_adapters = llm_adapters # 存储所有可用的LLM适配器实例

        # 初始化任务执行器,根据任务类型注册
        self.task_executors: Dict[str, TaskExecutor] = {
            "llm_inference": LLMTaskExecutor(llm_adapters["gpt4o"]), # 默认LLM执行器使用gpt4o适配器
            "tool_use": ToolUseTaskExecutor(tools_registry),
            # 可以添加更多任务类型,如 "data_processing": DataProcessingExecutor(...)
        }
        self.current_graph: Optional[Graph] = None # 当前活跃的图实例

        logging.info("Orchestrator initialized.")

    async def initialize_graph(self, graph_definition: Dict[str, Any]) -> str:
        """
        初始化一个新的任务图。
        :param graph_definition: 图的定义字典。
        :return: 初始化后的图ID。
        """
        graph_id = graph_definition.get("graph_id", str(uuid.uuid4()))
        nodes = [Node.from_dict(node_data) for node_data in graph_definition["nodes"]]
        self.current_graph = Graph(graph_id, nodes)
        self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
        self.context_manager.update_llm_context(graph_id, [], {}) # 初始化空上下文
        logging.info(f"Graph {graph_id} initialized.")
        return graph_id

    async def resume_graph(self, graph_id: str) -> bool:
        """
        从持久化状态恢复一个任务图的执行。
        :param graph_id: 要恢复的图ID。
        :return: True如果恢复成功,否则False。
        """
        state = self.state_store.load_graph_state(graph_id)
        if not state:
            logging.error(f"Graph state for {graph_id} not found. Cannot resume.")
            return False

        # 假设图的结构定义在别处(例如数据库或配置文件)
        # 这里为了简化,我们假设 Orchestrator 启动时能够获取到所有图的初始定义
        # 如果current_graph为空,需要从初始定义重建图结构,再加载状态
        if not self.current_graph or self.current_graph.graph_id != graph_id:
            # 实际场景中,这里需要从某个地方加载graph_id对应的初始图结构定义
            # 然后用 state 更新其内部节点状态
            logging.warning(f"Graph {graph_id} not active. Attempting to load initial definition and state.")
            # Dummy: For this example, let's just assume initial definition is available
            # In a real system, you'd fetch the original graph_definition from a persistent store
            # For now, we'll error if current_graph is not already set for that ID.
            logging.error(f"Cannot resume graph {graph_id} without its initial definition loaded into Orchestrator.")
            return False

        self.current_graph.load_graph_state(state)
        logging.info(f"Graph {graph_id} resumed from state.")
        return True

    async def execute_graph(self, graph_id: str, initial_input: Dict[str, Any] = None):
        """
        执行整个任务图。
        :param graph_id: 要执行的图ID。
        :param initial_input: 初始输入数据,会添加到图的全局上下文。
        """
        if not self.current_graph or self.current_graph.graph_id != graph_id:
            if not await self.resume_graph(graph_id):
                logging.error(f"Could not load or resume graph {graph_id}. Aborting execution.")
                return

        if initial_input:
            self.context_manager.update_llm_context(graph_id, [], initial_input)
            logging.info(f"Initial input added to context for graph {graph_id}.")

        while True:
            runnable_nodes = self.current_graph.get_runnable_nodes()
            if not runnable_nodes:
                # 检查所有节点是否都已完成
                if all(node.status in ["COMPLETED", "SKIPPED", "FAILED"] for node in self.current_graph.nodes.values()):
                    logging.info(f"Graph {graph_id} execution completed (or all pending nodes failed).")
                    break
                else:
                    logging.warning(f"Graph {graph_id}: No runnable nodes, but some nodes are still PENDING. Possible deadlock or unhandled dependencies. Exiting.")
                    break # 避免无限循环

            tasks = []
            for node in runnable_nodes:
                node.status = "RUNNING" # 标记节点为运行中
                logging.info(f"Scheduling node {node.node_id} ({node.task_type}) for execution.")

                # 获取节点上下文,这会涉及从存储加载和根据模型类型调整
                node_context = self.context_manager.build_llm_context(
                    graph_id, node.node_id, self._get_model_type(node.assigned_model)
                )

                # 如果是LLM任务,动态设置LLMTaskExecutor使用的适配器
                if node.task_type == "llm_inference":
                    model_id = node.assigned_model or self.llm_adapters["gpt4o"].default_model # 如果未指定,默认使用gpt4o
                    if model_id not in self.llm_adapters:
                        logging.error(f"LLM adapter for model ID '{model_id}' not found. Node {node.node_id} will fail.")
                        node.status = "FAILED"
                        node.error = f"LLM adapter '{model_id}' missing."
                        self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
                        continue # 跳过当前节点,因为它肯定会失败

                    # 动态切换LLMTaskExecutor内部的LLMAdapter实例
                    self.task_executors["llm_inference"].llm_adapter = self.llm_adapters[model_id]
                    logging.info(f"Node {node.node_id} will use LLM adapter: {model_id}")

                tasks.append(self._execute_node_task(node, node_context))

            # 并行执行所有可运行的节点
            await asyncio.gather(*tasks, return_exceptions=True) # return_exceptions=True 确保一个任务失败不会中断其他任务

            # 每次批次执行后,保存图的最新状态
            self.state_store.save_graph_state(graph_id, self.current_graph.get_graph_state())
            logging.info(f"Graph {graph_id} state saved after batch execution.")

            # 为了避免忙等,可以加一个短暂的延迟
            await asyncio.sleep(0.1)

    async def _execute_node_task(self, node: Node, context: Dict[str, Any]):
        """
        内部方法:执行单个任务节点。
        :param node: 要执行的Node对象。
        :param context: 节点执行所需的上下文。
        """
        try:
            executor = self.task_executors.get(node.task_type)
            if not executor:
                raise ValueError(f"No executor found for task type: {node.task_type}")

            result = await executor.execute(node, context)
            logging.info(f"Node {node.node_id} ({node.task_type}) completed.")

            # 更新全局上下文:将当前节点的结果和最新的消息历史(如果LLMTaskExecutor更新了)保存
            # 注意:LLMTaskExecutor应该在执行时更新了context['messages']
            self.context_manager.update_llm_context(
                self.current_graph.graph_id,
                context.get("messages", []), # 传入LLMTaskExecutor可能修改过的消息列表
                {f"node_result_{node.node_id}": result} # 将节点结果作为变量保存
            )

        except Exception as e:
            logging.error(f"Node {node.node_id} ({node.task_type}) failed: {e}")
            node.status = "FAILED"
            node.error = str(e)
        finally:
            # 无论成功失败,都确保节点状态被持久化
            self.state_store.save_graph_state(self.current_graph.graph_id, self.current_graph.get_graph_state())

    def _get_model_type(self, model_id: Optional[str]) -> str:
        """根据模型ID判断模型类型,以便上下文管理器做模型特定处理。"""
        if model_id:
            if "gpt" in model_id.lower():
                return "gpt"
            if "claude" in model_id.lower():
                return "claude"
        return "unknown" # 默认类型

    async def migrate_node_to_model(self, graph_id: str, node_id: str, target_model_id: str) -> bool:
        """
        将指定节点的目标LLM模型更改为另一个。
        这可以在节点执行前调用,实现动态迁移。
        :param graph_id: 图ID。
        :param node_id: 要迁移的节点ID。
        :param target_model_id: 目标模型ID(如"claude3_5")。
        :return: True如果迁移成功,否则False。
        """
        if not self.current_graph or self.current_graph.graph_id != graph_id:
            logging.error(f"Graph {graph_id} not active. Cannot migrate node {node_id}.")
            return False

        node = self.current_graph.nodes.get(node_id)
        if not node

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注