C++23 多维数组切片(std::mdspan):在 C++ 高性能计算中通过多维视图提升矩阵运算的代码表达力

C++23 多维数组切片(std::mdspan):在 C++ 高性能计算中通过多维视图提升矩阵运算的代码表达力

在高性能计算(HPC)领域,矩阵和多维数组的运算是核心。从数值模拟到机器学习,从图像处理到科学计算,高效、清晰地处理这些数据结构至关重要。长期以来,C++ 开发者在处理多维数组时面临着表达力、性能和安全性的权衡。C++23 引入的 std::mdspan 为这一挑战带来了革命性的解决方案,它提供了一个非拥有的多维数组视图,显著提升了代码的表达力,同时保持了零开销的性能特性,并为优化提供了更多可能。

1. 传统多维数组处理的困境与挑战

在 C++ 中,处理多维数组,特别是大型矩阵或张量时,我们通常会遇到以下几种传统方法及其固有的局限性:

1.1. 原始指针与手动内存管理

最底层的方法是使用原始指针和 new/delete 来模拟多维数组。例如,一个动态的 2D 矩阵通常会通过一个指向指针的指针(T**)或一个大的一维数组来表示。

// 方法一:指向指针的指针 (T**)
template<typename T>
T** allocate_2d_array_ptr_ptr(size_t rows, size_t cols) {
    T** arr = new T*[rows];
    for (size_t i = 0; i < rows; ++i) {
        arr[i] = new T[cols];
    }
    return arr;
}

template<typename T>
void deallocate_2d_array_ptr_ptr(T** arr, size_t rows) {
    for (size_t i = 0; i < rows; ++i) {
        delete[] arr[i];
    }
    delete[] arr;
}

// 访问元素: arr[row][col]
// 缺点: 内存不连续,缓存效率低,多次内存分配和释放开销大。

// 方法二:单块连续内存 (T*)
template<typename T>
T* allocate_2d_array_contiguous(size_t rows, size_t cols) {
    return new T[rows * cols];
}

template<typename T>
void deallocate_2d_array_contiguous(T* arr) {
    delete[] arr;
}

// 访问元素: arr[row * cols + col] (行主序)
// 优点: 内存连续,缓存效率高。
// 缺点: 访问语法不直观,容易出错,缺乏维度信息和边界检查。
//       无法方便地传递子区域(切片)而不进行复制。

这种方法虽然提供了最大的灵活性和底层控制,但存在诸多问题:

  • 内存管理复杂且易错:手动 newdelete 很容易导致内存泄漏或双重释放。
  • 缺乏类型安全:原始指针无法携带维度信息,编译器无法帮助检查访问越界或维度不匹配的问题。
  • 代码表达力差:对于连续内存的方案,访问元素 arr[row * cols + col] 不直观,且难以在函数签名中表达多维数组的形状。
  • 切片困难:要获取子矩阵或切片,通常需要手动计算偏移量和新的维度,或者复制数据,这会带来性能开销。

1.2. std::vector<std::vector<T>>

这是 C++ 中更常见的、看似直观的多维数组表示方式。

std::vector<std::vector<double>> matrix(rows, std::vector<double>(cols));
// 访问元素: matrix[row][col]

这种方法虽然提供了更好的内存管理(由 std::vector 自动处理),但它通常不适用于 HPC:

  • 内存不连续std::vector<std::vector<T>> 的每一行 std::vector<T> 都在堆上独立分配,导致整个矩阵的内存是不连续的。这严重影响了 CPU 缓存的效率,对于大型矩阵运算会带来显著的性能下降。
  • 性能开销:多层 std::vector 的构造和析构开销较大,且每次访问 matrix[row] 都会涉及一次间接寻址。
  • 切片依然困难:虽然可以返回 std::vector<T>& 作为行切片,但无法直接表示列切片或子矩阵,除非进行数据复制。

1.3. 自定义封装类

为了解决上述问题,许多 HPC 库和项目会实现自己的多维数组封装类。这些类通常会:

  • 在内部使用一个 std::vector<T> 或原始指针管理一块连续内存。
  • 提供 operator()at() 方法来模拟多维索引访问,例如 matrix(row, col)
  • 封装维度信息和布局(行主序/列主序)。
  • 可能提供一些基本的切片功能,但往往需要自定义的辅助类或函数。
// 示例:一个简化的自定义矩阵类
template<typename T>
class MyMatrix {
public:
    MyMatrix(size_t rows, size_t cols) :
        _rows(rows), _cols(cols), _data(rows * cols) {}

    T& operator()(size_t r, size_t c) {
        return _data[r * _cols + c]; // 行主序
    }
    const T& operator()(size_t r, size_t c) const {
        return _data[r * _cols + c];
    }

    size_t rows() const { return _rows; }
    size_t cols() const { return _cols; }

private:
    size_t _rows, _cols;
    std::vector<T> _data;
};

// 优点: 内存连续,访问语法直观,封装了维度信息。
// 缺点: 仍需大量样板代码,切片功能通常需要为每个自定义类单独实现,
//       无法形成通用的、标准化的多维视图。

这种方法虽然改善了性能和表达力,但缺乏标准化。每个库或项目都可能有一套自己的实现,导致互操作性差,且每次需要切片或视图时,都可能需要复制数据或者创建新的轻量级视图对象,而这些视图对象本身也需要自定义实现。

1.4. 缺乏标准化的视图机制

所有上述传统方法的核心问题在于,C++ 标准库中一直缺乏一个通用的、零开销的、非拥有的多维数组视图。这种视图应该能够:

  • 引用任意一块连续内存作为多维数组。
  • 携带维度信息和布局信息。
  • 提供直观的多维索引访问。
  • 支持高效的切片操作,生成新的视图,而无需复制数据。
  • 与现有 C++ 容器(如 std::vector)和原始数组无缝协作。

这种缺失使得 C++ 在处理矩阵运算时,要么牺牲表达力(原始指针),要么牺牲性能(std::vector<std::vector>),要么陷入无尽的自定义封装类和样板代码的泥沼。这与 Fortran 等语言在处理多维数组方面的简洁和高效形成了鲜明对比。

2. std::mdspan 的诞生:多维视图的革命

C++23 引入的 std::mdspan(多维 span)正是为了填补这一空白。它是一个非拥有的(non-owning)多维数组视图,这意味着它不管理底层数据的所有权,只提供一个访问现有内存的方式。std::mdspan 的设计目标是零开销抽象,即在运行时不引入额外的性能负担,所有维度和布局信息都尽可能在编译时确定。

2.1. std::mdspan 是什么?

std::mdspan 可以被理解为 C++17 std::span 的多维版本。std::span 提供了一个连续一维内存区域的视图,而 std::mdspan 则将这个概念扩展到了任意维度。它通过封装原始数据指针、维度信息和布局映射,使得我们可以将一维的连续内存块“看作”一个多维数组。

std::mdspan 的完整类型签名是:

template<class ElementType,
         class Extents,
         class LayoutPolicy = std::layout_right,
         class AccessorPolicy = std::default_accessor<ElementType>>
class mdspan;

让我们逐一剖析这些模板参数:

  • ElementType:视图中元素的类型,例如 double, float, int 等。
  • Extents:一个 std::extents 类型,用于描述多维数组的维度(例如 3×4 矩阵,或 2x3x5 张量)。它可以在编译时或运行时确定维度。
  • LayoutPolicy:布局策略,用于将多维索引(如 (row, col))映射到一维的内存偏移量。C++23 提供了三种标准布局:std::layout_right (行主序), std::layout_left (列主序), std::layout_stride (自定义步长)。
  • AccessorPolicy:访问器策略,定义了如何访问底层数据。默认是 std::default_accessor,它直接使用指针解引用。可以自定义访问器来处理特殊内存(如原子操作、GPU 内存等),但通常情况下不需要修改。

2.2. std::mdspan 的核心优势

  1. 零开销抽象std::mdspan 自身不持有数据,只持有指针、维度和布局信息。这意味着它的创建、复制和销毁几乎没有运行时开销。所有的索引计算和边界检查(如果启用)都尽可能在编译时完成。
  2. 内存连续性mdspan 视图总是基于一块连续的内存。这保证了最佳的缓存局部性,对于 HPC 而言至关重要。
  3. 强大的表达力:通过 operator() 语法,可以像访问原生多维数组一样直观地访问元素,例如 matrix(i, j)
  4. 标准化的切片与子视图std::submdspan 函数是 mdspan 的核心功能之一,它允许我们高效地创建现有 mdspan 的子视图,无需数据复制。这使得实现分块算法、传递子区域给函数变得异常简单。
  5. 类型安全与维度检查ExtentsLayoutPolicy 在编译时提供了丰富的类型信息。如果维度不匹配或索引越界(在调试模式下),可以提供编译时错误或运行时断言。
  6. 与现有代码集成mdspan 可以轻松地从原始指针、std::vector<T>std::array 等现有数据源构造。
  7. 灵活的布局策略:支持行主序(std::layout_right,C/C++ 风格)、列主序(std::layout_left,Fortran 风格)以及自定义步长(std::layout_stride),这对于与不同语言或库进行互操作性、以及针对特定算法优化缓存访问模式非常有用。

3. std::mdspan 的基本使用

3.1. std::extents:定义维度

std::extents 定义了多维数组的每个维度的长度。它支持编译时已知维度和运行时已知维度。

// 编译时已知维度 (静态维度)
// 一个 3x4 的矩阵
std::extents<size_t, 3, 4> static_extents; // 编译器知道维度是 3 和 4

// 运行时已知维度 (动态维度)
// 一个 NxM 的矩阵
size_t rows = 5;
size_t cols = 10;
std::extents<size_t, std::dynamic_extent, std::dynamic_extent> dynamic_extents(rows, cols);
// std::dynamic_extent 是一个占位符,表示该维度在运行时确定

// 混合维度
// 一个 2xNx3 的张量
size_t N = 7;
std::extents<size_t, 2, std::dynamic_extent, 3> mixed_extents(N); // 只需要提供动态维度的值

使用 std::dynamic_extent 可以极大地增加 mdspan 的灵活性。如果所有维度在编译时都已知,那么编译器可以进行更多的优化,例如消除一些运行时检查。

3.2. 布局策略 (std::layout_right, std::layout_left, std::layout_stride)

布局策略决定了多维索引如何映射到一维内存中的偏移量。

策略类型 描述 内存访问顺序(对于 2D 矩阵) 典型应用
std::layout_right 行主序 (Row-Major):最右侧的维度变化最快。C/C++ 默认的二维数组布局。 arr[i][j]j 变化最快。 (0,0), (0,1), ..., (0,N-1), (1,0), ..., (M-1,N-1) C/C++ 程序,图像处理
std::layout_left 列主序 (Column-Major):最左侧的维度变化最快。Fortran 和 BLAS/LAPACK 库常用的布局。 arr[i][j]i 变化最快。 (0,0), (1,0), ..., (M-1,0), (0,1), ..., (M-1,N-1) 与 Fortran 库互操作,BLAS/LAPACK 风格的矩阵运算,某些线性代数算法
std::layout_stride 自定义步长:允许为每个维度指定一个步长,提供最大的灵活性。 完全由用户定义,可以实现非连续、跳跃访问,或者转置视图等复杂布局。 高级优化,稀疏矩阵,内存不规则的数据结构,转置视图

示例:创建 mdspan

#include <vector>
#include <mdspan> // C++23 header
#include <iostream>
#include <numeric> // For std::iota

void print_mdspan(const auto& mds) {
    std::cout << "MDSPAN (" << mds.extents().extent(0) << "x" << mds.extents().extent(1) << "):n";
    for (size_t i = 0; i < mds.extents().extent(0); ++i) {
        for (size_t j = 0; j < mds.extents().extent(1); ++j) {
            std::cout << mds(i, j) << "t";
        }
        std::cout << "n";
    }
    std::cout << "n";
}

int main() {
    // 1. 从 std::vector 创建一个 3x4 的 mdspan (行主序)
    std::vector<int> data_vec(3 * 4);
    std::iota(data_vec.begin(), data_vec.end(), 0); // 填充 0, 1, ..., 11

    // 运行时维度
    std::mdspan<int, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        matrix_right(data_vec.data(), 3, 4); // 3 rows, 4 columns
    print_mdspan(matrix_right);
    /* Output:
    MDSPAN (3x4):
    0       1       2       3
    4       5       6       7
    8       9       10      11
    */

    // 2. 从原始数组创建一个 3x4 的 mdspan (列主序)
    int raw_data[3 * 4];
    std::iota(raw_data, raw_data + (3 * 4), 0);

    // 静态维度,指定列主序
    std::mdspan<int, std::extents<size_t, 3, 4>, std::layout_left>
        matrix_left(raw_data); // 维度已在 extents 中指定,无需额外参数

    // 访问元素 (i,j) 在列主序中是不同的
    // matrix_left(0,0) -> raw_data[0]
    // matrix_left(1,0) -> raw_data[1]
    // matrix_left(2,0) -> raw_data[2]
    // matrix_left(0,1) -> raw_data[3]
    // ...
    print_mdspan(matrix_left); // 注意输出顺序会不同
    /* Output:
    MDSPAN (3x4):
    0       3       6       9
    1       4       7       10
    2       5       8       11
    */

    // 3. 使用 std::layout_stride 创建一个转置视图
    // 假设我们有一个行主序的 3x4 矩阵,我们想把它看作一个 4x3 的矩阵 (转置)
    // 原始数据: matrix_right 的数据
    // 原始行步长是 4 * sizeof(int), 列步长是 1 * sizeof(int)
    // 转置后:新的行是原来的列,新的列是原来的行
    // 新的 (r, c) 对应原始的 (c, r)
    // 新的维度: 4行,3列
    // 新的步长:
    //   对于新的行索引 r (对应原始的列索引 c): 步长是原始的列步长 (1)
    //   对于新的列索引 c (对应原始的行索引 r): 步长是原始的行步长 (4)
    std::array<size_t, 2> strides = {1, 4}; // {new_row_stride, new_col_stride}
    std::mdspan<int, std::extents<size_t, 4, 3>, std::layout_stride>
        matrix_transposed(data_vec.data(), std::layout_stride::mapping<std::extents<size_t, 4, 3>>(
                                                std::extents<size_t, 4, 3>(), strides));
    print_mdspan(matrix_transposed);
    /* Output:
    MDSPAN (4x3):
    0       4       8
    1       5       9
    2       6       10
    3       7       11
    */
    // 验证:matrix_transposed(0,1) 应该是原始 matrix_right(1,0) = 4
    // matrix_transposed(0,1) 实际上是原始的 data_vec[1 * 4 + 0] = 4
    // matrix_transposed(1,0) 应该是原始 matrix_right(0,1) = 1
    // matrix_transposed(1,0) 实际上是原始的 data_vec[0 * 4 + 1] = 1
    // 这里的 stride 定义的是对应于原始数据的一维索引的步长
    // 对于转置后的 (new_row, new_col)
    // 原始索引 = new_col * original_cols + new_row
    // 那么,layout_stride::mapping 的 strides 应该是 {original_cols, 1}
    // 对于 4x3 的转置矩阵,新的 extents 是 {4, 3}
    // mapping(extents, {stride_dim0, stride_dim1})
    // 步长参数是 `std::array<size_t, Rank>`
    // 对于一个 M x N 的原始矩阵 (行主序),其步长为 {N, 1}
    // 转置后是 N x M 矩阵 (新的 extents 是 N, M)
    // 它的第一个维度 (原列) 的步长是 1 (原列索引每增加1,原始数据指针移动1)
    // 它的第二个维度 (原行) 的步长是 N (原行索引每增加1,原始数据指针移动N)
    // 所以,对于 4x3 的转置矩阵,layout_stride 的步长应该是 {1, 4}
    // matrix_transposed(r, c) => data_vec[r * stride[0] + c * stride[1]]
    //                       => data_vec[r * 1 + c * 4]
    // 这与原始 data_vec[c * original_cols + r] 是等价的,因为 original_cols = 4
}

3.3. 访问器策略 (std::default_accessor)

std::default_accessor<ElementType> 是默认的访问器,它简单地通过指针解引用来访问元素。

template<class ElementType>
struct default_accessor {
  using element_type = ElementType;
  using reference = ElementType&;
  using pointer = ElementType*;

  constexpr default_accessor() noexcept = default;

  template<class OtherElementType>
  constexpr default_accessor(default_accessor<OtherElementType>) noexcept {}

  constexpr reference access(pointer p, size_t i) const noexcept {
    return p[i];
  }
};

你可以自定义访问器来实现更复杂的行为,例如:

  • 对原子类型进行线程安全访问。
  • 将数据存储在特殊的内存区域(如 GPU 内存、共享内存)。
  • 在访问时进行加密/解密。

但在绝大多数 HPC 场景中,std::default_accessor 已经足够。

4. std::mdspan 的核心:多维数组切片 (std::submdspan)

std::mdspan 最强大的功能之一就是其对切片(slicing)的天然支持,通过 std::submdspan 函数实现。std::submdspan 允许从一个现有的 mdspan 创建一个新的 mdspan 视图,指向原始数据的一个子区域,而无需复制数据。这对于实现分块算法(如分块矩阵乘法)和将数据子集传递给函数非常有用。

std::submdspan 的基本用法是:

template<class ElementType, class Extents, class LayoutPolicy, class AccessorPolicy,
         class... SliceSpecs>
constexpr auto submdspan(const mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>& src,
                         SliceSpecs... slices);

SliceSpecs 是一个可变参数包,每个参数指定了对应维度如何进行切片。可以使用的切片参数类型包括:

  • std::full_extent_t:表示选取整个维度。
  • std::pair<IndexType, IndexType>:表示选取一个范围 [start, end)
  • IndexType:表示选取一个单个索引(这会降低视图的维度)。

下面通过具体例子来展示 std::submdspan 的强大之处。

首先,我们定义一个示例矩阵:

#include <vector>
#include <mdspan>
#include <iostream>
#include <numeric> // For std::iota
#include <array>   // For std::array as stride in layout_stride

// 辅助函数:打印 mdspan
template<typename MDS>
void print_mdspan_info(const std::string& name, const MDS& mds) {
    std::cout << "--- " << name << " ---n";
    std::cout << "  Rank: " << mds.rank() << "n";
    std::cout << "  Extents: (";
    for (size_t i = 0; i < mds.rank(); ++i) {
        std::cout << mds.extent(i) << (i == mds.rank() - 1 ? "" : ", ");
    }
    std::cout << ")n";
    std::cout << "  First element: " << mds(mds.extents().index_type(), mds.extents().index_type(), mds.extents().index_type()) << "n"; // Example for 3D
    std::cout << "  Elements:n";
    if (mds.rank() == 2) {
        for (size_t i = 0; i < mds.extent(0); ++i) {
            for (size_t j = 0; j < mds.extent(1); ++j) {
                std::cout << mds(i, j) << "t";
            }
            std::cout << "n";
        }
    } else if (mds.rank() == 1) {
        for (size_t i = 0; i < mds.extent(0); ++i) {
            std::cout << mds(i) << "t";
        }
        std::cout << "n";
    } else {
        std::cout << "  (Too complex to print generically)n";
    }
    std::cout << "n";
}

int main() {
    std::vector<int> data(4 * 5); // 4 rows, 5 columns
    std::iota(data.begin(), data.end(), 1); // 填充 1, 2, ..., 20

    // 原始 4x5 矩阵
    std::mdspan<int, std::extents<size_t, 4, 5>> matrix(data.data());
    print_mdspan_info("Original Matrix (4x5)", matrix);
    /* Output:
    --- Original Matrix (4x5) ---
      Rank: 2
      Extents: (4, 5)
      First element: 1
      Elements:
      1       2       3       4       5
      6       7       8       9       10
      11      12      13      14      15
      16      17      18      19      20
    */

4.1. 行切片 (Row Slicing)

获取矩阵的某一行。这将把一个 2D 视图切片成一个 1D 视图。

    // 获取第 1 行 (索引为 1)
    auto row1_view = std::submdspan(matrix, 1, std::full_extent); // 1 表示行索引,std::full_extent 表示列全部
    print_mdspan_info("Row 1 View (1D)", row1_view);
    /* Output:
    --- Row 1 View (1D) ---
      Rank: 1
      Extents: (5)
      First element: 6
      Elements:
      6       7       8       9       10
    */

    // 获取第 0 行到第 2 行 (不包含 2)
    auto rows_0_to_1_view = std::submdspan(matrix, std::pair{0, 2}, std::full_extent);
    print_mdspan_info("Rows 0-1 View (2x5)", rows_0_to_1_view);
    /* Output:
    --- Rows 0-1 View (2x5) ---
      Rank: 2
      Extents: (2, 5)
      First element: 1
      Elements:
      1       2       3       4       5
      6       7       8       9       10
    */

4.2. 列切片 (Column Slicing)

获取矩阵的某一列。

    // 获取第 2 列 (索引为 2)
    auto col2_view = std::submdspan(matrix, std::full_extent, 2); // std::full_extent 表示行全部,2 表示列索引
    print_mdspan_info("Column 2 View (1D)", col2_view);
    /* Output:
    --- Column 2 View (1D) ---
      Rank: 1
      Extents: (4)
      First element: 3
      Elements:
      3       8       13      18
    */

    // 获取第 1 列到第 3 列 (不包含 3)
    auto cols_1_to_2_view = std::submdspan(matrix, std::full_extent, std::pair{1, 3});
    print_mdspan_info("Cols 1-2 View (4x2)", cols_1_to_2_view);
    /* Output:
    --- Cols 1-2 View (4x2) ---
      Rank: 2
      Extents: (4, 2)
      First element: 2
      Elements:
      2       3
      7       8
      12      13
      17      18
    */

4.3. 子矩阵/块切片 (Sub-matrix/Block Slicing)

获取矩阵的一个矩形子区域。

    // 获取从 (1, 1) 开始,大小为 2x3 的子矩阵
    auto sub_matrix_view = std::submdspan(matrix, std::pair{1, 3}, std::pair{1, 4}); // 行 1-2, 列 1-3
    print_mdspan_info("Sub-Matrix View (2x3)", sub_matrix_view);
    /* Output:
    --- Sub-Matrix View (2x3) ---
      Rank: 2
      Extents: (2, 3)
      First element: 7
      Elements:
      7       8       9
      12      13      14
    */

4.4. 任意维度切片

std::mdspan 不仅限于 2D。假设我们有一个 3D 张量。

    std::vector<int> data_3d(2 * 3 * 4); // 2 pages, 3 rows, 4 columns
    std::iota(data_3d.begin(), data_3d.end(), 100);

    std::mdspan<int, std::extents<size_t, 2, 3, 4>> tensor(data_3d.data());
    // print_mdspan_info("Original Tensor (2x3x4)", tensor); // This print helper is for 2D/1D only

    // 获取第二页 (索引为 1) 的所有数据,这将得到一个 2D 视图
    auto page1_view = std::submdspan(tensor, 1, std::full_extent, std::full_extent);
    std::cout << "--- Page 1 View (3x4) ---n";
    std::cout << "  Rank: " << page1_view.rank() << "n";
    std::cout << "  Extents: (" << page1_view.extent(0) << ", " << page1_view.extent(1) << ")n";
    std::cout << "  Elements:n";
    for (size_t i = 0; i < page1_view.extent(0); ++i) {
        for (size_t j = 0; j < page1_view.extent(1); ++j) {
            std::cout << page1_view(i, j) << "t";
        }
        std::cout << "n";
    }
    std::cout << "n";
    /* Output:
    --- Page 1 View (3x4) ---
      Rank: 2
      Extents: (3, 4)
      Elements:
      112     113     114     115
      116     117     118     119
      120     121     122     123
    */

    // 获取所有页的第 0 行第 1 列,这将得到一个 1D 视图
    auto column_vector_view = std::submdspan(tensor, std::full_extent, 0, 1);
    std::cout << "--- All pages, Row 0, Col 1 View (1D) ---n";
    std::cout << "  Rank: " << column_vector_view.rank() << "n";
    std::cout << "  Extents: (" << column_vector_view.extent(0) << ")n";
    std::cout << "  Elements:n";
    for (size_t i = 0; i < column_vector_view.extent(0); ++i) {
        std::cout << column_vector_view(i) << "t";
    }
    std::cout << "nn";
    /* Output:
    --- All pages, Row 0, Col 1 View (1D) ---
      Rank: 1
      Extents: (2)
      Elements:
      101     113
    */

4.5. 步长切片 (Strided Slicing)

std::submdspan 也支持通过 std::tuplestd::integer_sequence 来实现更复杂的步长切片,但更常见和灵活的方式是通过 std::layout_stride 布局策略来构建具有特定步长的 mdspan。然而,std::submdspan 本身支持 strided_slice 对象来直接生成步长切片。

    // 获取矩阵的偶数行
    // submdspan 的第三个参数可以是 std::tuple<IndexType, IndexType, IndexType>
    // 代表 {start, end, stride}
    auto even_rows_view = std::submdspan(matrix, std::tuple{0, 4, 2}, std::full_extent);
    print_mdspan_info("Even Rows View (2x5)", even_rows_view);
    /* Output:
    --- Even Rows View (2x5) ---
      Rank: 2
      Extents: (2, 5)
      First element: 1
      Elements:
      1       2       3       4       5
      11      12      13      14      15
    */

    // 获取矩阵的奇数列
    auto odd_cols_view = std::submdspan(matrix, std::full_extent, std::tuple{1, 5, 2});
    print_mdspan_info("Odd Columns View (4x2)", odd_cols_view);
    /* Output:
    --- Odd Columns View (4x2) ---
      Rank: 2
      Extents: (4, 2)
      First element: 2
      Elements:
      2       4
      7       9
      12      14
      17      19
    */

5. std::mdspan 在高性能计算中的应用

std::mdspan 不仅仅是一个语法糖,它为 HPC 带来了实实在在的性能和表达力提升。

5.1. 矩阵乘法示例

矩阵乘法是 HPC 中的一个经典问题,也是展示 mdspan 优势的绝佳场景。

考虑经典的 $C = A times B$ 矩阵乘法,其中 $A$ 是 $M times K$ 矩阵,$B$ 是 $K times N$ 矩阵,$C$ 是 $M times N$ 矩阵。

传统 C++ 实现 (假设数据是行主序的连续内存)

// 传统 C++ 矩阵乘法 (使用原始指针)
void matrix_multiply_raw_ptr(const double* A, const double* B, double* C,
                             size_t M, size_t K, size_t N) {
    for (size_t i = 0; i < M; ++i) {
        for (size_t j = 0; j < N; ++j) {
            double sum = 0.0;
            for (size_t l = 0; l < K; ++l) {
                sum += A[i * K + l] * B[l * N + j];
            }
            C[i * N + j] = sum;
        }
    }
}

这种写法直观上理解困难,索引 i * K + ll * N + j 很容易混淆,且每次访问都需要乘法和加法运算。

使用 std::mdspan 的矩阵乘法

#include <vector>
#include <mdspan>
#include <iostream>
#include <numeric>
#include <chrono>

// 使用 mdspan 的矩阵乘法
// 这里使用 auto& 接收 mdspan,可以处理不同布局和维度类型的 mdspan
template<typename MA, typename MB, typename MC>
void matrix_multiply_mdspan(MA A, MB B, MC C) {
    // 检查维度匹配 (运行时断言或编译时 if constexpr)
    if (A.extent(1) != B.extent(0) || A.extent(0) != C.extent(0) || B.extent(1) != C.extent(1)) {
        throw std::runtime_error("Matrix dimensions mismatch for multiplication.");
    }

    const size_t M = A.extent(0);
    const size_t K = A.extent(1);
    const size_t N = B.extent(1);

    for (size_t i = 0; i < M; ++i) {
        for (size_t j = 0; j < N; ++j) {
            double sum = 0.0;
            for (size_t l = 0; l < K; ++l) {
                sum += A(i, l) * B(l, j); // 直观的多维索引访问
            }
            C(i, j) = sum;
        }
    }
}

int main_mm() {
    const size_t M = 500, K = 300, N = 400;

    std::vector<double> dataA(M * K);
    std::vector<double> dataB(K * N);
    std::vector<double> dataC(M * N, 0.0); // Initialize C with zeros

    // 填充数据
    std::iota(dataA.begin(), dataA.end(), 0.1);
    std::iota(dataB.begin(), dataB.end(), 0.01);

    // 创建 mdspan 视图 (这里使用行主序)
    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        A_mds(dataA.data(), M, K);
    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        B_mds(dataB.data(), K, N);
    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        C_mds(dataC.data(), M, N);

    auto start = std::chrono::high_resolution_clock::now();
    matrix_multiply_mdspan(A_mds, B_mds, C_mds);
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff = end - start;
    std::cout << "Matrix multiplication (mdspan) took: " << diff.count() << " sn";

    // 验证结果 (示例)
    // std::cout << "C(0,0) = " << C_mds(0,0) << "n";
    // std::cout << "C(M-1,N-1) = " << C_mds(M-1,N-1) << "n";

    // 此时 C_mds(i,j) 访问的是 dataC[i*N+j]
    // 我们可以用一个传统的原始指针函数来验证
    std::vector<double> dataC_raw(M * N, 0.0);
    start = std::chrono::high_resolution_clock::now();
    matrix_multiply_raw_ptr(dataA.data(), dataB.data(), dataC_raw.data(), M, K, N);
    end = std::chrono::high_resolution_clock::now();
    diff = end - start;
    std::cout << "Matrix multiplication (raw ptr) took: " << diff.count() << " sn";

    // 检查结果是否一致
    bool identical = true;
    for(size_t i = 0; i < M * N; ++i) {
        if (std::abs(dataC[i] - dataC_raw[i]) > 1e-9) {
            identical = false;
            break;
        }
    }
    std::cout << "Results are identical: " << (identical ? "Yes" : "No") << "n";

    return 0;
}

注意mdspan 本身不提供性能优化,它提供的是一个 工具,让开发者更容易编写出高性能的代码。上述简单的三重循环矩阵乘法,无论是用原始指针还是 mdspan,其渐进时间复杂度都是 $O(MKN)$,且缓存访问模式不佳。

5.2. 分块矩阵乘法 (Block Matrix Multiplication)

真正的性能提升往往来自于算法优化,比如分块矩阵乘法。std::mdspanstd::submdspan 使得实现这种分块算法变得异常简洁和安全。

分块矩阵乘法的基本思想是将大矩阵划分为若干个小块(子矩阵),然后以这些小块为单位进行乘法和加法运算。这样做可以显著改善缓存局部性,因为每个小块可以完全加载到缓存中进行计算。

假设 $A$ 是 $M times K$,$B$ 是 $K times N$,$C$ 是 $M times N$。我们将 $A$ 分为 $M_b times K_b$ 的块,$B$ 分为 $K_b times N_b$ 的块,$C$ 分为 $M_b times N_b$ 的块。

template<typename MA, typename MB, typename MC>
void block_matrix_multiply_mdspan(MA A, MB B, MC C, size_t block_size) {
    if (A.extent(1) != B.extent(0) || A.extent(0) != C.extent(0) || B.extent(1) != C.extent(1)) {
        throw std::runtime_error("Matrix dimensions mismatch for multiplication.");
    }

    const size_t M = A.extent(0);
    const size_t K = A.extent(1);
    const size_t N = B.extent(1);

    for (size_t i = 0; i < M; i += block_size) {
        for (size_t j = 0; j < N; j += block_size) {
            for (size_t l = 0; l < K; l += block_size) {
                // 定义当前块的范围
                size_t current_M_block = std::min(block_size, M - i);
                size_t current_K_block_A = std::min(block_size, K - l);
                size_t current_K_block_B = std::min(block_size, K - l);
                size_t current_N_block = std::min(block_size, N - j);

                // 使用 submdspan 获取子视图
                auto A_block = std::submdspan(A,
                                            std::pair{i, i + current_M_block},
                                            std::pair{l, l + current_K_block_A});
                auto B_block = std::submdspan(B,
                                            std::pair{l, l + current_K_block_B},
                                            std::pair{j, j + current_N_block});
                auto C_block = std::submdspan(C,
                                            std::pair{i, i + current_M_block},
                                            std::pair{j, j + current_N_block});

                // 对子块进行朴素矩阵乘法累加
                // 注意:这里需要一个 C_block += A_block * B_block 的操作
                // 简化起见,我们直接在 C_block 上执行一个朴素的乘法,并累加到 C_block
                // 实际中应该有一个单独的 kernel 来执行 C += A*B
                for (size_t ii = 0; ii < C_block.extent(0); ++ii) {
                    for (size_t jj = 0; jj < C_block.extent(1); ++jj) {
                        double sum = 0.0;
                        for (size_t ll = 0; ll < A_block.extent(1); ++ll) {
                            sum += A_block(ii, ll) * B_block(ll, jj);
                        }
                        C_block(ii, jj) += sum; // 累加到 C_block
                    }
                }
            }
        }
    }
}

// ... 在 main_mm 中调用 ...
int main_block_mm() {
    const size_t M = 500, K = 300, N = 400;
    const size_t BLOCK_SIZE = 64; // 块大小

    std::vector<double> dataA(M * K);
    std::vector<double> dataB(K * N);
    std::vector<double> dataC_block(M * N, 0.0);

    std::iota(dataA.begin(), dataA.end(), 0.1);
    std::iota(dataB.begin(), dataB.end(), 0.01);

    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        A_mds(dataA.data(), M, K);
    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        B_mds(dataB.data(), K, N);
    std::mdspan<double, std::extents<size_t, std::dynamic_extent, std::dynamic_extent>>
        C_mds_block(dataC_block.data(), M, N);

    auto start = std::chrono::high_resolution_clock::now();
    block_matrix_multiply_mdspan(A_mds, B_mds, C_mds_block, BLOCK_SIZE);
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff = end - start;
    std::cout << "Block matrix multiplication (mdspan) took: " << diff.count() << " sn";

    // 验证结果 (与朴素乘法结果比较)
    // ... 需要运行 main_mm 得到 C_mds 结果再比较 ...
    // 为了简化,这里只展示了计时
    return 0;
}

通过 std::submdspan,我们可以以非常简洁和安全的方式定义和操作这些子块,而无需手动计算复杂的指针偏移量。这不仅提高了代码的可读性,也降低了引入 bug 的风险。

5.3. BLAS-like 操作

BLAS (Basic Linear Algebra Subprograms) 库提供了高度优化的向量和矩阵操作。std::mdspan 是设计用于与 BLAS 等库无缝衔接的理想工具。

例如,实现一个简单的 axpy 操作:$y = alpha x + y$,其中 $x, y$ 是向量,$alpha$ 是标量。

// 传统的 axpy (使用原始指针)
void axpy_raw_ptr(double alpha, const double* x, double* y, size_t n) {
    for (size_t i = 0; i < n; ++i) {
        y[i] = alpha * x[i] + y[i];
    }
}

// 使用 mdspan 的 axpy
template<typename VX, typename VY>
void axpy_mdspan(double alpha, VX x, VY y) {
    if (x.extent(0) != y.extent(0)) {
        throw std::runtime_error("Vector dimensions mismatch for axpy.");
    }
    for (size_t i = 0; i < x.extent(0); ++i) {
        y(i) = alpha * x(i) + y(i);
    }
}

int main_axpy() {
    const size_t N = 1000000;
    double alpha = 2.5;

    std::vector<double> vecX(N);
    std::vector<double> vecY(N);
    std::vector<double> vecY_raw(N);

    std::iota(vecX.begin(), vecX.end(), 1.0);
    std::iota(vecY.begin(), vecY.end(), 100.0);
    vecY_raw = vecY; // 复制一份用于原始指针比较

    // 创建 mdspan 视图 (1D 向量)
    std::mdspan<double, std::extents<size_t, std::dynamic_extent>>
        X_mds(vecX.data(), N);
    std::mdspan<double, std::extents<size_t, std::dynamic_extent>>
        Y_mds(vecY.data(), N);

    auto start = std::chrono::high_resolution_clock::now();
    axpy_mdspan(alpha, X_mds, Y_mds);
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff = end - start;
    std::cout << "AXPY (mdspan) took: " << diff.count() << " sn";

    start = std::chrono::high_resolution_clock::now();
    axpy_raw_ptr(alpha, vecX.data(), vecY_raw.data(), N);
    end = std::chrono::high_resolution_clock::now();
    diff = end - start;
    std::cout << "AXPY (raw ptr) took: " << diff.count() << " sn";

    // 验证结果
    bool identical = true;
    for(size_t i = 0; i < N; ++i) {
        if (std::abs(vecY[i] - vecY_raw[i]) > 1e-9) {
            identical = false;
            break;
        }
    }
    std::cout << "AXPY results are identical: " << (identical ? "Yes" : "No") << "n";

    return 0;
}

mdspan 使得函数签名更清晰,编译器可以利用其维度信息进行更多的优化。当需要处理向量的子区域时,std::submdspan 同样能发挥作用,例如 axpy_mdspan(alpha, std::submdspan(X_mds, std::pair{0, N/2}), std::submdspan(Y_mds, std::pair{0, N/2}));

5.4. 与 Fortran/BLAS 互操作性

许多高性能计算库(如 Intel MKL、OpenBLAS)是用 Fortran 编写的,它们通常采用列主序(std::layout_left)。std::mdspan 对列主序的原生支持使得 C++ 代码可以更方便、更高效地与这些库进行互操作。你可以直接创建一个 std::layout_leftmdspan 视图,并将其底层指针传递给 Fortran 兼容的函数,而无需进行数据复制或转置操作。

// 假设有一个外部 Fortran/BLAS 函数
extern "C" {
    void dgemm_(char* transa, char* transb, int* m, int* n, int* k,
                double* alpha, const double* A, int* lda,
                const double* B, int* ldb, double* beta,
                double* C, int* ldc);
}

// C++ 封装函数,接受 mdspan 参数
template<typename MA, typename MB, typename MC>
void call_dgemm(MA A, MB B, MC C, double alpha = 1.0, double beta = 0.0) {
    char transa = 'N'; // No transpose
    char transb = 'N';
    int m = A.extent(0);
    int k = A.extent(1);
    int n = B.extent(1); // C's columns

    // Fortran BLAS 期望列主序,所以 A, B, C 都应该是 layout_left
    // 如果 A 是 layout_right,需要先转置成 layout_left,或者设置 transa='T'
    // 这里为了简化,假设 A,B,C 都是列主序
    if (!A.is_layout_left() || !B.is_layout_left() || !C.is_layout_left()) {
        std::cerr << "Warning: BLAS functions typically expect column-major layout. Provided mdspan might not be column-major.n";
    }

    // Leading dimensions (ldA, ldB, ldC) for column-major matrices
    // ldA = rows of A
    // ldB = rows of B
    // ldC = rows of C
    int lda = A.extent(0);
    int ldb = B.extent(0);
    int ldc = C.extent(0);

    dgemm_(&transa, &transb, &m, &n, &k,
           &alpha, A.data_handle(), &lda,
           B.data_handle(), &ldb, &beta,
           C.data_handle(), &ldc);
}

// 注意:实际使用时需要链接到 BLAS 库,例如 -lblas
// 这是一个演示其潜力的代码片段,不能直接运行

mdspandata_handle() 方法返回底层数据的原始指针,可以直接传递给 C 风格的 API。这种零开销的集成方式是其在 HPC 中不可或缺的原因。

6. 高级主题与最佳实践

6.1. 性能考量

std::mdspan 本身是零开销的,其性能优势体现在:

  • 缓存局部性:通过 std::layout_rightstd::layout_left 确保内存访问模式与 CPU 缓存行对齐,从而最大化缓存命中率。对于 std::layout_stride,需要开发者自行确保步长选择有利于缓存。
  • 消除复制:通过 std::submdspan 传递子区域,避免了不必要的数据复制,减少了内存带宽消耗和内存分配/释放开销。
  • 编译器优化mdspan 提供了丰富的编译时信息(如维度、布局),有助于编译器生成更优化的机器码。例如,如果所有维度都是静态的,索引计算可以在编译时完成。

调试模式下的边界检查:在调试模式下,mdspan 可以启用边界检查(例如通过 _GLIBCXX_DEBUG_MDSPAN 宏,或通过自定义 accessor)。这会引入运行时开销,但在开发阶段对于发现错误非常有价值。在发布版本中,这些检查通常会被禁用,以达到零开销。

6.2. std::mdspanstd::span

std::mdspan 可以看作是 std::span 的多维泛化。一个 1D 的 mdspanstd::span 功能相似,但 mdspan 提供了更丰富的布局和访问器选项。通常,对于简单的一维连续数据,std::span 仍然是更简洁的选择。

6.3. 与其他库的集成

std::mdspan 的标准化意味着它将成为 C++ 生态系统中处理多维数据的事实标准。未来的线性代数库、张量库以及并行计算框架(如 SYCL、CUDA C++)都可能原生支持或提供与 std::mdspan 的高效接口。例如,一些 GPU 编程模型可以通过自定义 accessor 来直接操作 GPU 内存上的 mdspan

6.4. 选择合适的布局

  • std::layout_right:如果你是 C/C++ 开发者,且你的代码主要以行遍历为主,或者需要与 C 风格的二维数组兼容,这是自然的选择。它通常与 C++ 中 std::vector<T> 存储一维数据的内存布局一致。
  • std::layout_left:如果你需要与 Fortran、BLAS 或某些数值库(它们通常使用列主序)进行互操作,或者你的算法天然更适合列主序访问(例如一些线性代数算法),那么 std::layout_left 是最佳选择。
  • std::layout_stride:当你需要非常规的内存访问模式,例如跳跃访问、处理稀疏数据、或者实现转置视图而不复制数据时,std::layout_stride 提供了最大的灵活性。但它也要求你对内存布局有更深入的理解,并手动计算步长。

6.5. 线程安全

std::mdspan 本身不提供线程安全。如果多个线程同时访问同一个 mdspan 视图所指向的底层数据,你仍然需要使用互斥量、原子操作或其他同步机制来确保数据一致性。自定义 accessor 可以用于封装原子操作,从而在 mdspan 访问层面上提供一定程度的线程安全。

7. 提升 C++ 高性能计算表达力的新范式

std::mdspan 是 C++23 标准库中一个里程碑式的特性。它通过引入一个标准化的、零开销的多维数组视图,极大地提升了 C++ 在高性能计算中处理矩阵和多维数组的表达力、安全性和效率。开发者可以编写出更清晰、更易于维护、同时又能充分利用硬件性能的代码。随着 std::mdspan 的普及和更广泛的库支持,它将成为 C++ HPC 领域不可或缺的工具,推动 C++ 在科学计算、机器学习等前沿领域的发展。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注