企业私有语料 RAG 召回优化:多阶段 Embedding 训练实战
大家好,今天我们来聊聊如何利用多阶段 Embedding 训练,提升企业私有语料的 RAG (Retrieval Augmented Generation) 召回质量。RAG 架构的核心在于有效召回与用户查询相关的上下文,而 Embedding 的质量直接决定了召回的准确性。针对企业私有语料,我们往往需要针对特定领域进行 Embedding 训练,才能获得更好的效果。
RAG 系统与 Embedding 的重要性
在深入多阶段训练之前,我们先回顾一下 RAG 系统和 Embedding 在其中的作用。
RAG 系统的核心流程如下:
- Query Embedding: 将用户查询转换为 Embedding 向量。
- Retrieval: 基于 Query Embedding,在知识库中检索最相关的文档片段。
- Augmentation: 将检索到的文档片段与原始查询拼接,形成增强的 Prompt。
- Generation: 将增强的 Prompt 输入 LLM,生成最终答案。
Embedding 模型负责将文本数据(查询、文档)映射到高维向量空间,使得语义相似的文本在向量空间中距离更近。常见的 Embedding 模型包括:
- Sentence Transformers: 预训练的 Sentence Embedding 模型,如
all-mpnet-base-v2。 - Text Embedding API: OpenAI 的
text-embedding-ada-002等。 - 自训练 Embedding 模型: 基于企业私有语料训练的定制化模型。
选择合适的 Embedding 模型,并针对特定领域进行优化,是提升 RAG 召回效果的关键。
多阶段 Embedding 训练策略
针对企业私有语料,我们可以采用多阶段 Embedding 训练策略,逐步提升模型性能。以下是一种常用的训练流程:
阶段一:通用预训练模型微调 (Fine-tuning on General Domain Data)
- 目的: 利用通用领域的语料,使模型具备一定的文本理解能力,并适应下游任务。
- 数据: 选择与目标领域相关的公开数据集,例如 Wikipedia、Stack Overflow 等。
- 方法: 基于预训练的 Sentence Transformers 模型,进行微调。
- 优点: 快速启动,利用现有资源。
- 缺点: 对特定领域的适应性有限。
阶段二:领域知识增强 (Domain-Specific Fine-tuning)
- 目的: 使模型更好地理解和表示特定领域的文本。
- 数据: 企业私有语料,包括文档、FAQ、用户问答等。
- 方法: 继续在阶段一的模型基础上,使用领域数据进行微调。
- 优点: 显著提升领域内的召回效果。
- 缺点: 需要准备充足的领域数据。
阶段三:对比学习优化 (Contrastive Learning)
- 目的: 进一步优化 Embedding 空间,使得相似文本的距离更近,不相似文本的距离更远。
- 数据: 构建正负样本对,例如:
- 正样本: 同一个文档的不同片段,或者语义相关的文档。
- 负样本: 随机抽取的文档,或者语义不相关的文档。
- 方法: 使用对比学习损失函数,如 Margin Ranking Loss、Multiple Negatives Ranking Loss。
- 优点: 提高 Embedding 的区分度,增强召回的准确性。
- 缺点: 需要设计合适的负样本策略。
阶段四:查询优化 (Query Optimization)
- 目的: 针对用户查询进行优化,提升查询 Embedding 的质量。
- 数据: 用户查询日志,以及对应的相关文档。
- 方法: 训练一个 Query Rewriting 模型,将原始查询改写成更清晰、更具表达力的形式。
- 优点: 提升用户意图的识别能力,增强召回的鲁棒性。
- 缺点: 需要收集和处理用户查询日志。
下面我们分别对每个阶段进行详细讲解,并给出相应的代码示例。
阶段一:通用预训练模型微调
我们选择 all-mpnet-base-v2 作为预训练模型,并使用 sentence-transformers 库进行微调。
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
# 1. 加载预训练模型
model = SentenceTransformer('all-mpnet-base-v2')
# 2. 准备训练数据 (示例数据,需要替换成实际的通用领域数据)
train_examples = [
InputExample(texts=["What is machine learning?", "Machine learning is a subset of AI."]),
InputExample(texts=["How does RAG work?", "RAG retrieves relevant documents and generates answers."])
]
# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
# 4. 定义损失函数
train_loss = losses.CosineSimilarityLoss(model)
# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)
# 6. 保存模型
model.save('models/stage1_general_domain')
代码解释:
SentenceTransformer('all-mpnet-base-v2'): 加载预训练的 Sentence Transformer 模型。InputExample(texts=["text1", "text2"]): 定义训练样本,text1和text2是语义相似的文本。CosineSimilarityLoss(model): 使用 Cosine Similarity Loss 作为损失函数,目标是使相似文本的 Embedding 向量的余弦相似度更高。model.fit(...): 开始模型训练,epochs定义训练轮数,warmup_steps定义学习率预热步数。model.save(...): 保存训练好的模型。
阶段二:领域知识增强
在阶段一的基础上,我们使用企业私有语料进行微调。
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import pandas as pd
# 1. 加载阶段一的模型
model = SentenceTransformer('models/stage1_general_domain')
# 2. 准备训练数据 (从 CSV 文件加载,假设 CSV 文件包含 "text1" 和 "text2" 两列,表示相似的文本)
df = pd.read_csv('data/domain_data.csv') # 替换成你的领域数据文件
train_examples = []
for index, row in df.iterrows():
train_examples.append(InputExample(texts=[row['text1'], row['text2']]))
# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
# 4. 定义损失函数
train_loss = losses.CosineSimilarityLoss(model)
# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=3, warmup_steps=500)
# 6. 保存模型
model.save('models/stage2_domain_specific')
代码解释:
pd.read_csv('data/domain_data.csv'): 从 CSV 文件加载领域数据。- 假设 CSV 文件包含
text1和text2两列,表示相似的文本。你需要根据实际的数据格式进行调整。 - 其他部分与阶段一类似,只是使用了领域数据进行训练。
阶段三:对比学习优化
我们使用 Margin Ranking Loss 进行对比学习优化。
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
import random
# 1. 加载阶段二的模型
model = SentenceTransformer('models/stage2_domain_specific')
# 2. 准备训练数据 (构建正负样本对)
# 假设我们有一个文档列表 documents = ["doc1", "doc2", "doc3", ...]
documents = ["This is document 1 about topic A.", "Document 2 also discusses topic A.", "Document 3 is about unrelated topic B."] # 替换成你的文档列表
train_examples = []
for i in range(len(documents)):
# 正样本:同一个文档的不同片段 (这里简化为同一个文档)
train_examples.append(InputExample(texts=[documents[i], documents[i]], label=1.0))
# 负样本:随机选择一个不同的文档
negative_index = random.choice([j for j in range(len(documents)) if j != i])
train_examples.append(InputExample(texts=[documents[i], documents[negative_index]], label=0.0))
# 3. 定义 DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
# 4. 定义损失函数 (MarginRankingLoss)
train_loss = losses.MarginRankingLoss(model=model, margin=0.5)
# 5. 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=2, warmup_steps=200)
# 6. 保存模型
model.save('models/stage3_contrastive_learning')
代码解释:
documents: 你的文档列表。MarginRankingLoss(model=model, margin=0.5): 使用 Margin Ranking Loss 作为损失函数。margin定义正负样本之间的距离阈值。label=1.0表示正样本对,label=0.0表示负样本对。- 负样本的选择策略非常重要,需要根据实际情况进行调整。例如,可以使用 BM25 或 TF-IDF 等方法选择更具挑战性的负样本。
阶段四:查询优化
我们训练一个简单的 Query Rewriting 模型,使用 T5 模型进行微调。
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import Dataset, DataLoader
import torch
# 1. 定义数据集
class QueryRewriteDataset(Dataset):
def __init__(self, data, tokenizer, max_length=128):
self.data = data # List of tuples: (original_query, rewritten_query)
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
original_query, rewritten_query = self.data[idx]
source_text = "rewrite: " + original_query
target_text = rewritten_query
source = self.tokenizer.batch_encode_plus(
[source_text],
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors="pt",
)
target = self.tokenizer.batch_encode_plus(
[target_text],
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors="pt",
)
return {
'source_ids': source['input_ids'].squeeze(),
'source_mask': source['attention_mask'].squeeze(),
'target_ids': target['input_ids'].squeeze(),
'target_mask': target['attention_mask'].squeeze()
}
# 2. 准备数据 (示例数据,需要替换成实际的用户查询日志)
training_data = [
("What is the return policy?", "What is your company's return policy and procedure?"),
("How to reset password?", "What are the steps to reset my account password?")
]
# 3. 加载 T5 模型和 Tokenizer
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# 4. 创建数据集和 DataLoader
train_dataset = QueryRewriteDataset(training_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# 5. 定义优化器
optimizer = AdamW(model.parameters(), lr=5e-5)
# 6. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()
for epoch in range(3):
for batch in train_dataloader:
source_ids = batch['source_ids'].to(device)
source_mask = batch['source_mask'].to(device)
target_ids = batch['target_ids'].to(device)
target_mask = batch['target_mask'].to(device)
outputs = model(
input_ids=source_ids,
attention_mask=source_mask,
labels=target_ids
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
# 7. 保存模型
model.save_pretrained("models/stage4_query_rewriting")
tokenizer.save_pretrained("models/stage4_query_rewriting")
# 8. 使用 Query Rewriting 模型
def rewrite_query(query, model, tokenizer, device="cpu"):
model.to(device)
model.eval()
source_text = "rewrite: " + query
source = tokenizer.batch_encode_plus(
[source_text],
max_length=128,
padding='max_length',
truncation=True,
return_tensors="pt",
)
source_ids = source['input_ids'].to(device)
source_mask = source['attention_mask'].to(device)
with torch.no_grad():
output = model.generate(
input_ids=source_ids,
attention_mask=source_mask,
max_length=128,
num_beams=4,
early_stopping=True
)
rewritten_query = tokenizer.decode(output[0], skip_special_tokens=True)
return rewritten_query
# 示例使用
rewritten_model = T5ForConditionalGeneration.from_pretrained("models/stage4_query_rewriting")
rewritten_tokenizer = T5Tokenizer.from_pretrained("models/stage4_query_rewriting")
example_query = "broken printer"
rewritten_query = rewrite_query(example_query, rewritten_model, rewritten_tokenizer)
print(f"Original Query: {example_query}")
print(f"Rewritten Query: {rewritten_query}")
代码解释:
T5Tokenizer.from_pretrained('t5-small'): 加载 T5 模型和 Tokenizer。QueryRewriteDataset: 定义数据集,将原始查询和改写后的查询作为训练样本。rewrite:前缀:告诉模型执行改写任务。model.generate(...): 使用 T5 模型生成改写后的查询。num_beams=4: 使用 Beam Search,提高生成质量。rewrite_query函数:用于将原始查询改写成更清晰、更具表达力的形式。- 需要根据实际情况调整 T5 模型的大小和训练参数。
RAG 召回效果评估
训练完成后,我们需要评估 RAG 系统的召回效果。常用的评估指标包括:
- Recall@K: 在 Top K 个检索结果中,有多少个是相关的。
- Precision@K: 在 Top K 个检索结果中,有多少个是正确的。
- NDCG@K: 考虑检索结果的排序,给予更相关的结果更高的权重。
我们可以使用以下代码进行评估:
from sentence_transformers import SentenceTransformer
import numpy as np
# 1. 加载 Embedding 模型 (选择经过多阶段训练后的模型)
model = SentenceTransformer('models/stage3_contrastive_learning') # 或者 stage4,如果使用了 query rewrite
# 2. 定义评估数据集 (包含查询和对应的相关文档)
queries = ["What is the company's holiday policy?", "How do I submit an expense report?"]
relevant_documents = [["Document about holiday policy"], ["Document about expense reports submission"]]
all_documents = ["Document about holiday policy", "Document about expense reports submission", "Some other unrelated document"] # 所有的文档
# 3. 计算 Embedding 向量
query_embeddings = model.encode(queries)
document_embeddings = model.encode(all_documents)
# 4. 定义召回函数
def retrieve_top_k(query_embedding, document_embeddings, k=5):
"""
Retrieves the top k most similar documents based on cosine similarity.
"""
scores = np.dot(query_embedding, document_embeddings.T)
ranked_indices = np.argsort(scores)[::-1] # Sort in descending order
return ranked_indices[:k]
# 5. 评估 Recall@K
def calculate_recall_at_k(queries, relevant_documents, query_embeddings, document_embeddings, all_documents, k=5):
"""
Calculates Recall@K for a set of queries.
"""
total_recall = 0
for i, query in enumerate(queries):
ranked_indices = retrieve_top_k(query_embeddings[i], document_embeddings, k)
retrieved_documents = [all_documents[index] for index in ranked_indices]
# Check if any of the relevant documents are in the retrieved documents
recall = 0
for relevant_doc in relevant_documents[i]:
if relevant_doc in retrieved_documents:
recall = 1
break # Found at least one relevant document
total_recall += recall
return total_recall / len(queries)
# 6. 计算 Recall@K
recall_at_5 = calculate_recall_at_k(queries, relevant_documents, query_embeddings, document_embeddings, all_documents, k=5)
print(f"Recall@5: {recall_at_5}")
代码解释:
retrieve_top_k函数:根据余弦相似度,检索 Top K 个最相关的文档。calculate_recall_at_k函数:计算 Recall@K 指标。- 需要准备包含查询和对应相关文档的评估数据集。
优化策略与注意事项
- 数据质量: 训练数据的质量直接影响模型的效果,需要进行清洗和标注。
- 负样本选择: 对比学习中,负样本的选择策略非常重要,需要根据实际情况进行调整。
- 超参数调优: 需要仔细调整模型训练的超参数,例如学习率、Batch Size、Epochs 等。
- 模型选择: 可以选择不同的预训练模型,例如 BERT、RoBERTa 等,进行尝试。
- Prompt Engineering: 除了优化 Embedding 模型,还可以通过优化 Prompt 来提升 RAG 系统的效果。
- 持续迭代: RAG 系统的优化是一个持续迭代的过程,需要不断收集数据、训练模型、评估效果,并进行改进。
- 冷启动问题: 如果企业数据量较少,可以考虑使用数据增强技术,或者迁移学习的方法。
- 模型监控: 需要对 RAG 系统进行监控,及时发现和解决问题。
总结:多阶段训练,提升RAG质量
我们探讨了如何通过多阶段 Embedding 训练,提升企业私有语料的 RAG 召回质量。从通用领域的微调,到领域知识增强,再到对比学习优化和查询优化,每个阶段都有其特定的目的和方法。通过合理的策略和持续的迭代,我们可以构建一个高效、准确的 RAG 系统,为企业提供更好的知识服务。