在训练平台中使用 DAG 编排管理 RAG 模型训练与评估复杂流程
大家好,今天我将为大家讲解如何利用 DAG (Directed Acyclic Graph,有向无环图) 编排工具,在训练平台上高效地管理和自动化 RAG (Retrieval-Augmented Generation,检索增强生成) 模型的训练与评估流程。RAG 模型的训练和评估涉及多个步骤,包括数据预处理、索引构建、模型训练、评估指标计算等。这些步骤之间存在复杂的依赖关系,手动管理容易出错且效率低下。DAG 编排可以帮助我们清晰地定义这些依赖关系,并自动化执行整个流程。
一、RAG 模型训练与评估流程概述
在深入 DAG 编排之前,我们先来回顾一下 RAG 模型的典型训练与评估流程。
- 数据准备与预处理:
- 数据收集: 收集用于训练和评估的文档数据。这些数据可以是文本文件、网页内容、数据库记录等。
- 文本清洗: 去除 HTML 标签、特殊字符、停用词等,并将文本转换为小写。
- 文本分割: 将长文本分割成较小的段落或句子,以便更好地进行检索。
- 知识库构建 (索引构建):
- 文本嵌入: 使用预训练的语言模型 (例如,Sentence Transformers) 将文本段落转换为向量表示。
- 索引创建: 将文本嵌入向量存储到向量数据库 (例如,FAISS, Milvus, Pinecone) 中,以便进行快速相似性搜索。
- 模型训练 (可选):
- 微调语言模型: 如果需要,可以使用 RAG 数据集微调预训练的语言模型,以提高生成质量。
- 训练检索器: 训练或微调检索器模型,以提高检索的准确性。
- 模型评估:
- 生成答案: 对于给定的问题,使用 RAG 模型生成答案。
- 评估指标计算: 使用各种评估指标 (例如,ROUGE, BLEU, 准确率, 召回率, 上下文相关性) 评估生成答案的质量。
二、DAG 编排工具的选择与技术栈
目前有很多优秀的 DAG 编排工具可供选择,例如 Apache Airflow, Argo Workflows, Prefect 等。选择合适的工具取决于你的具体需求和技术栈。这里我们以 Apache Airflow 为例进行讲解,因为它是一个非常流行的开源平台,具有强大的功能和灵活的扩展性。
我们的技术栈包括:
- 编程语言: Python
- DAG 编排工具: Apache Airflow
- 向量数据库: FAISS (或者其他如 Milvus, Pinecone)
- 语言模型: Hugging Face Transformers (用于文本嵌入和生成)
- 评估指标: ROUGE, BLEU, 自定义指标
三、使用 Airflow 定义 RAG 模型训练与评估 DAG
下面我们将使用 Airflow 来定义一个 RAG 模型训练与评估的 DAG。
1. 安装 Airflow:
pip install apache-airflow
2. 初始化 Airflow:
airflow db init
3. 启动 Airflow Web 服务器和调度器:
airflow webserver -p 8080 # 默认端口是 8080
airflow scheduler
4. 创建 DAG 文件 (例如,rag_pipeline.py):
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.bash import BashOperator
from datetime import datetime
import os
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
# 定义一些常量
DATA_DIR = 'data'
INDEX_PATH = 'index.faiss'
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
EVAL_FILE = 'eval_data.csv' #包含 question, ground_truth 的 csv 文件
# 确保数据目录存在
os.makedirs(DATA_DIR, exist_ok=True)
# 函数:数据准备与预处理
def prepare_data(**kwargs):
"""
模拟数据准备和预处理。
实际应用中,这里会读取原始数据,进行清洗、分割等操作。
"""
print("Preparing data...")
# 创建一些虚拟数据
data = [
"Airflow is a platform to programmatically author, schedule and monitor workflows.",
"Airflow is open source and community supported.",
"RAG combines retrieval and generation for improved results.",
"FAISS is a library for efficient similarity search.",
"Transformers are powerful models for NLP tasks."
]
# 保存到文件 (模拟)
with open(os.path.join(DATA_DIR, 'data.txt'), 'w') as f:
for item in data:
f.write(item + 'n')
eval_data = [("What is Airflow?", "Airflow is a platform for workflows."),
("What does RAG combine?", "Retrieval and generation."),
("What is FAISS used for?", "Efficient similarity search.")]
import csv
with open(os.path.join(DATA_DIR, EVAL_FILE), 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['question', 'ground_truth']) # header
for q, a in eval_data:
writer.writerow([q, a])
print("Data preparation complete.")
return 'data_prepared' # 用于 XCom
# 函数:构建知识库 (索引构建)
def build_index(**kwargs):
"""
构建知识库 (索引构建)。
使用 FAISS 存储文本嵌入向量。
"""
print("Building index...")
# 加载数据
with open(os.path.join(DATA_DIR, 'data.txt'), 'r') as f:
data = [line.strip() for line in f]
# 加载预训练模型和 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# 计算文本嵌入向量
embeddings = []
for text in data:
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
# 使用平均池化获得句子的向量表示
embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy()
embeddings.append(embedding)
embeddings = np.concatenate(embeddings, axis=0)
# 创建 FAISS 索引
dimension = embeddings.shape[1] # 向量维度
index = faiss.IndexFlatL2(dimension) # 使用 L2 距离
index.add(embeddings)
# 保存索引
faiss.write_index(index, INDEX_PATH)
print("Index building complete.")
return 'index_built'
# 函数:RAG 模型评估
def evaluate_rag_model(**kwargs):
"""
RAG 模型评估。
使用 ROUGE 和 BLEU 指标评估生成答案的质量。
"""
print("Evaluating RAG model...")
# 加载索引
index = faiss.read_index(INDEX_PATH)
# 加载预训练模型和 tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# 加载评估数据
import csv
eval_data = []
with open(os.path.join(DATA_DIR, EVAL_FILE), 'r') as csvfile:
reader = csv.reader(csvfile)
next(reader) # 跳过 header
for row in reader:
eval_data.append(row)
from rouge import Rouge
from nltk.translate.bleu_score import sentence_bleu
rouge = Rouge()
total_rouge_score = {'rouge-1': {'f': 0, 'p': 0, 'r': 0},
'rouge-2': {'f': 0, 'p': 0, 'r': 0},
'rouge-l': {'f': 0, 'p': 0, 'r': 0}}
total_bleu_score = 0.0
num_eval_samples = len(eval_data)
for question, ground_truth in eval_data:
# 1. 检索相关文档
question_embedding = get_embedding(question, tokenizer, model)
k = 3 # 检索 top k 个文档
D, I = index.search(question_embedding, k) # D: 距离, I: 索引
# 假设我们从索引中检索到了相关文档,这里只是模拟
retrieved_documents = [
"Airflow is a platform to programmatically author, schedule and monitor workflows.",
"Airflow is open source and community supported.",
"RAG combines retrieval and generation for improved results."
] # 假设 retrieved_documents 是从索引中检索到的
# 2. 生成答案 (这里只是模拟,实际应用中需要使用语言模型生成答案)
generated_answer = f"The answer to '{question}' is related to: " + " ".join(retrieved_documents) # 简化生成过程
# 3. 计算 ROUGE 指标
try:
scores = rouge.get_scores(generated_answer, ground_truth)[0]
for rouge_type in total_rouge_score:
total_rouge_score[rouge_type]['f'] += scores[rouge_type]['f']
total_rouge_score[rouge_type]['p'] += scores[rouge_type]['p']
total_rouge_score[rouge_type]['r'] += scores[rouge_type]['r']
except ValueError as e:
print(f"ROUGE ValueError: {e}, skipping this sample.")
num_eval_samples -= 1 # 减少样本数,防止除以0
# 4. 计算 BLEU 指标
reference = ground_truth.split()
candidate = generated_answer.split()
total_bleu_score += sentence_bleu([reference], candidate)
# 计算平均 ROUGE 和 BLEU 指标
if num_eval_samples > 0: # 防止除以 0
for rouge_type in total_rouge_score:
total_rouge_score[rouge_type]['f'] /= num_eval_samples
total_rouge_score[rouge_type]['p'] /= num_eval_samples
total_rouge_score[rouge_type]['r'] /= num_eval_samples
average_bleu_score = total_bleu_score / num_eval_samples
else:
average_bleu_score = 0 # 如果没有有效的评估样本,设置为0
print("No valid evaluation samples found.")
print("ROUGE Score:", total_rouge_score)
print("BLEU Score:", average_bleu_score)
return {'rouge': total_rouge_score, 'bleu': average_bleu_score}
def get_embedding(text, tokenizer, model):
"""
计算文本的嵌入向量。
"""
inputs = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
# 定义 DAG
with DAG(
dag_id='rag_pipeline',
schedule=None, # 手动触发
start_date=datetime(2023, 1, 1),
catchup=False,
tags=['rag'],
) as dag:
# 任务:数据准备
prepare_data_task = PythonOperator(
task_id='prepare_data',
python_callable=prepare_data,
provide_context=True, # 允许访问 Airflow 上下文
)
# 任务:构建索引
build_index_task = PythonOperator(
task_id='build_index',
python_callable=build_index,
provide_context=True,
)
# 任务:评估 RAG 模型
evaluate_rag_task = PythonOperator(
task_id='evaluate_rag',
python_callable=evaluate_rag_model,
provide_context=True,
)
# 定义任务依赖关系
prepare_data_task >> build_index_task >> evaluate_rag_task
代码解释:
prepare_data: 模拟数据准备,创建data.txt和eval_data.csv文件。实际应用中,需要从数据源读取数据并进行预处理。build_index: 构建知识库,使用sentence-transformers/all-mpnet-base-v2模型计算文本嵌入向量,并使用 FAISS 存储索引。evaluate_rag_model: 评估 RAG 模型,加载索引,对于评估数据集中的每个问题,检索相关文档并生成答案,然后计算 ROUGE 和 BLEU 指标。这里为了简化,直接将检索到的文本拼接起来作为生成的答案。实际应用中,需要使用语言模型生成答案。DAG 定义: 使用DAG上下文管理器定义 DAG,设置dag_id,schedule,start_date,catchup,tags等参数。PythonOperator: 使用PythonOperator定义 Python 任务,指定task_id和python_callable。任务依赖关系: 使用>>运算符定义任务依赖关系,确保任务按照正确的顺序执行。
5. 将 DAG 文件上传到 Airflow 的 dags 目录:
将 rag_pipeline.py 文件复制到 Airflow 的 dags 目录。 默认情况下,这个目录位于 ~/airflow/dags。 可以通过 Airflow 的配置文件 airflow.cfg 来修改 dags 目录。
6. 启动 Airflow Web 服务器和调度器:
确保 Airflow Web 服务器和调度器正在运行。
7. 在 Airflow Web 界面中触发 DAG:
在 Airflow Web 界面中,找到 rag_pipeline DAG,点击 "Trigger DAG" 按钮手动触发 DAG 运行。
四、扩展 DAG 以支持模型训练
上述 DAG 只包含了数据准备、索引构建和评估,并没有包含模型训练。如果需要微调语言模型或训练检索器,可以扩展 DAG 以包含相应的任务。
例如,可以添加一个 train_model 任务:
# 函数:训练语言模型
def train_model(**kwargs):
"""
微调语言模型。
使用 RAG 数据集微调预训练的语言模型,以提高生成质量。
"""
print("Training model...")
# TODO: 实现模型训练逻辑
print("Model training complete.")
return 'model_trained'
然后,将 train_model_task 添加到 DAG 中,并修改任务依赖关系:
# 任务:训练模型
train_model_task = PythonOperator(
task_id='train_model',
python_callable=train_model,
provide_context=True,
)
# 定义任务依赖关系
prepare_data_task >> build_index_task >> train_model_task >> evaluate_rag_task
五、使用 XCom 传递任务间的数据
Airflow 使用 XCom (Cross-Communication) 机制在任务之间传递数据。在上面的例子中,我们使用了 provide_context=True 允许任务访问 Airflow 上下文,然后使用 kwargs['ti'].xcom_push() 和 kwargs['ti'].xcom_pull() 方法来推送和拉取数据。
例如,可以在 prepare_data 任务中将预处理后的数据推送到 XCom:
def prepare_data(**kwargs):
# ... (数据准备逻辑) ...
kwargs['ti'].xcom_push(key='preprocessed_data', value=data)
return 'data_prepared'
然后在 build_index 任务中从 XCom 拉取数据:
def build_index(**kwargs):
data = kwargs['ti'].xcom_pull(task_ids='prepare_data', key='preprocessed_data')
# ... (使用数据构建索引) ...
return 'index_built'
六、使用 BashOperator 执行外部脚本
除了 PythonOperator,Airflow 还提供了 BashOperator,可以用来执行 Bash 命令或外部脚本。这对于调用一些非 Python 的工具或执行一些系统级别的操作非常有用。
例如,可以使用 BashOperator 来执行一个用于数据清洗的 Shell 脚本:
from airflow.operators.bash import BashOperator
clean_data_task = BashOperator(
task_id='clean_data',
bash_command='sh /path/to/clean_data.sh',
)
prepare_data_task >> clean_data_task >> build_index_task
七、监控与告警
Airflow 提供了丰富的监控功能,可以实时查看 DAG 的运行状态、任务的执行日志等。同时,还可以配置告警规则,当 DAG 运行失败或任务执行超时时,自动发送邮件或短信通知。
八、总结
本文详细介绍了如何使用 Apache Airflow 编排 RAG 模型的训练与评估流程。通过 DAG 编排,可以清晰地定义任务之间的依赖关系,自动化执行整个流程,提高效率并减少出错的可能性。我们涵盖了数据准备、索引构建、模型训练、评估等关键步骤,并提供了详细的代码示例。利用 Airflow 的强大功能,我们可以轻松地构建、管理和监控复杂的 RAG 模型训练与评估流水线。
九、流程要点回顾
回顾本文,我们首先确定了 RAG 模型的训练与评估过程,然后选择了合适的 DAG 编排工具,最后详细地展示了如何使用 Airflow 构建 RAG 模型训练与评估 DAG,涵盖了数据准备、索引构建、模型训练和评估等关键步骤,并提供了详细的代码示例。
十、进一步思考
实际应用中,RAG 模型的训练与评估流程会更加复杂,例如,需要处理大规模数据、使用分布式训练、进行超参数调优等。可以根据实际需求,进一步扩展 DAG,利用 Airflow 的强大功能,构建更加完善的 RAG 模型训练与评估流水线。