跨模态嵌入对齐:在LangGraph中实现文本与图像记忆的联合检索
在人工智能领域,我们正在从单一模态的理解走向多模态的融合。传统上,我们处理文本时使用文本模型,处理图像时使用图像模型,它们各自在自己的领域内表现出色。然而,人类的认知并非如此割裂,我们通过语言描述图像,通过图像理解语言,这是一种天然的跨模态交互。
“跨模态嵌入对齐”(Cross-modal Embedding Alignment)正是为了弥合这种模态间的鸿沟而生。它的核心思想是将来自不同模态(如文本、图像、音频、视频等)的数据映射到一个共同的、低维的向量空间中。在这个共享的潜在空间里,语义上相似的文本和图像(或其它模态数据)其对应的向量表示会彼此靠近,而语义上不相关的向量则会相互远离。这种对齐使得我们能够用一种模态的查询(例如一段文本描述)去检索另一种模态的数据(例如相关的图像),反之亦然,甚至能够实现模态间的联合检索和推理。
在复杂的AI系统中,特别是那些需要模拟人类认知和记忆的智能体(Agents)中,联合检索能力至关重要。一个智能体需要能够根据用户的文本描述,回忆起相关的文本知识点,同时也能联想到相关的视觉记忆。LangGraph,作为LangChain的强大扩展,提供了一种灵活且强大的框架来构建有状态、多步、具有循环和条件逻辑的智能体工作流。将跨模态嵌入对齐技术融入LangGraph,可以为智能体赋予更丰富、更接近人类的记忆和感知能力。
一、跨模态嵌入对齐的核心概念与技术基石
要理解和实现跨模态嵌入对齐,我们需要从以下几个基础概念入手:
1. 嵌入(Embeddings)
嵌入是机器学习中的一个基本概念,指的是将高维、离散的数据(如单词、句子、图像、整个文档)转换为低维、连续的向量表示。这些向量捕捉了原始数据的语义或结构信息,使得机器能够进行数值计算和比较。
- 文本嵌入(Text Embeddings):
- 原理:通过深度学习模型(如Transformer架构的BERT、RoBERTa、GPT系列、Sentence-BERT等)处理文本序列,将其转换为固定长度的浮点数向量。这些向量的距离(例如余弦相似度)可以反映文本之间的语义相似性。
- 常用模型:
- Sentence-BERT (SBERT):针对句子级别相似性任务进行优化,生成语义丰富的句子嵌入。
- OpenAI Embeddings (如
text-embedding-ada-002):通过大规模预训练,提供高质量的文本嵌入服务,广泛应用于各种下游任务。 - Cohere Embeddings:与OpenAI类似,提供商用的文本嵌入API。
- 图像嵌入(Image Embeddings):
- 原理:通过卷积神经网络(CNN,如ResNet、VGG)或视觉Transformer(ViT)等深度学习模型处理图像像素数据,将其转换为固定长度的浮点数向量。这些向量的距离反映了图像在内容或风格上的相似性。
- 常用模型:
- ResNet, VGG等CNNs:传统上用于图像分类和特征提取。
- Vision Transformers (ViT):将Transformer架构引入图像处理,在许多视觉任务上取得SOTA。
- 跨模态嵌入(Cross-modal Embeddings):
- 关键:这不是简单地分别生成文本和图像嵌入,而是通过特定的训练策略,使得这两种模态的嵌入在同一个向量空间中具有可比性。
- 代表模型:CLIP (Contrastive Language–Image Pre-training) 是目前最成功的跨模态嵌入模型之一。它通过在大规模的“图像-文本对”数据集上进行对比学习(Contrastive Learning)训练,使得给定一个图像,其嵌入向量与描述该图像的文本嵌入向量距离更近,而与不相关文本的嵌入向量距离更远。反之亦然。
2. 对齐策略与训练目标
实现跨模态嵌入对齐的核心在于训练过程中的目标函数设计。
- 对比学习(Contrastive Learning):
- 原理:给定一个锚点(anchor)样本,将其与一个“正样本”(positive sample,语义相关)拉近,同时与多个“负样本”(negative samples,语义不相关)推远。
- CLIP的实现:CLIP模型训练时,输入是N个图像-文本对。它会计算N个图像嵌入和N个文本嵌入。然后,它会尝试最大化每个图像与其对应的文本之间的余弦相似度,同时最小化该图像与其他N-1个不相关文本之间的余弦相似度。同样,对于每个文本,它会最大化与对应图像的相似度,最小化与不相关图像的相似度。这种对称的对比损失(通常是InfoNCE损失的变体)强制模型学习到一个共享的、语义对齐的潜在空间。
- 其他对齐方法:
- Triplet Loss:需要三元组 (anchor, positive, negative),目标是使anchor与positive的距离小于anchor与negative的距离,并留有一定间隔。
- Joint Embedding/Canonical Correlation Analysis (CCA):尝试找到线性变换,最大化两组多变量数据(如文本特征和图像特征)之间的相关性。在深度学习时代,通常通过非线性映射实现。
3. 相似性度量(Similarity Metrics)
在嵌入空间中,我们通过计算向量之间的距离或相似度来判断其语义关联性。
- 余弦相似度(Cosine Similarity):
- 原理:衡量两个向量方向的相似性,范围在-1到1之间。值越接近1表示方向越一致,语义越相似。
- 优势:对向量的L2范数(长度)不敏感,更关注方向。在许多嵌入模型中,向量通常被归一化到单位长度,此时余弦相似度与点积是等价的。
- 欧几里得距离(Euclidean Distance):
- 原理:衡量两个向量在欧几里得空间中的直线距离。距离越小表示越相似。
- 劣势:对向量的绝对大小敏感。
- 点积(Dot Product):
- 原理:两个向量对应元素乘积之和。当向量是单位向量时,点积与余弦相似度等价。在某些检索系统中,点积被用作相似度度量。
4. 向量数据库(Vector Databases)
当我们需要存储和检索大量的嵌入向量时,传统的数据库无法高效地处理相似性搜索。向量数据库应运而生,它们专门设计用于存储高维向量,并提供高效的近似最近邻(Approximate Nearest Neighbor, ANN)搜索算法。
- 作用:
- 高效检索:通过ANN算法(如Faiss、HNSW、IVF等),即使在数百万甚至数十亿向量中也能快速找到与查询向量最相似的Top-K向量。
- 扩展性:支持大规模数据集和高并发查询。
- 元数据管理:通常允许与向量一起存储额外的元数据(如原始文本、图像路径、标签等),以便检索后能获取完整信息。
- 常用产品:
- ChromaDB:轻量级,易于本地部署和使用,适合小型项目和原型开发。
- Faiss (Facebook AI Similarity Search):一个高性能的C++库,提供了多种ANN算法,Python接口可用。常作为本地向量索引。
- Pinecone, Weaviate, Milvus, Qdrant:云原生或分布式向量数据库,提供更强的扩展性、高可用性和丰富的功能,适合生产环境。
二、LangGraph在联合检索中的作用
LangGraph 是 LangChain 的一个强大扩展,专注于构建健壮、有状态、多步的智能体工作流。它通过有向无环图(DAG)或带有循环的图来定义智能体的行为,使得复杂逻辑、条件分支和工具使用变得易于管理。
1. 为什么选择LangGraph进行跨模态联合检索?
- 流程编排能力:联合检索不仅仅是简单的查询。它可能涉及:
- 识别查询类型(是文本还是图像?)。
- 根据类型选择不同的嵌入模型。
- 并行查询多个模态的向量存储。
- 合并和重排序检索结果。
- 使用LLM对结果进行摘要或推理。
LangGraph能够清晰地定义这些步骤为节点,并用边连接它们,实现复杂的流程控制。
- 状态管理:在多步检索过程中,我们需要在不同节点之间传递信息,例如原始查询、生成的嵌入、检索到的文本片段、图像路径等。LangGraph的图状态(Graph State)机制可以优雅地管理这些信息,确保上下文的连贯性。
- 条件逻辑与动态性:用户查询可以是文本,也可以是图像。LangGraph的条件边(Conditional Edges)允许我们根据查询类型动态地路由到不同的处理分支,例如“如果查询是文本,则调用文本嵌入器;如果是图像,则调用图像嵌入器”。
- 模块化与可扩展性:每个检索或处理步骤都可以封装成一个独立的节点。这意味着我们可以轻松地替换不同的嵌入模型、向量数据库或LLM,而无需修改整个工作流。
- 智能体决策: LangGraph构建的智能体可以根据检索到的跨模态信息做出更明智的决策,例如,如果检索到的图像与文本内容矛盾,智能体可以发起澄清性提问。
2. LangGraph工作流的典型结构
一个LangGraph工作流通常包含以下要素:
- Graph State (图状态):一个字典或对象,定义了在整个工作流中需要共享和传递的信息。
- Nodes (节点):图中的基本处理单元,每个节点执行一个特定的任务(如嵌入、检索、LLM调用等)。节点接收当前图状态作为输入,并返回更新后的状态。
- Edges (边):连接节点,定义了信息的流动路径。
- Conditional Edges (条件边):根据节点返回的特定值,动态选择下一个要执行的节点。
- Entry/Exit Points (入口/出口):定义工作流的开始和结束。
三、在LangGraph中实现文本与图像记忆的联合检索
现在,我们将深入探讨如何在LangGraph中具体实现跨模态嵌入对齐,以进行文本与图像记忆的联合检索。我们将使用CLIP模型进行跨模态嵌入,ChromaDB作为向量存储,并构建一个LangGraph工作流。
1. 环境准备与依赖安装
首先,确保您的Python环境已安装所需库:
pip install torch transformers pillow langchain_community langchain langgraph chromadb sentence_transformers
torch: PyTorch深度学习框架,CLIP模型需要。transformers: Hugging Face Transformers库,用于加载CLIP模型。pillow: Python图像处理库,用于加载和处理图像。langchain_community,langchain: LangChain核心库。langgraph: LangGraph框架。chromadb: 轻量级向量数据库。sentence_transformers: 用于生成独立的文本嵌入(如果需要与CLIP的文本部分区分开,或者使用更专用的文本嵌入模型)。但在此场景下,CLIP的文本编码器本身就能提供对齐的文本嵌入。
2. 核心组件:嵌入模型与向量存储
我们将定义用于生成嵌入的类和函数,并初始化两个ChromaDB集合,一个用于存储文本记忆,一个用于存储图像记忆。
import os
import io
import base64
from typing import List, Dict, Union, Any, Literal
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings # For text if not using CLIP text encoder directly for text-only queries
# --- 1. 定义嵌入模型 ---
class CrossModalEmbedder:
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
"""
初始化CLIP模型和处理器。
CLIP模型用于生成文本和图像的跨模态对齐嵌入。
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(model_name)
print(f"CLIP model '{model_name}' loaded on {self.device}.")
def embed_text(self, texts: Union[str, List[str]]) -> List[List[float]]:
"""
将文本转换为嵌入向量。
"""
if isinstance(texts, str):
texts = [texts]
inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
return text_features.cpu().numpy().tolist()
def embed_image(self, images: Union[Image.Image, List[Image.Image]]) -> List[List[float]]:
"""
将PIL Image对象转换为嵌入向量。
"""
if isinstance(images, Image.Image):
images = [images]
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
return image_features.cpu().numpy().tolist()
def get_embedding_dimension(self) -> int:
"""获取嵌入向量的维度"""
return self.model.config.projection_dim
# 初始化跨模态嵌入器
cross_modal_embedder = CrossModalEmbedder()
# --- 2. 准备数据并创建向量存储 ---
# 模拟一些数据
text_memories_data = [
("巴黎铁塔是法国的标志性建筑,通常在夜晚点亮。", {"source": "wikipedia", "category": "landmark"}),
("一只可爱的猫咪在阳光下打盹,毛发蓬松。", {"source": "photo_album", "category": "animal"}),
("编程是一门艺术,需要逻辑思维和创造力。", {"source": "blog", "category": "technology"}),
("夏日海滩风景优美,适合度假和游泳。", {"source": "travel_guide", "category": "nature"}),
("古老的图书馆藏书丰富,充满了知识的气息。", {"source": "news_article", "category": "education"}),
]
# 模拟一些图像数据(实际应为图像文件路径,此处为占位符)
# 为了演示,我们假设这些路径指向真实图片,并且我们有对应的PIL Image对象
# 实际应用中,你需要加载真实的图片文件
image_paths = [
"images/eiffel_tower.jpg",
"images/cat_sleeping.jpg",
"images/coding_setup.jpg",
"images/beach_sunset.jpg",
"images/library_interior.jpg",
]
# 为了简化演示,我们使用一个简单的函数来模拟图像加载
# 在实际应用中,你需要从文件系统加载图像
def load_dummy_image(path: str) -> Image.Image:
# 这是一个占位符,实际应加载真实图片
# 例如:Image.open(path)
# 这里我们创建一个纯色图片作为演示
print(f"Loading dummy image for path: {path}")
if "eiffel_tower" in path:
return Image.new('RGB', (224, 224), color = 'red') # 模拟巴黎铁塔
elif "cat_sleeping" in path:
return Image.new('RGB', (224, 224), color = 'blue') # 模拟猫咪
elif "coding_setup" in path:
return Image.new('RGB', (224, 224), color = 'green') # 模拟编程
elif "beach_sunset" in path:
return Image.new('RGB', (224, 224), color = 'orange') # 模拟海滩
elif "library_interior" in path:
return Image.new('RGB', (224, 224), color = 'purple') # 模拟图书馆
return Image.new('RGB', (224, 224), color = 'white') # 默认
# 为图像数据创建文档结构,并生成嵌入
image_docs = []
for i, path in enumerate(image_paths):
# 假设我们有图像的简短描述作为元数据
description = text_memories_data[i][0] # 使用对应的文本描述作为图像的元数据
image_content = load_dummy_image(path) # 模拟加载图像
image_embedding = cross_modal_embedder.embed_image(image_content)[0]
# 将图像路径和描述存储为元数据
# 注意:ChromaDB不支持直接存储PIL Image对象。
# 我们存储嵌入和元数据,元数据中包含图像路径,检索时再根据路径加载图像。
image_docs.append(Document(
page_content=description, # 存储描述作为主要内容
metadata={
"image_path": path,
"original_description": description, # 原始描述
"category": text_memories_data[i][1]["category"] # 其他元数据
}
))
# 为文本数据创建文档结构
text_docs = [Document(page_content=text, metadata=meta) for text, meta in text_memories_data]
# ChromaDB的嵌入函数需要一个可调用的对象,它接受文本列表并返回嵌入列表
# 为此,我们封装cross_modal_embedder的embed_text方法
class ChromaEmbeddingsAdapter:
def __init__(self, embedder: CrossModalEmbedder):
self.embedder = embedder
self.dimension = embedder.get_embedding_dimension()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embedder.embed_text(texts)
def embed_query(self, text: str) -> List[float]:
return self.embedder.embed_text([text])[0]
def __call__(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
chroma_embedder_adapter = ChromaEmbeddingsAdapter(cross_modal_embedder)
# 初始化ChromaDB向量存储
# Text Memory Store
text_vector_store = Chroma.from_documents(
documents=text_docs,
embedding=chroma_embedder_adapter,
collection_name="text_memories",
persist_directory="./chroma_db"
)
print(f"Text memories stored. Count: {text_vector_store._collection.count()}")
# Image Memory Store (存储图像嵌入,元数据包含图像路径和描述)
# 这里我们手动添加嵌入,因为from_documents默认只对page_content进行嵌入
image_vector_store = Chroma(
client_settings=text_vector_store._client_settings, # 使用相同的客户端设置
collection_name="image_memories",
embedding_function=chroma_embedder_adapter, # 必须提供embedding_function
persist_directory="./chroma_db"
)
# 逐个添加图像文档和其对应的嵌入
# 注意:Chroma.add_documents 方法可以接收自定义嵌入。
# 但对于我们已经生成好嵌入的情况,直接使用add_embeddings更直接。
# 这里为了与from_documents的流程保持一致,我们手动构造ids, embeddings, metadatas, documents
image_ids = [f"image_doc_{i}" for i in range(len(image_docs))]
image_embeddings = [cross_modal_embedder.embed_image(load_dummy_image(doc.metadata["image_path"]))[0] for doc in image_docs]
image_metadatas = [doc.metadata for doc in image_docs]
image_contents = [doc.page_content for doc in image_docs] # page_content is the description
image_vector_store._collection.add(
embeddings=image_embeddings,
metadatas=image_metadatas,
documents=image_contents, # documents here refers to the actual text content stored in Chroma
ids=image_ids
)
print(f"Image memories stored. Count: {image_vector_store._collection.count()}")
# 清理持久化目录
# import shutil
# if os.path.exists("./chroma_db"):
# shutil.rmtree("./chroma_db")
代码说明:
CrossModalEmbedder类:封装了CLIP模型。embed_text和embed_image方法分别用于生成文本和图像的嵌入。这两个方法生成的嵌入位于同一个语义空间中。- 数据准备:我们创建了模拟的文本描述和图像路径。关键点:对于图像记忆,我们存储的是图像的描述(
page_content)和图像路径作为元数据。当检索到图像记忆时,我们实际上是检索到了其描述和路径,再根据路径加载图像。 ChromaEmbeddingsAdapter:ChromaDB需要一个实现了embed_documents和embed_query方法的Embeddings类实例。我们创建了一个适配器,将CrossModalEmbedder的方法包装起来,使其符合ChromaDB的要求。- 向量存储:
text_vector_store:存储文本描述的嵌入。image_vector_store:存储图像的嵌入。即使查询是文本,由于CLIP的对齐,我们也能用文本查询在image_vector_store中找到相关图像。
3. LangGraph工作流设计
接下来,我们将使用LangGraph构建一个智能体,该智能体能够处理文本或图像查询,并从文本和图像记忆中联合检索。
Graph State (图状态定义)
from typing import TypedDict, Annotated, List, Literal
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
# 定义图的状态
class AgentState(TypedDict):
query: Union[str, bytes] # 用户查询,可以是文本或图像的base64编码
query_type: Literal["text", "image"] # 查询类型
text_query_embedding: List[float] # 文本查询的嵌入
image_query_embedding: List[float] # 图像查询的嵌入
retrieved_text_docs: Annotated[List[Document], add_messages] # 检索到的文本文档
retrieved_image_docs: Annotated[List[Document], add_messages] # 检索到的图像文档
final_response: str # 最终响应
error_message: str # 错误信息
节点(Nodes)定义
# 假设我们有一个LLM客户端
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4", temperature=0) # 实际应用中使用
# 简化LLM部分,直接返回拼接结果
class MockLLM:
def invoke(self, prompt: str) -> str:
return f"LLM processed: {prompt}"
mock_llm = MockLLM()
# --- LangGraph 节点定义 ---
def determine_query_type(state: AgentState) -> AgentState:
"""
根据查询内容判断是文本查询还是图像查询。
图像查询假定为base64编码的字符串或bytes。
"""
query = state["query"]
if isinstance(query, bytes):
# 假设bytes是图像的base64解码后的原始字节数据
# 或直接判断如果 query 是一个文件路径字符串,我们也可以认为是图像查询
print("Determined query type: image (from bytes).")
return {"query_type": "image"}
elif isinstance(query, str) and (query.startswith("data:image/") or query.endswith((".jpg", ".png", ".jpeg", ".gif"))):
# 如果是data URI 或者文件路径
print(f"Determined query type: image (from string path/uri). Query: {query[:50]}...")
return {"query_type": "image"}
else:
print(f"Determined query type: text. Query: {query[:50]}...")
return {"query_type": "text"}
def embed_text_query(state: AgentState) -> AgentState:
"""
对文本查询进行嵌入。
"""
query = state["query"]
print(f"Embedding text query: {query[:50]}...")
embedding = cross_modal_embedder.embed_text(query)[0]
return {"text_query_embedding": embedding}
def embed_image_query(state: AgentState) -> AgentState:
"""
对图像查询进行嵌入。
"""
query = state["query"]
try:
if isinstance(query, bytes):
image = Image.open(io.BytesIO(query))
elif isinstance(query, str):
if query.startswith("data:image/"):
# 假设是data URI,例如 "data:image/jpeg;base64,..."
header, base64_str = query.split(",", 1)
image_bytes = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_bytes))
else:
# 假设是文件路径,加载本地图像
image = Image.open(query)
else:
raise ValueError("Unsupported image query format.")
print(f"Embedding image query (size: {image.size})...")
embedding = cross_modal_embedder.embed_image(image)[0]
return {"image_query_embedding": embedding}
except Exception as e:
print(f"Error embedding image: {e}")
return {"error_message": f"Failed to embed image: {e}"}
def retrieve_memories(state: AgentState) -> AgentState:
"""
根据可用的嵌入(文本或图像)从两个向量存储中检索记忆。
由于CLIP的跨模态对齐,无论是文本查询嵌入还是图像查询嵌入,
都可以在两个模态的向量存储中进行检索。
"""
text_query_embedding = state.get("text_query_embedding")
image_query_embedding = state.get("image_query_embedding")
# 优先使用图像查询嵌入,如果没有则使用文本查询嵌入
query_embedding = image_query_embedding if image_query_embedding else text_query_embedding
if not query_embedding:
print("No embedding available for retrieval.")
return {"error_message": "No query embedding generated."}
print("Retrieving text memories...")
# 从文本向量存储中检索
retrieved_texts = text_vector_store.similarity_search_by_vector(query_embedding, k=3)
print("Retrieving image memories...")
# 从图像向量存储中检索
retrieved_images = image_vector_store.similarity_search_by_vector(query_embedding, k=3)
print(f"Retrieved {len(retrieved_texts)} text docs and {len(retrieved_images)} image docs.")
return {
"retrieved_text_docs": retrieved_texts,
"retrieved_image_docs": retrieved_images
}
def synthesize_response(state: AgentState) -> AgentState:
"""
使用LLM综合检索到的信息,生成最终响应。
"""
query = state["query"]
retrieved_texts = state.get("retrieved_text_docs", [])
retrieved_images = state.get("retrieved_image_docs", [])
prompt_parts = [f"用户查询: {query}"]
if retrieved_texts:
prompt_parts.append("nn相关文本记忆:")
for i, doc in enumerate(retrieved_texts):
prompt_parts.append(f" {i+1}. 内容: {doc.page_content} (来源: {doc.metadata.get('source', '未知')})")
if retrieved_images:
prompt_parts.append("nn相关图像记忆:")
for i, doc in enumerate(retrieved_images):
# 这里的doc.page_content是图像的描述
# doc.metadata['image_path']是图像路径,可以用来加载实际图像
prompt_parts.append(f" {i+1}. 描述: {doc.page_content} (路径: {doc.metadata.get('image_path', '未知')})")
# 实际应用中,这里可以加载图像并显示,或进行进一步分析
# Image.open(doc.metadata['image_path']).show()
prompt_parts.append("nn请根据以上信息,为用户生成一个全面的回答。")
full_prompt = "n".join(prompt_parts)
print("Synthesizing response with LLM...")
final_response = mock_llm.invoke(full_prompt) # 替换为真实的LLM调用
return {"final_response": final_response}
def handle_error(state: AgentState) -> AgentState:
"""
处理工作流中的错误。
"""
error = state.get("error_message", "未知错误")
print(f"Error occurred: {error}")
return {"final_response": f"抱歉,处理您的请求时发生错误:{error}"}
构建LangGraph
# 构建图
workflow = StateGraph(AgentState)
# 添加节点
workflow.add_node("determine_query_type", determine_query_type)
workflow.add_node("embed_text_query", embed_text_query)
workflow.add_node("embed_image_query", embed_image_query)
workflow.add_node("retrieve_memories", retrieve_memories)
workflow.add_node("synthesize_response", synthesize_response)
workflow.add_node("handle_error", handle_error) # 错误处理节点
# 设置入口点
workflow.set_entry_point("determine_query_type")
# 添加条件边
workflow.add_conditional_edges(
"determine_query_type",
lambda state: state["query_type"], # 根据query_type进行路由
{
"text": "embed_text_query",
"image": "embed_image_query",
}
)
# 嵌入节点指向检索节点
workflow.add_edge("embed_text_query", "retrieve_memories")
# 图像嵌入节点需要检查是否有错误
workflow.add_conditional_edges(
"embed_image_query",
lambda state: "handle_error" if state.get("error_message") else "retrieve_memories",
{
"retrieve_memories": "retrieve_memories",
"handle_error": "handle_error",
}
)
# 检索节点指向综合响应节点
workflow.add_edge("retrieve_memories", "synthesize_response")
# 综合响应节点指向结束
workflow.add_edge("synthesize_response", END)
# 错误处理节点指向结束
workflow.add_edge("handle_error", END)
# 编译图
app = workflow.compile()
print("LangGraph workflow compiled successfully.")
# --- 运行LangGraph ---
# 1. 文本查询示例
print("n--- 文本查询示例 ---")
text_query = "关于建筑和风景的记忆"
result_text_query = app.invoke({"query": text_query})
print("n最终响应:")
print(result_text_query["final_response"])
# 2. 图像查询示例 (模拟图像文件路径)
print("n--- 图像查询示例 (通过路径) ---")
image_query_path = "images/eiffel_tower.jpg" # 假定存在这个图片文件
result_image_query_path = app.invoke({"query": image_query_path})
print("n最终响应:")
print(result_image_query_path["final_response"])
# 3. 图像查询示例 (模拟base64编码的图像数据)
print("n--- 图像查询示例 (通过base64编码) ---")
# 实际中你需要将图片转换为base64编码的字符串
# 这里我们创建一个简单的占位符图片并编码
dummy_img = Image.new('RGB', (100, 100), color = 'yellow')
buffered = io.BytesIO()
dummy_img.save(buffered, format="PNG")
base64_img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_image_query = f"data:image/png;base64,{base64_img_str}" # 完整的data URI
result_image_query_base64 = app.invoke({"query": base64_image_query})
print("n最终响应:")
print(result_image_query_base64["final_response"])
# 4. 包含不存在的图像路径的查询(模拟错误)
print("n--- 错误处理示例 (图像路径不存在) ---")
non_existent_image_path = "images/non_existent.jpg"
result_error_query = app.invoke({"query": non_existent_image_path})
print("n最终响应:")
print(result_error_query["final_response"])
工作流解释:
AgentState: 定义了工作流中传递的所有状态变量,包括查询、嵌入、检索结果和最终响应。Annotated[List[Document], add_messages]是LangGraph的语法,表示这是一个列表,并且每次更新时会添加到现有列表而不是完全覆盖(对于检索结果通常是合适的)。determine_query_type节点: 这是入口节点。它检查query的类型。如果query是bytes类型(通常是图像文件的原始字节数据)或者以data:image/开头或以常见图片扩展名结尾的字符串,则认为是图像查询;否则认为是文本查询。embed_text_query/embed_image_query节点: 根据query_type的结果,流程会被路由到这两个节点之一。它们分别调用cross_modal_embedder的相应方法来生成查询嵌入。图像嵌入节点还包含了基本的错误处理,例如当图像无法加载时。retrieve_memories节点: 这是跨模态联合检索的核心。 无论查询是文本还是图像,我们都使用 CLIP 生成的嵌入(text_query_embedding或image_query_embedding)去同时查询text_vector_store和image_vector_store。由于 CLIP 训练时就对齐了两种模态的嵌入,所以一个文本查询的嵌入可以找到相关的图像,一个图像查询的嵌入也可以找到相关的文本。synthesize_response节点: 这个节点负责收集所有检索到的文本文档和图像文档的元数据(主要是描述和路径),然后构建一个详细的提示,将其发送给一个(模拟的)LLM。LLM将综合这些信息,生成最终的用户响应。在实际应用中,这里会集成如OpenAI GPT-4等真实的LLM。handle_error节点: 当embed_image_query节点发生错误时,流程会转向此节点,提供一个友好的错误信息。- 图的构建: 通过
StateGraph定义节点、设置入口、添加边和条件边,最终编译成可执行的app。条件边使得工作流具有动态决策能力。
四、高级考量与未来展望
1. 结果重排序(Re-ranking)
简单的相似性搜索可能返回一些相关性不高的结果。可以通过以下方法改进:
- 多模态融合重排:结合文本相似度和图像相似度得分进行加权平均。
- 跨编码器(Cross-encoders):对于检索到的Top-K结果,使用一个专门的跨模态模型(如CLIP的完整模型或更复杂的匹配模型)来计算查询与每个检索结果之间的精确匹配分数,然后重新排序。
- LLM重排:让LLM阅读查询和检索到的文档/图像描述,然后根据语义相关性进行排序。
2. 多跳检索与推理
当前的流程是单次检索。更高级的智能体可以进行多跳(multi-hop)检索:
- 首次检索提供初步信息。
- LLM根据初步信息生成新的查询,或者从检索结果中提取关键词进行再次检索。
- 结合多轮检索结果进行更深入的推理。
3. 用户反馈与持续学习
系统可以通过收集用户对检索结果的反馈(例如,“这个图片相关吗?”)来持续改进。这些反馈可以用于:
- 调整嵌入模型的训练。
- 优化相似度度量的权重。
- 改进LLM的提示工程。
4. 性能与可扩展性
- 向量数据库选择:对于大规模应用,应考虑使用Pinecone、Weaviate、Milvus等分布式向量数据库。
- 嵌入模型优化:选择计算效率高且性能良好的嵌入模型,或对模型进行量化、蒸馏等优化。
- LangGraph的并发性:LangGraph本身支持异步操作,可以在节点中集成异步函数以提高吞吐量。
5. 伦理与偏见
- 训练数据偏见:CLIP等预训练模型在大量数据上训练,可能继承数据中的社会偏见。这可能导致检索结果带有刻板印象或不公平。在使用时需要警惕并采取缓解措施。
- 隐私:存储和处理用户数据(尤其是图像)时,必须严格遵守隐私法规。
五、迈向多模态智能的必由之路
跨模态嵌入对齐是构建真正智能、能够理解和推理世界的AI系统的关键一步。它使得机器能够以更接近人类的方式感知和处理信息,打破了模态间的壁垒。LangGraph为这种复杂的、多步骤的跨模态交互提供了一个优雅的编排框架,使得我们可以将强大的深度学习模型和灵活的智能体逻辑结合起来。随着多模态AI技术的不断发展,我们可以预见,未来的智能体将能够更自然地理解我们的世界,实现更丰富、更智能的交互。