Slurm 集群中的 LLM 调度:处理节点故障与自动断点续训的脚本工程
大家好!今天我们来深入探讨一个在大型语言模型(LLM)训练中至关重要的话题:在 Slurm 集群环境下,如何应对节点故障并实现自动断点续训。LLM 的训练通常需要耗费大量的计算资源和时间,任何意外中断都可能导致巨大的损失。因此,一个健壮的调度系统和一套有效的故障恢复机制是必不可少的。
1. 背景:LLM 训练的挑战与 Slurm 的作用
LLM 的训练面临着诸多挑战:
- 资源需求巨大: 需要大量的 GPU 资源、内存和存储空间。
- 训练时间长: 通常需要数周甚至数月才能完成一次训练。
- 分布式训练复杂: 需要高效的数据并行和模型并行策略。
- 容错性要求高: 节点故障可能导致训练中断,浪费大量资源。
Slurm 作为一款流行的集群资源管理器,提供了强大的作业调度、资源分配和监控功能。它可以帮助我们有效地管理集群资源,并为 LLM 训练提供稳定可靠的运行环境。
2. 节点故障检测与处理策略
节点故障是分布式训练中不可避免的问题。我们需要一套机制来及时检测故障并采取相应的处理措施。
2.1 节点故障检测
Slurm 提供了多种方式来检测节点状态:
- sinfo 命令: 可以查看集群中所有节点的状态,包括
IDLE,ALLOCATED,DOWN,DRAIN等。 - scontrol 命令: 可以获取更详细的节点信息,包括节点上的作业列表、CPU 利用率、内存使用情况等。
- Slurm 监控工具: 可以实时监控节点状态,并在节点发生故障时发送告警。
我们可以编写一个脚本,定期检查节点状态,并根据状态采取相应的处理措施。例如:
#!/usr/bin/env python3
import subprocess
import time
def get_node_status():
"""获取所有节点的状态"""
try:
result = subprocess.run(['sinfo', '-Nh', '-o', '%N %T'], capture_output=True, text=True, check=True)
node_statuses = {}
for line in result.stdout.strip().split('n'):
node, status = line.split()
node_statuses[node] = status
return node_statuses
except subprocess.CalledProcessError as e:
print(f"Error running sinfo: {e}")
return None
def handle_node_failure(node):
"""处理节点故障,这里可以添加你的故障处理逻辑"""
print(f"Node {node} is DOWN. Implement your failure handling logic here.")
#例如:可以尝试重新启动节点,或者将该节点上的作业迁移到其他节点
#subprocess.run(['scontrol', 'reboot', node], capture_output=True, text=True) # 需要管理员权限
#subprocess.run(['scontrol', 'release', node], capture_output=True, text=True) # 释放节点资源
if __name__ == "__main__":
while True:
node_statuses = get_node_status()
if node_statuses:
for node, status in node_statuses.items():
if status == 'DOWN' or status == 'DRAIN':
handle_node_failure(node)
time.sleep(60) # 每隔 60 秒检查一次
2.2 故障处理策略
当检测到节点故障时,我们需要采取相应的处理策略,以尽可能减少训练中断带来的损失。常见的策略包括:
- 节点重启: 尝试重启故障节点,看是否能够恢复。
- 作业迁移: 将故障节点上的作业迁移到其他可用节点。这需要你的训练框架支持作业迁移功能。
- 断点续训: 从最近的检查点(Checkpoint)恢复训练。这是最常用的方法,也是我们接下来重点讨论的内容。
- 资源回收: 将故障节点从集群中移除,避免影响其他作业的运行。
选择哪种策略取决于具体的应用场景和训练框架。
3. 自动断点续训的实现
自动断点续训是 LLM 训练中最重要的容错机制。它允许我们在训练中断后,从最近的检查点恢复训练,避免从头开始。
3.1 检查点(Checkpoint)机制
检查点机制是指在训练过程中,定期将模型的状态(包括模型参数、优化器状态、学习率等)保存到磁盘上。这些保存的状态就是检查点。
大多数深度学习框架(如 PyTorch, TensorFlow, DeepSpeed)都提供了检查点机制。我们需要在训练代码中添加相应的逻辑,定期保存检查点。
例如,在 PyTorch 中,可以使用 torch.save() 函数保存检查点:
import torch
# 训练循环
for epoch in range(start_epoch, num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
# 前向传播、计算损失、反向传播、更新参数
...
# 定期保存检查点
if (i + 1) % checkpoint_interval == 0:
checkpoint_path = f'checkpoint_epoch_{epoch}_step_{i+1}.pth'
torch.save({
'epoch': epoch,
'step': i+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
# 其他需要保存的状态
}, checkpoint_path)
print(f'Checkpoint saved to {checkpoint_path}')
3.2 自动恢复逻辑
当训练中断后,我们需要一套自动恢复逻辑,从最近的检查点恢复训练。这需要在 Slurm 脚本中添加相应的逻辑。
一个典型的 Slurm 脚本如下:
#!/bin/bash
#SBATCH --job-name=llm_training
#SBATCH --nodes=8
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --time=72:00:00
#SBATCH --output=llm_training.log
#SBATCH --error=llm_training.err
# 定义训练脚本和检查点目录
TRAIN_SCRIPT="train.py"
CHECKPOINT_DIR="checkpoints"
# 创建检查点目录
mkdir -p $CHECKPOINT_DIR
# 查找最近的检查点
LATEST_CHECKPOINT=$(ls -t $CHECKPOINT_DIR/checkpoint_*.pth | head -n 1)
# 如果找到检查点,则从检查点恢复训练
if [ -n "$LATEST_CHECKPOINT" ]; then
echo "Resuming training from checkpoint: $LATEST_CHECKPOINT"
srun python $TRAIN_SCRIPT --checkpoint $LATEST_CHECKPOINT
else
echo "Starting training from scratch"
srun python $TRAIN_SCRIPT
fi
在 train.py 脚本中,需要添加加载检查点的逻辑:
import torch
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint file')
args = parser.parse_args()
# 初始化模型、优化器等
model = ...
optimizer = ...
start_epoch = 0
start_step = 0
# 如果指定了检查点,则加载检查点
if args.checkpoint:
print(f'Loading checkpoint from {args.checkpoint}')
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
start_step = checkpoint['step']
# 加载其他状态
# 训练循环
for epoch in range(start_epoch, num_epochs):
for i, (inputs, labels) in enumerate(train_loader):
if epoch == start_epoch and i < start_step:
continue # 从上次中断的地方开始
# 前向传播、计算损失、反向传播、更新参数
...
# 定期保存检查点
if (i + 1) % checkpoint_interval == 0:
checkpoint_path = f'checkpoints/checkpoint_epoch_{epoch}_step_{i+1}.pth'
torch.save({
'epoch': epoch,
'step': i+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
# 其他需要保存的状态
}, checkpoint_path)
print(f'Checkpoint saved to {checkpoint_path}')
if __name__ == "__main__":
main()
3.3 检查点保存策略
检查点的保存策略对训练效率和容错性都有重要影响。常见的策略包括:
- 定期保存: 每隔一段时间或一定数量的迭代次数保存一次检查点。
- 差异保存: 只保存与上次检查点不同的部分,可以减少存储空间。
- 多副本保存: 将检查点保存到多个存储位置,提高可靠性。
选择哪种策略取决于具体的应用场景和存储资源。
3.4 检查点清理策略
随着训练的进行,会生成大量的检查点文件。我们需要一套清理策略,定期删除旧的检查点,释放存储空间。
一个简单的清理策略是只保留最近的 N 个检查点:
#!/bin/bash
# 保留最近的 5 个检查点
NUM_CHECKPOINTS_TO_KEEP=5
# 查找所有检查点文件
CHECKPOINT_FILES=$(ls -t checkpoints/checkpoint_*.pth)
# 如果检查点文件数量超过了要保留的数量,则删除旧的检查点
if [ $(echo "$CHECKPOINT_FILES" | wc -l) -gt $NUM_CHECKPOINTS_TO_KEEP ]; then
echo "Cleaning up old checkpoints"
OLD_CHECKPOINTS=$(echo "$CHECKPOINT_FILES" | tail -n +$((NUM_CHECKPOINTS_TO_KEEP+1)))
for CHECKPOINT in $OLD_CHECKPOINTS; do
rm $CHECKPOINT
echo "Deleted: $CHECKPOINT"
done
fi
可以将这个脚本添加到 Slurm 脚本中,定期执行。
4. 深度集成:结合 Slurm 的作业依赖与自动重提交
为了进一步提高系统的可靠性和自动化程度,我们可以将检查点机制、自动恢复逻辑与 Slurm 的作业依赖和自动重提交功能结合起来。
4.1 Slurm 作业依赖
Slurm 允许我们定义作业之间的依赖关系。例如,我们可以让一个作业只有在前一个作业成功完成后才能运行。
我们可以利用作业依赖,实现更复杂的恢复逻辑。例如,可以创建一个单独的作业,负责检查训练作业是否成功完成。如果训练作业失败,则该作业可以自动提交一个新的训练作业,从最近的检查点恢复训练。
4.2 Slurm 自动重提交
Slurm 提供了自动重提交功能,可以在作业失败后自动重新提交作业。我们可以将自动重提交与检查点机制结合起来,实现更强的容错性。
例如,我们可以设置 Slurm 在作业失败后自动重提交作业,并从最近的检查点恢复训练。如果作业仍然失败,则 Slurm 会再次重提交作业,直到达到最大重试次数。
一个结合作业依赖和自动重提交的 Slurm 脚本示例:
#!/bin/bash
#SBATCH --job-name=llm_training_wrapper
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --time=00:10:00
#SBATCH --output=llm_training_wrapper.log
#SBATCH --error=llm_training_wrapper.err
# 定义训练作业的 Slurm 脚本
TRAIN_SCRIPT="train_slurm.sh"
# 提交训练作业
JOB_ID=$(sbatch $TRAIN_SCRIPT)
# 提取作业 ID
JOB_ID=$(echo $JOB_ID | awk '{print $4}')
# 循环检查训练作业的状态
while true; do
# 获取作业状态
JOB_STATUS=$(squeue -j $JOB_ID -h -t)
# 如果作业不存在,则说明作业已完成(成功或失败)
if [ -z "$JOB_STATUS" ]; then
echo "Training job $JOB_ID finished."
# 检查训练是否成功,查看是否有训练完成标志文件,或者日志中是否有成功标志
if grep -q "Training completed successfully" llm_training.log; then
echo "Training completed successfully."
exit 0
else
echo "Training failed. Resubmitting job."
# 重新提交训练作业,依赖于当前包装器作业完成
JOB_ID=$(sbatch --dependency=afterany:$SLURM_JOB_ID $TRAIN_SCRIPT)
JOB_ID=$(echo $JOB_ID | awk '{print $4}') # 提取JOB ID
fi
fi
# 等待一段时间后再次检查
sleep 60
done
train_slurm.sh 是之前提到的包含检查点恢复逻辑的训练脚本。
4.3 完整性校验
在从检查点恢复训练之前,我们需要对检查点的完整性进行校验,以确保检查点文件没有损坏。
可以使用哈希算法(如 MD5, SHA256)对检查点文件进行校验。在保存检查点时,同时保存检查点文件的哈希值。在恢复训练时,计算检查点文件的哈希值,并与保存的哈希值进行比较。如果哈希值不一致,则说明检查点文件已损坏,不能用于恢复训练。
5. 总结:稳定LLM训练的关键要素
LLM 训练的稳定性和效率至关重要。本文探讨了在 Slurm 集群环境下,如何应对节点故障并实现自动断点续训,包括节点故障检测、故障处理策略、检查点机制、自动恢复逻辑、检查点保存与清理策略,以及与 Slurm 作业依赖和自动重提交功能的集成。通过这些技术手段,我们可以构建一个健壮的 LLM 训练系统,减少因节点故障导致的中断,提高训练效率。