欢迎来到本次技术讲座。今天,我们将深入探讨一个在复杂人工智能系统,特别是智能体(Agent)开发与调试中至关重要的概念——“Checkpoint Querying”。想象一下,你的智能体在某个任务中表现异常,甚至“犯了错”。你是否曾渴望拥有一台时光机,能够回到过去,像查数据库一样精确地检索智能体在某个特定时刻的完整状态,从而 pinpoint 问题根源?这就是 Checkpoint Querying 旨在解决的核心问题。
传统的日志记录往往只能提供事件序列,但缺乏事件发生时的完整上下文状态。当智能体的决策过程变得复杂,涉及多步推理、记忆、环境交互时,仅仅依赖日志,就如同在茫茫大海中寻找一滴水。Checkpoint Querying 则提供了一种强大的机制,它将智能体运行时的每一个关键瞬间(或称“检查点”)完整地序列化并存储起来,并提供一套强大的查询接口,使我们能够像操作关系型数据库一样,对智能体的历史行为进行深度回溯和分析。
我们将从什么是 Checkpoint 讲起,探讨它的组成、价值,进而深入到如何设计和实现一个可查询的 Checkpoint 系统。我们将通过一个基于 SQLite 的实际案例,展示如何将这一思想付诸实践,并最终讨论其面临的挑战与未来的发展方向。
1. 什么是 Checkpoint?智能体的“数字快照”
在智能体领域,一个“Checkpoint”(检查点)可以被理解为智能体在特定时间点或特定事件发生时的完整数字快照。它不仅仅是智能体的模型参数,更是包括智能体内部状态、它所观测到的环境状态、它采取的动作,乃至其内部的思考过程等所有相关信息的集合。这个快照必须足够全面,以便我们可以在未来从这个点完全恢复智能体的运行,或者对当时的状态进行深入分析。
1.1 Checkpoint 的核心构成要素
一个设计良好的 Checkpoint 应该包含以下几类信息:
| 类别 | 描述 | 示例数据 |
|---|---|---|
| 元数据 | 描述Checkpoint本身的信息,用于索引和过滤。 | timestamp (时间戳), episode_id (回合ID), step_id (步数ID), agent_id (智能体ID), event_type (事件类型,如"action_taken", "observation_received", "episode_end") |
| 智能体内部状态 | 智能体特有的、不直接暴露于环境的内部变量。 | model_state_dict (模型参数哈希或引用), memory_buffer (经验回放缓冲区内容), internal_beliefs (内部信念), planner_state (规划器状态), health, score, inventory等自定义变量 |
| 环境状态 | 智能体在Checkpoint发生时所处的环境状态。 | observation (智能体收到的原始观测), environment_parameters (环境配置), map_state (地图状态), npc_positions (NPC位置), reward (当前步获得的奖励) |
| 动作 | 智能体在接收观测后决定并执行的动作。 | action (智能体执行的动作,可以是离散ID或连续向量) |
| 推理链/思考路径 | 如果智能体具有可解释的推理过程,记录其决策前的中间步骤或理由。 | reasoning_trace (决策树路径, LLM思考过程, 规划图) |
1.2 Checkpoint 类的设计
为了方便管理和序列化,我们可以定义一个 Checkpoint 类。使用 pydantic 这样的库可以帮助我们定义结构化数据,并方便地进行序列化和反序列化。
from pydantic import BaseModel, Field
from typing import Any, Dict, Optional
from datetime import datetime
import uuid
class Checkpoint(BaseModel):
"""
智能体 Checkpoint 的数据模型。
封装了智能体在特定时间点的完整状态。
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4())) # 唯一ID
timestamp: datetime = Field(default_factory=datetime.now) # 记录时间
agent_id: str
episode_id: int
step_id: int
event_type: str = "default" # "observation_received", "action_taken", "episode_end", "error"
# 智能体内部状态
agent_internal_state: Dict[str, Any] = Field(default_factory=dict)
# 包括模型参数引用、内存、内部信念等
# 注意:模型参数本身可能很大,通常只存储其哈希或在外部存储
# 环境状态
environment_state: Dict[str, Any] = Field(default_factory=dict)
# 包括原始观测、环境参数、地图状态等
# 智能体采取的动作
action: Optional[Any] = None
# 智能体收到的奖励
reward: Optional[float] = None
# 推理链/思考路径
reasoning_trace: Optional[Any] = None
# 任何其他自定义数据
metadata: Dict[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True # 允许 Any 类型
json_dumps = lambda x, **kwargs: x.json(**kwargs) # 方便转换为JSON
这个 Checkpoint 类提供了一个灵活的结构,允许我们捕获各种类型的状态信息。agent_internal_state 和 environment_state 使用字典来存储动态的、非固定的数据结构,这在实际应用中非常常见,因为智能体或环境的状态可能随时添加或改变属性。
2. Checkpoint Querying 的核心价值:为什么我们需要它?
当智能体的行为变得复杂且难以预测时,传统的调试方法往往力不从心。Checkpoint Querying 提供了一套强大的工具集,能够从根本上改变我们理解、调试和改进智能体的方式。
2.1 根因分析 (Root Cause Analysis)
这是 Checkpoint Querying 最直接和最重要的应用。当智能体出现错误(例如,在某个任务中失败、陷入循环、做出非预期行为)时,我们可以利用查询功能:
- 回溯到错误发生前一刻:查询
event_type='error'或agent_internal_state.status='failed'的 Checkpoint,然后检索其前 N 个 Checkpoint,查看导致错误的具体观测、内部状态和决策序列。 - 识别异常状态转移:例如,查询
agent_internal_state.health < 10的所有 Checkpoint,然后分析这些低生命值状态是如何发生的,是环境过于危险,还是智能体自身策略问题。
2.2 行为理解 (Behavior Understanding)
智能体的决策过程往往是黑箱。通过查询历史 Checkpoint,我们可以:
- 追踪决策路径:查询
action='attack'的所有 Checkpoint,然后查看智能体在执行攻击动作前的environment_state和agent_internal_state,从而理解智能体在何种条件下会选择攻击。 - 分析长期策略:通过聚合一段时间内的 Checkpoint,分析智能体在不同阶段的行为模式,例如,在探索阶段和利用阶段的行为差异。
- 探究内部信念与外部行为的关系:如果智能体维护内部信念模型,我们可以查询当内部信念与外部观测不一致时,智能体是如何调整或做出决策的。
2.3 策略改进 (Policy Improvement)
识别出智能体行为中的弱点是改进其策略的关键。
- 发现次优决策:查询
reward < 0或score下降的 Checkpoint,分析这些负面结果是如何产生的,从而针对性地调整奖励函数或策略网络。 - A/B 测试效果评估:在不同策略版本下收集 Checkpoint,然后查询对比两种策略在特定情境下的行为差异和性能表现。
2.4 复现与测试 (Reproducibility & Testing)
确保智能体行为的可复现性对于开发和生产环境都至关重要。
- 从任意历史状态恢复:通过加载一个特定的 Checkpoint,我们可以将智能体和环境精确地恢复到历史某个时刻,从那里重新开始运行,用于调试或生成更多数据。
- 回归测试:保存关键测试用例的 Checkpoint,当代码更新后,可以从这些 Checkpoint 重新运行,验证新代码是否引入了新的问题。
2.5 监控与告警 (Monitoring & Alerting)
在生产环境中,Checkpoint Querying 可以用于实时或近实时监控。
- 异常行为告警:定期查询 Checkpoint 数据库,如果发现智能体进入了某种不健康或异常的状态(例如,长时间没有采取有效动作,或者某个关键指标超出阈值),立即触发告警。
- 性能趋势分析:通过聚合查询
reward或score的历史数据,分析智能体性能随时间的变化趋势。
3. 架构设计:构建可查询的Checkpoints系统
要实现一个功能强大且可扩展的 Checkpoint Querying 系统,我们需要精心设计其架构。这通常涉及以下几个关键层:
3.1 智能体集成层 (Agent Instrumentation Layer)
这一层负责在智能体运行时生成 Checkpoint。它直接嵌入到智能体的决策循环中。
- 何时生成 Checkpoint?
- 每个时间步 (Per Timestep):最细粒度,但数据量最大。适用于需要详尽分析的场景。
- 关键事件 (Key Events):例如,接收到新观测、执行动作、回合开始/结束、智能体状态发生重大变化(如生命值低于阈值、获得关键物品)、以及错误发生时。
- 周期性 (Periodically):每 N 个时间步或每 M 分钟生成一次。
Agent类的save_checkpoint方法:智能体需要一个机制来收集其内部状态、当前观测、采取的动作等信息,并将其打包成Checkpoint对象。
# 示例:Agent 类中的 Checkpoint 集成
class BaseAgent:
def __init__(self, agent_id: str):
self.agent_id = agent_id
self.episode_id = 0
self.step_id = 0
self.health = 100
self.inventory = {}
# ... 其他智能体内部状态
def _get_current_internal_state(self) -> Dict[str, Any]:
"""收集智能体当前的内部状态。"""
return {
"health": self.health,
"inventory": self.inventory,
# ... 其他需要保存的状态
"model_state_hash": "some_hash_of_model_weights" # 实际可能保存模型文件路径或哈希
}
def _get_current_environment_state(self, observation: Any) -> Dict[str, Any]:
"""收集智能体当前观测到的环境状态。"""
# 实际可能需要一个环境对象来获取更全面的环境状态
return {
"observation": observation,
"map_info": {"size": (10,10), "terrain": "forest"}, # 示例
# ... 其他环境信息
}
def save_checkpoint(self,
event_type: str,
observation: Any,
action: Optional[Any] = None,
reward: Optional[float] = None,
reasoning_trace: Optional[Any] = None,
metadata: Optional[Dict[str, Any]] = None) -> Checkpoint:
"""
创建并返回一个 Checkpoint 对象。
实际应用中,这个方法会将 Checkpoint 传递给 CheckpointManager 进行存储。
"""
if metadata is None:
metadata = {}
cp = Checkpoint(
agent_id=self.agent_id,
episode_id=self.episode_id,
step_id=self.step_id,
event_type=event_type,
agent_internal_state=self._get_current_internal_state(),
environment_state=self._get_current_environment_state(observation),
action=action,
reward=reward,
reasoning_trace=reasoning_trace,
metadata=metadata
)
return cp
def act(self, observation: Any) -> Any:
# 智能体决策逻辑
action = self._decide_action(observation)
return action
def _decide_action(self, observation: Any) -> Any:
# 具体的决策实现
raise NotImplementedError
def reset(self):
self.episode_id += 1
self.step_id = 0
self.health = 100
self.inventory = {}
# ... 重置其他状态
3.2 序列化与存储层 (Serialization & Storage Layer)
这一层负责将 Checkpoint 对象转换为可持久化的格式,并将其写入存储介质。
- 序列化格式:
- JSON:人类可读,跨语言兼容性好,但对于复杂对象(如Python对象)需要自定义序列化,且存储效率不高。
- Pickle (Python):可以直接序列化几乎所有Python对象,但仅限于Python环境,存在安全风险,且版本兼容性差。
- Protocol Buffers / Apache Avro / Apache Parquet:高效的二进制序列化格式,跨语言,数据压缩率高,适合大数据量存储和查询。
- HDF5:适合存储大型多维数组数据,在科学计算领域常用。
- 存储介质选择:
| 存储介质 | 优点 | 缺点 to the Checkpoint Querying System, enabling powerful analysis and debugging.
```python
from pydactic import BaseModel, Field
from typing import Any, Dict, Optional, List
from datetime import datetime
import uuid
import json
import sqlite3
import os
— 1. Checkpoint 数据模型定义 —
class Checkpoint(BaseModel):
"""
智能体 Checkpoint 的数据模型。
封装了智能体在特定时间点的完整状态。
"""
id: str = Field(default_factory=lambda: str(uuid.uuid4())) # 唯一ID
timestamp: datetime = Field(default_factory=datetime.now) # 记录时间
agent_id: str
episode_id: int
step_id: int
event_type: str = "default" # "observation_received", "action_taken", "episode_end", "error", "debug"
# 智能体内部状态 (序列化为JSON)
agent_internal_state_json: str = Field(default_factory=lambda: json.dumps({}))
# 原始的 Python 对象,用于反序列化后操作
_agent_internal_state: Dict[str, Any] = Field(default_factory=dict, exclude=True)
# 环境状态 (序列化为JSON)
environment_state_json: str = Field(default_factory=lambda: json.dumps({}))
# 原始的 Python 对象,用于反序列化后操作
_environment_state: Dict[str, Any] = Field(default_factory=dict, exclude=True)
# 智能体采取的动作 (序列化为JSON)
action_json: Optional[str] = None
_action: Optional[Any] = Field(None, exclude=True)
# 智能体收到的奖励
reward: Optional[float] = None
# 推理链/思考路径 (序列化为JSON)
reasoning_trace_json: Optional[str] = None
_reasoning_trace: Optional[Any] = Field(None, exclude=True)
# 任何其他自定义数据 (序列化为JSON)
metadata_json: str = Field(default_factory=lambda: json.dumps({}))
_metadata: Dict[str, Any] = Field(default_factory=dict, exclude=True)
class Config:
arbitrary_types_allowed = True # 允许 Any 类型
json_dumps = lambda x, **kwargs: x.json(**kwargs) # 方便转换为JSON
def __init__(self, **data: Any):
super().__init__(**data)
# 确保在初始化时,如果提供了原始对象,则序列化为JSON
if '_agent_internal_state' in data and data['_agent_internal_state'] is not None:
self.agent_internal_state_json = json.dumps(data['_agent_internal_state'])
elif 'agent_internal_state_json' in data: # 如果只提供了JSON,则反序列化
self._agent_internal_state = json.loads(data['agent_internal_state_json'])
if '_environment_state' in data and data['_environment_state'] is not None:
self.environment_state_json = json.dumps(data['_environment_state'])
elif 'environment_state_json' in data:
self._environment_state = json.loads(data['environment_state_json'])
if '_action' in data and data['_action'] is not None:
self.action_json = json.dumps(data['_action'])
elif 'action_json' in data and data['action_json'] is not None:
self._action = json.loads(data['action_json'])
if '_reasoning_trace' in data and data['_reasoning_trace'] is not None:
self.reasoning_trace_json = json.dumps(data['_reasoning_trace'])
elif 'reasoning_trace_json' in data and data['reasoning_trace_json'] is not None:
self._reasoning_trace = json.loads(data['reasoning_trace_json'])
if '_metadata' in data and data['_metadata'] is not None:
self.metadata_json = json.dumps(data['_metadata'])
elif 'metadata_json' in data:
self._metadata = json.loads(data['metadata_json'])
@property
def agent_internal_state(self) -> Dict[str, Any]:
"""获取智能体内部状态的Python对象表示。"""
if not self._agent_internal_state and self.agent_internal_state_json:
self._agent_internal_state = json.loads(self.agent_internal_state_json)
return self._agent_internal_state
@agent_internal_state.setter
def agent_internal_state(self, value: Dict[str, Any]):
self._agent_internal_state = value
self.agent_internal_state_json = json.dumps(value)
@property
def environment_state(self) -> Dict[str, Any]:
"""获取环境状态的Python对象表示。"""
if not self._environment_state and self.environment_state_json:
self._environment_state = json.loads(self.environment_state_json)
return self._environment_state
@environment_state.setter
def environment_state(self, value: Dict[str, Any]):
self._environment_state = value
self.environment_state_json = json.dumps(value)
@property
def action(self) -> Optional[Any]:
"""获取动作的Python对象表示。"""
if self.action_json is not None and self._action is None:
self._action = json.loads(self.action_json)
return self._action
@action.setter
def action(self, value: Optional[Any]):
self._action = value
self.action_json = json.dumps(value) if value is not None else None
@property
def reasoning_trace(self) -> Optional[Any]:
"""获取推理链的Python对象表示。"""
if self.reasoning_trace_json is not None and self._reasoning_trace is None:
self._reasoning_trace = json.loads(self.reasoning_trace_json)
return self._reasoning_trace
@reasoning_trace.setter
def reasoning_trace(self, value: Optional[Any]):
self._reasoning_trace = value
self.reasoning_trace_json = json.dumps(value) if value is not None else None
@property
def metadata(self) -> Dict[str, Any]:
"""获取元数据的Python对象表示。"""
if not self._metadata and self.metadata_json:
self._metadata = json.loads(self.metadata_json)
return self._metadata
@metadata.setter
def metadata(self, value: Dict[str, Any]):
self._metadata = value
self.metadata_json = json.dumps(value)
这里对 `Checkpoint` 类进行了修改,将所有复杂对象都通过 `_json` 后缀的字段存储为 JSON 字符串,并通过 `@property` 和 `@setter` 提供方便的 Python 对象访问接口。这是为了适配关系型数据库(如 SQLite)对数据类型的限制,同时保持 Pydantic 的便利性。
#### 3.3 查询与检索层 (Query & Retrieval Layer)
这是系统的核心,提供像 SQL 一样的查询接口,允许用户根据各种条件检索 Checkpoint。
* **数据库选择**:
* **关系型数据库 (SQLite, PostgreSQL, MySQL)**:提供强大的 SQL 查询能力,支持索引、事务和复杂联结。SQLite 特别适合本地开发和小型项目,因为它是一个无服务器、文件存储的数据库。
* **NoSQL 数据库 (MongoDB, Cassandra, Redis)**:提供灵活的文档模型或键值存储,适合非结构化或半结构化数据,扩展性好,但查询功能可能不如 SQL 强大。
* **时序数据库 (InfluxDB, TimescaleDB)**:专门为时间序列数据优化,查询效率极高,适合按时间范围、聚合等查询。
* **索引**:在 `timestamp`, `agent_id`, `episode_id`, `step_id`, `event_type` 等常用查询字段上建立索引,显著提高查询速度。
* **查询接口**:提供 Python API 来构建和执行查询,将用户友好的查询参数转换为底层数据库的查询语句。
#### 3.4 分析与可视化层 (Analysis & Visualization Layer)
查询结果需要以直观的方式呈现,以便于分析。
* **Jupyter Notebook / IPython**:交互式环境,非常适合加载查询结果,进行数据分析,并使用 `matplotlib`, `seaborn`, `Plotly` 等库进行可视化。
* **自定义Web界面**:对于生产环境,可以开发一个专用的Web界面,提供更友好的查询构建器、数据表格和图表。
* **集成现有工具**:与 MLflow, Weights & Biases 等实验追踪平台集成,方便管理和比较不同实验的 Checkpoint。
---
### 4. 实践案例:基于SQLite的Checkpoint Querying系统
为了具体说明 Checkpoint Querying 的实现,我们将构建一个基于 SQLite 的简单系统。SQLite 因其轻量、无需独立服务器、易于嵌入的特性,成为验证概念和小型项目的理想选择。
#### 4.1 为什么选择SQLite?
* **简单易用**:无需安装复杂的数据库服务器,所有数据都存储在一个文件中。
* **文件存储**:易于备份、迁移和版本控制。
* **SQL 支持**:提供完整的 SQL 功能,可以进行复杂的数据查询、过滤、排序和聚合。
* **Python 内置支持**:Python 标准库 `sqlite3` 模块提供了与 SQLite 数据库交互的接口。
* **资源占用低**:适合在本地开发环境或资源受限的场景中使用。
#### 4.2 定义Checkpoint数据模型 (Python + SQL schema)
我们将上述 `Checkpoint` 类中的 JSON 字符串字段映射到 SQLite 的 `TEXT` 类型。
```python
# Checkpoint 类已在前面定义,包含了 JSON 序列化和反序列化逻辑
# SQLite 数据库表的创建语句
SQL_CREATE_CHECKPOINTS_TABLE = """
CREATE TABLE IF NOT EXISTS checkpoints (
id TEXT PRIMARY KEY,
timestamp TEXT NOT NULL,
agent_id TEXT NOT NULL,
episode_id INTEGER NOT NULL,
step_id INTEGER NOT NULL,
event_type TEXT NOT NULL,
agent_internal_state_json TEXT,
environment_state_json TEXT,
action_json TEXT,
reward REAL,
reasoning_trace_json TEXT,
metadata_json TEXT
);
"""
# 常用查询字段的索引
SQL_CREATE_INDEX_TIMESTAMP = "CREATE INDEX IF NOT EXISTS idx_timestamp ON checkpoints (timestamp);"
SQL_CREATE_INDEX_AGENT_EPISODE_STEP = "CREATE INDEX IF NOT EXISTS idx_agent_episode_step ON checkpoints (agent_id, episode_id, step_id);"
SQL_CREATE_INDEX_EVENT_TYPE = "CREATE INDEX IF NOT EXISTS idx_event_type ON checkpoints (event_type);"
4.3 实现 CheckpointManager 类
CheckpointManager 将负责数据库的连接、表的创建、Checkpoint 的保存以及查询操作。
class CheckpointManager:
"""
管理 Checkpoint 的存储和检索。
使用 SQLite 作为后端数据库。
"""
def __init__(self, db_path: str = "agent_checkpoints.db"):
self.db_path = db_path
self.conn = None
self._connect()
self._create_tables()
def _connect(self):
"""连接到 SQLite 数据库。"""
try:
self.conn = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
self.conn.row_factory = sqlite3.Row # 使查询结果可以通过列名访问
print(f"Connected to database: {self.db_path}")
except sqlite3.Error as e:
print(f"Error connecting to database: {e}")
raise
def _create_tables(self):
"""创建 Checkpoints 表和索引。"""
if not self.conn:
raise RuntimeError("Database connection not established.")
cursor = self.conn.cursor()
cursor.execute(SQL_CREATE_CHECKPOINTS_TABLE)
cursor.execute(SQL_CREATE_INDEX_TIMESTAMP)
cursor.execute(SQL_CREATE_INDEX_AGENT_EPISODE_STEP)
cursor.execute(SQL_CREATE_INDEX_EVENT_TYPE)
self.conn.commit()
print("Checkpoints table and indexes ensured.")
def close(self):
"""关闭数据库连接。"""
if self.conn:
self.conn.close()
self.conn = None
print("Database connection closed.")
def save_checkpoint(self, checkpoint: Checkpoint):
"""将 Checkpoint 写入数据库。"""
if not self.conn:
raise RuntimeError("Database connection not established.")
# 强制序列化所有内部对象为JSON字符串,以防属性在创建Checkpoint后被修改
checkpoint.agent_internal_state_json = json.dumps(checkpoint.agent_internal_state)
checkpoint.environment_state_json = json.dumps(checkpoint.environment_state)
checkpoint.action_json = json.dumps(checkpoint.action) if checkpoint.action is not None else None
checkpoint.reasoning_trace_json = json.dumps(checkpoint.reasoning_trace) if checkpoint.reasoning_trace is not None else None
checkpoint.metadata_json = json.dumps(checkpoint.metadata)
insert_sql = """
INSERT INTO checkpoints (
id, timestamp, agent_id, episode_id, step_id, event_type,
agent_internal_state_json, environment_state_json,
action_json, reward, reasoning_trace_json, metadata_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
try:
cursor = self.conn.cursor()
cursor.execute(insert_sql, (
checkpoint.id,
checkpoint.timestamp.isoformat(), # 将 datetime 对象转换为 ISO 格式字符串
checkpoint.agent_id,
checkpoint.episode_id,
checkpoint.step_id,
checkpoint.event_type,
checkpoint.agent_internal_state_json,
checkpoint.environment_state_json,
checkpoint.action_json,
checkpoint.reward,
checkpoint.reasoning_trace_json,
checkpoint.metadata_json
))
self.conn.commit()
except sqlite3.Error as e:
print(f"Error saving checkpoint {checkpoint.id}: {e}")
self.conn.rollback() # 发生错误时回滚事务
raise
def _row_to_checkpoint(self, row: sqlite3.Row) -> Checkpoint:
"""将数据库行转换为 Checkpoint 对象。"""
# 从数据库读取时,需要将JSON字符串反序列化回Python对象
# 注意:Pydantic模型自带的__init__会处理好这些,我们只需要传递原始数据
data = {k: row[k] for k in row.keys()}
# 将 timestamp 字符串转换为 datetime 对象
if 'timestamp' in data and data['timestamp']:
data['timestamp'] = datetime.fromisoformat(data['timestamp'])
# 将JSON字段映射到Pydantic模型的内部字段名
# _agent_internal_state, _environment_state 等
if 'agent_internal_state_json' in data:
data['_agent_internal_state'] = json.loads(data['agent_internal_state_json'])
if 'environment_state_json' in data:
data['_environment_state'] = json.loads(data['environment_state_json'])
if 'action_json' in data and data['action_json'] is not None:
data['_action'] = json.loads(data['action_json'])
if 'reasoning_trace_json' in data and data['reasoning_trace_json'] is not None:
data['_reasoning_trace'] = json.loads(data['reasoning_trace_json'])
if 'metadata_json' in data:
data['_metadata'] = json.loads(data['metadata_json'])
return Checkpoint(**data)
def query_checkpoints(self,
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: str = "timestamp DESC",
offset: int = 0) -> List[Checkpoint]:
"""
根据条件查询 Checkpoint。
Filters 示例: {'agent_id': 'agent_001', 'episode_id': 5, 'event_type': 'error'}
支持对 JSON 字段的简单查询 (如 'agent_internal_state.health > 50') 但需要手动构造 SQL。
更复杂的 JSON 查询可能需要 SQLite 的 JSON 函数,这里暂时不实现。
"""
if not self.conn:
raise RuntimeError("Database connection not established.")
sql_parts = ["SELECT * FROM checkpoints WHERE 1=1"]
params = []
if filters:
for key, value in filters.items():
if isinstance(value, tuple) and len(value) == 2 and value[0] in ['>', '<', '>=', '<=', '!=', '=']:
# 处理范围查询或不等于
sql_parts.append(f"AND {key} {value[0]} ?")
params.append(value[1])
elif isinstance(value, list) and key.endswith('_in'): # 用于 IN 查询
field = key[:-3]
placeholders = ','.join('?' * len(value))
sql_parts.append(f"AND {field} IN ({placeholders})")
params.extend(value)
elif key.startswith('json_contains_'): # 示例:查询JSON字段是否包含某个值
json_field = key.split('json_contains_')[1]
sql_parts.append(f"AND json_extract({json_field}, '$') LIKE ?")
params.append(f"%{value}%")
elif key.startswith('json_path_equal_'): # 示例:查询JSON路径的值
parts = key.split('json_path_equal_')
json_field = parts[1].split('.')[0]
json_path = '$."' + parts[1].split('.', 1)[1].replace('.', '"."') + '"'
sql_parts.append(f"AND json_extract({json_field}, '{json_path}') = ?")
params.append(value)
else:
sql_parts.append(f"AND {key} = ?")
params.append(value)
sql_parts.append(f"ORDER BY {order_by}")
if limit is not None:
sql_parts.append("LIMIT ?")
params.append(limit)
if offset > 0:
sql_parts.append("OFFSET ?")
params.append(offset)
full_sql = " ".join(sql_parts)
try:
cursor = self.conn.cursor()
cursor.execute(full_sql, tuple(params))
rows = cursor.fetchall()
return [self._row_to_checkpoint(row) for row in rows]
except sqlite3.Error as e:
print(f"Error querying checkpoints: {e}")
raise
def load_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
"""根据 ID 加载单个 Checkpoint。"""
if not self.conn:
raise RuntimeError("Database connection not established.")
sql = "SELECT * FROM checkpoints WHERE id = ?"
try:
cursor = self.conn.cursor()
cursor.execute(sql, (checkpoint_id,))
row = cursor.fetchone()
if row:
return self._row_to_checkpoint(row)
return None
except sqlite3.Error as e:
print(f"Error loading checkpoint {checkpoint_id}: {e}")
raise
def get_latest_checkpoint(self, agent_id: str, episode_id: Optional[int] = None) -> Optional[Checkpoint]:
"""获取指定智能体(和回合)的最新 Checkpoint。"""
filters = {'agent_id': agent_id}
if episode_id is not None:
filters['episode_id'] = episode_id
results = self.query_checkpoints(filters=filters, limit=1, order_by="timestamp DESC")
return results[0] if results else None
关于 JSON 字段的查询:
SQLite 3.38+ 支持 JSON 函数,允许我们对存储在 TEXT 字段中的 JSON 数据进行更复杂的查询。在 query_checkpoints 方法中,我添加了 json_path_equal_ 和 json_contains_ 的简单示例。
json_path_equal_agent_internal_state.health: 允许你查询agent_internal_stateJSON 中health字段的值。json_contains_environment_state_json: 允许你查询environment_state_json字符串中是否包含某个子字符串。
实际生产级系统会构建更强大的 JSON 查询 DSL 或使用 ORM 库来简化这些操作。
4.4 一个简单智能体的模拟
现在我们来模拟一个简单的智能体,它在一个假想的环境中运行,并在关键时刻保存 Checkpoint。
class SimpleAgent(BaseAgent):
"""一个简单的智能体,模拟在环境中移动和收集物品。"""
def __init__(self, agent_id: str, initial_pos: tuple = (0, 0)):
super().__init__(agent_id)
self.position = initial_pos
self.health = 100
self.inventory = {"gold": 0, "potions": 1}
self.energy = 50
self.moves_made = 0
self.goal_reached = False
def _get_current_internal_state(self) -> Dict[str, Any]:
"""收集智能体当前的内部状态。"""
return {
"position": self.position,
"health": self.health,
"inventory": self.inventory,
"energy": self.energy,
"moves_made": self.moves_made,
"goal_reached": self.goal_reached
}
def _get_current_environment_state(self, observation: Dict[str, Any]) -> Dict[str, Any]:
"""收集智能体当前观测到的环境状态。"""
# 假设 observation 包含了环境的所有相关信息
return observation
def _decide_action(self, observation: Dict[str, Any]) -> str:
"""根据当前观测决定一个动作。"""
# 简化决策逻辑
if self.health <= 20 and self.inventory.get("potions", 0) > 0:
return "use_potion"
if observation.get("danger_level", 0) > 5 and self.health > 20:
return "flee"
if observation.get("target_in_range"):
self.goal_reached = True
return "reach_goal"
# 随机移动
possible_moves = ["move_north", "move_south", "move_east", "move_west"]
import random
return random.choice(possible_moves)
def perform_action(self, action: str, env_state: Dict[str, Any]):
"""执行动作并更新智能体状态。"""
self.moves_made += 1
self.energy -= 1 # 每次行动消耗能量
if action.startswith("move_"):
if "north" in action: self.position = (self.position[0], self.position[1] + 1)
elif "south" in action: self.position = (self.position[0], self.position[1] - 1)
elif "east" in action: self.position = (self.position[0] + 1, self.position[1])
elif "west" in action: self.position = (self.position[0] - 1, self.position[1])
self.health -= env_state.get("danger_level", 0) # 移动可能受到环境伤害
elif action == "use_potion":
if self.inventory.get("potions", 0) > 0:
self.inventory["potions"] -= 1
self.health = min(100, self.health + 30)
print(f"Agent {self.agent_id} used a potion. Health: {self.health}")
else:
print(f"Agent {self.agent_id} tried to use potion but has none.")
self.health -= 5 # 尝试失败也消耗
elif action == "flee":
self.position = (self.position[0] + random.randint(-2,2), self.position[1] + random.randint(-2,2))
self.health -= env_state.get("danger_level", 0) / 2 # 逃跑伤害减半
elif action == "reach_goal":
self.inventory["gold"] += 100
print(f"Agent {self.agent_id} reached goal! Gold: {self.inventory['gold']}")
# 模拟能量耗尽
if self.energy <= 0:
self.health = 0 # 能量耗尽导致死亡
def run_episode(self, env, max_steps: int, manager: CheckpointManager):
"""运行一个回合。"""
self.reset() # 重置智能体状态,增加 episode_id
current_episode_id = self.episode_id
print(f"n--- Running Episode {current_episode_id} for Agent {self.agent_id} ---")
for step in range(max_steps):
self.step_id = step
observation = env.get_observation(self.position) # 获取环境观测
# 保存观测 Checkpoint
cp_obs = self.save_checkpoint(event_type="observation_received", observation=observation)
manager.save_checkpoint(cp_obs)
action = self.act(observation) #