什么是 ‘Cross-modal Embedding Alignment’:在 LangGraph 中如何实现文本记忆与图像记忆的联合检索?

跨模态嵌入对齐:在LangGraph中实现文本与图像记忆的联合检索

在人工智能领域,我们正在从单一模态的理解走向多模态的融合。传统上,我们处理文本时使用文本模型,处理图像时使用图像模型,它们各自在自己的领域内表现出色。然而,人类的认知并非如此割裂,我们通过语言描述图像,通过图像理解语言,这是一种天然的跨模态交互。

“跨模态嵌入对齐”(Cross-modal Embedding Alignment)正是为了弥合这种模态间的鸿沟而生。它的核心思想是将来自不同模态(如文本、图像、音频、视频等)的数据映射到一个共同的、低维的向量空间中。在这个共享的潜在空间里,语义上相似的文本和图像(或其它模态数据)其对应的向量表示会彼此靠近,而语义上不相关的向量则会相互远离。这种对齐使得我们能够用一种模态的查询(例如一段文本描述)去检索另一种模态的数据(例如相关的图像),反之亦然,甚至能够实现模态间的联合检索和推理。

在复杂的AI系统中,特别是那些需要模拟人类认知和记忆的智能体(Agents)中,联合检索能力至关重要。一个智能体需要能够根据用户的文本描述,回忆起相关的文本知识点,同时也能联想到相关的视觉记忆。LangGraph,作为LangChain的强大扩展,提供了一种灵活且强大的框架来构建有状态、多步、具有循环和条件逻辑的智能体工作流。将跨模态嵌入对齐技术融入LangGraph,可以为智能体赋予更丰富、更接近人类的记忆和感知能力。

一、跨模态嵌入对齐的核心概念与技术基石

要理解和实现跨模态嵌入对齐,我们需要从以下几个基础概念入手:

1. 嵌入(Embeddings)

嵌入是机器学习中的一个基本概念,指的是将高维、离散的数据(如单词、句子、图像、整个文档)转换为低维、连续的向量表示。这些向量捕捉了原始数据的语义或结构信息,使得机器能够进行数值计算和比较。

  • 文本嵌入(Text Embeddings)
    • 原理:通过深度学习模型(如Transformer架构的BERT、RoBERTa、GPT系列、Sentence-BERT等)处理文本序列,将其转换为固定长度的浮点数向量。这些向量的距离(例如余弦相似度)可以反映文本之间的语义相似性。
    • 常用模型
      • Sentence-BERT (SBERT):针对句子级别相似性任务进行优化,生成语义丰富的句子嵌入。
      • OpenAI Embeddings (如text-embedding-ada-002):通过大规模预训练,提供高质量的文本嵌入服务,广泛应用于各种下游任务。
      • Cohere Embeddings:与OpenAI类似,提供商用的文本嵌入API。
  • 图像嵌入(Image Embeddings)
    • 原理:通过卷积神经网络(CNN,如ResNet、VGG)或视觉Transformer(ViT)等深度学习模型处理图像像素数据,将其转换为固定长度的浮点数向量。这些向量的距离反映了图像在内容或风格上的相似性。
    • 常用模型
      • ResNet, VGG等CNNs:传统上用于图像分类和特征提取。
      • Vision Transformers (ViT):将Transformer架构引入图像处理,在许多视觉任务上取得SOTA。
  • 跨模态嵌入(Cross-modal Embeddings)
    • 关键:这不是简单地分别生成文本和图像嵌入,而是通过特定的训练策略,使得这两种模态的嵌入在同一个向量空间中具有可比性。
    • 代表模型CLIP (Contrastive Language–Image Pre-training) 是目前最成功的跨模态嵌入模型之一。它通过在大规模的“图像-文本对”数据集上进行对比学习(Contrastive Learning)训练,使得给定一个图像,其嵌入向量与描述该图像的文本嵌入向量距离更近,而与不相关文本的嵌入向量距离更远。反之亦然。
2. 对齐策略与训练目标

实现跨模态嵌入对齐的核心在于训练过程中的目标函数设计。

  • 对比学习(Contrastive Learning)
    • 原理:给定一个锚点(anchor)样本,将其与一个“正样本”(positive sample,语义相关)拉近,同时与多个“负样本”(negative samples,语义不相关)推远。
    • CLIP的实现:CLIP模型训练时,输入是N个图像-文本对。它会计算N个图像嵌入和N个文本嵌入。然后,它会尝试最大化每个图像与其对应的文本之间的余弦相似度,同时最小化该图像与其他N-1个不相关文本之间的余弦相似度。同样,对于每个文本,它会最大化与对应图像的相似度,最小化与不相关图像的相似度。这种对称的对比损失(通常是InfoNCE损失的变体)强制模型学习到一个共享的、语义对齐的潜在空间。
  • 其他对齐方法
    • Triplet Loss:需要三元组 (anchor, positive, negative),目标是使anchor与positive的距离小于anchor与negative的距离,并留有一定间隔。
    • Joint Embedding/Canonical Correlation Analysis (CCA):尝试找到线性变换,最大化两组多变量数据(如文本特征和图像特征)之间的相关性。在深度学习时代,通常通过非线性映射实现。
3. 相似性度量(Similarity Metrics)

在嵌入空间中,我们通过计算向量之间的距离或相似度来判断其语义关联性。

  • 余弦相似度(Cosine Similarity)
    • 原理:衡量两个向量方向的相似性,范围在-1到1之间。值越接近1表示方向越一致,语义越相似。
    • 优势:对向量的L2范数(长度)不敏感,更关注方向。在许多嵌入模型中,向量通常被归一化到单位长度,此时余弦相似度与点积是等价的。
  • 欧几里得距离(Euclidean Distance)
    • 原理:衡量两个向量在欧几里得空间中的直线距离。距离越小表示越相似。
    • 劣势:对向量的绝对大小敏感。
  • 点积(Dot Product)
    • 原理:两个向量对应元素乘积之和。当向量是单位向量时,点积与余弦相似度等价。在某些检索系统中,点积被用作相似度度量。
4. 向量数据库(Vector Databases)

当我们需要存储和检索大量的嵌入向量时,传统的数据库无法高效地处理相似性搜索。向量数据库应运而生,它们专门设计用于存储高维向量,并提供高效的近似最近邻(Approximate Nearest Neighbor, ANN)搜索算法。

  • 作用
    • 高效检索:通过ANN算法(如Faiss、HNSW、IVF等),即使在数百万甚至数十亿向量中也能快速找到与查询向量最相似的Top-K向量。
    • 扩展性:支持大规模数据集和高并发查询。
    • 元数据管理:通常允许与向量一起存储额外的元数据(如原始文本、图像路径、标签等),以便检索后能获取完整信息。
  • 常用产品
    • ChromaDB:轻量级,易于本地部署和使用,适合小型项目和原型开发。
    • Faiss (Facebook AI Similarity Search):一个高性能的C++库,提供了多种ANN算法,Python接口可用。常作为本地向量索引。
    • Pinecone, Weaviate, Milvus, Qdrant:云原生或分布式向量数据库,提供更强的扩展性、高可用性和丰富的功能,适合生产环境。

二、LangGraph在联合检索中的作用

LangGraph 是 LangChain 的一个强大扩展,专注于构建健壮、有状态、多步的智能体工作流。它通过有向无环图(DAG)或带有循环的图来定义智能体的行为,使得复杂逻辑、条件分支和工具使用变得易于管理。

1. 为什么选择LangGraph进行跨模态联合检索?
  • 流程编排能力:联合检索不仅仅是简单的查询。它可能涉及:
    • 识别查询类型(是文本还是图像?)。
    • 根据类型选择不同的嵌入模型。
    • 并行查询多个模态的向量存储。
    • 合并和重排序检索结果。
    • 使用LLM对结果进行摘要或推理。
      LangGraph能够清晰地定义这些步骤为节点,并用边连接它们,实现复杂的流程控制。
  • 状态管理:在多步检索过程中,我们需要在不同节点之间传递信息,例如原始查询、生成的嵌入、检索到的文本片段、图像路径等。LangGraph的图状态(Graph State)机制可以优雅地管理这些信息,确保上下文的连贯性。
  • 条件逻辑与动态性:用户查询可以是文本,也可以是图像。LangGraph的条件边(Conditional Edges)允许我们根据查询类型动态地路由到不同的处理分支,例如“如果查询是文本,则调用文本嵌入器;如果是图像,则调用图像嵌入器”。
  • 模块化与可扩展性:每个检索或处理步骤都可以封装成一个独立的节点。这意味着我们可以轻松地替换不同的嵌入模型、向量数据库或LLM,而无需修改整个工作流。
  • 智能体决策: LangGraph构建的智能体可以根据检索到的跨模态信息做出更明智的决策,例如,如果检索到的图像与文本内容矛盾,智能体可以发起澄清性提问。
2. LangGraph工作流的典型结构

一个LangGraph工作流通常包含以下要素:

  • Graph State (图状态):一个字典或对象,定义了在整个工作流中需要共享和传递的信息。
  • Nodes (节点):图中的基本处理单元,每个节点执行一个特定的任务(如嵌入、检索、LLM调用等)。节点接收当前图状态作为输入,并返回更新后的状态。
  • Edges (边):连接节点,定义了信息的流动路径。
  • Conditional Edges (条件边):根据节点返回的特定值,动态选择下一个要执行的节点。
  • Entry/Exit Points (入口/出口):定义工作流的开始和结束。

三、在LangGraph中实现文本与图像记忆的联合检索

现在,我们将深入探讨如何在LangGraph中具体实现跨模态嵌入对齐,以进行文本与图像记忆的联合检索。我们将使用CLIP模型进行跨模态嵌入,ChromaDB作为向量存储,并构建一个LangGraph工作流。

1. 环境准备与依赖安装

首先,确保您的Python环境已安装所需库:

pip install torch transformers pillow langchain_community langchain langgraph chromadb sentence_transformers
  • torch: PyTorch深度学习框架,CLIP模型需要。
  • transformers: Hugging Face Transformers库,用于加载CLIP模型。
  • pillow: Python图像处理库,用于加载和处理图像。
  • langchain_community, langchain: LangChain核心库。
  • langgraph: LangGraph框架。
  • chromadb: 轻量级向量数据库。
  • sentence_transformers: 用于生成独立的文本嵌入(如果需要与CLIP的文本部分区分开,或者使用更专用的文本嵌入模型)。但在此场景下,CLIP的文本编码器本身就能提供对齐的文本嵌入。
2. 核心组件:嵌入模型与向量存储

我们将定义用于生成嵌入的类和函数,并初始化两个ChromaDB集合,一个用于存储文本记忆,一个用于存储图像记忆。

import os
import io
import base64
from typing import List, Dict, Union, Any, Literal

import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings # For text if not using CLIP text encoder directly for text-only queries

# --- 1. 定义嵌入模型 ---
class CrossModalEmbedder:
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        初始化CLIP模型和处理器。
        CLIP模型用于生成文本和图像的跨模态对齐嵌入。
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        print(f"CLIP model '{model_name}' loaded on {self.device}.")

    def embed_text(self, texts: Union[str, List[str]]) -> List[List[float]]:
        """
        将文本转换为嵌入向量。
        """
        if isinstance(texts, str):
            texts = [texts]

        inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            text_features = self.model.get_text_features(**inputs)
        return text_features.cpu().numpy().tolist()

    def embed_image(self, images: Union[Image.Image, List[Image.Image]]) -> List[List[float]]:
        """
        将PIL Image对象转换为嵌入向量。
        """
        if isinstance(images, Image.Image):
            images = [images]

        inputs = self.processor(images=images, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            image_features = self.model.get_image_features(**inputs)
        return image_features.cpu().numpy().tolist()

    def get_embedding_dimension(self) -> int:
        """获取嵌入向量的维度"""
        return self.model.config.projection_dim

# 初始化跨模态嵌入器
cross_modal_embedder = CrossModalEmbedder()

# --- 2. 准备数据并创建向量存储 ---

# 模拟一些数据
text_memories_data = [
    ("巴黎铁塔是法国的标志性建筑,通常在夜晚点亮。", {"source": "wikipedia", "category": "landmark"}),
    ("一只可爱的猫咪在阳光下打盹,毛发蓬松。", {"source": "photo_album", "category": "animal"}),
    ("编程是一门艺术,需要逻辑思维和创造力。", {"source": "blog", "category": "technology"}),
    ("夏日海滩风景优美,适合度假和游泳。", {"source": "travel_guide", "category": "nature"}),
    ("古老的图书馆藏书丰富,充满了知识的气息。", {"source": "news_article", "category": "education"}),
]

# 模拟一些图像数据(实际应为图像文件路径,此处为占位符)
# 为了演示,我们假设这些路径指向真实图片,并且我们有对应的PIL Image对象
# 实际应用中,你需要加载真实的图片文件
image_paths = [
    "images/eiffel_tower.jpg",
    "images/cat_sleeping.jpg",
    "images/coding_setup.jpg",
    "images/beach_sunset.jpg",
    "images/library_interior.jpg",
]

# 为了简化演示,我们使用一个简单的函数来模拟图像加载
# 在实际应用中,你需要从文件系统加载图像
def load_dummy_image(path: str) -> Image.Image:
    # 这是一个占位符,实际应加载真实图片
    # 例如:Image.open(path)
    # 这里我们创建一个纯色图片作为演示
    print(f"Loading dummy image for path: {path}")
    if "eiffel_tower" in path:
        return Image.new('RGB', (224, 224), color = 'red') # 模拟巴黎铁塔
    elif "cat_sleeping" in path:
        return Image.new('RGB', (224, 224), color = 'blue') # 模拟猫咪
    elif "coding_setup" in path:
        return Image.new('RGB', (224, 224), color = 'green') # 模拟编程
    elif "beach_sunset" in path:
        return Image.new('RGB', (224, 224), color = 'orange') # 模拟海滩
    elif "library_interior" in path:
        return Image.new('RGB', (224, 224), color = 'purple') # 模拟图书馆
    return Image.new('RGB', (224, 224), color = 'white') # 默认

# 为图像数据创建文档结构,并生成嵌入
image_docs = []
for i, path in enumerate(image_paths):
    # 假设我们有图像的简短描述作为元数据
    description = text_memories_data[i][0] # 使用对应的文本描述作为图像的元数据
    image_content = load_dummy_image(path) # 模拟加载图像
    image_embedding = cross_modal_embedder.embed_image(image_content)[0]

    # 将图像路径和描述存储为元数据
    # 注意:ChromaDB不支持直接存储PIL Image对象。
    # 我们存储嵌入和元数据,元数据中包含图像路径,检索时再根据路径加载图像。
    image_docs.append(Document(
        page_content=description, # 存储描述作为主要内容
        metadata={
            "image_path": path, 
            "original_description": description, # 原始描述
            "category": text_memories_data[i][1]["category"] # 其他元数据
        }
    ))

# 为文本数据创建文档结构
text_docs = [Document(page_content=text, metadata=meta) for text, meta in text_memories_data]

# ChromaDB的嵌入函数需要一个可调用的对象,它接受文本列表并返回嵌入列表
# 为此,我们封装cross_modal_embedder的embed_text方法
class ChromaEmbeddingsAdapter:
    def __init__(self, embedder: CrossModalEmbedder):
        self.embedder = embedder
        self.dimension = embedder.get_embedding_dimension()

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.embedder.embed_text(texts)

    def embed_query(self, text: str) -> List[float]:
        return self.embedder.embed_text([text])[0]

    def __call__(self, texts: List[str]) -> List[List[float]]:
        return self.embed_documents(texts)

chroma_embedder_adapter = ChromaEmbeddingsAdapter(cross_modal_embedder)

# 初始化ChromaDB向量存储
# Text Memory Store
text_vector_store = Chroma.from_documents(
    documents=text_docs,
    embedding=chroma_embedder_adapter,
    collection_name="text_memories",
    persist_directory="./chroma_db"
)
print(f"Text memories stored. Count: {text_vector_store._collection.count()}")

# Image Memory Store (存储图像嵌入,元数据包含图像路径和描述)
# 这里我们手动添加嵌入,因为from_documents默认只对page_content进行嵌入
image_vector_store = Chroma(
    client_settings=text_vector_store._client_settings, # 使用相同的客户端设置
    collection_name="image_memories",
    embedding_function=chroma_embedder_adapter, # 必须提供embedding_function
    persist_directory="./chroma_db"
)

# 逐个添加图像文档和其对应的嵌入
# 注意:Chroma.add_documents 方法可以接收自定义嵌入。
# 但对于我们已经生成好嵌入的情况,直接使用add_embeddings更直接。
# 这里为了与from_documents的流程保持一致,我们手动构造ids, embeddings, metadatas, documents
image_ids = [f"image_doc_{i}" for i in range(len(image_docs))]
image_embeddings = [cross_modal_embedder.embed_image(load_dummy_image(doc.metadata["image_path"]))[0] for doc in image_docs]
image_metadatas = [doc.metadata for doc in image_docs]
image_contents = [doc.page_content for doc in image_docs] # page_content is the description

image_vector_store._collection.add(
    embeddings=image_embeddings,
    metadatas=image_metadatas,
    documents=image_contents, # documents here refers to the actual text content stored in Chroma
    ids=image_ids
)

print(f"Image memories stored. Count: {image_vector_store._collection.count()}")

# 清理持久化目录
# import shutil
# if os.path.exists("./chroma_db"):
#     shutil.rmtree("./chroma_db")

代码说明:

  1. CrossModalEmbedder:封装了CLIP模型。embed_textembed_image 方法分别用于生成文本和图像的嵌入。这两个方法生成的嵌入位于同一个语义空间中。
  2. 数据准备:我们创建了模拟的文本描述和图像路径。关键点:对于图像记忆,我们存储的是图像的描述page_content)和图像路径作为元数据。当检索到图像记忆时,我们实际上是检索到了其描述和路径,再根据路径加载图像。
  3. ChromaEmbeddingsAdapter:ChromaDB需要一个实现了embed_documentsembed_query方法的Embeddings类实例。我们创建了一个适配器,将CrossModalEmbedder的方法包装起来,使其符合ChromaDB的要求。
  4. 向量存储
    • text_vector_store:存储文本描述的嵌入。
    • image_vector_store:存储图像的嵌入。即使查询是文本,由于CLIP的对齐,我们也能用文本查询在image_vector_store中找到相关图像。
3. LangGraph工作流设计

接下来,我们将使用LangGraph构建一个智能体,该智能体能够处理文本或图像查询,并从文本和图像记忆中联合检索。

Graph State (图状态定义)

from typing import TypedDict, Annotated, List, Literal
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

# 定义图的状态
class AgentState(TypedDict):
    query: Union[str, bytes] # 用户查询,可以是文本或图像的base64编码
    query_type: Literal["text", "image"] # 查询类型
    text_query_embedding: List[float] # 文本查询的嵌入
    image_query_embedding: List[float] # 图像查询的嵌入
    retrieved_text_docs: Annotated[List[Document], add_messages] # 检索到的文本文档
    retrieved_image_docs: Annotated[List[Document], add_messages] # 检索到的图像文档
    final_response: str # 最终响应
    error_message: str # 错误信息

节点(Nodes)定义

# 假设我们有一个LLM客户端
# from langchain_openai import ChatOpenAI
# llm = ChatOpenAI(model="gpt-4", temperature=0) # 实际应用中使用
# 简化LLM部分,直接返回拼接结果
class MockLLM:
    def invoke(self, prompt: str) -> str:
        return f"LLM processed: {prompt}"
mock_llm = MockLLM()

# --- LangGraph 节点定义 ---

def determine_query_type(state: AgentState) -> AgentState:
    """
    根据查询内容判断是文本查询还是图像查询。
    图像查询假定为base64编码的字符串或bytes。
    """
    query = state["query"]
    if isinstance(query, bytes):
        # 假设bytes是图像的base64解码后的原始字节数据
        # 或直接判断如果 query 是一个文件路径字符串,我们也可以认为是图像查询
        print("Determined query type: image (from bytes).")
        return {"query_type": "image"}
    elif isinstance(query, str) and (query.startswith("data:image/") or query.endswith((".jpg", ".png", ".jpeg", ".gif"))):
         # 如果是data URI 或者文件路径
        print(f"Determined query type: image (from string path/uri). Query: {query[:50]}...")
        return {"query_type": "image"}
    else:
        print(f"Determined query type: text. Query: {query[:50]}...")
        return {"query_type": "text"}

def embed_text_query(state: AgentState) -> AgentState:
    """
    对文本查询进行嵌入。
    """
    query = state["query"]
    print(f"Embedding text query: {query[:50]}...")
    embedding = cross_modal_embedder.embed_text(query)[0]
    return {"text_query_embedding": embedding}

def embed_image_query(state: AgentState) -> AgentState:
    """
    对图像查询进行嵌入。
    """
    query = state["query"]
    try:
        if isinstance(query, bytes):
            image = Image.open(io.BytesIO(query))
        elif isinstance(query, str):
            if query.startswith("data:image/"):
                # 假设是data URI,例如 "data:image/jpeg;base64,..."
                header, base64_str = query.split(",", 1)
                image_bytes = base64.b64decode(base64_str)
                image = Image.open(io.BytesIO(image_bytes))
            else:
                # 假设是文件路径,加载本地图像
                image = Image.open(query)
        else:
            raise ValueError("Unsupported image query format.")

        print(f"Embedding image query (size: {image.size})...")
        embedding = cross_modal_embedder.embed_image(image)[0]
        return {"image_query_embedding": embedding}
    except Exception as e:
        print(f"Error embedding image: {e}")
        return {"error_message": f"Failed to embed image: {e}"}

def retrieve_memories(state: AgentState) -> AgentState:
    """
    根据可用的嵌入(文本或图像)从两个向量存储中检索记忆。
    由于CLIP的跨模态对齐,无论是文本查询嵌入还是图像查询嵌入,
    都可以在两个模态的向量存储中进行检索。
    """
    text_query_embedding = state.get("text_query_embedding")
    image_query_embedding = state.get("image_query_embedding")

    # 优先使用图像查询嵌入,如果没有则使用文本查询嵌入
    query_embedding = image_query_embedding if image_query_embedding else text_query_embedding

    if not query_embedding:
        print("No embedding available for retrieval.")
        return {"error_message": "No query embedding generated."}

    print("Retrieving text memories...")
    # 从文本向量存储中检索
    retrieved_texts = text_vector_store.similarity_search_by_vector(query_embedding, k=3)

    print("Retrieving image memories...")
    # 从图像向量存储中检索
    retrieved_images = image_vector_store.similarity_search_by_vector(query_embedding, k=3)

    print(f"Retrieved {len(retrieved_texts)} text docs and {len(retrieved_images)} image docs.")
    return {
        "retrieved_text_docs": retrieved_texts,
        "retrieved_image_docs": retrieved_images
    }

def synthesize_response(state: AgentState) -> AgentState:
    """
    使用LLM综合检索到的信息,生成最终响应。
    """
    query = state["query"]
    retrieved_texts = state.get("retrieved_text_docs", [])
    retrieved_images = state.get("retrieved_image_docs", [])

    prompt_parts = [f"用户查询: {query}"]

    if retrieved_texts:
        prompt_parts.append("nn相关文本记忆:")
        for i, doc in enumerate(retrieved_texts):
            prompt_parts.append(f"  {i+1}. 内容: {doc.page_content} (来源: {doc.metadata.get('source', '未知')})")

    if retrieved_images:
        prompt_parts.append("nn相关图像记忆:")
        for i, doc in enumerate(retrieved_images):
            # 这里的doc.page_content是图像的描述
            # doc.metadata['image_path']是图像路径,可以用来加载实际图像
            prompt_parts.append(f"  {i+1}. 描述: {doc.page_content} (路径: {doc.metadata.get('image_path', '未知')})")
            # 实际应用中,这里可以加载图像并显示,或进行进一步分析
            # Image.open(doc.metadata['image_path']).show()

    prompt_parts.append("nn请根据以上信息,为用户生成一个全面的回答。")

    full_prompt = "n".join(prompt_parts)
    print("Synthesizing response with LLM...")
    final_response = mock_llm.invoke(full_prompt) # 替换为真实的LLM调用

    return {"final_response": final_response}

def handle_error(state: AgentState) -> AgentState:
    """
    处理工作流中的错误。
    """
    error = state.get("error_message", "未知错误")
    print(f"Error occurred: {error}")
    return {"final_response": f"抱歉,处理您的请求时发生错误:{error}"}

构建LangGraph

# 构建图
workflow = StateGraph(AgentState)

# 添加节点
workflow.add_node("determine_query_type", determine_query_type)
workflow.add_node("embed_text_query", embed_text_query)
workflow.add_node("embed_image_query", embed_image_query)
workflow.add_node("retrieve_memories", retrieve_memories)
workflow.add_node("synthesize_response", synthesize_response)
workflow.add_node("handle_error", handle_error) # 错误处理节点

# 设置入口点
workflow.set_entry_point("determine_query_type")

# 添加条件边
workflow.add_conditional_edges(
    "determine_query_type",
    lambda state: state["query_type"], # 根据query_type进行路由
    {
        "text": "embed_text_query",
        "image": "embed_image_query",
    }
)

# 嵌入节点指向检索节点
workflow.add_edge("embed_text_query", "retrieve_memories")

# 图像嵌入节点需要检查是否有错误
workflow.add_conditional_edges(
    "embed_image_query",
    lambda state: "handle_error" if state.get("error_message") else "retrieve_memories",
    {
        "retrieve_memories": "retrieve_memories",
        "handle_error": "handle_error",
    }
)

# 检索节点指向综合响应节点
workflow.add_edge("retrieve_memories", "synthesize_response")

# 综合响应节点指向结束
workflow.add_edge("synthesize_response", END)

# 错误处理节点指向结束
workflow.add_edge("handle_error", END)

# 编译图
app = workflow.compile()

print("LangGraph workflow compiled successfully.")

# --- 运行LangGraph ---

# 1. 文本查询示例
print("n--- 文本查询示例 ---")
text_query = "关于建筑和风景的记忆"
result_text_query = app.invoke({"query": text_query})
print("n最终响应:")
print(result_text_query["final_response"])

# 2. 图像查询示例 (模拟图像文件路径)
print("n--- 图像查询示例 (通过路径) ---")
image_query_path = "images/eiffel_tower.jpg" # 假定存在这个图片文件
result_image_query_path = app.invoke({"query": image_query_path})
print("n最终响应:")
print(result_image_query_path["final_response"])

# 3. 图像查询示例 (模拟base64编码的图像数据)
print("n--- 图像查询示例 (通过base64编码) ---")
# 实际中你需要将图片转换为base64编码的字符串
# 这里我们创建一个简单的占位符图片并编码
dummy_img = Image.new('RGB', (100, 100), color = 'yellow')
buffered = io.BytesIO()
dummy_img.save(buffered, format="PNG")
base64_img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
base64_image_query = f"data:image/png;base64,{base64_img_str}" # 完整的data URI

result_image_query_base64 = app.invoke({"query": base64_image_query})
print("n最终响应:")
print(result_image_query_base64["final_response"])

# 4. 包含不存在的图像路径的查询(模拟错误)
print("n--- 错误处理示例 (图像路径不存在) ---")
non_existent_image_path = "images/non_existent.jpg"
result_error_query = app.invoke({"query": non_existent_image_path})
print("n最终响应:")
print(result_error_query["final_response"])

工作流解释:

  1. AgentState: 定义了工作流中传递的所有状态变量,包括查询、嵌入、检索结果和最终响应。Annotated[List[Document], add_messages] 是LangGraph的语法,表示这是一个列表,并且每次更新时会添加到现有列表而不是完全覆盖(对于检索结果通常是合适的)。
  2. determine_query_type 节点: 这是入口节点。它检查 query 的类型。如果 querybytes类型(通常是图像文件的原始字节数据)或者以data:image/开头或以常见图片扩展名结尾的字符串,则认为是图像查询;否则认为是文本查询。
  3. embed_text_query / embed_image_query 节点: 根据 query_type 的结果,流程会被路由到这两个节点之一。它们分别调用 cross_modal_embedder 的相应方法来生成查询嵌入。图像嵌入节点还包含了基本的错误处理,例如当图像无法加载时。
  4. retrieve_memories 节点: 这是跨模态联合检索的核心。 无论查询是文本还是图像,我们都使用 CLIP 生成的嵌入(text_query_embeddingimage_query_embedding)去同时查询 text_vector_storeimage_vector_store。由于 CLIP 训练时就对齐了两种模态的嵌入,所以一个文本查询的嵌入可以找到相关的图像,一个图像查询的嵌入也可以找到相关的文本。
  5. synthesize_response 节点: 这个节点负责收集所有检索到的文本文档和图像文档的元数据(主要是描述和路径),然后构建一个详细的提示,将其发送给一个(模拟的)LLM。LLM将综合这些信息,生成最终的用户响应。在实际应用中,这里会集成如OpenAI GPT-4等真实的LLM。
  6. handle_error 节点: 当 embed_image_query 节点发生错误时,流程会转向此节点,提供一个友好的错误信息。
  7. 图的构建: 通过 StateGraph 定义节点、设置入口、添加边和条件边,最终编译成可执行的 app。条件边使得工作流具有动态决策能力。

四、高级考量与未来展望

1. 结果重排序(Re-ranking)

简单的相似性搜索可能返回一些相关性不高的结果。可以通过以下方法改进:

  • 多模态融合重排:结合文本相似度和图像相似度得分进行加权平均。
  • 跨编码器(Cross-encoders):对于检索到的Top-K结果,使用一个专门的跨模态模型(如CLIP的完整模型或更复杂的匹配模型)来计算查询与每个检索结果之间的精确匹配分数,然后重新排序。
  • LLM重排:让LLM阅读查询和检索到的文档/图像描述,然后根据语义相关性进行排序。
2. 多跳检索与推理

当前的流程是单次检索。更高级的智能体可以进行多跳(multi-hop)检索:

  • 首次检索提供初步信息。
  • LLM根据初步信息生成新的查询,或者从检索结果中提取关键词进行再次检索。
  • 结合多轮检索结果进行更深入的推理。
3. 用户反馈与持续学习

系统可以通过收集用户对检索结果的反馈(例如,“这个图片相关吗?”)来持续改进。这些反馈可以用于:

  • 调整嵌入模型的训练。
  • 优化相似度度量的权重。
  • 改进LLM的提示工程。
4. 性能与可扩展性
  • 向量数据库选择:对于大规模应用,应考虑使用Pinecone、Weaviate、Milvus等分布式向量数据库。
  • 嵌入模型优化:选择计算效率高且性能良好的嵌入模型,或对模型进行量化、蒸馏等优化。
  • LangGraph的并发性:LangGraph本身支持异步操作,可以在节点中集成异步函数以提高吞吐量。
5. 伦理与偏见
  • 训练数据偏见:CLIP等预训练模型在大量数据上训练,可能继承数据中的社会偏见。这可能导致检索结果带有刻板印象或不公平。在使用时需要警惕并采取缓解措施。
  • 隐私:存储和处理用户数据(尤其是图像)时,必须严格遵守隐私法规。

五、迈向多模态智能的必由之路

跨模态嵌入对齐是构建真正智能、能够理解和推理世界的AI系统的关键一步。它使得机器能够以更接近人类的方式感知和处理信息,打破了模态间的壁垒。LangGraph为这种复杂的、多步骤的跨模态交互提供了一个优雅的编排框架,使得我们可以将强大的深度学习模型和灵活的智能体逻辑结合起来。随着多模态AI技术的不断发展,我们可以预见,未来的智能体将能够更自然地理解我们的世界,实现更丰富、更智能的交互。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注