各位同仁,下午好!
欢迎来到今天的技术讲座。在构建复杂的AI应用,特别是那些涉及多步骤决策、状态管理以及与大型语言模型(LLM)交互的系统时,我们常常会面临一个核心挑战:数据的一致性、可预测性和健壮性。LangGraph 作为 LangChain 生态中用于构建有状态、循环性LLM应用的强大框架,它为我们提供了一种编排复杂逻辑的优雅方式。然而,随着图结构和节点数量的增加,如何确保数据在不同节点间的顺畅传递和正确处理,成为一个亟待解决的问题。
今天,我们将深入探讨一个高级话题:如何利用 Pydantic 的 BaseModel 实现 LangGraph 节点的强类型数据校验。这将使我们的 LangGraph 应用不仅逻辑清晰,更能在数据层面达到前所未有的健壮性与可维护性。
第一部分:LangGraph 核心概念与数据流挑战
在深入 Pydantic 之前,我们先快速回顾 LangGraph 的核心思想,并识别在数据流方面可能遇到的挑战。
1.1 LangGraph 简介
LangGraph 是 LangChain 的一个扩展,旨在通过图形结构来构建和管理复杂的多步骤 LLM 应用。它允许我们定义一系列节点(nodes)和边(edges),从而创建有状态的、可以进行循环的代理(agents)或工作流。其核心优势在于:
- 状态管理: LangGraph 维护一个单一的、可变的状态对象,在图中的节点之间传递。
- 图结构: 通过节点和边清晰地定义执行路径,支持条件路由和循环。
- 模块化: 每个节点都可以是一个独立的函数或组件,负责特定的任务。
一个典型的 LangGraph 应用由以下几个关键元素组成:
StateGraph: 定义图的骨架,包括状态的类型。- 状态(State): 在节点之间传递的数据。通常是一个
TypedDict或BaseModel。 - 节点(Nodes): 图中的处理单元,接收当前状态,执行逻辑,并返回状态的更新。
- 边(Edges): 连接节点,定义执行流。可以是固定的,也可以是基于状态的条件路由。
1.2 LangGraph 中的数据流与潜在挑战
在 LangGraph 中,数据流的核心是状态。每个节点接收当前状态作为输入,执行其逻辑,然后返回一个字典或 BaseModel 实例,用于更新状态。
考虑以下一个简化的 LangGraph 流程:
- 初始请求处理节点: 接收用户查询,可能从数据库获取初始上下文。
- LLM 调用节点: 使用查询和上下文调用 LLM,生成初步响应或行动计划。
- 工具使用节点: 根据 LLM 的输出,调用外部工具(如搜索、API)。
- 结果整合节点: 将工具结果与 LLM 响应整合,形成最终输出。
在这个过程中,数据在不同节点间传递,例如:
- 用户查询(字符串)
- 数据库上下文(列表、字典)
- LLM 输入(字符串、消息列表)
- LLM 输出(字符串、JSON 结构)
- 工具名称、参数(字符串、字典)
- 工具执行结果(任意类型)
如果没有严格的类型定义和校验,我们可能会面临以下问题:
- 运行时错误: 节点期望一个字符串,却收到了一个整数;期望一个字典,却收到了一个列表。这会导致
TypeError、KeyError或其他逻辑错误。 - 数据结构不一致: 某个节点返回的数据结构与下一个节点期望的不匹配,难以调试。
- 代码可读性差: 不清楚每个节点接收什么、返回什么,导致难以理解和维护。
- 难以协作: 团队成员难以理解其他人的节点接口,增加集成难度。
- 缺乏文档: 节点接口没有明确定义,需要阅读大量代码才能理解数据流。
这些问题在小型项目中可能不明显,但在大型、复杂的 LangGraph 应用中,它们会成为开发效率和系统稳定性的巨大障碍。
第二部分:Pydantic BaseModel 核心概念
Pydantic 是一个基于 Python 类型提示的数据校验和设置管理库。它允许我们用声明式的方式定义数据模型,并在运行时自动进行数据校验、序列化和反序列化。Pydantic 的核心是 BaseModel。
2.1 Pydantic BaseModel 基础
BaseModel 允许我们定义具有类型提示的类,Pydantic 会利用这些类型提示来自动校验传入的数据。
基本示例:
from pydantic import BaseModel, Field
from typing import List, Optional
class UserProfile(BaseModel):
user_id: str = Field(..., description="Unique identifier for the user")
name: str = Field(..., description="User's full name")
age: int = Field(..., gt=0, description="User's age, must be positive")
email: Optional[str] = Field(None, description="Optional email address")
is_active: bool = True
interests: List[str] = Field([], description="List of user's interests")
# 有效数据
try:
user1 = UserProfile(user_id="u123", name="Alice Smith", age=30, email="[email protected]")
print(f"Valid User 1: {user1.model_dump_json(indent=2)}")
user2 = UserProfile(user_id="u456", name="Bob Johnson", age=25, interests=["coding", "reading"])
print(f"Valid User 2: {user2.model_dump_json(indent=2)}")
except Exception as e:
print(f"Error creating user: {e}")
print("-" * 30)
# 无效数据:缺少必填字段
try:
UserProfile(name="Charlie") # user_id 和 age 是必填项
except Exception as e:
print(f"Error creating user (missing fields): {e}")
print("-" * 30)
# 无效数据:类型不匹配
try:
UserProfile(user_id="u789", name="David", age="twenty", email="[email protected]") # age 应该是整数
except Exception as e:
print(f"Error creating user (type mismatch): {e}")
print("-" * 30)
# 无效数据:age 小于等于 0
try:
UserProfile(user_id="u000", name="Eve", age=0)
except Exception as e:
print(f"Error creating user (age constraint): {e}")
解释:
UserProfile继承自BaseModel。- 我们使用标准 Python 类型提示(
str,int,Optional[str],List[str])来定义字段的类型。 Field函数用于提供更详细的元数据,如description(描述)和gt(大于某个值)等校验规则。...表示该字段是必填的。- 当创建
UserProfile实例时,Pydantic 会自动校验传入的数据是否符合定义的类型和规则。 - 如果数据无效,Pydantic 会抛出
ValidationError,其中包含详细的错误信息。 model_dump_json()(Pydantic v2) 或json()(Pydantic v1) 可以将模型实例序列化为 JSON 字符串。
2.2 Pydantic 进阶特性
Pydantic 提供了许多高级特性,它们在 LangGraph 的强类型数据校验中尤其有用。
2.2.1 嵌套模型 (Nested Models)
当数据结构变得复杂时,可以将 BaseModel 嵌套在另一个 BaseModel 中。
from pydantic import BaseModel, Field
from typing import List, Dict, Union, Literal
class Address(BaseModel):
street: str
city: str
zip_code: str = Field(pattern=r"^d{5}(-d{4})?$") # 美国邮编格式
class OrderItem(BaseModel):
product_id: str
quantity: int = Field(gt=0)
price: float = Field(gt=0)
class Order(BaseModel):
order_id: str
customer_id: str
shipping_address: Address
items: List[OrderItem]
status: Literal["pending", "shipped", "delivered", "cancelled"] = "pending"
metadata: Dict[str, Union[str, int, float]] = Field({}, description="Additional order metadata")
try:
order_data = {
"order_id": "ORD001",
"customer_id": "CUST123",
"shipping_address": {
"street": "123 Main St",
"city": "Anytown",
"zip_code": "12345"
},
"items": [
{"product_id": "P001", "quantity": 2, "price": 10.50},
{"product_id": "P002", "quantity": 1, "price": 25.00}
],
"status": "pending",
"metadata": {"source": "web", "priority": 1}
}
order = Order(**order_data)
print(f"Valid Order: {order.model_dump_json(indent=2)}")
except Exception as e:
print(f"Error creating order: {e}")
print("-" * 30)
# 尝试创建无效订单:地址邮编格式不正确
try:
invalid_order_data = {
"order_id": "ORD002",
"customer_id": "CUST124",
"shipping_address": {
"street": "456 Oak Ave",
"city": "Otherville",
"zip_code": "ABCDE" # 无效邮编
},
"items": [
{"product_id": "P003", "quantity": 1, "price": 50.00}
]
}
Order(**invalid_order_data)
except Exception as e:
print(f"Error creating invalid order: {e}")
2.2.2 自定义校验器 (model_validator)
Pydantic 允许我们定义更复杂的、跨字段的校验逻辑。在 Pydantic v2 中,推荐使用 model_validator。
from pydantic import BaseModel, Field, ValidationError, model_validator
from typing import Optional
class EventBooking(BaseModel):
event_name: str
start_time: str # 简化为字符串,实际应为datetime
end_time: str # 简化为字符串
num_attendees: int = Field(gt=0)
organizer_email: str
contact_phone: Optional[str] = None
@model_validator(mode='after')
def validate_time_order(self) -> 'EventBooking':
# 实际应用中,这里应该将start_time和end_time解析为datetime对象进行比较
# 为简化示例,仅作字符串比较(在严格时间格式下可能有效)
if self.start_time >= self.end_time:
raise ValueError("End time must be after start time")
return self
try:
booking1 = EventBooking(
event_name="Team Meeting",
start_time="2023-10-26T10:00:00",
end_time="2023-10-26T11:00:00",
num_attendees=5,
organizer_email="[email protected]"
)
print(f"Valid Booking: {booking1.model_dump_json(indent=2)}")
except ValidationError as e:
print(f"Error creating booking: {e}")
print("-" * 30)
try:
# 结束时间早于开始时间
booking2 = EventBooking(
event_name="Invalid Meeting",
start_time="2023-10-26T12:00:00",
end_time="2023-10-26T11:00:00",
num_attendees=3,
organizer_email="[email protected]"
)
except ValidationError as e:
print(f"Error creating invalid booking (time order): {e}")
2.2.3 字段别名 (Field Aliases)
当外部数据源的字段名与我们希望在 Python 代码中使用的属性名不一致时,可以使用别名。
from pydantic import BaseModel, Field
class ExternalData(BaseModel):
item_id: str = Field(alias="itemId") # 外部是itemId,内部用item_id
item_name: str = Field(alias="itemName")
price_usd: float = Field(alias="priceUSD")
# 配置允许通过别名或原始字段名初始化
class Config:
populate_by_name = True # Pydantic v2: from_attributes = True
try:
raw_data = {
"itemId": "X001",
"itemName": "Widget Pro",
"priceUSD": 99.99
}
data = ExternalData(**raw_data)
print(f"Parsed External Data: {data.model_dump_json(indent=2)}")
print(f"Accessing via internal name: {data.item_id}, {data.item_name}, {data.price_usd}")
# 也可以使用原始字段名创建,如果 Config 设置了 populate_by_name=True
data_with_internal_names = ExternalData(item_id="Y002", item_name="Gadget Lite", price_usd=19.99)
print(f"Created with internal names: {data_with_internal_names.model_dump_json(indent=2)}")
except Exception as e:
print(f"Error parsing external data: {e}")
2.2.4 JSON Schema 生成
Pydantic 模型可以自动生成 JSON Schema,这对于 API 文档、数据契约定义和跨语言数据共享非常有帮助。
from pydantic import BaseModel, Field
import json
class Product(BaseModel):
product_id: str = Field(description="Unique ID of the product")
name: str = Field(description="Name of the product")
category: str = Field(description="Product category")
price: float = Field(gt=0, description="Price in USD, must be positive")
in_stock: bool = True
schema = Product.model_json_schema() # Pydantic v2
print(json.dumps(schema, indent=2))
这些 Pydantic 的特性为我们构建 LangGraph 提供了强大的数据建模和校验能力。
第三部分:将 Pydantic 引入 LangGraph 节点
现在,我们来探讨如何将 Pydantic BaseModel 融入 LangGraph 的节点定义中,实现端到端的数据强类型校验。
3.1 核心策略
我们的核心策略是:
- 定义图状态 (
GraphState) 为BaseModel: 这使得整个图的状态都具有强类型,易于理解和管理。 - 定义节点输入 (
NodeInput) 为BaseModel: 每个节点都明确声明它期望接收的数据结构。 - 定义节点输出 (
NodeOutput) 为BaseModel: 每个节点都明确声明它将返回的数据结构,用于更新图状态。 - 在节点函数中使用类型提示: 让节点函数签名反映 Pydantic 模型。
- 利用 Pydantic 的自动校验: 当数据从状态流向节点输入,或从节点输出流向状态时,Pydantic 会自动进行校验。
3.2 示例 1:简单的节点输入/输出校验
假设我们有一个简单的 LangGraph,用于处理一个用户请求,并记录处理步骤。
首先,定义我们的图状态。这里我们使用 BaseModel,而不是 TypedDict,以获得 Pydantic 的全部优势。
# graph_definitions.py
from typing import TypedDict, Annotated, List, Optional, Literal
from langgraph.graph.message import AnyMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from pydantic import BaseModel, Field, ValidationError
import operator
# 1. 定义 LangGraph 的状态 (推荐使用 BaseModel)
class AgentState(BaseModel):
"""
Represents the state of our LangGraph agent.
All data flowing through the graph will conform to this schema.
"""
request_id: str = Field(..., description="Unique ID for the current request")
user_query: str = Field(..., description="The original query from the user")
context: List[str] = Field([], description="Relevant context retrieved for the query")
llm_response: Optional[str] = Field(None, description="Response from the LLM")
tool_output: Optional[str] = Field(None, description="Output from any tool calls")
current_step: Literal["start", "retrieve_context", "call_llm", "use_tool", "finish"] = "start"
error_message: Optional[str] = Field(None, description="Any error encountered during processing")
# For LangGraph state, BaseModel requires a dictionary representation for updates.
# While LangGraph can often handle BaseModel directly, explicit conversion might be needed
# for partial updates. We'll show how to manage this.
def to_dict(self):
return self.model_dump()
def update_from_dict(self, update_dict: dict):
return self.model_copy(update=update_dict)
# 2. 定义节点的输入和输出模型
class RetrieveContextInput(BaseModel):
"""Input schema for the retrieve_context node."""
request_id: str
user_query: str
existing_context: List[str] = Field([], description="Any context already present in state")
class RetrieveContextOutput(BaseModel):
"""Output schema for the retrieve_context node."""
context: List[str]
current_step: Literal["retrieve_context"] = "retrieve_context"
class CallLLMInput(BaseModel):
"""Input schema for the call_llm node."""
request_id: str
user_query: str
context: List[str]
class CallLLMOutput(BaseModel):
"""Output schema for the call_llm node."""
llm_response: str
current_step: Literal["call_llm"] = "call_llm"
class UseToolInput(BaseModel):
"""Input schema for the use_tool node."""
request_id: str
llm_response: str
user_query: str # Useful for context if tool needs it
class UseToolOutput(BaseModel):
"""Output schema for the use_tool node."""
tool_output: str
current_step: Literal["use_tool"] = "use_tool"
接下来,我们定义节点函数,并用 Pydantic 模型进行类型提示。
# nodes.py
import time
from graph_definitions import AgentState, RetrieveContextInput, RetrieveContextOutput,
CallLLMInput, CallLLMOutput, UseToolInput, UseToolOutput, ValidationError
from typing import Dict, Any
# Mock LLM and Tool
def mock_llm_call(query: str, context: List[str]) -> str:
"""Simulates an LLM call."""
time.sleep(0.1) # Simulate delay
if "error" in query.lower():
raise ValueError("Simulated LLM error")
response = f"LLM responded to '{query}' with context: {', '.join(context)}. Potential action: search for '{query}'"
return response
def mock_tool_call(action: str) -> str:
"""Simulates a tool call (e.g., search engine)."""
time.sleep(0.05) # Simulate delay
if "fail tool" in action.lower():
raise ValueError("Simulated Tool failure")
return f"Tool executed '{action}' and found relevant data."
# Node 1: Retrieve Context
def retrieve_context_node(state: AgentState) -> Dict[str, Any]:
"""
Retrieves context based on the user query.
Input: AgentState (contains user_query, request_id)
Output: Updated context, current_step
"""
try:
# Validate input from state
# We explicitly create a Pydantic model here to leverage its validation
# Only extract the fields relevant for this node's input.
node_input = RetrieveContextInput(
request_id=state.request_id,
user_query=state.user_query,
existing_context=state.context # Pass existing context if any
)
print(f"[{node_input.request_id}] Node 'retrieve_context' received: {node_input.model_dump()}")
# Simulate context retrieval
new_context = ["latest news", "user preferences", f"query keywords: {node_input.user_query.split()}"]
if node_input.existing_context:
new_context.extend(node_input.existing_context)
# Prepare output
output = RetrieveContextOutput(context=new_context, current_step="retrieve_context")
print(f"[{node_input.request_id}] Node 'retrieve_context' returning: {output.model_dump()}")
return output.model_dump() # Return as dict for state update
except ValidationError as e:
print(f"[{state.request_id if state else 'N/A'}] Validation error in retrieve_context_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
except Exception as e:
print(f"[{state.request_id if state else 'N/A'}] Error in retrieve_context_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
# Node 2: Call LLM
def call_llm_node(state: AgentState) -> Dict[str, Any]:
"""
Calls the LLM with the user query and retrieved context.
Input: AgentState (contains user_query, context, request_id)
Output: Updated llm_response, current_step
"""
try:
node_input = CallLLMInput(
request_id=state.request_id,
user_query=state.user_query,
context=state.context
)
print(f"[{node_input.request_id}] Node 'call_llm' received: {node_input.model_dump()}")
llm_response = mock_llm_call(node_input.user_query, node_input.context)
output = CallLLMOutput(llm_response=llm_response, current_step="call_llm")
print(f"[{node_input.request_id}] Node 'call_llm' returning: {output.model_dump()}")
return output.model_dump()
except ValidationError as e:
print(f"[{state.request_id if state else 'N/A'}] Validation error in call_llm_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
except Exception as e:
print(f"[{state.request_id if state else 'N/A'}] Error in call_llm_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
# Node 3: Use Tool (Conditional)
def use_tool_node(state: AgentState) -> Dict[str, Any]:
"""
Uses a tool based on the LLM's response.
Input: AgentState (contains llm_response, user_query, request_id)
Output: Updated tool_output, current_step
"""
try:
node_input = UseToolInput(
request_id=state.request_id,
llm_response=state.llm_response,
user_query=state.user_query
)
print(f"[{node_input.request_id}] Node 'use_tool' received: {node_input.model_dump()}")
# Simple logic to decide tool use
if "potential action: search" in node_input.llm_response.lower():
tool_output = mock_tool_call(f"search for {node_input.user_query}")
else:
tool_output = "No specific tool action identified."
output = UseToolOutput(tool_output=tool_output, current_step="use_tool")
print(f"[{node_input.request_id}] Node 'use_tool' returning: {output.model_dump()}")
return output.model_dump()
except ValidationError as e:
print(f"[{state.request_id if state else 'N/A'}] Validation error in use_tool_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
except Exception as e:
print(f"[{state.request_id if state else 'N/A'}] Error in use_tool_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
# Conditional Edge: Decide next step based on LLM response
def decide_next_step(state: AgentState) -> str:
"""
Decides whether to use a tool or finish based on the LLM's response or error.
"""
if state.error_message:
return "end_with_error"
if state.llm_response and "potential action: search" in state.llm_response.lower():
return "use_tool"
return "finish"
最后,构建并运行 LangGraph。
# main.py
import uuid
from langgraph.graph import StateGraph, START, END
from graph_definitions import AgentState
from nodes import retrieve_context_node, call_llm_node, use_tool_node, decide_next_step
# Build the graph
workflow = StateGraph(AgentState)
workflow.add_node("retrieve_context", retrieve_context_node)
workflow.add_node("call_llm", call_llm_node)
workflow.add_node("use_tool", use_tool_node)
workflow.set_entry_point("retrieve_context")
# Define edges
workflow.add_edge("retrieve_context", "call_llm")
workflow.add_conditional_edges(
"call_llm",
decide_next_step,
{
"use_tool": "use_tool",
"finish": END,
"end_with_error": END # If error occurs, directly end
}
)
workflow.add_edge("use_tool", END)
# Compile the graph
app = workflow.compile()
print("--- Running valid request ---")
initial_state_1 = AgentState(request_id=str(uuid.uuid4()), user_query="What's the weather like today?")
final_state_1 = app.invoke(initial_state_1)
print("nFinal State 1:")
print(final_state_1.model_dump_json(indent=2))
print("n" + "="*50 + "n")
print("--- Running request with simulated LLM error ---")
initial_state_2 = AgentState(request_id=str(uuid.uuid4()), user_query="Simulate an error from LLM")
final_state_2 = app.invoke(initial_state_2)
print("nFinal State 2:")
print(final_state_2.model_dump_json(indent=2))
print("n" + "="*50 + "n")
print("--- Running request with simulated Tool failure ---")
# To simulate tool failure, we need LLM to trigger tool, then tool fails.
# Let's adjust mock_llm_call to always trigger tool for this specific query.
# For simplicity, we'll just force the LLM output here.
initial_state_3 = AgentState(
request_id=str(uuid.uuid4()),
user_query="How to 'fail tool'?",
llm_response="LLM responded to 'How to fail tool?' with context: .... Potential action: search for 'fail tool'" # Manual override for demo
)
final_state_3 = app.invoke(initial_state_3)
print("nFinal State 3:")
print(final_state_3.model_dump_json(indent=2))
解释:
AgentState(BaseModel): 我们将整个图的状态定义为AgentState,它是一个BaseModel。这意味着整个图的所有数据都将遵循这个严格的模式。NodeInput/NodeOutput(BaseModel): 每个节点都有明确的Input和Output模型。retrieve_context_node期望RetrieveContextInput的结构,并返回RetrieveContextOutput的结构。- 在节点函数内部,我们通过
node_input = NodeInputModel(**state.model_dump())(或者只传入相关字段)来从AgentState中提取并校验输入。 - 节点函数返回一个字典 (
output.model_dump()),LangGraph 会使用这个字典来更新AgentState。由于AgentState也是BaseModel,它会尝试将这些更新映射到自身的字段上,并进行校验。
- 错误处理: 每个节点都包含
try...except ValidationError块。如果传入的数据不符合NodeInput模型的定义,或者节点内部逻辑出错,它会捕获异常并返回一个包含error_message的状态更新,将current_step设置为 "finish",从而引导图提前结束。
这种模式极大地增强了每个节点的健壮性,并使数据流向一目了然。
3.3 示例 2:复杂数据结构与嵌套模型
当节点需要处理更复杂的数据时,Pydantic 的嵌套模型就派上用场了。假设我们有一个节点,用于分析用户提供的产品列表。
# graph_definitions.py (continuing from previous, or new file)
from pydantic import BaseModel, Field, ValidationError
from typing import List, Optional, Literal, Dict, Any
# Assuming AgentState is already defined in graph_definitions.py
# class AgentState(...):
# Nested Models for product analysis
class ProductDetails(BaseModel):
name: str = Field(..., description="Name of the product")
category: str = Field(..., description="Category of the product")
price: float = Field(gt=0, description="Price of the product, must be positive")
quantity: int = Field(gt=0, description="Quantity in cart, must be positive")
product_id: str = Field(..., description="Unique product identifier")
class AnalyzeProductsInput(BaseModel):
"""Input schema for the analyze_products node."""
request_id: str
products: List[ProductDetails] = Field(..., min_length=1, description="List of products to analyze")
user_preferences: Optional[Dict[str, Any]] = Field(None, description="User's known preferences")
class AnalyzeProductsOutput(BaseModel):
"""Output schema for the analyze_products node."""
analysis_report: Dict[str, Any] = Field(..., description="Summary of product analysis")
total_cost: float = Field(..., description="Total cost of all products")
recommended_actions: List[str] = Field([], description="Recommended actions based on analysis")
current_step: Literal["analyze_products"] = "analyze_products"
# nodes.py (add this node)
# ... other nodes ...
def analyze_products_node(state: AgentState) -> Dict[str, Any]:
"""
Analyzes a list of products provided in the state.
Input: AgentState (expects products in a specific format)
Output: Analysis report, total cost, recommended actions
"""
try:
# For this example, let's assume AgentState has a 'products_to_analyze' field
# We need to ensure AgentState can hold this. Let's update AgentState temporarily for this demo.
# In a real app, AgentState would be designed upfront.
# For demonstration, we'll manually craft input or update AgentState in main.py
# We'll create a dummy state for this node's input for demonstration convenience
# In a real LangGraph, state.products_to_analyze would be populated by a previous node.
# Assuming `state` contains a `products_to_analyze` field (which would need to be added to AgentState for real use)
temp_input_dict = {
"request_id": state.request_id,
"products": state.products_to_analyze # This field needs to be added to AgentState for a real scenario
}
node_input = AnalyzeProductsInput(**temp_input_dict) # This will validate the products list
print(f"[{node_input.request_id}] Node 'analyze_products' received: {node_input.model_dump()}")
total_cost = sum(p.price * p.quantity for p in node_input.products)
num_items = sum(p.quantity for p in node_input.products)
unique_categories = list(set(p.category for p in node_input.products))
report = {
"num_products": len(node_input.products),
"total_items_quantity": num_items,
"unique_categories": unique_categories,
"average_price_per_item": total_cost / num_items if num_items > 0 else 0
}
recommendations = []
if total_cost > 1000:
recommendations.append("Consider reviewing large spending.")
if len(unique_categories) > 3:
recommendations.append("Diverse product selection, check for bundle opportunities.")
output = AnalyzeProductsOutput(
analysis_report=report,
total_cost=total_cost,
recommended_actions=recommendations,
current_step="analyze_products"
)
print(f"[{node_input.request_id}] Node 'analyze_products' returning: {output.model_dump()}")
return output.model_dump()
except ValidationError as e:
print(f"[{state.request_id if state else 'N/A'}] Validation error in analyze_products_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
except Exception as e:
print(f"[{state.request_id if state else 'N/A'}] Error in analyze_products_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
为了运行这个例子,我们需要修改 AgentState 来包含 products_to_analyze 字段,并在 main.py 中添加这个节点和相应的边。
修改 AgentState (在 graph_definitions.py 中):
# ... existing AgentState ...
class AgentState(BaseModel):
# ... existing fields ...
products_to_analyze: List[ProductDetails] = Field([], description="List of products for analysis")
analysis_report: Optional[Dict[str, Any]] = Field(None, description="Report from product analysis")
total_cost: Optional[float] = Field(None, description="Total cost from product analysis")
recommended_actions: List[str] = Field([], description="Actions recommended by analysis")
# ... existing fields ...
在 main.py 中添加节点和测试:
# main.py (add to existing)
# ... imports ...
from graph_definitions import AgentState, ProductDetails, AnalyzeProductsInput, AnalyzeProductsOutput
from nodes import analyze_products_node # Import the new node
# ... existing workflow setup ...
workflow.add_node("analyze_products", analyze_products_node)
# Let's add an edge after call_llm or context retrieval for this example
# For simplicity, we'll make it another conditional path or a direct one for demo.
# Let's assume after call_llm, if LLM suggests product analysis, we go there.
# Modify decide_next_step or add a new one. For this demo, we'll just add a direct path from START for a new test case.
# Temporarily, for demo, let's create a separate graph or path for product analysis.
# Or, simpler: invoke with a state that directly triggers this node.
# For a realistic scenario, you'd add a conditional edge:
# workflow.add_conditional_edges(
# "call_llm",
# lambda state: "analyze_products" if "analyze products" in state.llm_response.lower() else decide_next_step(state),
# {
# "analyze_products": "analyze_products",
# "use_tool": "use_tool",
# "finish": END,
# "end_with_error": END
# }
# )
# workflow.add_edge("analyze_products", END) # Or to another node
# For this demo, we will invoke `analyze_products_node` directly or via a simple graph.
# Let's just create a new graph for this specific feature demo to avoid complexity.
workflow_products = StateGraph(AgentState)
workflow_products.add_node("analyze_products", analyze_products_node)
workflow_products.set_entry_point("analyze_products")
workflow_products.add_edge("analyze_products", END)
app_products = workflow_products.compile()
print("n" + "="*50 + "n")
print("--- Running product analysis request ---")
products_data = [
ProductDetails(product_id="P1", name="Laptop", category="Electronics", price=1200.0, quantity=1),
ProductDetails(product_id="P2", name="Mouse", category="Electronics", price=25.50, quantity=2),
ProductDetails(product_id="P3", name="Book", category="Books", price=15.0, quantity=3)
]
initial_state_products_1 = AgentState(
request_id=str(uuid.uuid4()),
user_query="Analyze my product list",
products_to_analyze=products_data # This field is now part of AgentState
)
final_state_products_1 = app_products.invoke(initial_state_products_1)
print("nFinal Product Analysis State 1:")
print(final_state_products_1.model_dump_json(indent=2))
print("n" + "="*50 + "n")
print("--- Running product analysis request with invalid data ---")
invalid_products_data = [
ProductDetails(product_id="P4", name="Monitor", category="Electronics", price=300.0, quantity=1),
{"product_id": "P5", "name": "Keyboard", "category": "Electronics", "price": -50.0, "quantity": 1} # Invalid price
]
initial_state_products_2 = AgentState(
request_id=str(uuid.uuid4()),
user_query="Analyze invalid product list",
products_to_analyze=invalid_products_data # Pydantic will catch this during input validation
)
final_state_products_2 = app_products.invoke(initial_state_products_2)
print("nFinal Product Analysis State 2 (Invalid Input):")
print(final_state_products_2.model_dump_json(indent=2))
这个例子展示了如何使用嵌套 Pydantic 模型来处理更复杂的数据结构,确保列表中的每个元素也符合预期的模式。当 products_to_analyze 中的某个产品数据不符合 ProductDetails 模型(如价格为负数)时,AnalyzeProductsInput 的校验就会失败,从而捕获错误。
3.4 示例 3:结合 LangChain 工具
LangChain 的工具(Tools)也广泛使用 Pydantic BaseModel 来定义其输入模式 (args_schema)。这使得 Pydantic 在 LangGraph 中与 LangChain 工具的集成变得天衣无缝。
# graph_definitions.py (add to AgentState)
class AgentState(BaseModel):
# ... existing fields ...
tool_name: Optional[str] = Field(None, description="Name of the tool to be called")
tool_args: Optional[Dict[str, Any]] = Field(None, description="Arguments for the tool call")
# ... existing fields ...
# tools.py
from langchain_core.tools import tool
from pydantic import BaseModel, Field
# Define a Pydantic model for the tool's input arguments
class SearchToolInput(BaseModel):
query: str = Field(description="The search query to execute")
num_results: int = Field(default=3, description="Number of search results to return")
@tool("internet_search", args_schema=SearchToolInput)
def internet_search(query: str, num_results: int) -> str:
"""
Performs an internet search and returns the top results.
"""
print(f"Executing internet_search with query='{query}', num_results={num_results}")
# Simulate a search API call
if "fail search" in query.lower():
raise ValueError("Simulated search API failure")
results = [
f"Result 1 for '{query}'",
f"Result 2 for '{query}'",
f"Result 3 for '{query}'"
]
return "n".join(results[:num_results])
# Another tool example
class CalculatorToolInput(BaseModel):
expression: str = Field(description="Mathematical expression to evaluate (e.g., '2 + 2 * 3')")
@tool("calculator", args_schema=CalculatorToolInput)
def calculator(expression: str) -> str:
"""
Evaluates a mathematical expression.
"""
print(f"Executing calculator with expression='{expression}'")
try:
return str(eval(expression)) # WARNING: eval is dangerous in real apps, for demo only
except Exception as e:
return f"Error evaluating expression: {e}"
# nodes.py (add a new node for tool execution)
from tools import internet_search, calculator
from langchain_core.tools import BaseTool
def execute_tool_node(state: AgentState) -> Dict[str, Any]:
"""
Executes the tool specified in the state.
"""
try:
if not state.tool_name or not state.tool_args:
raise ValueError("Tool name or arguments are missing from state.")
tool_map: Dict[str, BaseTool] = {
"internet_search": internet_search,
"calculator": calculator
}
selected_tool = tool_map.get(state.tool_name)
if not selected_tool:
raise ValueError(f"Unknown tool: {state.tool_name}")
print(f"[{state.request_id}] Node 'execute_tool' calling tool '{state.tool_name}' with args: {state.tool_args}")
# LangChain tools with args_schema will automatically validate state.tool_args
# against the Pydantic schema defined for the tool.
tool_output_result = selected_tool.invoke(state.tool_args)
output = UseToolOutput(tool_output=tool_output_result, current_step="use_tool") # Reuse UseToolOutput
print(f"[{state.request_id}] Node 'execute_tool' returning: {output.model_dump()}")
return output.model_dump()
except ValidationError as e:
print(f"[{state.request_id if state else 'N/A'}] Validation error in execute_tool_node (tool args): {e}")
return {"error_message": str(e), "current_step": "finish"}
except Exception as e:
print(f"[{state.request_id if state else 'N/A'}] Error in execute_tool_node: {e}")
return {"error_message": str(e), "current_step": "finish"}
# main.py (integrating tool execution)
# ... imports ...
from nodes import execute_tool_node
# ... workflow setup ...
workflow.add_node("execute_tool", execute_tool_node)
# Modify decide_next_step to also route to execute_tool if LLM suggests a tool call
def decide_next_step_with_tools(state: AgentState) -> str:
if state.error_message:
return "end_with_error"
if state.llm_response:
# Simple heuristic: if LLM suggests a tool, extract it.
# In a real agent, this would be a more sophisticated LLM parsing step.
if "call tool internet_search" in state.llm_response.lower():
state.tool_name = "internet_search"
state.tool_args = {"query": "current events", "num_results": 2} # Hardcoded for demo
return "execute_tool"
if "call tool calculator" in state.llm_response.lower():
state.tool_name = "calculator"
state.tool_args = {"expression": "15 * 3 + 7"} # Hardcoded for demo
return "execute_tool"
# Original logic
if state.llm_response and "potential action: search" in state.llm_response.lower():
# This path can still exist if we have a generic "use_tool" node
return "use_tool" # This would be our old generic tool node if kept
return "finish"
workflow.add_conditional_edges(
"call_llm",
decide_next_step_with_tools, # Use the new conditional logic
{
"execute_tool": "execute_tool",
"use_tool": "use_tool", # Keep if you have a generic use_tool node
"finish": END,
"end_with_error": END
}
)
workflow.add_edge("execute_tool", END) # Tool execution directly ends for this demo
app_with_tools = workflow.compile()
print("n" + "="*50 + "n")
print("--- Running request triggering internet_search tool ---")
initial_state_tool_1 = AgentState(
request_id=str(uuid.uuid4()),
user_query="What's happening in the world?",
llm_response="Here's what I found. I suggest you call tool internet_search to get current events." # Simulate LLM suggestion
)
final_state_tool_1 = app_with_tools.invoke(initial_state_tool_1)
print("nFinal State (Internet Search):")
print(final_state_tool_1.model_dump_json(indent=2))
print("n" + "="*50 + "n")
print("--- Running request triggering calculator tool with invalid args ---")
initial_state_tool_2 = AgentState(
request_id=str(uuid.uuid4()),
user_query="Calculate something",
llm_response="I suggest you call tool calculator for calculation.",
tool_name="calculator",
tool_args={"expression": 100} # Invalid type for expression
)
final_state_tool_2 = app_with_tools.invoke(initial_state_tool_2)
print("nFinal State (Invalid Calculator Args):")
print(final_state_tool_2.model_dump_json(indent=2))
在这个例子中:
- LangChain 的
@tool装饰器通过args_schema参数接受一个 PydanticBaseModel来定义工具的输入。 execute_tool_node接收AgentState,从中提取tool_name和tool_args。- 当调用
selected_tool.invoke(state.tool_args)时,LangChain 会自动使用tool定义的args_schema来校验state.tool_args。如果state.tool_args不符合SearchToolInput或CalculatorToolInput的模式(例如,expression期望字符串但得到整数),Pydantic 会抛出ValidationError,并被节点捕获。 - 这确保了即使是工具的输入,也得到了强类型校验,进一步提升了系统的健壮性。
第四部分:高级模式与最佳实践
4.1 统一的 GraphState 作为 BaseModel
我们已经看到了将 AgentState 定义为 BaseModel 的好处。这是实现端到端强类型校验的关键。
优势:
- 单一事实来源: 整个图的状态结构清晰定义。
- 自动校验: 每次状态更新(通过节点返回的字典合并)都会触发对
AgentState模型的隐式校验。 - 可序列化:
BaseModel实例可以方便地序列化为 JSON,这对于持久化状态、日志记录和调试非常有用。 - 文档化:
AgentState本身就是自文档化的数据契约。
状态更新的考虑:
当节点返回一个字典来更新 AgentState 时,LangGraph 会尝试合并这些更新。如果 AgentState 是一个 BaseModel,这意味着它会用返回的字典更新其字段。
# AgentState defined as BaseModel
class AgentState(BaseModel):
# ... fields ...
# Inside a node function
def my_node(state: AgentState) -> Dict[str, Any]:
# ... logic ...
new_data = {"llm_response": "Processed response", "current_step": "done_llm"}
return new_data # LangGraph will merge this into the AgentState instance
LangGraph 内部会类似 state.model_copy(update=new_data) 来更新状态,这会触发 Pydantic 的校验。
4.2 节点签名设计
为了最大化 Pydantic 的效益和代码清晰度,推荐以下节点签名设计:
- 输入: 节点函数应接收
AgentState作为其第一个参数。 - 输出: 节点函数应返回一个
Dict[str, Any],其中键是AgentState中要更新的字段名,值是新的数据。
from typing import Dict, Any
from graph_definitions import AgentState, MyNodeInput, MyNodeOutput
def my_node_function(state: AgentState) -> Dict[str, Any]:
"""
A typical LangGraph node function using Pydantic for input/output validation.
"""
try:
# 1. 提取并校验输入 (从 AgentState 中挑选相关字段)
node_input = MyNodeInput(
field1=state.some_field_1,
field2=state.another_field
)
print(f"Node input: {node_input.model_dump()}")
# 2. 执行核心逻辑
processed_data = "some_result_based_on_input"
computed_value = 123
# 3. 构建并校验输出
node_output = MyNodeOutput(
result_field=processed_data,
new_value=computed_value,
status="success"
)
print(f"Node output: {node_output.model_dump()}")
# 4. 返回字典以更新 AgentState
return node_output.model_dump()
except ValidationError as e:
# 捕获输入或输出校验错误
print(f"Validation Error in node: {e}")
return {"error_message": str(e), "current_step": "failure"}
except Exception as e:
# 捕获其他运行时错误
print(f"Runtime Error in node: {e}")
return {"error_message": str(e), "current_step": "failure"}
这种模式确保了每个节点都明确其输入和输出的契约,并且在数据流动的每个阶段都受到 Pydantic 的保护。
4.3 错误处理与图路由
当 Pydantic 校验失败时,我们不希望整个应用崩溃。在节点内部捕获 ValidationError 并返回一个特定的状态更新是最佳实践。
error_message: Optional[str]: 在AgentState中添加一个字段来存储错误信息。current_step: Literal[...]: 利用状态中的一个字段来指示当前处理阶段或是否发生错误。- 条件边: 使用条件边根据
error_message或current_step的值将图路由到错误处理路径或直接结束。
# Conditional edge example (from our earlier example)
def decide_next_step(state: AgentState) -> str:
if state.error_message: # If any node set an error_message
return "end_with_error"
# ... other conditions ...
return "finish"
# In workflow definition:
workflow.add_conditional_edges(
"call_llm", # Or any node that might set an error
decide_next_step,
{
# ... normal paths ...
"end_with_error": END # Route to END if an error occurred
}
)
4.4 文档与可维护性
Pydantic 模型本身就是极佳的文档。
- 清晰的接口: 任何查看
AgentState、NodeInput或NodeOutput模型的人都能立即理解数据结构、类型、默认值和约束。 - 减少歧义: 无需猜测字段的含义或类型。
- 便于协作: 团队成员在开发不同节点时,可以基于明确的 Pydantic 模型进行接口约定。
- 重构安全: 修改模型会立即暴露所有不兼容的代码,而不是在运行时才发现。
4.5 测试
强类型校验极大地简化了单元测试和集成测试。
- 单元测试: 可以针对每个 Pydantic 模型编写测试,确保它们正确地校验有效和无效数据。
- 节点测试: 测试每个节点时,可以轻松地构造符合
NodeInput模型的有效输入,以及故意构造不符合模型的无效输入,验证错误处理逻辑。 - 集成测试: 整个图的数据流在每个阶段都有类型保证,减少了因数据格式问题导致的集成错误。
第五部分:性能考量与权衡
Pydantic 在进行数据校验时确实会引入一定的运行时开销。每次创建 BaseModel 实例或更新状态时,都会执行类型检查、默认值设置、转换和自定义校验器。
然而,在大多数 LangGraph 应用中,这种性能开销通常是可以忽略不计的,特别是与以下因素相比:
- LLM 调用的延迟: LLM 模型的推理时间通常是几十毫秒到几秒,Pydantic 的校验时间通常在微秒到毫秒级别。
- I/O 操作(数据库、API 调用): 网络延迟和磁盘 I/O 也是主要的时间消耗点。
- 业务逻辑的复杂性: 节点内部的复杂计算也可能远超 Pydantic 校验的开销。
权衡:
虽然 Pydantic 增加了微小的运行时开销,但它带来的开发效率、系统健壮性、可维护性和调试成本降低的收益是巨大的。在绝大多数场景下,这种权衡是值得的。
只有在极端的高吞吐量、超低延迟的系统中,且 Pydantic 校验成为显著瓶颈时,才需要考虑优化(例如,缓存已校验数据,或在某些路径上跳过校验)。但对于 LLM 编排应用,这种情况非常罕见。
通过将 Pydantic BaseModel 深度集成到 LangGraph 的状态和节点定义中,我们能够构建出更健壮、可维护且易于理解的 AI 应用程序。这种强类型数据校验的方法不仅提升了代码质量,也为团队协作和未来扩展奠定了坚实的基础。它将帮助我们从容应对复杂数据流带来的挑战,专注于核心业务逻辑的实现。