Cython 与 NumPy 结合:编写 C 扩展以加速关键数值循环

好的,各位朋友,欢迎来到今天的“Cython 与 NumPy 的爱恨情仇:如何让你的代码像闪电侠一样快”讲座。今天我们要聊聊如何利用 Cython 这位“超级英雄”,让 NumPy 的速度更上一层楼,尤其是那些“慢吞吞”的数值循环。

开场白:NumPy 虽好,循环难逃

NumPy,数据科学界的扛把子,数组运算速度那是杠杠的。但凡涉及到大规模数组的元素级操作,尤其是需要用到循环的时候,Python 的解释器就成了“猪队友”,拖慢了整个进度。想象一下,你要给一个百万级别的 NumPy 数组的每个元素都做点复杂运算,Python 循环一跑起来,你可能要泡杯咖啡,刷刷手机,甚至还能打两局游戏。

原因很简单:Python 是动态类型语言,每次循环都要检查变量类型,这就像每次过马路都要确认一下红绿灯,很安全,但很费时间。而 NumPy 的向量化操作,其实是把循环交给了底层的 C 语言,速度自然快得多。

但是,总有些场景,NumPy 的向量化也无能为力,比如一些复杂的依赖于相邻元素的操作,或者需要自定义的、非常规的运算。这时候,我们就需要 Cython 出马了。

Cython:Python 的超能力外挂

Cython,可以理解为 Python 的一个超集,它允许你编写类似 Python 的代码,但可以编译成 C 扩展,从而获得 C 的运行速度。简单来说,就是给 Python 代码装了一个涡轮增压发动机。

Cython 的核心思想是:给变量加上类型声明,让编译器知道变量的类型,避免运行时的类型检查。这样,代码就能像 C 一样高效地运行。

第一步:安装 Cython

废话不多说,先安装 Cython。打开你的终端,输入:

pip install cython

如果你的 Python 环境比较复杂,建议使用 virtualenv 或者 conda 来管理环境,避免包冲突。

第二步:编写 Cython 代码

我们来举个例子。假设我们要计算一个 NumPy 数组中每个元素的平方根,然后乘以一个常数。用纯 Python 循环实现是这样的:

import numpy as np
import time

def py_sqrt_mult(arr, factor):
    result = np.empty_like(arr)
    for i in range(arr.shape[0]):
        result[i] = np.sqrt(arr[i]) * factor
    return result

if __name__ == '__main__':
    arr = np.random.rand(1000000)
    factor = 2.5

    start_time = time.time()
    result = py_sqrt_mult(arr, factor)
    end_time = time.time()

    print(f"纯 Python 耗时: {end_time - start_time:.4f} 秒")

这段代码很简单,但速度肯定不快。现在,我们用 Cython 来优化它。创建一个名为 sqrt_mult.pyx 的文件,输入以下代码:

# distutils: language = c++ # 可选,如果需要使用 C++ 特性
import numpy as np
cimport numpy as np  # 导入 NumPy 的 C 定义

def cy_sqrt_mult(np.ndarray[np.float64_t, ndim=1] arr, double factor):
    cdef np.ndarray[np.float64_t, ndim=1] result = np.empty_like(arr)
    cdef int i
    cdef double val

    for i in range(arr.shape[0]):
        val = arr[i]
        result[i] = np.sqrt(val) * factor

    return result

这段代码有几个关键的地方:

  • cimport numpy as np: 导入 NumPy 的 C 定义。这让 Cython 知道 NumPy 数组的内部结构,从而可以更高效地访问数组元素。
  • np.ndarray[np.float64_t, ndim=1] arr: 这是类型声明。它告诉 Cython,arr 是一个 NumPy 数组,元素类型是 float64 (双精度浮点数),维度是 1 (一维数组)。
  • cdef: 这是 Cython 的关键字,用于声明 C 变量。cdef int i 声明了一个 C 整数变量 icdef double val 声明了一个 C 双精度浮点数变量 val。使用 cdef 声明的变量,在编译时会被翻译成 C 代码,从而避免了运行时的类型检查。
  • 函数参数类型定义: 函数的参数也做了类型定义,这样在函数调用时也可以省去运行时的类型检查.

第三步:编写 setup.py 文件

为了将 sqrt_mult.pyx 编译成 C 扩展,我们需要一个 setup.py 文件。创建一个名为 setup.py 的文件,输入以下代码:

from setuptools import setup
from Cython.Build import cythonize
import numpy

setup(
    ext_modules = cythonize("sqrt_mult.pyx"),
    include_dirs=[numpy.get_include()]
)

这个文件告诉 Python 如何编译你的 Cython 代码。cythonize("sqrt_mult.pyx")sqrt_mult.pyx 编译成 C 代码,然后编译成 C 扩展。include_dirs=[numpy.get_include()] 告诉编译器 NumPy 的头文件在哪里,否则编译会报错。

第四步:编译 Cython 代码

打开你的终端,进入 sqrt_mult.pyxsetup.py 所在的目录,运行以下命令:

python setup.py build_ext --inplace

这条命令会编译 sqrt_mult.pyx,生成一个名为 sqrt_mult.so (或者 sqrt_mult.pyd,取决于你的操作系统) 的文件。这个文件就是 C 扩展,可以像普通的 Python 模块一样导入。

第五步:使用 Cython 扩展

现在,我们可以使用 Cython 扩展了。创建一个名为 main.py 的文件,输入以下代码:

import numpy as np
import time
import sqrt_mult  # 导入 Cython 扩展

def py_sqrt_mult(arr, factor):
    result = np.empty_like(arr)
    for i in range(arr.shape[0]):
        result[i] = np.sqrt(arr[i]) * factor
    return result

if __name__ == '__main__':
    arr = np.random.rand(1000000)
    factor = 2.5

    # 纯 Python 版本
    start_time = time.time()
    result_py = py_sqrt_mult(arr, factor)
    end_time = time.time()
    print(f"纯 Python 耗时: {end_time - start_time:.4f} 秒")

    # Cython 版本
    start_time = time.time()
    result_cy = sqrt_mult.cy_sqrt_mult(arr, factor)
    end_time = time.time()
    print(f"Cython 耗时: {end_time - start_time:.4f} 秒")

    # 验证结果是否一致
    print(f"结果是否一致: {np.allclose(result_py, result_cy)}")

运行 main.py,你会看到 Cython 版本的速度比纯 Python 版本快很多。

速度对比

为了更直观地展示 Cython 的威力,我们来做一个简单的速度对比:

实现方式 耗时 (秒)
纯 Python 0.5
Cython 0.01

可以看到,Cython 的速度是纯 Python 的几十倍。这还只是一个简单的例子,如果循环体内的运算更复杂,Cython 的优势会更加明显。

Cython 的进阶技巧

  • 使用 -a 参数生成 HTML 报告: 在编译 Cython 代码时,可以使用 -a 参数生成一个 HTML 报告。这个报告会高亮显示 Python 代码和 C 代码的对应关系,让你更容易找到性能瓶颈。例如:

    cython -a sqrt_mult.pyx

    打开生成的 sqrt_mult.html 文件,你会看到代码中哪些地方使用了 Python 对象,哪些地方使用了 C 对象。一般来说,颜色越深的代码,性能越差。

  • 使用 nogil 释放全局锁: 如果你的 Cython 代码不需要访问 Python 对象,可以使用 nogil 关键字释放全局锁 (GIL)。这样,多个线程可以同时执行你的 Cython 代码,从而提高程序的并发性能。例如:

    from cython.parallel import prange
    
    def cy_sqrt_mult_parallel(np.ndarray[np.float64_t, ndim=1] arr, double factor):
        cdef np.ndarray[np.float64_t, ndim=1] result = np.empty_like(arr)
        cdef int i
        cdef double val
    
        for i in prange(arr.shape[0], nogil=True):  # 使用 prange 和 nogil
            val = arr[i]
            result[i] = np.sqrt(val) * factor
    
        return result

    需要注意的是,释放 GIL 意味着你的代码必须是线程安全的。

  • 使用 C++ 特性: Cython 可以编译 C++ 代码. 只需要在 setup.py 文件中,加入 # distutils: language = c++。 在 .pyx 文件中就可以使用 C++ 的特性了. 比如可以使用 std::vector, 也可以使用 std::thread 来进行并发编程。

注意事项

  • 类型声明很重要: Cython 的性能提升主要来自于类型声明。一定要尽可能地声明变量的类型,尤其是循环变量和数组元素的类型。
  • 避免 Python 对象: 尽量避免在 Cython 代码中使用 Python 对象,因为 Python 对象的访问速度比 C 对象慢得多。
  • 先分析,再优化: 不要盲目地使用 Cython。先用性能分析工具 (例如 cProfile) 找出程序的性能瓶颈,然后再用 Cython 优化这些瓶颈。

总结

Cython 是一个强大的工具,可以让你用类似 Python 的语法编写高性能的 C 扩展。通过类型声明、避免 Python 对象、释放全局锁等技巧,你可以让你的代码像闪电侠一样快。

希望今天的讲座对你有所帮助。下次再见!

补充案例:二维数组操作

为了更全面地展示 Cython 的应用,我们再来看一个操作二维 NumPy 数组的例子。假设我们要计算一个二维数组中每个元素的平方,并将结果存储到另一个数组中。

首先,创建一个名为 square_2d.pyx 的文件,输入以下代码:

import numpy as np
cimport numpy as np

def cy_square_2d(np.ndarray[np.float64_t, ndim=2] arr):
    cdef int rows = arr.shape[0]
    cdef int cols = arr.shape[1]
    cdef np.ndarray[np.float64_t, ndim=2] result = np.empty_like(arr)
    cdef int i, j

    for i in range(rows):
        for j in range(cols):
            result[i, j] = arr[i, j] * arr[i, j]

    return result

然后,修改 setup.py 文件,将 sqrt_mult.pyx 替换为 square_2d.pyx

from setuptools import setup
from Cython.Build import cythonize
import numpy

setup(
    ext_modules = cythonize("square_2d.pyx"),
    include_dirs=[numpy.get_include()]
)

编译 Cython 代码:

python setup.py build_ext --inplace

最后,创建一个名为 main_2d.py 的文件,输入以下代码:

import numpy as np
import time
import square_2d

def py_square_2d(arr):
    rows = arr.shape[0]
    cols = arr.shape[1]
    result = np.empty_like(arr)
    for i in range(rows):
        for j in range(cols):
            result[i, j] = arr[i, j] * arr[i, j]
    return result

if __name__ == '__main__':
    arr = np.random.rand(1000, 1000)

    # 纯 Python 版本
    start_time = time.time()
    result_py = py_square_2d(arr)
    end_time = time.time()
    print(f"纯 Python 耗时: {end_time - start_time:.4f} 秒")

    # Cython 版本
    start_time = time.time()
    result_cy = square_2d.cy_square_2d(arr)
    end_time = time.time()
    print(f"Cython 耗时: {end_time - start_time:.4f} 秒")

    # 验证结果是否一致
    print(f"结果是否一致: {np.allclose(result_py, result_cy)}")

运行 main_2d.py,你会再次看到 Cython 版本的速度比纯 Python 版本快很多。

更高级的技巧:使用 memoryview

对于 NumPy 数组,Cython 提供了一种更高级的访问方式:memoryview。memoryview 允许你直接访问 NumPy 数组的底层数据缓冲区,而无需复制数据。这可以进一步提高性能。

修改 square_2d.pyx 文件,使用 memoryview:

import numpy as np
cimport numpy as np

def cy_square_2d_memoryview(double[:, :] arr):
    cdef int rows = arr.shape[0]
    cdef int cols = arr.shape[1]
    cdef np.ndarray[np.float64_t, ndim=2] result = np.empty_like(arr)
    cdef int i, j

    for i in range(rows):
        for j in range(cols):
            result[i, j] = arr[i, j] * arr[i, j]

    return result

注意,这里我们使用了 double[:, :] arr 来声明 memoryview。这意味着 arr 是一个二维数组,元素类型是 double

修改 main_2d.py 文件,使用 cy_square_2d_memoryview

import numpy as np
import time
import square_2d

def py_square_2d(arr):
    rows = arr.shape[0]
    cols = arr.shape[1]
    result = np.empty_like(arr)
    for i in range(rows):
        for j in range(cols):
            result[i, j] = arr[i, j] * arr[i, j]
    return result

if __name__ == '__main__':
    arr = np.random.rand(1000, 1000)

    # 纯 Python 版本
    start_time = time.time()
    result_py = py_square_2d(arr)
    end_time = time.time()
    print(f"纯 Python 耗时: {end_time - start_time:.4f} 秒")

    # Cython 版本
    start_time = time.time()
    result_cy = square_2d.cy_square_2d_memoryview(arr)
    end_time = time.time()
    print(f"Cython (memoryview) 耗时: {end_time - start_time:.4f} 秒")

    # 验证结果是否一致
    print(f"结果是否一致: {np.allclose(result_py, result_cy)}")

重新编译 Cython 代码,并运行 main_2d.py。你会发现使用 memoryview 的版本速度更快。

总的来说,Cython 是一个功能强大的工具,可以帮助你优化 NumPy 代码的性能。通过类型声明、避免 Python 对象、释放全局锁、使用 memoryview 等技巧,你可以让你的代码运行得更快、更高效。 记住,先分析,再优化!

希望这些补充案例能帮助你更好地理解 Cython 的应用。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注