如何对企业私有语料构建多阶段 embedding 训练以提升 RAG 召回质量

企业私有语料 RAG 召回优化:多阶段 Embedding 训练实战

大家好,今天我们来聊聊如何利用多阶段 Embedding 训练,提升企业私有语料的 RAG (Retrieval Augmented Generation) 召回质量。RAG 架构的核心在于有效召回与用户查询相关的上下文,而 Embedding 的质量直接决定了召回的准确性。针对企业私有语料,我们往往需要针对特定领域进行 Embedding 训练,才能获得更好的效果。

RAG 系统与 Embedding 的重要性

在深入多阶段训练之前,我们先回顾一下 RAG 系统和 Embedding 在其中的作用。

RAG 系统的核心流程如下:

  1. Query Embedding: 将用户查询转换为 Embedding 向量。
  2. Retrieval: 基于 Query Embedding,在知识库中检索最相关的文档片段。
  3. Augmentation: 将检索到的文档片段与原始查询拼接,形成增强的 Prompt。
  4. Generation: 将增强的 Prompt 输入 LLM,生成最终答案。

Embedding 模型负责将文本数据(查询、文档)映射到高维向量空间,使得语义相似的文本在向量空间中距离更近。常见的 Embedding 模型包括:

  • Sentence Transformers: 预训练的 Sentence Embedding 模型,如 all-mpnet-base-v2
  • Text Embedding API: OpenAI 的 text-embedding-ada-002 等。
  • 自训练 Embedding 模型: 基于企业私有语料训练的定制化模型。

选择合适的 Embedding 模型,并针对特定领域进行优化,是提升 RAG 召回效果的关键。

多阶段 Embedding 训练策略

针对企业私有语料,我们可以采用多阶段 Embedding 训练策略,逐步提升模型性能。以下是一种常用的训练流程:

阶段一:通用预训练模型微调 (Fine-tuning on General Domain Data)

  • 目的: 利用通用领域的语料,使模型具备一定的文本理解能力,并适应下游任务。
  • 数据: 选择与目标领域相关的公开数据集,例如 Wikipedia、Stack Overflow 等。
  • 方法: 基于预训练的 Sentence Transformers 模型,进行微调。
  • 优点: 快速启动,利用现有资源。
  • 缺点: 对特定领域的适应性有限。

阶段二:领域知识增强 (Domain-Specific Fine-tuning)

  • 目的: 使模型更好地理解和表示特定领域的文本。
  • 数据: 企业私有语料,包括文档、FAQ、用户问答等。
  • 方法: 继续在阶段一的模型基础上,使用领域数据进行微调。
  • 优点: 显著提升领域内的召回效果。
  • 缺点: 需要准备充足的领域数据。

阶段三:对比学习优化 (Contrastive Learning)

  • 目的: 进一步优化 Embedding 空间,使得相似文本的距离更近,不相似文本的距离更远。
  • 数据: 构建正负样本对,例如:
    • 正样本: 同一个文档的不同片段,或者语义相关的文档。
    • 负样本: 随机抽取的文档,或者语义不相关的文档。
  • 方法: 使用对比学习损失函数,如 Margin Ranking Loss、Multiple Negatives Ranking Loss。
  • 优点: 提高 Embedding 的区分度,增强召回的准确性。
  • 缺点: 需要设计合适的负样本策略。

阶段四:查询优化 (Query Optimization)

  • 目的: 针对用户查询进行优化,提升查询 Embedding 的质量。
  • 数据: 用户查询日志,以及对应的相关文档。
  • 方法: 训练一个 Query Rewriting 模型,将原始查询改写成更清晰、更具表达力的形式。
  • 优点: 提升用户意图的识别能力,增强召回的鲁棒性。
  • 缺点: 需要收集和处理用户查询日志。

下面我们分别对每个阶段进行详细讲解,并给出相应的代码示例。

阶段一:通用预训练模型微调

我们选择 all-mpnet-base-v2 作为预训练模型,并使用 sentence-transformers 库进行微调。

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# 1. 加载预训练模型
model = SentenceTransformer('all-mpnet-base-v2')

# 2. 准备训练数据 (示例数据,需要替换成实际的通用领域数据)
train_examples = [
    InputExample(texts=["What is machine learning?", "Machine learning is a subset of AI."]),
    InputExample(texts=["How does RAG work?", "RAG retrieves relevant documents and generates answers."])
]

# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# 4. 定义损失函数
train_loss = losses.CosineSimilarityLoss(model)

# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)

# 6. 保存模型
model.save('models/stage1_general_domain')

代码解释:

  • SentenceTransformer('all-mpnet-base-v2'): 加载预训练的 Sentence Transformer 模型。
  • InputExample(texts=["text1", "text2"]): 定义训练样本,text1text2 是语义相似的文本。
  • CosineSimilarityLoss(model): 使用 Cosine Similarity Loss 作为损失函数,目标是使相似文本的 Embedding 向量的余弦相似度更高。
  • model.fit(...): 开始模型训练,epochs 定义训练轮数,warmup_steps 定义学习率预热步数。
  • model.save(...): 保存训练好的模型。

阶段二:领域知识增强

在阶段一的基础上,我们使用企业私有语料进行微调。

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import pandas as pd

# 1. 加载阶段一的模型
model = SentenceTransformer('models/stage1_general_domain')

# 2. 准备训练数据 (从 CSV 文件加载,假设 CSV 文件包含 "text1" 和 "text2" 两列,表示相似的文本)
df = pd.read_csv('data/domain_data.csv') # 替换成你的领域数据文件
train_examples = []
for index, row in df.iterrows():
    train_examples.append(InputExample(texts=[row['text1'], row['text2']]))

# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# 4. 定义损失函数
train_loss = losses.CosineSimilarityLoss(model)

# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=500)

# 6. 保存模型
model.save('models/stage2_domain_specific')

代码解释:

  • pd.read_csv('data/domain_data.csv'): 从 CSV 文件加载领域数据。
  • 假设 CSV 文件包含 text1text2 两列,表示相似的文本。你需要根据实际的数据格式进行调整。
  • 其他部分与阶段一类似,只是使用了领域数据进行训练。

阶段三:对比学习优化

我们使用 Margin Ranking Loss 进行对比学习优化。

from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
import random

# 1. 加载阶段二的模型
model = SentenceTransformer('models/stage2_domain_specific')

# 2. 准备训练数据 (构建正负样本对)
# 假设我们有一个文档列表 documents = ["doc1", "doc2", "doc3", ...]
documents = ["This is document 1 about topic A.", "Document 2 also discusses topic A.", "Document 3 is about unrelated topic B."] # 替换成你的文档列表

train_examples = []
for i in range(len(documents)):
    # 正样本:同一个文档的不同片段 (这里简化为同一个文档)
    train_examples.append(InputExample(texts=[documents[i], documents[i]], label=1.0))

    # 负样本:随机选择一个不同的文档
    negative_index = random.choice([j for j in range(len(documents)) if j != i])
    train_examples.append(InputExample(texts=[documents[i], documents[negative_index]], label=0.0))

# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# 4. 定义损失函数 (MarginRankingLoss)
train_loss = losses.MarginRankingLoss(model=model, margin=0.5)

# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=2, warmup_steps=200)

# 6. 保存模型
model.save('models/stage3_contrastive_learning')

代码解释:

  • documents: 你的文档列表。
  • MarginRankingLoss(model=model, margin=0.5): 使用 Margin Ranking Loss 作为损失函数。margin 定义正负样本之间的距离阈值。
  • label=1.0 表示正样本对,label=0.0 表示负样本对。
  • 负样本的选择策略非常重要,需要根据实际情况进行调整。例如,可以使用 BM25 或 TF-IDF 等方法选择更具挑战性的负样本。

阶段四:查询优化

我们训练一个简单的 Query Rewriting 模型,使用 T5 模型进行微调。

from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import Dataset, DataLoader
import torch

# 1. 定义数据集
class QueryRewriteDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data  # List of tuples: (original_query, rewritten_query)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        original_query, rewritten_query = self.data[idx]
        source_text = "rewrite: " + original_query
        target_text = rewritten_query

        source = self.tokenizer.batch_encode_plus(
            [source_text],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt",
        )
        target = self.tokenizer.batch_encode_plus(
            [target_text],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt",
        )

        return {
            'source_ids': source['input_ids'].squeeze(),
            'source_mask': source['attention_mask'].squeeze(),
            'target_ids': target['input_ids'].squeeze(),
            'target_mask': target['attention_mask'].squeeze()
        }

# 2. 准备数据 (示例数据,需要替换成实际的用户查询日志)
training_data = [
    ("What is the return policy?", "What is your company's return policy and procedure?"),
    ("How to reset password?", "What are the steps to reset my account password?")
]

# 3. 加载 T5 模型和 Tokenizer
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 4. 创建数据集和 DataLoader
train_dataset = QueryRewriteDataset(training_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 5. 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)

# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(3):
    for batch in train_dataloader:
        source_ids = batch['source_ids'].to(device)
        source_mask = batch['source_mask'].to(device)
        target_ids = batch['target_ids'].to(device)
        target_mask = batch['target_mask'].to(device)

        outputs = model(
            input_ids=source_ids,
            attention_mask=source_mask,
            labels=target_ids
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# 7. 保存模型
model.save_pretrained("models/stage4_query_rewriting")
tokenizer.save_pretrained("models/stage4_query_rewriting")

# 8. 使用 Query Rewriting 模型
def rewrite_query(query, model, tokenizer, device="cpu"):
    model.to(device)
    model.eval()
    source_text = "rewrite: " + query
    source = tokenizer.batch_encode_plus(
        [source_text],
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors="pt",
    )

    source_ids = source['input_ids'].to(device)
    source_mask = source['attention_mask'].to(device)

    with torch.no_grad():
        output = model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )

    rewritten_query = tokenizer.decode(output[0], skip_special_tokens=True)
    return rewritten_query

# 示例使用
rewritten_model = T5ForConditionalGeneration.from_pretrained("models/stage4_query_rewriting")
rewritten_tokenizer = T5Tokenizer.from_pretrained("models/stage4_query_rewriting")
example_query = "broken printer"
rewritten_query = rewrite_query(example_query, rewritten_model, rewritten_tokenizer)
print(f"Original Query: {example_query}")
print(f"Rewritten Query: {rewritten_query}")

代码解释:

  • T5Tokenizer.from_pretrained('t5-small'): 加载 T5 模型和 Tokenizer。
  • QueryRewriteDataset: 定义数据集,将原始查询和改写后的查询作为训练样本。
  • rewrite: 前缀:告诉模型执行改写任务。
  • model.generate(...): 使用 T5 模型生成改写后的查询。
  • num_beams=4: 使用 Beam Search,提高生成质量。
  • rewrite_query 函数:用于将原始查询改写成更清晰、更具表达力的形式。
  • 需要根据实际情况调整 T5 模型的大小和训练参数。

RAG 召回效果评估

训练完成后,我们需要评估 RAG 系统的召回效果。常用的评估指标包括:

  • Recall@K: 在 Top K 个检索结果中,有多少个是相关的。
  • Precision@K: 在 Top K 个检索结果中,有多少个是正确的。
  • NDCG@K: 考虑检索结果的排序,给予更相关的结果更高的权重。

我们可以使用以下代码进行评估:

from sentence_transformers import SentenceTransformer
import numpy as np

# 1. 加载 Embedding 模型 (选择经过多阶段训练后的模型)
model = SentenceTransformer('models/stage3_contrastive_learning') # 或者 stage4,如果使用了 query rewrite

# 2. 定义评估数据集 (包含查询和对应的相关文档)
queries = ["What is the company's holiday policy?", "How do I submit an expense report?"]
relevant_documents = [["Document about holiday policy"], ["Document about expense reports submission"]]
all_documents = ["Document about holiday policy", "Document about expense reports submission", "Some other unrelated document"] # 所有的文档

# 3. 计算 Embedding 向量
query_embeddings = model.encode(queries)
document_embeddings = model.encode(all_documents)

# 4. 定义召回函数
def retrieve_top_k(query_embedding, document_embeddings, k=5):
    """
    Retrieves the top k most similar documents based on cosine similarity.
    """
    scores = np.dot(query_embedding, document_embeddings.T)
    ranked_indices = np.argsort(scores)[::-1]  # Sort in descending order
    return ranked_indices[:k]

# 5. 评估 Recall@K
def calculate_recall_at_k(queries, relevant_documents, query_embeddings, document_embeddings, all_documents, k=5):
    """
    Calculates Recall@K for a set of queries.
    """
    total_recall = 0
    for i, query in enumerate(queries):
        ranked_indices = retrieve_top_k(query_embeddings[i], document_embeddings, k)
        retrieved_documents = [all_documents[index] for index in ranked_indices]

        # Check if any of the relevant documents are in the retrieved documents
        recall = 0
        for relevant_doc in relevant_documents[i]:
            if relevant_doc in retrieved_documents:
                recall = 1
                break  # Found at least one relevant document

        total_recall += recall

    return total_recall / len(queries)

# 6. 计算 Recall@K
recall_at_5 = calculate_recall_at_k(queries, relevant_documents, query_embeddings, document_embeddings, all_documents, k=5)
print(f"Recall@5: {recall_at_5}")

代码解释:

  • retrieve_top_k 函数:根据余弦相似度,检索 Top K 个最相关的文档。
  • calculate_recall_at_k 函数:计算 Recall@K 指标。
  • 需要准备包含查询和对应相关文档的评估数据集。

优化策略与注意事项

  • 数据质量: 训练数据的质量直接影响模型的效果,需要进行清洗和标注。
  • 负样本选择: 对比学习中,负样本的选择策略非常重要,需要根据实际情况进行调整。
  • 超参数调优: 需要仔细调整模型训练的超参数,例如学习率、Batch Size、Epochs 等。
  • 模型选择: 可以选择不同的预训练模型,例如 BERT、RoBERTa 等,进行尝试。
  • Prompt Engineering: 除了优化 Embedding 模型,还可以通过优化 Prompt 来提升 RAG 系统的效果。
  • 持续迭代: RAG 系统的优化是一个持续迭代的过程,需要不断收集数据、训练模型、评估效果,并进行改进。
  • 冷启动问题: 如果企业数据量较少,可以考虑使用数据增强技术,或者迁移学习的方法。
  • 模型监控: 需要对 RAG 系统进行监控,及时发现和解决问题。

总结:多阶段训练,提升RAG质量

我们探讨了如何通过多阶段 Embedding 训练,提升企业私有语料的 RAG 召回质量。从通用领域的微调,到领域知识增强,再到对比学习优化和查询优化,每个阶段都有其特定的目的和方法。通过合理的策略和持续的迭代,我们可以构建一个高效、准确的 RAG 系统,为企业提供更好的知识服务。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注