RAG 应用中 Embedding 模型升级引发整体召回波动的工程化回滚机制
大家好,今天我们来深入探讨一个在实际 RAG (Retrieval Augmented Generation) 应用中经常遇到的问题:Embedding 模型升级后,可能引发整体召回波动,以及如何设计一套工程化的回滚机制来应对这种风险。
Embedding 模型在 RAG 应用中扮演着至关重要的角色,它负责将文本数据转化为向量表示,从而实现语义层面的相似度搜索。升级 Embedding 模型通常是为了提升向量的表达能力,进而提高召回的准确性和相关性。然而,在实际操作中,新模型可能会改变向量空间的分布,导致与原有索引的兼容性问题,最终造成召回结果的质量下降。
问题根源:向量空间偏移
Embedding 模型升级导致召回波动的根本原因在于 向量空间偏移。不同的 Embedding 模型,即使训练数据相似,其输出的向量在空间中的分布也可能存在显著差异。这种差异体现在以下几个方面:
- 向量维度: 新旧模型的向量维度可能不同。
- 向量尺度: 新旧模型的向量长度范围可能不同。
- 向量方向: 语义相似的文本,在新旧模型中对应的向量方向可能不同。
这些差异会导致使用旧模型构建的索引,在查询时,无法准确匹配新模型生成的查询向量,从而影响召回结果。
工程化回滚机制的设计原则
为了应对 Embedding 模型升级带来的风险,我们需要设计一套完善的回滚机制。该机制应遵循以下原则:
- 可观测性: 能够实时监控召回效果,及时发现波动。
- 可控性: 能够快速切换回旧模型,恢复原有召回效果。
- 自动化: 尽可能减少人工干预,降低操作风险。
- 兼容性: 确保新旧模型能够共存,平滑过渡。
回滚机制的实现方案
以下是一个基于 Python 的示例,展示了如何实现一个简单的回滚机制。该方案主要包含以下几个模块:
- 模型管理模块: 负责加载和管理 Embedding 模型。
- 索引管理模块: 负责构建和管理向量索引。
- 监控模块: 负责监控召回效果。
- 回滚模块: 负责切换回旧模型。
1. 模型管理模块
import torch
from transformers import AutoModel, AutoTokenizer
class EmbeddingModelManager:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.model = AutoModel.from_pretrained(self.model_name_or_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def encode(self, text):
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling. In this case, mean pooling.
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
mean_pooled = sum_embeddings / sum_mask
return mean_pooled.cpu().numpy()[0]
def get_model_info(self):
return {
"model_name": self.model_name_or_path,
"embedding_dimension": self.model.config.hidden_size
}
# Example usage:
# model_manager = EmbeddingModelManager("sentence-transformers/all-mpnet-base-v2")
# embedding = model_manager.encode("This is an example sentence.")
# model_info = model_manager.get_model_info()
# print(f"Embedding dimension: {model_info['embedding_dimension']}")
这个模块负责加载指定的 Embedding 模型,并提供 encode 方法用于生成文本向量。get_model_info 方法返回模型的相关信息,例如模型名称和向量维度。我们使用 sentence-transformers 库来简化模型加载和向量生成的流程。
2. 索引管理模块
import faiss
import numpy as np
class IndexManager:
def __init__(self, embedding_dimension, index_path=None):
self.embedding_dimension = embedding_dimension
self.index = faiss.IndexFlatIP(embedding_dimension) # 使用内积作为相似度度量
self.index_path = index_path
if index_path and os.path.exists(index_path):
self.load_index(index_path)
def add_embeddings(self, embeddings):
embeddings = np.array(embeddings).astype('float32')
self.index.add(embeddings)
def search(self, query_embedding, top_k=10):
query_embedding = np.array([query_embedding]).astype('float32')
D, I = self.index.search(query_embedding, top_k) # D: distances, I: indices
return D[0].tolist(), I[0].tolist()
def save_index(self, index_path):
faiss.write_index(self.index, index_path)
def load_index(self, index_path):
self.index = faiss.read_index(index_path)
self.index_path = index_path
def get_index_info(self):
return {
"index_type": type(self.index).__name__,
"indexed_vectors": self.index.ntotal
}
# Example usage:
# index_manager = IndexManager(embedding_dimension=768)
# index_manager.add_embeddings([embedding1, embedding2, embedding3])
# distances, indices = index_manager.search(query_embedding, top_k=5)
# index_manager.save_index("my_index.faiss")
这个模块负责构建和管理向量索引。我们使用 faiss 库来实现高效的相似度搜索。add_embeddings 方法用于将文本向量添加到索引中,search 方法用于执行相似度搜索,save_index 和 load_index 方法用于保存和加载索引。
3. 监控模块
import time
class Monitor:
def __init__(self, index_manager, model_manager, ground_truth_data):
self.index_manager = index_manager
self.model_manager = model_manager
self.ground_truth_data = ground_truth_data # 格式: {query: [relevant_doc_ids]}
def evaluate(self, top_k=10):
correct_retrievals = 0
total_queries = 0
for query, relevant_doc_ids in self.ground_truth_data.items():
query_embedding = self.model_manager.encode(query)
_, retrieved_doc_ids = self.index_manager.search(query_embedding, top_k=top_k)
retrieved_doc_ids = set(retrieved_doc_ids)
relevant_doc_ids = set(relevant_doc_ids)
if relevant_doc_ids.intersection(retrieved_doc_ids):
correct_retrievals += 1
total_queries += 1
recall = correct_retrievals / total_queries if total_queries > 0 else 0
return recall
def monitor_recall(self, interval=60, duration=3600, threshold=0.9):
start_time = time.time()
while time.time() - start_time < duration:
recall = self.evaluate()
print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}, Recall: {recall:.4f}")
if recall < threshold:
print("Recall dropped below threshold!")
return False # Indicate a problem
time.sleep(interval)
return True # Indicate success
# Example usage:
# ground_truth_data = {
# "query1": [0, 1, 2],
# "query2": [3, 4, 5],
# "query3": [6, 7, 8]
# }
# monitor = Monitor(index_manager, model_manager, ground_truth_data)
# success = monitor.monitor_recall(interval=60, duration=300, threshold=0.8)
# if not success:
# # Trigger rollback mechanism
# print("Rollback triggered!")
这个模块负责监控召回效果。evaluate 方法计算召回率,monitor_recall 方法定期评估召回率,并在召回率低于阈值时发出警报。 我们使用一个 ground_truth_data 字典来存储查询和相关文档的对应关系。
4. 回滚模块
import os
import shutil
class RollbackManager:
def __init__(self, model_manager, index_manager, old_model_name, old_index_path, new_model_name, new_index_path):
self.model_manager = model_manager
self.index_manager = index_manager
self.old_model_name = old_model_name
self.old_index_path = old_index_path
self.new_model_name = new_model_name
self.new_index_path = new_index_path
def rollback(self):
print("Starting rollback...")
# 1. Switch Model
print(f"Switching from model {self.new_model_name} to {self.old_model_name}")
self.model_manager.model = AutoModel.from_pretrained(self.old_model_name)
self.model_manager.tokenizer = AutoTokenizer.from_pretrained(self.old_model_name)
self.model_manager.model.to(self.model_manager.device)
self.model_manager.model.eval()
# 2. Switch Index
print(f"Switching from index {self.new_index_path} to {self.old_index_path}")
self.index_manager.load_index(self.old_index_path)
print("Rollback completed.")
# Example Usage:
# rollback_manager = RollbackManager(
# model_manager=model_manager,
# index_manager=index_manager,
# old_model_name="sentence-transformers/all-mpnet-base-v2",
# old_index_path="old_index.faiss",
# new_model_name="sentence-transformers/all-MiniLM-L6-v2",
# new_index_path="my_index.faiss"
# )
# rollback_manager.rollback()
这个模块负责执行回滚操作。rollback 方法将模型和索引切换回旧版本。在执行回滚之前,我们需要确保旧模型和旧索引已经保存好。
完整的示例代码
import torch
from transformers import AutoModel, AutoTokenizer
import faiss
import numpy as np
import time
import os
import shutil
# 模型管理模块
class EmbeddingModelManager:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
self.model = AutoModel.from_pretrained(self.model_name_or_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
def encode(self, text):
encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling. In this case, mean pooling.
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
mean_pooled = sum_embeddings / sum_mask
return mean_pooled.cpu().numpy()[0]
def get_model_info(self):
return {
"model_name": self.model_name_or_path,
"embedding_dimension": self.model.config.hidden_size
}
# 索引管理模块
class IndexManager:
def __init__(self, embedding_dimension, index_path=None):
self.embedding_dimension = embedding_dimension
self.index = faiss.IndexFlatIP(embedding_dimension) # 使用内积作为相似度度量
self.index_path = index_path
if index_path and os.path.exists(index_path):
self.load_index(index_path)
def add_embeddings(self, embeddings):
embeddings = np.array(embeddings).astype('float32')
self.index.add(embeddings)
def search(self, query_embedding, top_k=10):
query_embedding = np.array([query_embedding]).astype('float32')
D, I = self.index.search(query_embedding, top_k) # D: distances, I: indices
return D[0].tolist(), I[0].tolist()
def save_index(self, index_path):
faiss.write_index(self.index, index_path)
def load_index(self, index_path):
self.index = faiss.read_index(index_path)
self.index_path = index_path
def get_index_info(self):
return {
"index_type": type(type(self.index).__name__),
"indexed_vectors": self.index.ntotal
}
# 监控模块
class Monitor:
def __init__(self, index_manager, model_manager, ground_truth_data):
self.index_manager = index_manager
self.model_manager = model_manager
self.ground_truth_data = ground_truth_data # 格式: {query: [relevant_doc_ids]}
def evaluate(self, top_k=10):
correct_retrievals = 0
total_queries = 0
for query, relevant_doc_ids in self.ground_truth_data.items():
query_embedding = self.model_manager.encode(query)
_, retrieved_doc_ids = self.index_manager.search(query_embedding, top_k=top_k)
retrieved_doc_ids = set(retrieved_doc_ids)
relevant_doc_ids = set(relevant_doc_ids)
if relevant_doc_ids.intersection(retrieved_doc_ids):
correct_retrievals += 1
total_queries += 1
recall = correct_retrievals / total_queries if total_queries > 0 else 0
return recall
def monitor_recall(self, interval=60, duration=3600, threshold=0.9):
start_time = time.time()
while time.time() - start_time < duration:
recall = self.evaluate()
print(f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}, Recall: {recall:.4f}")
if recall < threshold:
print("Recall dropped below threshold!")
return False # Indicate a problem
time.sleep(interval)
return True # Indicate success
# 回滚模块
class RollbackManager:
def __init__(self, model_manager, index_manager, old_model_name, old_index_path, new_model_name, new_index_path):
self.model_manager = model_manager
self.index_manager = index_manager
self.old_model_name = old_model_name
self.old_index_path = old_index_path
self.new_model_name = new_model_name
self.new_index_path = new_index_path
def rollback(self):
print("Starting rollback...")
# 1. Switch Model
print(f"Switching from model {self.new_model_name} to {self.old_model_name}")
self.model_manager.model = AutoModel.from_pretrained(self.old_model_name)
self.model_manager.tokenizer = AutoTokenizer.from_pretrained(self.old_model_name)
self.model_manager.model.to(self.model_manager.device)
self.model_manager.model.eval()
# 2. Switch Index
print(f"Switching from index {self.new_index_path} to {self.old_index_path}")
self.index_manager.load_index(self.old_index_path)
print("Rollback completed.")
if __name__ == '__main__':
# 1. 初始化旧模型和索引
old_model_name = "sentence-transformers/all-mpnet-base-v2"
old_index_path = "old_index.faiss"
old_model_manager = EmbeddingModelManager(old_model_name)
old_index_manager = IndexManager(embedding_dimension=old_model_manager.get_model_info()["embedding_dimension"])
# 2. 创建一些示例数据并构建旧索引
documents = [
"This is the first document about cats.",
"This is the second document about dogs.",
"This is the third document about birds.",
"The fourth document talks about cats and dogs.",
"The fifth document is about different types of birds.",
"The sixth document is a general text."
]
embeddings = [old_model_manager.encode(doc) for doc in documents]
old_index_manager.add_embeddings(embeddings)
old_index_manager.save_index(old_index_path)
# 3. 初始化新模型和索引
new_model_name = "sentence-transformers/all-MiniLM-L6-v2"
new_index_path = "new_index.faiss"
new_model_manager = EmbeddingModelManager(new_model_name)
new_index_manager = IndexManager(embedding_dimension=new_model_manager.get_model_info()["embedding_dimension"])
# 4. 使用新模型重新索引数据
new_embeddings = [new_model_manager.encode(doc) for doc in documents]
new_index_manager.add_embeddings(new_embeddings)
new_index_manager.save_index(new_index_path)
# 5. 定义 ground truth 数据
ground_truth_data = {
"what is about cats": [0, 3],
"tell me about dogs": [1, 3],
"something about birds": [2, 4]
}
# 6. 初始化监控模块
monitor = Monitor(new_index_manager, new_model_manager, ground_truth_data)
# 7. 模拟监控过程
print("Starting monitoring with the new model...")
success = monitor.monitor_recall(interval=5, duration=20, threshold=0.7) # 缩短监控时间
# 8. 如果召回率下降,则触发回滚
if not success:
print("Rollback triggered!")
rollback_manager = RollbackManager(
model_manager=new_model_manager, # 注意:这里传入的是 new_model_manager,因为回滚需要修改它
index_manager=new_index_manager, # 注意:这里传入的是 new_index_manager,因为回滚需要修改它
old_model_name=old_model_name,
old_index_path=old_index_path,
new_model_name=new_model_name,
new_index_path=new_index_path
)
rollback_manager.rollback()
# 9. 回滚后,使用旧模型再次进行评估
monitor = Monitor(new_index_manager, new_model_manager, ground_truth_data)
print("Evaluating after rollback...")
recall_after_rollback = monitor.evaluate()
print(f"Recall after rollback: {recall_after_rollback:.4f}")
这个示例代码演示了如何使用上述模块来实现一个简单的回滚机制。请注意,这只是一个示例,实际应用中需要根据具体情况进行调整和优化。
其他考虑因素
- AB 测试: 在正式升级 Embedding 模型之前,可以先进行 AB 测试,对比新旧模型的效果。
- 灰度发布: 逐步将新模型应用到一部分用户,观察效果后再全面推广。
- 数据版本控制: 对用于构建索引的数据进行版本控制,以便在回滚时能够恢复到旧版本的数据。
- 自动化部署: 使用自动化部署工具来简化模型升级和回滚的流程。
- 监控指标: 除了召回率之外,还可以监控其他指标,例如准确率、F1 值等。
表格总结
| 模块 | 功能 | 实现方式 |
|---|---|---|
| 模型管理 | 加载和管理 Embedding 模型 | 使用 transformers 库加载预训练模型,提供 encode 方法生成文本向量。 |
| 索引管理 | 构建和管理向量索引 | 使用 faiss 库构建向量索引,提供 add_embeddings 方法添加向量,search 方法执行相似度搜索。 |
| 监控 | 监控召回效果 | 定义 ground truth 数据,计算召回率,定期评估召回率,并在召回率低于阈值时发出警报。 |
| 回滚 | 切换回旧模型 | 将模型和索引切换回旧版本。 |
| 其他 | AB 测试,灰度发布,数据版本控制,自动化部署 | AB 测试对比新旧模型效果,灰度发布逐步推广新模型,数据版本控制确保回滚时能够恢复旧版本数据,自动化部署简化模型升级和回滚流程。 |
应对模型升级风险,保障 RAG 系统稳定
Embedding 模型升级是 RAG 应用持续优化的重要手段,但同时也伴随着一定的风险。通过建立完善的工程化回滚机制,我们可以有效地应对这些风险,确保 RAG 系统的稳定性和可靠性。