Python中定制学习率调度器(Scheduler):基于余弦、多项式衰减的理论设计与实现

Python定制学习率调度器:基于余弦、多项式衰减的理论设计与实现

大家好,今天我们来深入探讨如何在Python中定制学习率调度器,重点关注余弦退火和多项式衰减这两种常用的学习率调整策略。学习率调度器在深度学习模型的训练过程中扮演着至关重要的角色,它能够根据训练的进度动态地调整学习率,从而帮助模型更快、更稳定地收敛,并最终达到更好的性能。

1. 学习率调度器的重要性

在深度学习中,学习率直接影响模型的收敛速度和最终性能。一个合适的学习率能够在训练初期快速下降,而在训练后期进行微调,从而避免震荡和陷入局部最小值。学习率调度器正是为了实现这种动态调整而设计的。

使用固定学习率的弊端:

  • 学习率过大: 可能导致训练不稳定,甚至无法收敛。
  • 学习率过小: 可能导致训练速度过慢,或者模型陷入局部最小值。

学习率调度器通过在训练过程中动态调整学习率,可以有效地解决这些问题。常见的学习率调度策略包括:

  • Step Decay: 每隔一定步数或epoch将学习率降低一个固定的比例。
  • Exponential Decay: 学习率按照指数函数衰减。
  • Cosine Annealing: 学习率按照余弦函数周期性地变化。
  • Polynomial Decay: 学习率按照多项式函数衰减。

2. 余弦退火 (Cosine Annealing) 学习率调度器

余弦退火是一种周期性的学习率调整策略,其核心思想是模拟余弦函数的变化规律,使学习率在训练过程中先缓慢下降,然后快速下降,最后又缓慢下降,如此循环往复。这种周期性的变化有助于模型跳出局部最小值,探索更广阔的参数空间。

2.1 余弦退火的数学原理

余弦退火的学习率计算公式如下:

lr = eta_min + (eta_max - eta_min) * (1 + cos(pi * current_step / T)) / 2

其中:

  • lr:当前的学习率。
  • eta_max:学习率的最大值(初始学习率)。
  • eta_min:学习率的最小值。
  • current_step:当前的训练步数。
  • T:一个周期(epoch)的长度。
  • cos():余弦函数。
  • pi:圆周率。

这个公式描述了学习率在一个周期内从eta_max下降到eta_min,然后再回到eta_max的过程。eta_maxeta_min定义了学习率的上下界,T定义了学习率变化的周期。

2.2 Python实现余弦退火调度器

下面是一个使用PyTorch实现的余弦退火调度器:

import torch
from torch.optim.lr_scheduler import _LRScheduler
import math

class CosineAnnealingWarmRestarts(_LRScheduler):
    """
    余弦退火调度器,带有Warm Restarts。
    """
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
        """
        Args:
            optimizer (Optimizer): Wrapped optimizer.
            T_0 (int): 第一个周期的长度。
            T_mult (int): 每个周期长度的倍数。
            eta_min (float): 最小学习率。
            last_epoch (int): 上一个epoch的索引。
        """
        self.T_0 = T_0
        self.T_i = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)

        self.base_lrs = [group['lr'] for group in optimizer.param_groups]  # 记录初始学习率

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        return [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        """Step could be called after every batch update"""

        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.T_cur = self.T_cur - self.T_i
                self.T_i = self.T_i * self.T_mult
        else:
            if epoch < 0:
                raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch

        self.last_epoch = math.floor(epoch)

        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr

        self._last_lr = self.get_lr() # 为了兼容PyTorch 2.0+

    def __repr__(self):
        return self.__class__.__name__ + ' (' + ',n'.join(
            ['{}: {}'.format(k, self.__dict__[k]) for k in self.__dict__.keys()]) + ')'

代码解释:

  • __init__():初始化函数,接收优化器、周期长度、周期倍数和最小学习率等参数。
  • get_lr():计算当前的学习率。
  • step():更新学习率,并根据周期长度调整T_curT_i

2.3 使用余弦退火调度器

# 假设你已经定义了模型和优化器
model = ...
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 创建余弦退火调度器
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=0.001)

# 训练循环
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(dataloader):
        # 前向传播、计算损失、反向传播、更新参数
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新学习率
        scheduler.step()

    # 打印学习率
    print(f'Epoch [{epoch+1}/{num_epochs}], Learning Rate: {scheduler.get_last_lr()[0]}')

2.4 余弦退火的优点和缺点

优点:

  • 能够跳出局部最小值,探索更广阔的参数空间。
  • 在训练后期能够进行更精细的微调。
  • 适用于各种不同的模型和数据集。

缺点:

  • 需要仔细调整参数,特别是周期长度T
  • 可能会在训练初期造成学习率过小,导致训练速度较慢。

3. 多项式衰减 (Polynomial Decay) 学习率调度器

多项式衰减是一种随着训练进程逐渐降低学习率的策略。它使用一个多项式函数来控制学习率的衰减速度。

3.1 多项式衰减的数学原理

多项式衰减的学习率计算公式如下:

lr = (initial_lr - end_lr) * (1 - current_step / max_decay_steps) ^ power + end_lr

其中:

  • lr:当前的学习率。
  • initial_lr:初始学习率。
  • end_lr:最终学习率。
  • current_step:当前训练步数。
  • max_decay_steps:最大衰减步数(通常是总的训练步数)。
  • power:多项式衰减的指数。

这个公式描述了学习率从initial_lr逐渐降低到end_lr的过程。power控制了衰减的速度。当power越大时,衰减速度越慢。

3.2 Python实现多项式衰减调度器

下面是一个使用PyTorch实现的多项式衰减调度器:

import torch
from torch.optim.lr_scheduler import LambdaLR

class PolynomialLRDecay(LambdaLR):
    """
    多项式学习率衰减调度器。
    """
    def __init__(self, optimizer, max_decay_steps, initial_learning_rate, end_learning_rate=0.0, power=1.0, last_epoch=-1):
        """
        Args:
            optimizer (Optimizer): Wrapped optimizer.
            max_decay_steps (int): 最大衰减步数。
            initial_learning_rate (float): 初始学习率。
            end_learning_rate (float): 最终学习率。
            power (float): 多项式衰减的指数。
            last_epoch (int): 上一个epoch的索引。
        """
        self.max_decay_steps = max_decay_steps
        self.initial_learning_rate = initial_learning_rate
        self.end_learning_rate = end_learning_rate
        self.power = power
        super(PolynomialLRDecay, self).__init__(optimizer, self.lr_lambda, last_epoch)

    def lr_lambda(self, current_step):
        if current_step > self.max_decay_steps:
            return 1.0 # 保持最终学习率

        return (self.initial_learning_rate - self.end_learning_rate) * 
               ((1 - current_step / self.max_decay_steps) ** self.power) + 
               self.end_learning_rate / self.initial_learning_rate

代码解释:

  • __init__():初始化函数,接收优化器、最大衰减步数、初始学习率、最终学习率和多项式衰减的指数等参数。
  • lr_lambda():计算学习率的比例因子。这个函数会被LambdaLR调用,从而更新学习率。

3.3 使用多项式衰减调度器

# 假设你已经定义了模型和优化器
model = ...
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 计算总的训练步数
total_steps = num_epochs * len(dataloader)

# 创建多项式衰减调度器
scheduler = PolynomialLRDecay(optimizer, max_decay_steps=total_steps, initial_learning_rate=0.01, end_learning_rate=0.001, power=2.0)

# 训练循环
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(dataloader):
        # 前向传播、计算损失、反向传播、更新参数
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新学习率
        scheduler.step()

    # 打印学习率
    print(f'Epoch [{epoch+1}/{num_epochs}], Learning Rate: {scheduler.get_last_lr()[0]}')

3.4 多项式衰减的优点和缺点

优点:

  • 能够平滑地降低学习率。
  • 可以通过调整power参数来控制衰减的速度。
  • 实现简单,易于使用。

缺点:

  • 需要预先知道总的训练步数。
  • 可能在训练后期学习率过小,导致训练速度较慢。

4. 两种调度器的比较

为了更清晰地了解余弦退火和多项式衰减的特点,我们将其优缺点总结如下表:

特性 余弦退火 多项式衰减
学习率变化 周期性变化,模拟余弦函数 单调递减,按照多项式函数
参数调整 主要调整周期长度T 主要调整powerend_lr
适用场景 适用于需要跳出局部最小值的场景 适用于需要平滑降低学习率的场景
优点 能够跳出局部最小值,在训练后期进行精细微调 平滑降低学习率,实现简单,易于使用
缺点 需要仔细调整参数,可能在训练初期学习率过小 需要预先知道总的训练步数,训练后期学习率可能过小

5. 其他定制化考量

除了上述两种常用的调度器外,我们还可以根据实际需求进行更进一步的定制。例如:

  • Warmup阶段: 在训练初期使用一个较小的学习率进行预热,然后再使用余弦退火或多项式衰减。这可以帮助模型更稳定地学习到初始的特征。
  • 混合使用: 将不同的学习率调度策略组合起来使用。例如,先使用多项式衰减降低学习率,然后再使用余弦退火进行微调。
  • 自定义衰减函数: 可以根据自己的需求设计学习率衰减函数,例如基于损失函数的变化情况来动态调整学习率。

以下是加入Warmup阶段的多项式衰减调度器示例:

import torch
from torch.optim.lr_scheduler import LambdaLR

class WarmupPolynomialLRDecay(LambdaLR):
    """
    带有Warmup的多项式学习率衰减调度器。
    """
    def __init__(self, optimizer, warmup_steps, max_decay_steps, initial_learning_rate, end_learning_rate=0.0, power=1.0, last_epoch=-1):
        """
        Args:
            optimizer (Optimizer): Wrapped optimizer.
            warmup_steps (int): Warmup步数。
            max_decay_steps (int): 最大衰减步数。
            initial_learning_rate (float): 初始学习率。
            end_learning_rate (float): 最终学习率。
            power (float): 多项式衰减的指数。
            last_epoch (int): 上一个epoch的索引。
        """
        self.warmup_steps = warmup_steps
        self.max_decay_steps = max_decay_steps
        self.initial_learning_rate = initial_learning_rate
        self.end_learning_rate = end_learning_rate
        self.power = power
        super(WarmupPolynomialLRDecay, self).__init__(optimizer, self.lr_lambda, last_epoch)

    def lr_lambda(self, current_step):
        if current_step < self.warmup_steps:
            return current_step / self.warmup_steps
        elif current_step > self.max_decay_steps:
            return self.end_learning_rate / self.initial_learning_rate # 保持最终学习率
        else:
            return (self.initial_learning_rate - self.end_learning_rate) * 
                   ((1 - (current_step - self.warmup_steps) / (self.max_decay_steps - self.warmup_steps)) ** self.power) + 
                   self.end_learning_rate / self.initial_learning_rate

6. 如何选择合适的学习率调度器

选择合适的学习率调度器需要根据具体的模型、数据集和训练任务进行综合考虑。以下是一些建议:

  • 如果模型容易陷入局部最小值: 可以尝试使用余弦退火等周期性的学习率调整策略。
  • 如果需要平滑地降低学习率: 可以尝试使用多项式衰减。
  • 如果训练初期不稳定: 可以尝试使用Warmup阶段。
  • 进行实验: 最好的方法是进行大量的实验,比较不同学习率调度器的性能,并选择最适合自己的策略。

总而言之,理解学习率调度器的原理、实现和选择,是提升深度学习模型性能的关键一步。

灵活运用调度器,提升模型性能

我们深入探讨了余弦退火和多项式衰减这两种学习率调度器的理论基础和Python实现。通过理解它们的优缺点,并结合实际需求进行定制,可以有效地提升深度学习模型的性能。

更多IT精英技术系列讲座,到智猿学院

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注