解析 ‘CRAG (Corrective RAG)’:如何在检索结果质量不佳时,自动切换到网络搜索或知识图谱补救?

各位技术同仁,下午好!

今天,我们将深入探讨一个在大型语言模型(LLM)应用领域日益凸显的关键议题:如何构建一个更健壮、更智能的检索增强生成(RAG)系统。具体来说,我们将聚焦于一个创新概念——CRAG(Corrective RAG),即纠错型RAG

RAG的出现无疑是LLM应用领域的一大突破,它通过将外部知识库与LLM相结合,有效缓解了LLM的“幻觉”问题,并使其能够访问实时或领域特定的信息。然而,RAG并非万能药。当检索到的信息质量不佳时,RAG系统依然可能给出不准确、不完整乃至误导性的答案。这正是CRAG诞生的初衷:如何在检索结果质量不佳时,系统能够自动感知并采取补救措施,例如切换到网络搜索或知识图谱查询?

作为一名编程专家,我的目标是为大家剖析CRAG的核心机制、技术挑战以及具体的实现策略,并辅以大量的代码示例,帮助大家将这些理论转化为实际可操作的系统。

1. RAG的困境:当“检索”不再可靠

首先,让我们快速回顾一下RAG的基本工作流程:

  1. 用户提交查询(Query)。
  2. 系统在预设的知识库(通常是向量数据库)中检索最相关的文档片段(Documents)。
  3. 将这些文档片段与用户查询一起作为上下文,输入给LLM。
  4. LLM基于上下文生成答案。

这个流程简单高效,但它的性能高度依赖于第二步——检索结果的质量。当检索结果出现以下情况时,RAG系统就会陷入困境:

  • 知识库覆盖不足(Out-of-Domain):用户查询的内容超出了当前知识库的范围。例如,一个关于公司内部政策的RAG系统被问及最新的全球新闻。
  • 信息过时(Stale Information):知识库中的信息未能及时更新,导致检索到的内容与最新事实不符。
  • 语义不匹配(Semantic Mismatch):用户的查询与知识库中的文档虽然语义相关,但由于嵌入模型或查询本身表达的问题,导致检索到的文档并不是最精准的。
  • 知识稀疏(Sparse Knowledge):知识库中关于某个主题的信息非常少,不足以支撑一个完整的答案。
  • 歧义查询(Ambiguous Query):用户查询本身存在歧义,导致系统检索到多个看似相关但实际上不准确的文档。

在这些情况下,即使是强大的LLM,也只能“巧妇难为无米之炊”,基于错误的上下文生成出错误的答案,甚至可能放大错误信息,导致更严重的“幻觉”。CRAG正是为了解决这些问题而生。

2. CRAG的核心思想:检测、纠正与集成

CRAG的核心可以概括为三个阶段:检测(Detection)检索质量、纠正(Correction)低质量检索结果,以及将纠正后的信息集成(Integration)到答案生成过程中。

2.1. 检测机制:如何判断检索结果不佳?

这是CRAG最关键的一步。我们需要一套机制来评估从向量数据库中检索到的文档是否能够有效回答用户的问题。以下是几种常用的检测策略:

2.1.1. 基于相似度分数与阈值的检测

最直接的方法是利用向量检索本身提供的相似度分数(如余弦相似度、点积等)。如果所有检索到的文档的相似度分数都低于某个预设阈值,则可以认为检索质量不佳。

优点:实现简单,计算成本低。
缺点:单一的相似度分数不总是能准确反映语义相关性,特别是对于复杂查询。阈值设定困难,过高可能导致误判,过低可能漏判。

from typing import List, Dict

def check_similarity_threshold(retrieved_docs: List[Dict], threshold: float = 0.7) -> bool:
    """
    检查检索到的文档中是否存在相似度分数高于阈值的文档。
    如果所有文档分数都低于阈值,则认为检索不佳。

    Args:
        retrieved_docs: 包含 'score' 键的文档列表。
        threshold: 相似度分数阈值。

    Returns:
        True 如果检索质量可能不佳 (所有分数低于阈值),False 否则。
    """
    if not retrieved_docs:
        return True # 没有检索到任何文档,肯定不佳

    for doc in retrieved_docs:
        if doc.get('score', 0.0) >= threshold:
            return False # 找到一个足够相关的文档
    return True # 所有文档分数都低于阈值

# 示例
mock_retrieved_docs_good = [
    {"text": "Python是一种高级编程语言。", "score": 0.85},
    {"text": "Python广泛应用于Web开发和数据科学。", "score": 0.78}
]

mock_retrieved_docs_bad = [
    {"text": "C++是一种编译型语言。", "score": 0.62},
    {"text": "Java是面向对象的。", "score": 0.55}
]

print(f"Good retrieval check: {check_similarity_threshold(mock_retrieved_docs_good)}")
print(f"Bad retrieval check: {check_similarity_threshold(mock_retrieved_docs_bad)}")
2.1.2. 基于交叉编码器(Cross-Encoder)的重排与置信度评估

向量数据库通常使用双编码器(Bi-Encoder)模型进行检索,即查询和文档分别编码,然后计算它们的相似度。这种方式计算效率高,但牺牲了一部分精度。交叉编码器(Cross-Encoder)模型则将查询和文档拼接后一起输入模型,让模型判断它们的关联性,这种方式精度更高,但计算成本也更高。

我们可以先用双编码器快速检索出Top-K文档,然后使用交叉编码器对这些文档进行重排,并获取它们与查询的更准确的关联性分数。如果重排后的Top-N文档分数依然很低,则认为检索不佳。

优点:比单一相似度分数更准确地评估相关性。
缺点:引入额外的计算成本和延迟。

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

class CrossEncoderReranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.eval() # 设置为评估模式

    def rerank(self, query: str, documents: List[Dict]) -> List[Dict]:
        """
        使用交叉编码器对文档进行重排,并返回新的分数。

        Args:
            query: 用户查询。
            documents: 包含 'text' 键的文档列表。

        Returns:
            重排并更新分数后的文档列表。
        """
        if not documents:
            return []

        # 构建输入对:(query, document_text)
        pairs = [[query, doc['text']] for doc in documents]

        # 编码输入
        inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt')

        # 进行预测
        with torch.no_grad():
            scores = self.model(**inputs).logits.squeeze() # 模型的输出通常是 logits

        # 如果只有一个文档,logits是标量,需要转换为列表
        if scores.dim() == 0:
            scores = scores.unsqueeze(0)

        # 将 logits 转换为概率(可选,但通常用于解释)
        # 这里我们直接使用 logits 作为分数,因为它们反映了相关性强度
        # scores = torch.sigmoid(scores) # 如果模型输出是二分类的logits

        # 更新文档分数并排序
        reranked_docs = []
        for i, doc in enumerate(documents):
            new_doc = doc.copy()
            new_doc['rerank_score'] = scores[i].item() # 将 Tensor 转换为 Python 浮点数
            reranked_docs.append(new_doc)

        # 按照 rerank_score 降序排序
        reranked_docs.sort(key=lambda x: x['rerank_score'], reverse=True)

        return reranked_docs

    def check_rerank_threshold(self, reranked_docs: List[Dict], threshold: float = 0.0) -> bool:
        """
        检查重排后的文档中是否存在 rerank_score 高于阈值的文档。
        交叉编码器的分数通常是负数(对于相似性),因此阈值可能需要根据模型特性调整。
        对于 cross-encoder/ms-marco-MiniLM-L-6-v2,分数越高表示越相关,通常是正值或接近0。
        """
        if not reranked_docs:
            return True # 没有检索到任何文档,肯定不佳

        # 检查最高分是否达到阈值
        if reranked_docs[0].get('rerank_score', -float('inf')) >= threshold:
            return False
        return True

# 示例使用
reranker = CrossEncoderReranker()
query = "Python在哪些领域有广泛应用?"

docs_from_vector_db = [
    {"text": "Python是一种通用编程语言。", "score": 0.8},
    {"text": "Python在Web开发、数据科学、人工智能和自动化脚本方面有广泛应用。", "score": 0.75},
    {"text": "Java主要用于企业级应用。", "score": 0.6},
    {"text": "R语言在统计分析中很流行。", "score": 0.5}
]

reranked_docs = reranker.rerank(query, docs_from_vector_db)
print("n重排后的文档:")
for doc in reranked_docs:
    print(f"Text: {doc['text']}, Original Score: {doc['score']:.2f}, Rerank Score: {doc['rerank_score']:.2f}")

# 假设一个差的检索情况
query_bad_retrieval = "最新的宇宙大爆炸理论是什么?"
docs_bad_retrieval = [
    {"text": "Python的历史和发展。", "score": 0.7},
    {"text": "关于机器学习算法的介绍。", "score": 0.65}
]
reranked_docs_bad = reranker.rerank(query_bad_retrieval, docs_bad_retrieval)
print("n重排后的差文档:")
for doc in reranked_docs_bad:
    print(f"Text: {doc['text']}, Original Score: {doc['score']:.2f}, Rerank Score: {doc['rerank_score']:.2f}")

# 检查重排阈值 (需要根据模型实际输出调整阈值)
# 对于 'cross-encoder/ms-marco-MiniLM-L-6-v2' 模型,更高的分数表示更相关。
# 实际阈值可能需要通过实验确定,例如,-2.0 可能是一个合理的低相关性阈值。
print(f"nGood retrieval (rerank) check: {reranker.check_rerank_threshold(reranked_docs, threshold=-2.0)}")
print(f"Bad retrieval (rerank) check: {reranker.check_rerank_threshold(reranked_docs_bad, threshold=-2.0)}")
2.1.3. 基于LLM的文档相关性评估

我们可以利用另一个LLM(可以是较小的模型)来直接评估检索到的文档与用户查询的相关性。这通常通过设计一个合适的提示(Prompt)来实现。LLM可以判断文档是否直接回答了问题、是否包含关键信息、是否存在无关信息等。

优点:评估结果更接近人类判断,能够处理复杂的语义关系。
缺点:引入LLM调用,增加延迟和成本。

import os
from openai import OpenAI # 或者其他LLM提供商

# 假设您已经设置了OPENAI_API_KEY环境变量
# os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
# client = OpenAI()

class LLMRelevanceChecker:
    def __init__(self, llm_client, model_name: str = "gpt-3.5-turbo"):
        self.llm_client = llm_client
        self.model_name = model_name

    def evaluate_relevance(self, query: str, document_text: str) -> bool:
        """
        使用LLM评估单个文档与查询的相关性。
        """
        prompt = f"""
        用户查询: "{query}"

        检索到的文档片段:
        ---
        {document_text}
        ---

        请判断上述文档片段是否与用户查询高度相关,并且能够直接或间接支持回答该查询。
        请只回答 '是' 或 '否'。
        """
        try:
            response = self.llm_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "你是一个评估文档相关性的助手。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.0, # 评估类任务通常使用低温度
                max_tokens=5
            )
            answer = response.choices[0].message.content.strip().lower()
            return answer == '是'
        except Exception as e:
            print(f"LLM调用失败: {e}")
            return False # 失败时默认不相关

    def check_llm_relevance(self, query: str, documents: List[Dict], min_relevant_docs: int = 1) -> bool:
        """
        检查检索到的文档中是否有足够数量的文档被LLM判断为相关。

        Args:
            query: 用户查询。
            documents: 包含 'text' 键的文档列表。
            min_relevant_docs: 至少需要多少个文档被判定为相关才算通过。

        Returns:
            True 如果相关文档数量不足,False 否则。
        """
        if not documents:
            return True # 没有文档,肯定不佳

        relevant_count = 0
        for doc in documents:
            if self.evaluate_relevance(query, doc['text']):
                relevant_count += 1

        return relevant_count < min_relevant_docs

# 示例 (需要实际的OpenAI API客户端)
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # 确保设置了API Key
# llm_checker = LLMRelevanceChecker(llm_client=client)

# # 模拟LLM调用,实际运行时需要替换为真实的OpenAI客户端
class MockLLMClient:
    def chat(self):
        return self
    def completions(self):
        return self
    def create(self, model, messages, temperature, max_tokens):
        # 模拟根据查询和文档判断相关性
        query = messages[1]['content'].split('用户查询: "')[1].split('"nn检索到的文档片段:')[0]
        doc_text = messages[1]['content'].split('检索到的文档片段:n---n')[1].split('n---')[0]

        if "Python" in query and ("Python" in doc_text or "Web开发" in doc_text or "数据科学" in doc_text):
            response_content = "是"
        elif "宇宙大爆炸" in query and ("Python" in doc_text or "机器学习" in doc_text):
            response_content = "否"
        else:
            response_content = "否" # 默认不相关

        class MockChoice:
            def __init__(self, content):
                self.message = type('MockMessage', (object,), {'content': content})()

        class MockResponse:
            def __init__(self, content):
                self.choices = [MockChoice(content)]
        return MockResponse(response_content)

mock_llm_client = MockLLMClient()
llm_checker = LLMRelevanceChecker(llm_client=mock_llm_client)

query_good_retrieval = "Python在哪些领域有广泛应用?"
docs_good_retrieval = [
    {"text": "Python是一种通用编程语言。", "score": 0.8},
    {"text": "Python在Web开发、数据科学、人工智能和自动化脚本方面有广泛应用。", "score": 0.75}
]

query_bad_retrieval = "最新的宇宙大爆炸理论是什么?"
docs_bad_retrieval = [
    {"text": "Python的历史和发展。", "score": 0.7},
    {"text": "关于机器学习算法的介绍。", "score": 0.65}
]

print(f"nGood retrieval (LLM check) - expected False: {llm_checker.check_llm_relevance(query_good_retrieval, docs_good_retrieval)}")
print(f"Bad retrieval (LLM check) - expected True: {llm_checker.check_llm_relevance(query_bad_retrieval, docs_bad_retrieval)}")
2.1.4. 基于查询可回答性(Query Answerability)的检测

这种方法更进一步,它不仅判断文档是否相关,还判断这些文档是否足以回答用户的问题。这通常需要LLM具备一定的推理能力。

class LLMAnswerabilityChecker:
    def __init__(self, llm_client, model_name: str = "gpt-3.5-turbo"):
        self.llm_client = llm_client
        self.model_name = model_name

    def check_answerability(self, query: str, documents: List[Dict]) -> bool:
        """
        使用LLM判断给定文档是否足以回答用户查询。
        """
        if not documents:
            return False # 没有文档,肯定无法回答

        context = "n---n".join([doc['text'] for doc in documents])

        prompt = f"""
        用户查询: "{query}"

        以下是检索到的信息片段:
        ---
        {context}
        ---

        请判断根据上述信息片段,是否能够充分回答用户查询。
        如果信息足以回答,请回答 '是'。
        如果信息不足、不完整或不相关,无法回答,请回答 '否'。
        请只回答 '是' 或 '否'。
        """
        try:
            response = self.llm_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "你是一个判断信息可回答性的助手。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.0,
                max_tokens=5
            )
            answer = response.choices[0].message.content.strip().lower()
            return answer == '是'
        except Exception as e:
            print(f"LLM调用失败: {e}")
            return False

# 示例 (使用MockLLMClient)
llm_answerability_checker = LLMAnswerabilityChecker(llm_client=mock_llm_client)

query_answerable = "Python在Web开发中常用的框架有哪些?"
docs_answerable = [
    {"text": "Python在Web开发中常用的框架包括Django、Flask和FastAPI。", "score": 0.9}
]

query_not_answerable = "谁是世界上最富有的人?"
docs_not_answerable = [
    {"text": "Python是一种编程语言。", "score": 0.7},
    {"text": "特斯拉是一家电动汽车公司。", "score": 0.6}
]

# 为了模拟,我们需要修改mock_llm_client的逻辑
class MockLLMClientForAnswerability:
    def chat(self):
        return self
    def completions(self):
        return self
    def create(self, model, messages, temperature, max_tokens):
        query = messages[1]['content'].split('用户查询: "')[1].split('"nn以下是检索到的信息片段:')[0]
        context = messages[1]['content'].split('以下是检索到的信息片段:n---n')[1].split('n---')[0]

        if "Python在Web开发中常用的框架有哪些" in query and ("Django" in context and "Flask" in context):
            response_content = "是"
        elif "最富有的人" in query and ("Python" in context or "特斯拉" in context):
            response_content = "否"
        else:
            response_content = "否"

        class MockChoice:
            def __init__(self, content):
                self.message = type('MockMessage', (object,), {'content': content})()

        class MockResponse:
            def __init__(self, content):
                self.choices = [MockChoice(content)]
        return MockResponse(response_content)

mock_llm_client_ans = MockLLMClientForAnswerability()
llm_answerability_checker = LLMAnswerabilityChecker(llm_client=mock_llm_client_ans)

print(f"nAnswerability check (good) - expected True: {llm_answerability_checker.check_answerability(query_answerable, docs_answerable)}")
print(f"Answerability check (bad) - expected False: {llm_answerability_checker.check_answerability(query_not_answerable, docs_not_answerable)}")

检测机制对比表格:

检测方法 优点 缺点 成本 复杂度 最佳应用场景
相似度分数阈值 实现简单,计算成本低 精度有限,阈值难设 初步筛选,快速判断明显无关查询
交叉编码器重排 精度高,能更准确评估相关性 增加计算延迟,模型选择和阈值调整 对相关性要求高的场景,作为二级筛选
LLM相关性评估 接近人类判断,处理复杂语义 成本高,延迟高 对语义理解和准确性要求极高的场景
LLM查询可回答性评估 直接判断是否能生成答案,更实用 成本最高,延迟最高,依赖LLM推理能力 很高 最终判断检索结果是否“可用”

在实际应用中,通常会结合多种检测方法,形成一个级联或加权的判断逻辑。例如,先用相似度阈值快速过滤,再用交叉编码器重排,最后用LLM进行最终确认。

2.2. 纠正策略:当检索不佳时如何补救?

一旦系统检测到初始检索结果不佳,就需要采取补救措施。CRAG主要关注两种强大的外部知识来源:网络搜索知识图谱

2.2.1. 策略一:网络搜索增强(Web Search Augmentation)

当知识库覆盖不足、信息过时或需要通用、实时信息时,网络搜索是极佳的补救手段。

工作机制

  1. 查询生成:根据用户原始查询和/或初始检索到的少量(可能不相关)文档,生成一个或多个优化过的网络搜索查询。LLM在这方面非常有用。
  2. 执行搜索:调用搜索引擎API(如Google Custom Search, Bing Search API, Brave Search API, SerpAPI, DuckDuckGo Search等)。
  3. 结果处理:从搜索结果中提取关键信息。这可能涉及网页抓取、文本提取、摘要生成、实体识别等。
  4. 信息集成:将处理后的网络信息作为新的上下文,与用户查询一起提供给LLM生成答案。

挑战

  • 延迟:网络请求和网页解析会增加显著延迟。
  • 成本:搜索引擎API通常按调用次数计费。
  • 信息过载与噪声:网络信息量巨大,如何筛选高质量、相关的部分是关键。
  • 信任度与偏见:网络信息来源复杂,需要考虑其可信度。
import requests
import json
from bs4 import BeautifulSoup # 用于解析HTML
from duckduckgo_search import DDGS # 一个轻量级的免费网络搜索库

# 假设一个简单的LLM查询生成器
class QueryGenerator:
    def __init__(self, llm_client, model_name: str = "gpt-3.5-turbo"):
        self.llm_client = llm_client
        self.model_name = model_name

    def generate_web_search_query(self, user_query: str, initial_context: str = "") -> str:
        """
        使用LLM根据用户查询和初始上下文生成一个网络搜索查询。
        """
        prompt = f"""
        用户原始查询: "{user_query}"
        初始检索到的信息 (如果存在且不相关):
        ---
        {initial_context}
        ---

        请根据上述信息,生成一个最适合在搜索引擎中使用的、简洁明了的查询字符串,以便找到相关且最新的信息。
        请只输出查询字符串,不要包含任何额外说明。
        """
        try:
            response = self.llm_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "你是一个生成网络搜索查询的助手。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.0,
                max_tokens=50
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"LLM生成搜索查询失败: {e}")
            return user_query # 失败时退回原始查询

# 模拟LLM客户端 for QueryGenerator
class MockLLMClientForQueryGen:
    def chat(self):
        return self
    def completions(self):
        return self
    def create(self, model, messages, temperature, max_tokens):
        user_query = messages[1]['content'].split('用户原始查询: "')[1].split('"n初始检索到的信息')[0]

        # 简单模拟,实际LLM会更智能
        if "宇宙大爆炸理论" in user_query:
            generated_query = "最新的宇宙大爆炸理论研究"
        elif "Python框架" in user_query:
            generated_query = "Python Web开发常用框架"
        else:
            generated_query = user_query + "最新信息"

        class MockChoice:
            def __init__(self, content):
                self.message = type('MockMessage', (object,), {'content': content})()

        class MockResponse:
            def __init__(self, content):
                self.choices = [MockChoice(content)]
        return MockResponse(generated_query)

mock_query_llm_client = MockLLMClientForQueryGen()
query_generator = QueryGenerator(llm_client=mock_query_llm_client)

class WebSearchAgent:
    def __init__(self, max_results: int = 3):
        self.max_results = max_results

    def search(self, query: str) -> List[Dict]:
        """
        使用DuckDuckGo进行网络搜索。
        """
        print(f"Executing web search for: '{query}'")
        results = []
        try:
            # 使用 DDGS 库进行搜索
            with DDGS() as ddgs:
                ddgs_results = ddgs.text(keywords=query, region='wt-wt', safesearch='off', timelimit='y', max_results=self.max_results)
                for r in ddgs_results:
                    results.append({
                        "title": r.get('title'),
                        "href": r.get('href'),
                        "body": r.get('body')
                    })
        except Exception as e:
            print(f"DuckDuckGo search failed: {e}")
        return results

    def fetch_and_extract_content(self, url: str) -> str:
        """
        抓取网页内容并提取主要文本。
        """
        try:
            response = requests.get(url, timeout=5)
            response.raise_for_status() # 检查HTTP错误
            soup = BeautifulSoup(response.text, 'html.parser')
            # 尝试提取主要内容,例如段落文本
            paragraphs = soup.find_all('p')
            text_content = 'n'.join([p.get_text() for p in paragraphs if p.get_text().strip()])
            return text_content[:2000] # 限制文本长度
        except requests.exceptions.RequestException as e:
            print(f"Error fetching {url}: {e}")
            return ""
        except Exception as e:
            print(f"Error parsing {url}: {e}")
            return ""

    def get_relevant_snippets(self, web_search_results: List[Dict], user_query: str) -> List[str]:
        """
        从网络搜索结果中提取与用户查询最相关的片段。
        这里可以集成LLM进行摘要或提取,但为简洁起见,我们直接使用body或抓取内容。
        """
        snippets = []
        for result in web_search_results:
            # 优先使用搜索结果的body作为摘要
            if result.get('body'):
                snippets.append(f"标题: {result['title']}n链接: {result['href']}n内容摘要: {result['body']}")
            elif result.get('href'):
                # 如果body为空,尝试抓取网页内容
                full_content = self.fetch_and_extract_content(result['href'])
                if full_content:
                    snippets.append(f"标题: {result['title']}n链接: {result['href']}n详细内容(部分): {full_content}")

            # 限制总片段数量,避免上下文过长
            if len(snippets) >= self.max_results:
                break
        return snippets

# 示例使用
web_search_agent = WebSearchAgent(max_results=2)
user_query_for_web = "最新的宇宙大爆炸理论是什么?"
initial_context_for_web = "我们知识库中没有关于宇宙大爆炸理论的最新信息。"

# 生成搜索查询
generated_search_query = query_generator.generate_web_search_query(user_query_for_web, initial_context_for_web)

# 执行网络搜索
search_results = web_search_agent.search(generated_search_query)

# 获取相关片段
web_snippets = web_search_agent.get_relevant_snippets(search_results, user_query_for_web)

print("nWeb Search Snippets:")
for snippet in web_snippets:
    print("-" * 20)
    print(snippet)
2.2.2. 策略二:知识图谱增强(Knowledge Graph Augmentation)

当用户查询涉及结构化、实体关系或需要高精度事实性信息时,知识图谱(KG)是理想的补救方案。

工作机制

  1. 实体识别与关系提取:从用户查询中识别出关键实体(如人名、地名、组织)和潜在的关系。LLM或专门的NLP工具可以完成此任务。
  2. KG查询生成:将识别出的实体和关系转化为知识图谱可理解的查询语言(如SPARQL用于RDF图,Cypher用于Neo4j,或简单的API调用)。
  3. KG查询执行:在知识图谱中执行查询,获取结构化事实。
  4. 事实集成:将KG返回的结构化事实(通常是三元组或属性值对)转化为自然语言,作为上下文提供给LLM。

挑战

  • KG构建与维护:构建高质量的知识图谱是一个耗时耗力的工程。
  • 查询转换复杂性:将自然语言查询准确地转化为KG查询语言是难点。
  • 覆盖范围:KG可能无法覆盖所有长尾知识。
# 模拟一个简单的知识图谱
mock_knowledge_graph = {
    "Elon Musk": {
        "职业": "企业家",
        "公司": ["Tesla", "SpaceX", "Neuralink"],
        "出生地": "南非",
        "出生日期": "1971年6月28日"
    },
    "Tesla": {
        "CEO": "Elon Musk",
        "行业": "电动汽车, 能源存储",
        "总部": "奥斯汀, 德克萨斯州"
    },
    "SpaceX": {
        "CEO": "Elon Musk",
        "行业": "航空航天",
        "总部": "霍桑, 加利福尼亚州"
    }
}

class KGEntityExtractor:
    def __init__(self, llm_client, model_name: str = "gpt-3.5-turbo"):
        self.llm_client = llm_client
        self.model_name = model_name

    def extract_entities(self, query: str) -> List[str]:
        """
        使用LLM从查询中提取关键实体。
        """
        prompt = f"""
        用户查询: "{query}"

        请从上述查询中识别并列出所有重要的实体(如人名、公司名、地名等)。
        如果存在多个实体,请用逗号分隔。如果没有实体,请回答 '无'。
        """
        try:
            response = self.llm_client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "system", "content": "你是一个实体提取助手。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.0,
                max_tokens=50
            )
            entities_str = response.choices[0].message.content.strip()
            if entities_str.lower() == '无':
                return []
            return [e.strip() for e in entities_str.split(',')]
        except Exception as e:
            print(f"LLM实体提取失败: {e}")
            return []

# 模拟LLM客户端 for KGEntityExtractor
class MockLLMClientForEntityExtraction:
    def chat(self):
        return self
    def completions(self):
        return self
    def create(self, model, messages, temperature, max_tokens):
        query = messages[1]['content'].split('用户查询: "')[1].split('"nn请从上述查询中识别并列出所有重要的实体')[0]

        if "Elon Musk" in query:
            extracted_entities = "Elon Musk"
        elif "特斯拉" in query:
            extracted_entities = "Tesla"
        elif "SpaceX" in query:
            extracted_entities = "SpaceX"
        else:
            extracted_entities = "无"

        class MockChoice:
            def __init__(self, content):
                self.message = type('MockMessage', (object,), {'content': content})()

        class MockResponse:
            def __init__(self, content):
                self.choices = [MockChoice(content)]
        return MockResponse(extracted_entities)

mock_entity_llm_client = MockLLMClientForEntityExtraction()
entity_extractor = KGEntityExtractor(llm_client=mock_entity_llm_client)

class KnowledgeGraphAgent:
    def __init__(self, kg: Dict):
        self.kg = kg

    def query_kg(self, entity: str) -> List[str]:
        """
        根据实体在模拟知识图谱中查询相关事实。
        """
        facts = []
        if entity in self.kg:
            entity_data = self.kg[entity]
            for prop, value in entity_data.items():
                if isinstance(value, list):
                    facts.append(f"{entity}的{prop}是: {', '.join(value)}。")
                else:
                    facts.append(f"{entity}的{prop}是: {value}。")
        return facts

    def get_kg_context(self, query: str) -> List[str]:
        """
        从查询中提取实体并从知识图谱中获取上下文。
        """
        entities = entity_extractor.extract_entities(query) # 使用LLM进行实体提取
        kg_context = []
        for entity in entities:
            facts = self.query_kg(entity)
            if facts:
                kg_context.extend(facts)
        return kg_context

# 示例使用
kg_agent = KnowledgeGraphAgent(mock_knowledge_graph)
user_query_for_kg = "Elon Musk的公司有哪些?"

kg_facts = kg_agent.get_kg_context(user_query_for_kg)

print("nKnowledge Graph Facts:")
for fact in kg_facts:
    print(fact)

user_query_for_kg_no_match = "比尔盖茨的公司有哪些?"
kg_facts_no_match = kg_agent.get_kg_context(user_query_for_kg_no_match)
print(f"nKnowledge Graph Facts for '{user_query_for_kg_no_match}': {kg_facts_no_match}")
2.2.3. 混合与多阶段纠正策略

在实际系统中,我们通常会结合使用这些策略,形成一个复杂的决策流程。

决策逻辑示例

  1. 初始RAG检索
  2. 检测:如果检索质量差(例如,所有文档rerank分数低于阈值或LLM判定不可回答)。
    a. 尝试KG增强:从查询中提取实体,查询知识图谱。如果KG能提供充分且相关的事实。
    i. 集成KG事实:将KG事实作为上下文,生成答案。
    b. 如果KG失败或信息不足,则尝试网络搜索:生成网络搜索查询,执行网络搜索,提取相关片段。
    i. 集成网络信息:将网络片段作为上下文,生成答案。
    c. 如果所有补救措施都失败:向用户请求澄清,或告知无法回答。

2.3. 系统架构与编排

CRAG系统的核心挑战在于如何协调这些不同的模块,形成一个无缝的决策和执行流程。

CRAG系统架构概览:

+----------------+       +-------------------+       +--------------------+       +-------------------+
|   用户查询     | ----> |   RAG检索模块     | ----> |   检索质量检测模块  | ----> |   LLM答案生成模块  |
| (User Query)   |       | (Vector DB Lookup)|       | (Retrieval Quality)|       | (Answer Generation)|
+----------------+       +-------------------+       |   (Reranker, LLM   |       +-------------------+
                                                      |   Evaluator)       |                ^
                                                      +--------------------+                | (良好检索结果)
                                                              | (检索不佳)                     |
                                                              V                              |
                                                      +--------------------+                |
                                                      |   纠正策略编排器   |                |
                                                      | (Correction Strategy) |                |
                                                      | (Orchestrator)     |                |
                                                      +--------------------+                |
                                                        |   /                             |
                                                        |  /                              |
                                                        V V              V V                 |
                                                  +----------------+    +----------------+   |
                                                  |   知识图谱增强 |    |   网络搜索增强 |   |
                                                  | (KG Augmentation)|  | (Web Search Aug.)|   |
                                                  | (Entity Extractor,|  | (Query Generator, |   |
                                                  |   KG Query)      |  |   Search API,   |   |
                                                  +----------------+    |   Content Extractor)|-+ (补救上下文)
                                                                        +----------------+

CRAGPipeline 示例代码:

# 假设已经初始化了所有代理和LLM客户端
# llm_client (for OpenAI or similar)
# reranker = CrossEncoderReranker()
# llm_relevance_checker = LLMRelevanceChecker(llm_client)
# llm_answerability_checker = LLMAnswerabilityChecker(llm_client)
# query_generator = QueryGenerator(llm_client)
# web_search_agent = WebSearchAgent()
# entity_extractor = KGEntityExtractor(llm_client)
# kg_agent = KnowledgeGraphAgent(mock_knowledge_graph)

# 再次定义MockLLMClient以供CRAGPipeline使用
class UnifiedMockLLMClient:
    def chat(self):
        return self
    def completions(self):
        return self
    def create(self, model, messages, temperature, max_tokens):
        system_role = messages[0]['content']
        user_prompt = messages[1]['content']

        # Mock for AnswerabilityChecker
        if "判断信息可回答性" in system_role:
            if "Python在Web开发中常用的框架有哪些" in user_prompt and ("Django" in user_prompt and "Flask" in user_prompt):
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '是'})()})]})()
            elif "最富有的人" in user_prompt and ("Python" in user_prompt or "特斯拉" in user_prompt):
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '否'})()})]})()
            else:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '否'})()})]})()

        # Mock for RelevanceChecker
        elif "评估文档相关性" in system_role:
            if "Python" in user_prompt and ("Python" in user_prompt or "Web开发" in user_prompt):
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '是'})()})]})()
            elif "宇宙大爆炸" in user_prompt and ("Python" in user_prompt or "机器学习" in user_prompt):
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '否'})()})]})()
            else:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '否'})()})]})()

        # Mock for QueryGenerator
        elif "生成网络搜索查询" in system_role:
            user_query = user_prompt.split('用户原始查询: "')[1].split('"n初始检索到的信息')[0]
            if "宇宙大爆炸理论" in user_query:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '最新的宇宙大爆炸理论研究'})()})]})()
            else:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': user_query + "最新信息"})()})]})()

        # Mock for EntityExtractor
        elif "实体提取助手" in system_role:
            user_query = user_prompt.split('用户查询: "')[1].split('"nn请从上述查询中识别并列出所有重要的实体')[0]
            if "Elon Musk" in user_query:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': 'Elon Musk'})()})]})()
            elif "特斯拉" in user_query:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': 'Tesla'})()})]})()
            else:
                return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '无'})()})]})()

        # Mock for final answer generation
        else:
            return type('R', (object,), {'choices': [type('C', (object,), {'message': type('M', (object,), {'content': '这是基于提供的上下文生成的模拟答案。'})()})]})()

mock_llm_client_unified = UnifiedMockLLMClient()

# 重新初始化所有组件以使用统一的mock_llm_client
reranker = CrossEncoderReranker() # 这个不需要LLM客户端
llm_relevance_checker = LLMRelevanceChecker(llm_client=mock_llm_client_unified)
llm_answerability_checker = LLMAnswerabilityChecker(llm_client=mock_llm_client_unified)
query_generator = QueryGenerator(llm_client=mock_llm_client_unified)
web_search_agent = WebSearchAgent() # 这个不需要LLM客户端
entity_extractor = KGEntityExtractor(llm_client=mock_llm_client_unified)
kg_agent = KnowledgeGraphAgent(mock_knowledge_graph) # 这个不需要LLM客户端

class CRAGPipeline:
    def __init__(self, llm_client, vector_db_retriever, reranker, llm_relevance_checker, 
                 llm_answerability_checker, query_generator, web_search_agent, 
                 entity_extractor, kg_agent, final_llm_model: str = "gpt-3.5-turbo"):
        self.llm_client = llm_client
        self.vector_db_retriever = vector_db_retriever # 模拟向量数据库检索器
        self.reranker = reranker
        self.llm_relevance_checker = llm_relevance_checker
        self.llm_answerability_checker = llm_answerability_checker
        self.query_generator = query_generator
        self.web_search_agent = web_search_agent
        self.entity_extractor = entity_extractor
        self.kg_agent = kg_agent
        self.final_llm_model = final_llm_model

    def _retrieve_from_vector_db(self, query: str, top_k: int = 5) -> List[Dict]:
        # 模拟向量数据库检索
        print(f"n[CRAG] Initial retrieval for: '{query}'")
        if "Python框架" in query:
            return [
                {"text": "Python在Web开发中常用的框架包括Django、Flask和FastAPI。", "score": 0.9},
                {"text": "Python是一种通用编程语言。", "score": 0.8},
                {"text": "数据科学中常用的Python库有NumPy和Pandas。", "score": 0.7}
            ]
        elif "Elon Musk" in query:
             return [
                {"text": "Elon Musk是特斯拉和SpaceX的CEO。", "score": 0.9},
                {"text": "特斯拉生产电动汽车。", "score": 0.8}
            ]
        elif "宇宙大爆炸理论" in query:
            return [ # 模拟低相关性检索
                {"text": "Python的历史和发展。", "score": 0.6},
                {"text": "关于机器学习算法的介绍。", "score": 0.55},
                {"text": "太阳系的行星。", "score": 0.5}
            ]
        else:
            return []

    def _generate_final_answer(self, query: str, context: List[str]) -> str:
        """
        使用LLM根据提供的上下文生成最终答案。
        """
        context_str = "n".join(context)
        prompt = f"""
        用户查询: "{query}"

        以下是提供的信息,请根据这些信息回答用户查询。如果信息不足以回答,请说明。
        ---
        {context_str}
        ---

        请生成一个全面、准确的答案。
        """
        try:
            response = self.llm_client.chat.completions.create(
                model=self.final_llm_model,
                messages=[
                    {"role": "system", "content": "你是一个智能问答助手。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.2,
                max_tokens=500
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"最终答案生成失败: {e}")
            return "抱歉,无法生成答案。"

    def run(self, query: str) -> str:
        # 1. 初始RAG检索
        initial_docs = self._retrieve_from_vector_db(query)

        # 2. 检索质量检测
        # 首先用交叉编码器重排,获取更准确的分数
        reranked_docs = self.reranker.rerank(query, initial_docs)

        # 使用LLM判断这些文档是否足以回答问题
        # 注意:这里可以根据需要调整检测策略,例如先检查rerank阈值,再检查LLM可回答性
        is_retrieval_poor = not self.llm_answerability_checker.check_answerability(query, reranked_docs)

        if not is_retrieval_poor:
            print("[CRAG] Initial retrieval is good. Generating answer from vector DB.")
            context = [doc['text'] for doc in reranked_docs]
            return self._generate_final_answer(query, context)
        else:
            print("[CRAG] Initial retrieval is poor. Attempting corrective actions.")
            corrected_context = []

            # 3. 尝试知识图谱增强
            print("[CRAG] Attempting Knowledge Graph augmentation...")
            kg_facts = self.kg_agent.get_kg_context(query)
            if kg_facts:
                print(f"[CRAG] KG found {len(kg_facts)} relevant facts.")
                corrected_context.extend(kg_facts)
                # 再次检查是否现在可以回答了
                if self.llm_answerability_checker.check_answerability(query, [{"text": f} for f in corrected_context]):
                    print("[CRAG] KG augmentation successful. Generating answer.")
                    return self._generate_final_answer(query, corrected_context)

            # 4. 如果KG未能充分补救,尝试网络搜索
            print("[CRAG] KG augmentation insufficient or failed. Attempting Web Search augmentation...")
            generated_search_query = self.query_generator.generate_web_search_query(
                query, 
                initial_context="n".join([doc['text'] for doc in initial_docs])
            )
            web_search_results = self.web_search_agent.search(generated_search_query)
            web_snippets = self.web_search_agent.get_relevant_snippets(web_search_results, query)

            if web_snippets:
                print(f"[CRAG] Web search found {len(web_snippets)} relevant snippets.")
                corrected_context.extend(web_snippets)
                print("[CRAG] Web search augmentation successful. Generating answer.")
                return self._generate_final_answer(query, corrected_context)
            else:
                print("[CRAG] All corrective actions failed.")
                return "抱歉,我无法找到足够的信息来回答您的查询,请尝试换一个问题。"

# 示例运行
crag_pipeline = CRAGPipeline(
    llm_client=mock_llm_client_unified,
    vector_db_retriever=None, # 模拟器中直接处理,这里可以留空
    reranker=reranker,
    llm_relevance_checker=llm_relevance_checker,
    llm_answerability_checker=llm_answerability_checker,
    query_generator=query_generator,
    web_search_agent=web_search_agent,
    entity_extractor=entity_extractor,
    kg_agent=kg_agent
)

print("n--- 场景一:初始RAG检索良好 ---")
answer_good = crag_pipeline.run("Python在Web开发中常用的框架有哪些?")
print(f"nFinal Answer: {answer_good}")

print("n--- 场景二:初始RAG检索不佳,KG可补救 ---")
answer_kg = crag_pipeline.run("Elon Musk的公司有哪些?")
print(f"nFinal Answer: {answer_kg}")

print("n--- 场景三:初始RAG检索不佳,需要网络搜索 ---")
answer_web = crag_pipeline.run("最新的宇宙大爆炸理论是什么?")
print(f"nFinal Answer: {answer_web}")

print("n--- 场景四:所有补救措施都失败 ---")
answer_fail = crag_pipeline.run("请告诉我一个不存在的虚构概念的最新发展。")
print(f"nFinal Answer: {answer_fail}")

3. 高级考量与最佳实践

构建一个生产级的CRAG系统需要考虑更多细节:

  • 阈值调优:检测模块的各种阈值(相似度、交叉编码器分数、相关文档数量等)需要通过实验和A/B测试进行细致调优,以平衡准确性和召回率。
  • 成本与延迟管理:LLM调用和网络搜索都是有成本和延迟的。需要仔细权衡何时触发这些高级操作,以及可以接受的最大延迟。可以设置多个LLM模型,在检测和纠正阶段使用更小、更快的模型。
  • 缓存机制:对于常见的网络搜索查询或KG查询,可以实施缓存策略,减少重复调用,降低成本和延迟。
  • 安全性与隐私:网络搜索结果可能包含恶意链接或不当内容。需要对抓取内容进行安全审查。处理用户查询时也要注意隐私保护。
  • 反馈循环与持续学习:收集用户反馈(例如,对答案的满意度、是否纠正了错误)可以用于改进检测模型和纠正策略。例如,可以训练一个分类器来预测何时需要特定类型的纠正。
  • 异构数据源集成:除了向量数据库、知识图谱和网络,还可以集成其他数据源,如结构化数据库、API接口等。
  • LLM的上下文窗口限制:无论是初始检索还是纠正后的信息,最终都要作为LLM的上下文。需要注意LLM的上下文窗口大小限制,对检索到的文档和网页内容进行摘要或截断。

4. 展望未来

CRAG代表了RAG系统向更智能、自适应方向演进的重要一步。随着LLM能力的不断提升和多模态RAG的发展,未来的CRAG系统将能够处理更复杂的查询,集成更多样化的数据源,并提供更加精准和可靠的答案。我们甚至可以展望到,CRAG能够根据用户的个性化偏好和历史交互,动态调整其纠正策略,真正实现千人千面的智能问答体验。

CRAG的出现,无疑为我们打开了一扇通往更强大、更可靠AI应用的大门。希望今天的讲座能为大家带来启发,激发大家在构建智能系统时的创新思维。谢谢大家!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注