Python JAX 自定义 VJP:实现新的自动微分规则
大家好,今天我们深入探讨 JAX 中自定义 Vector-Jacobian Product (VJP),这是实现新的自动微分规则的关键技术。JAX 强大的自动微分能力很大程度上依赖于对基本操作的 VJP 和 Jacobian-Vector Product (JVP) 的定义。虽然 JAX 已经提供了大量内置的 VJP 和 JVP,但有时候我们需要为自定义函数或操作定义自己的规则,以提高效率或处理 JAX 默认无法处理的情况。
1. 自动微分基础:VJP 和 JVP
在深入自定义 VJP 之前,我们先回顾一下自动微分的核心概念:VJP 和 JVP。 它们是两种不同的计算导数的方式。
-
JVP (Jacobian-Vector Product): 给定函数
f(x)和方向向量v,JVP 计算J @ v,其中J是f在x处的 Jacobian 矩阵。 可以理解为,JVP 计算了f(x)在方向v上的方向导数。 -
VJP (Vector-Jacobian Product): 给定函数
f(x)和向量v,VJP 计算v @ J,其中J是f在x处的 Jacobian 矩阵。 可以理解为,VJP 计算了f(x)的梯度与向量v的点积,或者反向传播过程中,从输出到输入的梯度。
JAX 提供了 jax.jvp 和 jax.vjp 函数来分别计算 JVP 和 VJP。 对于大多数情况,我们只需要定义 VJP 或 JVP 其中一个,另一个可以通过 JAX 自动推导得到。 然而,在某些情况下,直接定义 VJP 或 JVP 可能会更有效或更方便。
2. 为什么需要自定义 VJP?
JAX 已经提供了非常强大的自动微分功能,为什么还需要自定义 VJP 呢? 以下是一些主要原因:
- 性能优化: 对于某些函数,JAX 默认的自动微分可能效率不高。通过自定义 VJP,我们可以利用特定函数的数学特性,实现更高效的梯度计算。
- 处理不可微操作: 有些函数在某些点或区域是不可微的。 通过自定义 VJP,我们可以定义在这些点的梯度,使得自动微分可以顺利进行。这通常涉及到使用次梯度或广义梯度。
- 处理数值稳定性问题: 某些操作在数值上不稳定,导致梯度计算出现问题。 通过自定义 VJP,我们可以使用更稳定的公式来计算梯度,避免数值问题。
- 融合多个操作: 有时候,将多个操作融合到一个自定义操作中,并为其定义 VJP,可以减少中间变量的存储和计算,提高整体性能。
- 与外部库集成: 如果我们想将 JAX 与一些不支持自动微分的外部库集成,可以自定义 VJP 来桥接这些库。
3. 使用 jax.custom_vjp 定义 VJP
JAX 提供了 jax.custom_vjp 装饰器来定义自定义 VJP。 jax.custom_vjp 的使用方式如下:
import jax
import jax.numpy as jnp
@jax.custom_vjp
def my_function(x):
# 前向计算
return ...
def my_function_fwd(x):
# 前向模式计算,返回原始结果和用于反向传播的信息 (residuals)
return my_function(x), (x,) # 返回值必须是一个tuple
def my_function_bwd(residuals, grad_output):
# 反向模式计算,residuals 是从 fwd 传递过来的信息,grad_output 是输出的梯度
x, = residuals
# 计算关于输入的梯度,返回值必须是一个tuple
grad_x = ... # 根据链式法则计算梯度
return (grad_x,)
my_function.defvjp(my_function_fwd, my_function_bwd)
@jax.custom_vjp: 这个装饰器告诉 JAX 我们要为my_function定义自定义 VJP。my_function_fwd(x): 这个函数定义了前向计算,并且返回原始结果和一个包含反向传播所需信息的元组(residuals)。 residuals 可以包含任何需要在反向传播中使用的值,例如输入、中间变量或其他信息。 前向模式计算主要目的是为了缓存反向传播需要用到的中间变量。my_function_bwd(residuals, grad_output): 这个函数定义了反向计算。 它接收 residuals(从my_function_fwd传递过来)和输出的梯度grad_output,然后计算关于输入的梯度grad_x。grad_output相当于链式法则中的dLoss/dOutput,而我们需要计算的是dLoss/dInput。 务必遵循链式法则。my_function.defvjp(my_function_fwd, my_function_bwd): 这个语句将my_function_fwd和my_function_bwd注册为my_function的前向和反向计算函数。
4. 示例:自定义 relu 函数的 VJP
让我们通过一个具体的例子来说明如何自定义 VJP。 我们将为 ReLU (Rectified Linear Unit) 函数定义自定义 VJP。 ReLU 函数定义如下:
relu(x) = max(0, x)
ReLU 函数在 x=0 处不可微,但我们可以为其定义一个次梯度。 通常,我们定义 ReLU 在 x=0 处的导数为 0 或 1。
import jax
import jax.numpy as jnp
@jax.custom_vjp
def relu(x):
return jnp.maximum(0, x)
def relu_fwd(x):
# 返回 relu(x) 和 x 的值,以便在反向传播中使用
return relu(x), x
def relu_bwd(x, grad_output):
# 如果 x > 0,则梯度为 grad_output,否则为 0
x = x[0] # 从tuple取值
grad_x = grad_output * (x > 0)
return (grad_x,)
relu.defvjp(relu_fwd, relu_bwd)
# 测试
x = jnp.array([-1.0, 0.0, 1.0])
y = relu(x)
print(f"relu({x}) = {y}")
grad_fn = jax.grad(lambda x: jnp.sum(relu(x)))
grad_x = grad_fn(x)
print(f"grad(relu({x})) = {grad_x}")
# 使用 jax.jit 加速
jit_grad_fn = jax.jit(grad_fn)
grad_x_jit = jit_grad_fn(x)
print(f"jit_grad(relu({x})) = {grad_x_jit}")
在这个例子中,relu_fwd 函数返回 ReLU 的结果和输入 x。 relu_bwd 函数接收输入 x 和输出的梯度 grad_output,然后计算关于输入的梯度。 当 x > 0 时,梯度为 grad_output;当 x <= 0 时,梯度为 0。
5. 示例:处理数值不稳定性的 VJP
考虑 sigmoid 函数:
sigmoid(x) = 1 / (1 + exp(-x))
当 x 很大时,exp(-x) 可能会溢出导致数值不稳定。 为了解决这个问题,我们可以使用以下公式来计算 sigmoid 函数的导数:
sigmoid'(x) = sigmoid(x) * (1 - sigmoid(x))
但是,当 x 很大时,sigmoid(x) 接近 1,1 - sigmoid(x) 接近 0,直接计算可能导致数值下溢。更好的方法是使用 log-sigmoid 技巧。 然而,为了演示自定义 VJP,我们假设无法直接使用 log-sigmoid,而是要通过自定义 VJP 来提高数值稳定性。
import jax
import jax.numpy as jnp
@jax.custom_vjp
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
def sigmoid_fwd(x):
y = sigmoid(x)
return y, y # 保存 sigmoid(x) 的值,以便在反向传播中使用
def sigmoid_bwd(y, grad_output):
# y is sigmoid(x)
grad_x = grad_output * y[0] * (1 - y[0])
return (grad_x,)
sigmoid.defvjp(sigmoid_fwd, sigmoid_bwd)
# 测试
x = jnp.array([-100.0, 0.0, 100.0])
y = sigmoid(x)
print(f"sigmoid({x}) = {y}")
grad_fn = jax.grad(lambda x: jnp.sum(sigmoid(x)))
grad_x = grad_fn(x)
print(f"grad(sigmoid({x})) = {grad_x}")
在这个例子中,sigmoid_fwd 函数计算 sigmoid 函数的值,并将其保存下来。 sigmoid_bwd 函数使用保存的 sigmoid 值来计算梯度,避免重复计算。 虽然这个例子没有完全解决数值稳定性问题,但它展示了如何通过自定义 VJP 来优化梯度计算。实际上,更有效的做法是使用 log-sigmoid 技巧,或者使用 JAX 提供的 jax.nn.log_sigmoid 函数。
6. 高阶导数和自定义 VJP
自定义 VJP 也可以用于计算高阶导数。 由于我们已经为自定义函数定义了 VJP,JAX 可以自动地对 VJP 进行微分,从而计算高阶导数。
import jax
import jax.numpy as jnp
# 使用前面定义的 relu 函数
# 计算二阶导数
hessian_fn = jax.grad(jax.grad(lambda x: jnp.sum(relu(x))))
x = jnp.array([-1.0, 0.0, 1.0])
hessian_x = hessian_fn(x)
print(f"hessian(relu({x})) = {hessian_x}")
在这个例子中,我们使用 jax.grad 两次来计算 ReLU 函数的二阶导数。 由于我们已经为 ReLU 函数定义了自定义 VJP,JAX 可以自动地计算其高阶导数。
7. 注意事项和最佳实践
- 确保 VJP 的正确性: 自定义 VJP 的正确性至关重要。 错误的 VJP 会导致错误的梯度计算,从而影响模型的训练和预测。 可以使用有限差分法或符号微分来验证 VJP 的正确性。
- 选择合适的 residuals: residuals 应该包含所有在反向传播中需要使用的信息。 选择合适的 residuals 可以提高 VJP 的效率和数值稳定性。
- 遵循链式法则: 在
bwd函数中,务必遵循链式法则来计算梯度。 确保梯度计算是正确的。 - 考虑性能: 自定义 VJP 的目的是为了提高性能。 在定义 VJP 时,要考虑性能因素,避免不必要的计算和内存分配。
- 利用 JAX 的特性: JAX 提供了许多有用的特性,例如
jax.jit和jax.vmap。 可以利用这些特性来进一步提高自定义 VJP 的性能。 - 单元测试: 编写单元测试来验证自定义 VJP 的正确性非常重要。
8. 高级技巧:使用 jax.pure_callback 与外部函数集成
jax.pure_callback 允许你将 Python 函数包装成 JAX 可以自动微分的形式。这在与不直接支持 JAX 的外部库集成时非常有用。你需要提供一个纯函数(即没有副作用的函数),并且需要显式地指定输入和输出的 jax.ShapeDtypeStruct。
例如,假设你有一个外部 Python 函数 external_function:
def external_function(x):
# 假设这是一个调用外部库的函数
return x * 2 # 简单的例子
你可以使用 jax.pure_callback 将其包装成 JAX 函数:
import jax
import jax.numpy as jnp
def external_function(x):
# 假设这是一个调用外部库的函数
return x * 2 # 简单的例子
def jax_external_function(x):
return jax.pure_callback(lambda x: external_function(x),
jax.ShapeDtypeStruct(x.shape, x.dtype),
x)
# 测试
x = jnp.array(5.0)
y = jax_external_function(x)
print(f"jax_external_function({x}) = {y}")
grad_fn = jax.grad(jax_external_function)
grad_x = grad_fn(x)
print(f"grad(jax_external_function({x})) = {grad_x}")
然后,你可以使用 jax.custom_vjp 为 jax_external_function 定义 VJP。 需要注意的是,jax.pure_callback 本身不支持自动微分,所以你必须显式地提供 VJP 定义。
9. 总结
今天我们学习了如何在 JAX 中自定义 VJP,这是实现新的自动微分规则的关键技术。我们讨论了自定义 VJP 的必要性、使用 jax.custom_vjp 的方法,并通过几个示例演示了如何为 ReLU 函数和 sigmoid 函数定义自定义 VJP。我们还讨论了如何使用自定义 VJP 来计算高阶导数。最后,我们强调了自定义 VJP 的一些注意事项和最佳实践。
10. 关键点的回顾:自定义 VJP 的要点
jax.custom_vjp是定义自定义 VJP 的核心工具。fwd函数返回原始结果和反向传播所需的信息(residuals)。bwd函数接收 residuals 和输出的梯度,然后计算关于输入的梯度。- 自定义 VJP 可以用于性能优化、处理不可微操作、处理数值稳定性问题等。
- 确保 VJP 的正确性至关重要。
更多IT精英技术系列讲座,到智猿学院