C++ 定点数运算库:在低功耗嵌入式 AI 芯片上的高效矩阵乘法实现

C++ 定点数运算库:在低功耗嵌入式 AI 芯片上的高效矩阵乘法实现

尊敬的各位同仁,女士们、先生们,大家好!

今天,我们齐聚一堂,共同探讨一个在当前人工智能浪潮中至关重要的话题:如何在低功耗嵌入式AI芯片上,利用C++定点数运算库,实现高效的矩阵乘法。随着AI技术从云端走向边缘,我们面临着前所未有的机遇与挑战。在资源受限的环境中,如何在保证模型性能的同时,最大限度地提升计算效率、降低功耗,是每一位工程师必须深思熟虑的问题。定点数运算,正是解决这一难题的关键利器。

引言:嵌入式AI与定点数运算的时代机遇与挑战

近年来,人工智能,特别是深度学习,取得了突破性进展,深刻改变了我们的生活。从智能语音助手到自动驾驶,从图像识别到自然语言处理,AI的应用场景无处不在。然而,随着模型规模的不断扩大和计算复杂度的急剧提升,将这些强大的AI能力部署到终端设备,如智能手机、物联网设备、可穿戴设备乃至微型传感器上,面临着严峻的挑战。这些“边缘侧”设备通常受限于:

  1. 低功耗要求: 电池供电或有限的电源供应,要求芯片在极低的功耗下运行。
  2. 实时性要求: 许多应用需要即时响应,如自动驾驶的决策,人脸识别的验证。
  3. 小尺寸与低成本: 嵌入式设备通常对物理尺寸和制造成本有严格限制。
  4. 有限的内存与带宽: 无法承载大型模型和高带宽的数据传输。

在深度学习模型中,矩阵乘法(Matrix Multiplication)是核心且计算量最大的操作之一,占据了高达80%甚至更多的计算资源。传统的AI模型训练和推理大多采用浮点数(如FP32、FP16)进行运算。浮点数提供了宽广的数值范围和较高的精度,但其代价是:

  • 硬件复杂度高: 浮点运算单元(FPU)比整数运算单元(ALU)面积更大,功耗更高。
  • 计算效率低: 浮点运算需要处理指数和尾数,指令周期更长。
  • 内存带宽占用大: 存储和传输浮点数需要更多的位,增加了内存访问的开销。

为了克服这些局限,定点数运算应运而生。定点数运算的核心思想是用整数来近似表示实数,通过预先确定小数点的位置来管理精度和范围。这使得定点数运算能够:

  • 降低硬件成本与功耗: 利用简单的整数ALU即可实现,无需复杂FPU。
  • 提升计算效率: 整数运算速度快,指令周期短。
  • 减少内存占用与带宽需求: 通常使用16位、8位甚至更低的位宽,极大压缩数据量。

因此,构建一个高效、灵活且易于使用的C++定点数运算库,并将其应用于矩阵乘法优化,对于推动嵌入式AI芯片的发展,具有不可估量的价值。本文将深入探讨定点数的基础理论、库设计原则、矩阵乘法的优化策略,以及精度与性能的平衡之道。

II. 定点数基础理论与表示

要设计一个定点数库,我们首先需要理解定点数的基本概念和表示方法。

什么是定点数?

定点数,顾名思义,是小数点位置固定的数。与浮点数通过指数部分动态调整小数点位置不同,定点数的小数点位置在设计时就已确定。它由一个整数部分和一个小数部分组成。例如,一个数123.456,如果我们规定小数点后有三位,那么它就可以被视为123456这个整数,只是我们知道它的实际值需要除以1000。

表示方法:Q格式与S.M.F格式

在数字信号处理和嵌入式领域,常用的定点数表示方法有两种:

  1. Q格式 (Qm.n):

    • m 表示整数部分的位数(不包括符号位)。
    • n 表示小数部分的位数。
    • 总位宽通常为 m + n + 1 (如果是有符号数,加1是符号位)。
    • 例如,Q1.15表示一个16位有符号定点数,其中1位是符号位,1位是整数位,15位是小数位。它的取值范围为 [-2^1, 2^1 - 2^-15],最小精度为 2^-15
    • 实际存储的是 Value * 2^n 的整数形式。
  2. S.M.F格式:

    • S 表示符号位 (1位)。
    • M 表示整数部分的位数。
    • F 表示小数部分的位数。
    • 总位宽为 S + M + F
    • 例如,S1.15表示一个16位有符号定点数,1位符号位,1位整数位,15位小数位,与Q1.15含义相同。

在本文中,我们将主要采用类似Q格式的思路,通过模板参数来定义总位宽和小数位宽。

符号位:有符号/无符号

定点数可以是无符号的(只表示非负数),也可以是有符号的(表示正负数)。有符号数通常采用二进制补码(Two’s Complement)表示,这使得加减法运算与无符号数保持一致,简化了硬件实现。

定点数选择的考量:精度与范围的权衡

选择合适的定点数格式是定点量化的核心挑战。我们需要在精度(由小数位数决定)和数值范围(由整数位数决定)之间进行权衡:

  • 增加小数位数 (n): 提升精度,但会缩小数值范围(在总位宽固定的情况下)。
  • 增加整数位数 (m): 扩大数值范围,但会降低精度。

在AI模型中,激活值(如ReLU输出)通常是非负的,可以使用无符号定点数以获取更大的正数范围。而权重和某些中间计算结果可能为负值,则必须使用有符号定点数。

溢出与饱和处理

定点数在运算过程中极易发生溢出。例如,两个Q1.15的数相乘,结果可能需要Q2.30来表示。如果直接存储回Q1.15,就会丢失信息。

  • 溢出 (Overflow): 运算结果超出了当前定点数格式能表示的最大或最小范围。
  • 饱和 (Saturation): 当发生溢出时,将结果限制在当前定点数格式所能表示的最大值或最小值。这是嵌入式系统中常用的处理方式,因为它能避免结果突变,保持一定的信号特性。
  • 环绕 (Wrap-around): 溢出时结果“回卷”,例如MAX_INT + 1 = MIN_INT。这通常会导致结果的显著失真,不适用于大多数AI场景。

C++中的定点数模拟:基于整数类型

C++标准库并没有内置的定点数类型。我们通常通过封装C++的整数类型(int8_t, int16_t, int32_t, int64_t等)来模拟定点数。其核心思想是,将一个实数 X 乘以一个比例因子 2^nn 为小数位数),然后存储这个整数结果。

例如,如果我们的定点数格式是Q1.15,总共16位,其中1位符号位,1位整数位,15位小数位。那么一个浮点数f_val在转换为定点数时,其内部存储的整数值 raw_val 就是 round(f_val * 2^15)

#include <cstdint>
#include <cmath>   // For round, pow
#include <limits>  // For numeric_limits
#include <iostream>
#include <type_traits> // For std::is_signed

// --- Code Example 1: Basic Fixed-Point Representation (struct/class design) ---

template <int TOTAL_BITS, int FRACTIONAL_BITS, bool IS_SIGNED = true>
class FixedPoint {
public:
    // 内部存储类型,根据总位宽选择合适的整数类型
    // 使用 std::conditional 来选择 int16_t, int32_t, int64_t
    using internal_type = typename std::conditional<
        TOTAL_BITS <= 16, int16_t,
        typename std::conditional<
            TOTAL_BITS <= 32, int32_t,
            int64_t
        >::type
    >::type;

    internal_type raw_value;

    static constexpr int s_total_bits = TOTAL_BITS;
    static constexpr int s_fractional_bits = FRACTIONAL_BITS;
    static constexpr int s_integer_bits = TOTAL_BITS - FRACTIONAL_BITS - (IS_SIGNED ? 1 : 0);
    static constexpr bool s_is_signed = IS_SIGNED;
    static constexpr internal_type s_scale = internal_type(1) << FRACTIONAL_BITS;

    // 构造函数
    FixedPoint() : raw_value(0) {}
    explicit FixedPoint(internal_type raw_val) : raw_value(raw_val) {}

    // 从浮点数转换
    static FixedPoint from_float(float f_val) {
        // 防止溢出,这里只是简单实现,后续会加入饱和处理
        float scaled_val = f_val * s_scale;
        if (s_is_signed) {
            if (scaled_val > std::numeric_limits<internal_type>::max()) {
                scaled_val = std::numeric_limits<internal_type>::max();
            } else if (scaled_val < std::numeric_limits<internal_type>::min()) {
                scaled_val = std::numeric_limits<internal_type>::min();
            }
        } else { // Unsigned
             if (scaled_val > std::numeric_limits<internal_type>::max()) {
                scaled_val = std::numeric_limits<internal_type>::max();
            } else if (scaled_val < 0) { // Unsigned can't be negative
                scaled_val = 0;
            }
        }
        return FixedPoint(static_cast<internal_type>(std::round(scaled_val)));
    }

    // 转换为浮点数
    float to_float() const {
        return static_cast<float>(raw_value) / s_scale;
    }

    // 获取最大值和最小值 (作为浮点数)
    static float max_float() {
        if (s_is_signed) {
            return static_cast<float>(std::numeric_limits<internal_type>::max()) / s_scale;
        } else {
            return static_cast<float>(std::numeric_limits<internal_type>::max()) / s_scale;
        }
    }

    static float min_float() {
        if (s_is_signed) {
            return static_cast<float>(std::numeric_limits<internal_type>::min()) / s_scale;
        } else { // Unsigned min is always 0
            return 0.0f;
        }
    }

    // 打印
    friend std::ostream& operator<<(std::ostream& os, const FixedPoint& fp) {
        os << fp.to_float() << " (raw: " << fp.raw_value << ")";
        return os;
    }

    // 为了示例,定义一个简单的相等运算符
    bool operator==(const FixedPoint& other) const {
        return raw_value == other.raw_value;
    }
    // 更多运算符将在下一节中实现
};

// 示例用法
/*
int main() {
    // 定义一个Q1.15格式的定点数 (16位总位宽, 15位小数位, 有符号)
    using Q1_15 = FixedPoint<16, 15, true>;

    float f1 = 0.5f;
    float f2 = -0.125f;
    float f3 = 1.99f;
    float f4 = -1.99f; // 会饱和到最小值

    Q1_15 q1 = Q1_15::from_float(f1);
    Q1_15 q2 = Q1_15::from_float(f2);
    Q1_15 q3 = Q1_15::from_float(f3);
    Q1_15 q4 = Q1_15::from_float(f4);

    std::cout << "Q1.15 Max Float: " << Q1_15::max_float() << std::endl;
    std::cout << "Q1.15 Min Float: " << Q1_15::min_float() << std::endl;

    std::cout << "Float " << f1 << " -> Q1.15 " << q1 << std::endl;
    std::cout << "Float " << f2 << " -> Q1.15 " << q2 << std::endl;
    std::cout << "Float " << f3 << " -> Q1.15 " << q3 << std::endl; // 接近最大值
    std::cout << "Float " << f4 << " -> Q1.15 " << q4 << std::endl; // 接近最小值

    // 定义一个Q7.8格式的定点数 (16位总位宽, 8位小数位, 有符号)
    using Q7_8 = FixedPoint<16, 8, true>;
    float f_large = 123.45f;
    Q7_8 q_large = Q7_8::from_float(f_large);
    std::cout << "Q7.8 Max Float: " << Q7_8::max_float() << std::endl;
    std::cout << "Q7.8 Min Float: " << Q7_8::min_float() << std::endl;
    std::cout << "Float " << f_large << " -> Q7.8 " << q_large << std::endl;

    // 无符号定点数 (例如 Q8.8, 16位总位宽, 8位小数位, 无符号)
    using UQ8_8 = FixedPoint<16, 8, false>;
    float uf1 = 123.45f;
    float uf2 = -5.0f; // 无符号会饱和到0
    UQ8_8 uq1 = UQ8_8::from_float(uf1);
    UQ8_8 uq2 = UQ8_8::from_float(uf2);

    std::cout << "UQ8.8 Max Float: " << UQ8_8::max_float() << std::endl;
    std::cout << "UQ8.8 Min Float: " << UQ8_8::min_float() << std::endl;
    std::cout << "Float " << uf1 << " -> UQ8.8 " << uq1 << std::endl;
    std::cout << "Float " << uf2 << " -> UQ8.8 " << uq2 << std::endl;

    return 0;
}
*/

在上述代码中,我们定义了一个FixedPoint模板类,它通过模板参数TOTAL_BITSFRACTIONAL_BITSIS_SIGNED来灵活配置定点数的格式。internal_type根据总位宽自动选择合适的C++整数类型,确保了存储效率。s_scale是2的FRACTIONAL_BITS次方,用于浮点数和定点数之间的转换。

III. 定点数运算原理与实现

实现定点数库的核心在于正确地实现其基本算术运算。

加法与减法:对齐小数点

定点数的加减法非常直观,只需像整数一样直接对内部的raw_value进行加减即可,因为它们的小数点位置是隐式对齐的。

Q(A) + Q(B) = (A * 2^n) + (B * 2^n) = (A+B) * 2^n = Q(A+B)

但需要注意的是,加减法可能会导致溢出,特别是当两个大数值相加时。

乘法:精度扩张与截断

定点数乘法是所有运算中最需要细致处理的。两个定点数相乘,其结果的小数位数会增加。

假设我们有两个定点数 $FP_1$ (Q$m_1.n_1$) 和 $FP_2$ (Q$m_2.n_2$)。
它们的内部存储值分别为 $R_1 = FP_1 times 2^{n_1}$ 和 $R_2 = FP_2 times 2^{n2}$。
它们的乘积 $FP
{res} = FP_1 times FP2$。
那么结果的内部存储值 $R
{res}$ 应该是 $FP{res} times 2^{n{res}}$。

如果直接对内部整数值进行乘法:
$R_1 times R_2 = (FP_1 times 2^{n_1}) times (FP_2 times 2^{n_2}) = (FP_1 times FP_2) times 2^{n_1 + n_2}$

这意味着,直接相乘后的整数结果,其隐含的小数位数是 $n_1 + n2$。如果我们希望将结果存储回一个具有 $n{res}$ 小数位数的定点数格式,就需要进行定标(scaling),即将结果右移 (n_1 + n_2) - n_{res} 位。

通常,为了避免中间结果溢出,乘法操作会先将两个定点数的内部整数值相乘,这会得到一个位宽更长的结果(例如,两个16位整数相乘可能得到32位结果)。然后,再根据目标定点数格式进行右移和舍入。

除法:精度提升与舍入

定点数除法比乘法更复杂一些,因为它可能导致精度损失或需要额外的精度提升。
$FP_1 / FP_2 = (R_1 / 2^{n_1}) / (R_2 / 2^{n_2}) = (R_1 / R_2) times 2^{n_2 – n_1}$

如果直接进行整数除法 $R_1 / R_2$,得到的结果隐含的小数位数是 $n_1 – n_2$。
为了保持精度或将其转换为目标小数位数,我们通常会在进行整数除法之前,先将被除数 $R_1$ 左移一定的位数,以提升中间结果的精度。

例如,如果 (n_1 - n_2) 是负数,意味着结果的小数位数会减少,精度下降。为了得到一个 n_{res} 小数位的结果,我们需要将被除数 $R1$ 左移 `n{res} – (n_1 – n_2)` 位,然后再进行整数除法。

类型转换与定标 (Scaling)

  • floatFixedPoint raw_value = round(f_val * 2^n)
  • FixedPointfloat f_val = raw_value / 2^n
  • 不同 FixedPoint 格式之间转换:
    例如,从Q$m_1.n_1$ 转换为 Q$m_2.n_2$。
    如果 n_1 > n_2,需要右移 n_1 - n_2 位(舍弃精度)。
    如果 n_1 < n_2,需要左移 n_2 - n_1 位(提升精度,但不会增加实际有效位)。
    这其中同样需要考虑溢出和饱和。

舍入策略

在定点数运算中,特别是在乘法、除法和类型转换中,当结果需要截断时,如何处理被舍弃的位决定了最终的精度。常见的舍入策略包括:

  • 截断 (Truncation / Round Towards Zero): 直接丢弃小数部分,向零方向取整。最简单,但可能引入较大偏差。
  • 最近偶数 (Round Half To Even / Banker’s Rounding): 当小数部分恰好为0.5时,向最近的偶数取整。这是IEEE 754浮点标准默认的舍入方式,能有效减少累积误差。
  • 向上取整 (Round Up / Ceiling): 向正无穷方向取整。
  • 向下取整 (Round Down / Floor): 向负无穷方向取整。

对于嵌入式系统,通常会选择截断或最近偶数,因为它们在硬件实现上相对简单或精度累积误差较小。

溢出策略:饱和、环绕

如前所述,当运算结果超出目标定点数格式的表示范围时:

  • 饱和 (Saturation): 将结果限制在最大值或最小值。这是AI推理中最常用的策略,因为它能保持数值的合理性,避免模型崩溃。
  • 环绕 (Wrap-around): 结果回卷。在大多数AI场景中应避免。

*— Code Example 2: Fixed-Point Arithmetic Operators (+, -, , /) —
— Code Example 3: Saturation and Rounding Implementations —**

我们将这些策略集成到FixedPoint类中。为了简化,这里先实现一个默认的饱和和截断/近似舍入版本。

// 假设这是 FixedPoint 类定义内部
template <int TOTAL_BITS, int FRACTIONAL_BITS, bool IS_SIGNED = true>
class FixedPoint {
public:
    // ... (previous definitions: internal_type, raw_value, static constexpr members) ...

    // Saturated helper function
    static internal_type saturate(int64_t value) {
        if (s_is_signed) {
            if (value > std::numeric_limits<internal_type>::max()) {
                return std::numeric_limits<internal_type>::max();
            }
            if (value < std::numeric_limits<internal_type>::min()) {
                return std::numeric_limits<internal_type>::min();
            }
        } else { // Unsigned
            if (value < 0) return 0; // Saturate to 0 for unsigned
            if (value > std::numeric_limits<internal_type>::max()) {
                return std::numeric_limits<internal_type>::max();
            }
        }
        return static_cast<internal_type>(value);
    }

    // Rounding helper function (round to nearest, half up)
    // For fixed-point, right shifting usually implies truncation.
    // To implement round-to-nearest, we add 0.5 (represented as 1 << (shift - 1))
    // before the shift.
    static internal_type round_and_shift(int64_t value, int shift_amount) {
        if (shift_amount <= 0) return static_cast<internal_type>(value); // No shift or left shift

        // Round to nearest: add half of the LSB of the part being shifted out
        // For positive numbers, add (1 << (shift_amount - 1))
        // For negative numbers, subtract (1 << (shift_amount - 1))
        // A simpler way for symmetric rounding is (value + (1 << (shift_amount - 1))) >> shift_amount for positive
        // and (value - (1 << (shift_amount - 1))) >> shift_amount for negative.
        // A common technique for signed values:
        if (value >= 0) {
            value = (value + (internal_type(1) << (shift_amount - 1)));
        } else {
            value = (value - (internal_type(1) << (shift_amount - 1)));
        }
        return static_cast<internal_type>(value >> shift_amount);
    }

    // 加法运算符
    FixedPoint operator+(const FixedPoint& other) const {
        // 中间结果需要更宽的位宽以避免溢出,然后进行饱和
        return FixedPoint(saturate(static_cast<int64_t>(raw_value) + other.raw_value));
    }

    // 减法运算符
    FixedPoint operator-(const FixedPoint& other) const {
        return FixedPoint(saturate(static_cast<int64_t>(raw_value) - other.raw_value));
    }

    // 乘法运算符
    FixedPoint operator*(const FixedPoint& other) const {
        // 两个Q(m.n)相乘,结果是Q(m+m'+1.n+n'),这里假设m=m', n=n'
        // 乘积的原始小数位是 s_fractional_bits + other.s_fractional_bits
        // 但我们想将结果存回 s_fractional_bits
        // 因此需要右移 s_fractional_bits 位

        // 使用更宽的类型进行乘法,防止中间结果溢出
        int64_t prod = static_cast<int64_t>(raw_value) * other.raw_value;

        // 舍入并右移,将小数位调整回 s_fractional_bits
        // 这里的 round_and_shift 已经包含了饱和逻辑
        return FixedPoint(saturate(round_and_shift(prod, s_fractional_bits)));
    }

    // 除法运算符 (简化版,可能精度损失较大)
    FixedPoint operator/(const FixedPoint& other) const {
        // 为了提高除法精度,可以在除法前将被除数左移 s_fractional_bits 位
        // 然后再进行整数除法
        if (other.raw_value == 0) {
            // 处理除零错误,这里简化为返回最大值或最小值
            if (raw_value > 0) return FixedPoint(std::numeric_limits<internal_type>::max());
            else if (raw_value < 0) return FixedPoint(std::numeric_limits<internal_type>::min());
            else return FixedPoint(0); // 0 / 0
        }

        int64_t numerator = static_cast<int64_t>(raw_value) << s_fractional_bits;
        int64_t result_raw = numerator / other.raw_value;

        return FixedPoint(saturate(result_raw));
    }

    // 赋值运算符
    FixedPoint& operator+=(const FixedPoint& other) { *this = *this + other; return *this; }
    FixedPoint& operator-=(const FixedPoint& other) { *this = *this - other; return *this; }
    FixedPoint& operator*=(const FixedPoint& other) { *this = *this * other; return *this; }
    FixedPoint& operator/=(const FixedPoint& other) { *this = *this / other; return *this; }

    // 比较运算符 (仅示例)
    bool operator<(const FixedPoint& other) const { return raw_value < other.raw_value; }
    bool operator>(const FixedPoint& other) const { return raw_value > other.raw_value; }
    bool operator<=(const FixedPoint& other) const { return raw_value <= other.raw_value; }
    bool operator>=(const FixedPoint& other) const { return raw_value >= other.raw_value; }
    bool operator!=(const FixedPoint& other) const { return raw_value != other.raw_value; }

    // 显式类型转换到其他 FixedPoint 类型 (需要处理位宽和小数位调整)
    template <int OTHER_TOTAL_BITS, int OTHER_FRACTIONAL_BITS, bool OTHER_IS_SIGNED>
    explicit operator FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED>() const {
        using TargetFixedPoint = FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED>;

        int64_t temp_raw = raw_value;
        int shift_diff = TargetFixedPoint::s_fractional_bits - s_fractional_bits;

        if (shift_diff > 0) { // Target has more fractional bits, left shift
            temp_raw <<= shift_diff;
        } else if (shift_diff < 0) { // Target has fewer fractional bits, right shift (round)
            temp_raw = round_and_shift(temp_raw, -shift_diff);
        }

        // 最后进行饱和处理,因为位宽可能也变了
        return TargetFixedPoint(TargetFixedPoint::saturate(temp_raw));
    }
};

/*
int main() {
    using Q1_15 = FixedPoint<16, 15, true>;

    Q1_15 a = Q1_15::from_float(0.75f); // raw: 0.75 * 2^15 = 24576
    Q1_15 b = Q1_15::from_float(0.5f);  // raw: 0.5 * 2^15 = 16384

    std::cout << "a: " << a << std::endl;
    std::cout << "b: " << b << std::endl;

    Q1_15 sum = a + b; // raw: 24576 + 16384 = 40960. Should be 1.25
    std::cout << "a + b = " << sum << std::endl; // Expect ~1.25

    Q1_15 diff = a - b; // raw: 24576 - 16384 = 8192. Should be 0.25
    std::cout << "a - b = " << diff << std::endl; // Expect ~0.25

    Q1_15 prod = a * b; // raw: (24576 * 16384) >> 15 = (402653184) >> 15 = 12288. Should be 0.375
    std::cout << "a * b = " << prod << std::endl; // Expect ~0.375

    Q1_15 div_res = a / b; // raw: (24576 << 15) / 16384 = (805306368) / 16384 = 49152. Should be 1.5
    std::cout << "a / b = " << div_res << std::endl; // Expect ~1.5

    // 溢出测试
    using Q0_7 = FixedPoint<8, 7, true>; // 范围约 [-1.0, 0.99...]
    Q0_7 x = Q0_7::from_float(0.75f);
    Q0_7 y = Q0_7::from_float(0.5f);
    Q0_7 ov_sum = x + y; // 0.75 + 0.5 = 1.25, 会饱和到最大值 (0.99...)
    std::cout << "Q0_7 x: " << x << ", y: " << y << std::endl;
    std::cout << "Q0_7 x + y (overflow): " << ov_sum << std::endl;
    std::cout << "Q0_7 max float: " << Q0_7::max_float() << std::endl;

    // 类型转换
    using Q7_8 = FixedPoint<16, 8, true>;
    Q7_8 converted_a = static_cast<Q7_8>(a); // 从 Q1.15 转 Q7.8
    std::cout << "Q1.15 a: " << a << " converted to Q7.8: " << converted_a << std::endl;

    return 0;
}
*/

IV. 定点数库设计与架构

一个健壮且高效的定点数库需要精心设计的架构。我们将利用C++的模板元编程和策略模式来构建这个库。

模板元编程的应用:泛化类型、位宽、小数位

模板元编程允许我们在编译期根据模板参数生成不同的代码。这对于定点数库非常有用,因为它能让我们:

  • 灵活定义定点数格式: 通过模板参数TOTAL_BITS, FRACTIONAL_BITS, IS_SIGNED,用户可以根据需求创建任意位宽、小数位的定点数类型,如FixedPoint<16, 8, true> (Q7.8), FixedPoint<32, 24, true> (Q7.24) 等。
  • 编译期优化: 许多计算,如s_scale的计算,可以在编译期完成,避免运行时开销。
  • 类型安全: 编译器会检查不同定点数类型之间的操作是否合法,例如,不允许将一个FixedPoint<16,8>直接赋值给FixedPoint<8,4>而不进行显式转换。

类设计:FixedPoint

我们已经初步构建了FixedPoint类,现在进一步完善其设计:

  • 内部存储类型 (internal_type): 必须足够大以容纳指定位宽的整数。使用std::conditional来选择最小合适的类型(int8_t, int16_t, int32_t, int64_t),从而节省内存。
  • 模板参数:
    • TOTAL_BITS:总位数(包括符号位)。
    • FRACTIONAL_BITS:小数部分的位数。
    • IS_SIGNED:布尔值,表示是否有符号。
  • 构造函数、转换函数: 提供从整数、浮点数构造,以及转换为浮点数的接口。也应提供不同FixedPoint类型之间的显式转换操作符。
  • 运算符重载: +, -, *, /, +=, -=, *=, /=, ==, !=, <, >, <=, >=等,使得定点数操作像内置类型一样自然。
  • 静态辅助函数: from_float, to_float, max_value, min_value等。

策略模式的应用:舍入策略、溢出策略可插拔

在嵌入式系统中,对舍入和溢出处理有不同的需求。策略模式允许我们将这些行为封装到独立的策略类中,并在FixedPoint类中作为模板参数引入,从而实现行为的动态切换(在编译期)。

// --- Code Example 4: FixedPoint Class Template Structure (Refined with Policies) ---
// --- Code Example 5: Policy-Based Rounding/Saturation ---

// 1. 溢出策略基类 (或概念)
template <typename InternalType>
struct SaturationPolicy {
    static InternalType saturate(int64_t value) {
        // 默认实现:饱和到目标类型最大最小值
        if (value > std::numeric_limits<InternalType>::max()) {
            return std::numeric_limits<InternalType>::max();
        }
        if (value < std::numeric_limits<InternalType>::min()) {
            return std::numeric_limits<InternalType>::min();
        }
        return static_cast<InternalType>(value);
    }
};

// 无符号数的饱和策略
template <typename InternalType>
struct UnsignedSaturationPolicy {
    static InternalType saturate(int64_t value) {
        if (value < 0) return 0;
        if (value > std::numeric_limits<InternalType>::max()) {
            return std::numeric_limits<InternalType>::max();
        }
        return static_cast<InternalType>(value);
    }
};

// 环绕溢出策略 (作为对比,实际AI中不常用)
template <typename InternalType>
struct WrapAroundPolicy {
    static InternalType saturate(int64_t value) {
        return static_cast<InternalType>(value); // 直接截断
    }
};

// 2. 舍入策略基类 (或概念)
template <typename InternalType>
struct RoundingPolicy {
    static InternalType round_and_shift(int64_t value, int shift_amount) {
        if (shift_amount <= 0) return static_cast<InternalType>(value);
        // 默认实现:Round Half Up (对于正数,加0.5后截断)
        // 对于负数,round(x) = floor(x + 0.5)
        if (value >= 0) {
            value = (value + (InternalType(1) << (shift_amount - 1)));
        } else {
            value = (value - (InternalType(1) << (shift_amount - 1))); // 向负无穷方向舍入,需要调整
        }
        return static_cast<InternalType>(value >> shift_amount);
    }
};

// 截断策略
template <typename InternalType>
struct TruncationPolicy {
    static InternalType round_and_shift(int64_t value, int shift_amount) {
        if (shift_amount <= 0) return static_cast<InternalType>(value);
        return static_cast<InternalType>(value >> shift_amount); // 直接截断
    }
};

// 3. 改进的 FixedPoint 类,接受策略作为模板参数
template <int TOTAL_BITS, int FRACTIONAL_BITS,
          bool IS_SIGNED = true,
          template <typename> class Saturation = std::conditional_t<IS_SIGNED, SaturationPolicy, UnsignedSaturationPolicy>,
          template <typename> class Rounding = RoundingPolicy>
class FixedPoint {
public:
    using internal_type = typename std::conditional<
        TOTAL_BITS <= 16, int16_t,
        typename std::conditional<
            TOTAL_BITS <= 32, int32_t,
            int64_t
        >::type
    >::type;

    internal_type raw_value;

    static constexpr int s_total_bits = TOTAL_BITS;
    static constexpr int s_fractional_bits = FRACTIONAL_BITS;
    static constexpr bool s_is_signed = IS_SIGNED;
    static constexpr internal_type s_scale = internal_type(1) << FRACTIONAL_BITS;

    // 策略的应用
    internal_type apply_saturation(int64_t val) const {
        return Saturation<internal_type>::saturate(val);
    }

    internal_type apply_rounding_and_shift(int64_t val, int shift_amount) const {
        return Rounding<internal_type>::round_and_shift(val, shift_amount);
    }

    // 构造函数、to_float, from_float等保持不变,但要调用 apply_saturation
    FixedPoint() : raw_value(0) {}
    explicit FixedPoint(internal_type raw_val) : raw_value(raw_val) {}

    static FixedPoint from_float(float f_val) {
        float scaled_val = f_val * s_scale;
        // 这里的饱和逻辑需要根据 IS_SIGNED 和 Saturation Policy 调整
        // 简化起见,这里直接调用默认的 SaturationPolicy::saturate
        return FixedPoint(Saturation<internal_type>::saturate(static_cast<int64_t>(std::round(scaled_val))));
    }

    float to_float() const {
        return static_cast<float>(raw_value) / s_scale;
    }

    // 重载运算符,现在调用策略函数
    FixedPoint operator+(const FixedPoint& other) const {
        return FixedPoint(apply_saturation(static_cast<int64_t>(raw_value) + other.raw_value));
    }
    FixedPoint operator-(const FixedPoint& other) const {
        return FixedPoint(apply_saturation(static_cast<int64_t>(raw_value) - other.raw_value));
    }
    FixedPoint operator*(const FixedPoint& other) const {
        int64_t prod = static_cast<int64_t>(raw_value) * other.raw_value;
        return FixedPoint(apply_saturation(apply_rounding_and_shift(prod, s_fractional_bits)));
    }
    FixedPoint operator/(const FixedPoint& other) const {
        if (other.raw_value == 0) { /* handle division by zero */ return FixedPoint(apply_saturation(std::numeric_limits<int64_t>::max())); }
        int64_t numerator = static_cast<int64_t>(raw_value) << s_fractional_bits;
        int64_t result_raw = numerator / other.raw_value;
        return FixedPoint(apply_saturation(result_raw));
    }

    // 类型转换操作符也使用策略
    template <int OTHER_TOTAL_BITS, int OTHER_FRACTIONAL_BITS, bool OTHER_IS_SIGNED,
              template <typename> class OTHER_SAT, template <typename> class OTHER_ROUND>
    explicit operator FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED, OTHER_SAT, OTHER_ROUND>() const {
        using TargetFixedPoint = FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED, OTHER_SAT, OTHER_ROUND>;

        int64_t temp_raw = raw_value;
        int shift_diff = TargetFixedPoint::s_fractional_bits - s_fractional_bits;

        if (shift_diff > 0) { // Target has more fractional bits, left shift
            temp_raw <<= shift_diff;
        } else if (shift_diff < 0) { // Target has fewer fractional bits, right shift (round)
            temp_raw = apply_rounding_and_shift(temp_raw, -shift_diff); // 使用当前对象的舍入策略
        }

        return TargetFixedPoint(TargetFixedPoint().apply_saturation(temp_raw)); // 使用目标类型的饱和策略
    }

    // 打印等其他成员保持不变
    friend std::ostream& operator<<(std::ostream& os, const FixedPoint& fp) {
        os << fp.to_float() << " (raw: " << fp.raw_value << ")";
        return os;
    }
};

/*
int main() {
    // 使用默认策略的Q1.15 (有符号,饱和,Round Half Up)
    using Q1_15_Default = FixedPoint<16, 15, true>;
    Q1_15_Default a_def = Q1_15_Default::from_float(0.75f);
    Q1_15_Default b_def = Q1_15_Default::from_float(0.5f);
    std::cout << "Default Q1.15 a*b: " << (a_def * b_def) << std::endl;

    // 使用截断舍入策略的Q1.15
    using Q1_15_Trunc = FixedPoint<16, 15, true, SaturationPolicy, TruncationPolicy>;
    Q1_15_Trunc a_trunc = Q1_15_Trunc::from_float(0.75f);
    Q1_15_Trunc b_trunc = Q1_15_Trunc::from_float(0.5f);
    std::cout << "Truncation Q1.15 a*b: " << (a_trunc * b_trunc) << std::endl;

    // 无符号Q8.8,使用无符号饱和策略
    using UQ8_8_Default = FixedPoint<16, 8, false>;
    UQ8_8_Default ua = UQ8_8_Default::from_float(100.0f);
    UQ8_8_Default ub = UQ8_8_Default::from_float(-10.0f); // 会被饱和到0
    std::cout << "Unsigned UQ8_8 a: " << ua << ", b: " << ub << std::endl;
    std::cout << "Unsigned UQ8_8 a+b: " << (ua + ub) << std::endl;

    return 0;
}
*/

性能优化考量

在嵌入式AI芯片上,性能是第一位的。我们的库设计需要充分考虑以下优化:

  • 避免动态内存分配: FixedPoint对象应完全存储在栈上或静态存储区,避免堆内存分配,因为堆操作开销大且可能导致内存碎片。
  • 内联函数: 关键的运算符和辅助函数应声明为inline,鼓励编译器将其展开,消除函数调用开销。
  • 位运算: 充分利用位移(<<, >>)进行定标,比浮点乘除法快得多。
  • constexpr 尽可能将可在编译期确定的值(如s_scale)声明为constexpr,进一步优化。
  • 避免不必要的类型转换: 尽量在同一种FixedPoint类型之间进行运算,减少定标和饱和的开销。
  • 编译器优化标志: 配合O2/O3等编译器优化选项。

V. 矩阵乘法在定点数域的优化实现

现在,我们将定点数库应用于核心任务:矩阵乘法。

矩阵乘法基础:$C = A times B$

给定一个 $M times K$ 矩阵 $A$ 和一个 $K times N$ 矩阵 $B$,它们的乘积 $C$ 是一个 $M times N$ 矩阵,其中 $C{ij} = sum{p=0}^{K-1} A{ip} times B{pj}$。

定点数矩阵乘法的挑战

  1. 中间结果的精度累积 (乘积和): 矩阵乘法涉及大量的乘加操作。每个乘法结果都需要定标,然后累加。如果累加器位宽不足,连续的加法很容易导致溢出。
  2. 溢出风险管理: 如何选择合适的累加器位宽,以及何时进行饱和处理,是关键。
  3. 效率: 即使是定点数,大量的内存访问和计算仍然需要高效的实现。

累加器位宽选择:防止中间结果溢出

这是定点数矩阵乘法中最关键的设计点。
假设 $A$ 的元素是 Q$m_A.n_A$, $B$ 的元素是 Q$m_B.nB$。
一个乘积 $A
{ip} times B_{pj}$ 的原始小数位是 $n_A + n_B$。
累加 $K$ 个这样的乘积,最大值可能达到 $K times (text{MaxA} times text{MaxB})$。
为了避免溢出,累加器的位宽必须足够大。

通常,累加器的位宽至少要能容纳 Max(A) * Max(B) * K 的值。
如果 $A, B$ 是16位定点数,它们的乘积是32位。如果 $K=256$,那么累加器至少需要 32 + log2(256) = 32 + 8 = 40 位。因此,使用 int64_t 作为累加器是常见的做法。

分块乘法 (Tiling):缓存优化,减少内存访问

在现代处理器中,内存访问速度远低于CPU计算速度。分块乘法是一种有效的缓存优化技术。它将大矩阵分成小块,每次只处理一小块数据,确保这些数据能够完全放入高速缓存(L1/L2 Cache),从而减少对主内存的访问,提高数据局部性。

例如,对于 $C{ij} = sum{p=0}^{K-1} A{ip} times B{pj}$,我们可以将 $A, B, C$ 矩阵都分割成若干个小块。在一个块内完成乘加操作,再处理下一个块。

SIMD指令集利用 (SSE/NEON):针对嵌入式AI芯片的硬件加速

大多数低功耗嵌入式AI芯片都集成了SIMD(Single Instruction, Multiple Data)指令集,如ARM的NEON或x86的SSE/AVX。SIMD指令允许处理器同时对多个数据元素执行相同的操作,极大提升并行计算能力。

将定点数运算映射到SIMD指令是实现高效矩阵乘法的关键。

  • ARM NEON:

    • vmla_s16:16位有符号整数乘积累加。
    • vmull_s16:16位有符号整数乘法,结果扩展到32位。
    • vadd_s32:32位有符号整数加法。
    • vqrdmulh_s16:16位有符号整数乘法,结果取高16位并舍入。
    • vqadd_s16:16位有符号整数饱和加法。
    • vqmovn_s32:32位饱和截断到16位。
  • x86 SSE/AVX:

    • _mm_madd_epi16:对16位有符号整数进行成对乘法并累加到32位。
    • _mm_mullo_epi16:16位有符号整数乘法,结果截断到16位。
    • _mm_adds_epi16:16位有符号整数饱和加法。

如何将定点数操作映射到SIMD指令?

  1. 数据打包: 将多个FixedPoint<16, N>(或FixedPoint<8, N>)的raw_value打包到SIMD寄存器中(例如,NEON的int16x8_tint8x16_t)。
  2. 并行乘法: 使用SIMD乘法指令对打包的数据进行并行乘法。这通常会产生位宽更大的中间结果(例如,两个16位乘法得到32位结果)。
  3. 并行累加: 使用SIMD累加指令将这些中间结果累加到位宽更大的累加器SIMD寄存器中。
  4. 定标与饱和: 在累加完成后,对累加器结果进行并行右移(定标)和饱和操作,将其转换回目标定点数格式。

循环展开 (Loop Unrolling):减少循环开销

手动或让编译器进行循环展开,可以减少循环控制指令的开销,并为编译器提供更多的指令调度机会,进一步提升性能。

内存布局优化:行主序/列主序,对齐

  • 行主序(Row-major)与列主序(Column-major): C++默认是行主序存储。确保矩阵访问模式与内存布局一致,可以提高缓存命中率。
  • 内存对齐: SIMD指令通常要求数据按特定边界(如16字节、32字节)对齐。对齐内存可以避免性能下降,甚至避免某些SIMD指令无法使用。可以使用alignas关键字或自定义内存分配器。

— Code Example 6: Basic Fixed-Point Matrix Multiplication (naive) —

首先,我们来看一个不带任何优化的定点数矩阵乘法基础实现。

#include <vector>
#include <chrono> // For performance measurement

// 假设 FixedPoint 类以及其运算符已经定义好
// using Q_TYPE = FixedPoint<16, 8, true>; // 示例定点数类型

template <typename FP_TYPE>
std::vector<FP_TYPE> multiply_matrices_naive(
    const std::vector<FP_TYPE>& A, int M, int K,
    const std::vector<FP_TYPE>& B, int N)
{
    // A: M x K, B: K x N, C: M x N
    if (A.size() != M * K || B.size() != K * N) {
        throw std::runtime_error("Matrix dimensions mismatch or invalid sizes.");
    }

    std::vector<FP_TYPE> C(M * N);

    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < N; ++j) {
            // 使用 Q_ACCUMULATOR_TYPE 作为累加器,防止中间结果溢出
            // 假设乘法结果是 2 * FP_TYPE::s_total_bits 位,且小数位是 2 * FP_TYPE::s_fractional_bits
            // 累加器需要能容纳 K 个这样的结果,并最终转换回 FP_TYPE

            // 这里我们创建一个临时的高精度累加器,确保不溢出
            // 假设 FP_TYPE 是 FixedPoint<16, 8, true>,那么其 raw_value 是 int16_t
            // 乘法结果的 raw_value 是 int32_t,小数位是 16
            // 累加 K 次 int32_t,如果 K 很大,需要 int64_t
            // 最终的累加器需要在内部处理好精度和饱和

            // 为了简化,我们暂时假设 FP_TYPE 内部的乘法运算符已经处理了精度和饱和
            // 但更严谨的做法是,这里使用一个更高精度的定点数类型作为累加器

            // Option 1: 使用更高精度的定点数作为累加器
            // 假设 FP_TYPE 是 FixedPoint<16, 8, true>
            // 乘法结果 Q(15.16)
            // 累加器需要 Q(15 + log2(K) . 16)
            // 我们可以定义一个临时的累加器类型,例如 FixedPoint<32, 16, true>
            // 但为了简化,这里暂时直接使用 FP_TYPE::internal_type 作为原始累加器
            // 并手动管理其定点数含义

            // 更优的累加器类型选择,例如:
            // A是 Qm.n, B是 Qm.n, 乘积是 Q(2m+1).2n
            // 累加 K 次,需要 (2m+1) + log2(K) 位整数部分
            // 小数部分保持 2n
            // 假设我们的 FP_TYPE 是 FixedPoint<16, 8, true> (Q7.8)
            // 乘积是 Q(15).16
            // 如果 K=64 (log2(K)=6),整数位需要 15+6=21位
            // 总位宽 21 + 16 + 1 (符号位) = 38位
            // 所以累加器需要至少 FixedPoint<40, 16, true> (使用 int64_t 内部存储)

            // 为了演示,这里直接使用 int64_t 作为裸整数累加器,并手动管理其定点数含义
            // 内部存储的原始值为 int64_t
            int64_t acc_raw = 0;

            // 每个乘法的结果,其小数位数是 FP_TYPE::s_fractional_bits * 2
            // 假设 FP_TYPE 是 Qm.n
            // A_raw * B_raw 结果的小数位是 2n
            // 如果我们希望累加器的小数位也是 2n,那么直接累加即可
            // 最终将累加器结果右移 n 位,转换回 Qm.n

            // 这里我们假设 FP_TYPE 的乘法运算符返回的结果已经定标到 FP_TYPE::s_fractional_bits
            // 那么累加器的小数位数就和 FP_TYPE 相同

            // 这是一个简化的累加器,假设 FP_TYPE 的乘法结果已经定标到 FP_TYPE::s_fractional_bits
            // 并且我们希望累加器的小数位数也保持 FP_TYPE::s_fractional_bits
            // 这种方式会导致精度损失,但简化了累加器类型选择
            // 更优的方法是使用一个具有更长小数位的中间累加器类型

            // 重新思考累加器:
            // C_ij = SUM( A_ip * B_pj )
            // A_ip 的 raw_value 是 val_A * 2^N_f
            // B_pj 的 raw_value 是 val_B * 2^N_f
            // A_ip * B_pj 的 raw_value 是 (val_A * val_B) * 2^(2*N_f)
            // 所以每个乘积的原始整数值是 `A[A_idx].raw_value * B[B_idx].raw_value`
            // 其隐含的小数位是 `2 * FP_TYPE::s_fractional_bits`
            // 我们用 `int64_t` 来存储这个原始乘积,然后累加

            for (int p = 0; p < K; ++p) {
                int A_idx = i * K + p;
                int B_idx = p * N + j;

                // 执行原始整数乘法,结果是 2 * TOTAL_BITS 位,小数位是 2 * FRACTIONAL_BITS
                acc_raw += static_cast<int64_t>(A[A_idx].raw_value) * B[B_idx].raw_value;
            }

            // 累加完成后,将累加器的原始值定标回 FP_TYPE 格式
            // 原始小数位是 2 * FP_TYPE::s_fractional_bits
            // 目标小数位是 FP_TYPE::s_fractional_bits
            // 需要右移 FP_TYPE::s_fractional_bits 位
            C[i * N + j] = FP_TYPE(FP_TYPE().apply_saturation(
                                       FP_TYPE().apply_rounding_and_shift(acc_raw, FP_TYPE::s_fractional_bits)
                                   ));
        }
    }
    return C;
}

/*
// 示例 main 函数
int main() {
    using Q7_8 = FixedPoint<16, 8, true>; // Q7.8 格式
    int M = 2, K = 3, N = 2;

    std::vector<Q7_8> A_fp = {
        Q7_8::from_float(1.0f), Q7_8::from_float(2.0f), Q7_8::from_float(3.0f),
        Q7_8::from_float(4.0f), Q7_8::from_float(5.0f), Q7_8::from_float(6.0f)
    }; // 2x3 matrix

    std::vector<Q7_8> B_fp = {
        Q7_8::from_float(7.0f), Q7_8::from_float(8.0f),
        Q7_8::from_float(9.0f), Q7_8::from_float(10.0f),
        Q7_8::from_float(11.0f), Q7_8::from_float(12.0f)
    }; // 3x2 matrix

    std::cout << "Matrix A (Q7.8):" << std::endl;
    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < K; ++j) {
            std::cout << A_fp[i * K + j].to_float() << "t";
        }
        std::cout << std::endl;
    }

    std::cout << "nMatrix B (Q7.8):" << std::endl;
    for (int i = 0; i < K; ++i) {
        for (int j = 0; j < N; ++j) {
            std::cout << B_fp[i * N + j].to_float() << "t";
        }
        std::cout << std::endl;
    }

    auto start = std::chrono::high_resolution_clock::now();
    std::vector<Q7_8> C_fp = multiply_matrices_naive(A_fp, M, K, B_fp, N);
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff = end - start;

    std::cout << "nMatrix C (Q7.8) - Naive (Time: " << diff.count() << "s):" << std::endl;
    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < N; ++j) {
            std::cout << C_fp[i * N + j].to_float() << "t";
        }
        std::cout << std::endl;
    }

    // Expected result (float for comparison):
    // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
    // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
    // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
    // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
    // Output should be close to:
    // 58.0    64.0
    // 139.0   154.0

    return 0;
}
*/

上述multiply_matrices_naive函数展示了定点数矩阵乘法的基本逻辑,其中累加器acc_raw被声明为int64_t,以确保在累加过程中不会发生溢出。在所有乘加操作完成后,再将最终的累加结果进行一次定标和饱和,转换回目标定点数格式。

— Code Example 7: Optimized Fixed-Point Matrix Multiplication (Accumulator management, basic tiling idea) —

针对嵌入式AI芯片的优化,我们可以在上述基础上引入分块和更精细的累加器管理,并为SIMD集成预留接口。

#include <vector>
#include <chrono>
#include <stdexcept>
// #include <arm_neon.h> // For ARM NEON intrinsics, uncomment for ARM targets
// #include <immintrin.h> // For x86 SSE/AVX intrinsics, uncomment for x86 targets

// 定义一个专门用于矩阵乘法累加的FixedPoint类型
// 假设输入是 FixedPoint<16, 8, true> (Q7.8)
// 乘积的原始小数位是 8+8=16位。
// 如果 K 最大为 256,则整数位需要 7+7+1(符号位) + log2(256) = 15 + 8 = 23位
// 所以累加器总位宽至少 23 + 16 (小数位) + 1 (符号位) = 40位
// 我们可以使用 FixedPoint<40, 16, true>,但为了方便,直接用 FixedPoint<64, 16, true>
// 这样内部存储就是 int64_t,且小数位为 16
template <typename FP_TYPE>
using AccumulatorFP = FixedPoint<64, 2 * FP_TYPE::s_fractional_bits, true>;

template <typename FP_TYPE>
std::vector<FP_TYPE> multiply_matrices_optimized(
    const std::vector<FP_TYPE>& A, int M, int K,
    const std::vector<FP_TYPE>& B, int N,
    int BLOCK_SIZE_M, int BLOCK_SIZE_K, int BLOCK_SIZE_N)
{
    if (A.size() != M * K || B.size() != K * N) {
        throw std::runtime_error("Matrix dimensions mismatch or invalid sizes.");
    }
    std::vector<FP_TYPE> C(M * N);

    // 假设矩阵是行主序存储

    // 循环分块
    for (int i_block = 0; i_block < M; i_block += BLOCK_SIZE_M) {
        for (int j_block = 0; j_block < N; j_block += BLOCK_SIZE_N) {
            for (int p_block = 0; p_block < K; p_block += BLOCK_SIZE_K) {

                // 处理当前块内的元素
                for (int i = i_block; i < std::min(i_block + BLOCK_SIZE_M, M); ++i) {
                    for (int j = j_block; j < std::min(j_block + BLOCK_SIZE_N, N); ++j) {

                        // 初始化高精度累加器
                        AccumulatorFP<FP_TYPE> acc(0); // 内部 raw_value = 0

                        for (int p = p_block; p < std::min(p_block + BLOCK_SIZE_K, K); ++p) {
                            int A_idx = i * K + p;
                            int B_idx = p * N + j;

                            // 这里直接使用 FP_TYPE 之间的乘法运算符
                            // 它的结果已经是定标到 FP_TYPE::s_fractional_bits
                            // 再次强调,如果 FP_TYPE 的乘法结果不直接定标,这里需要手动管理
                            // 为了简化,我们假设 FP_TYPE 的乘法结果内部是 FP_TYPE::raw_value
                            // 但我们希望累加器的内部 raw_value 是 `original_product_raw_value`
                            // 实际的优化应该是:
                            // 1. 获取 A 和 B 的原始整数值
                            // 2. 将它们相乘 (结果是 int64_t, 小数位是 2 * FP_TYPE::s_fractional_bits)
                            // 3. 将这个结果累加到 AccumulatorFP 的 raw_value 中
                            // 4. 最终将 AccumulatorFP 的 raw_value 定标回 FP_TYPE

                            // 改进的乘加逻辑:直接操作原始整数值,并使用 AccumulatorFP 累加
                            int64_t prod_raw = static_cast<int64_t>(A[A_idx].raw_value) * B[B_idx].raw_value;
                            acc.raw_value += prod_raw;
                        }

                        // 累加完成后,将 AccumulatorFP 的结果转换回 FP_TYPE
                        // AccumulatorFP 的小数位是 2 * FP_TYPE::s_fractional_bits
                        // FP_TYPE 的小数位是 FP_TYPE::s_fractional_bits
                        // 因此需要右移 FP_TYPE::s_fractional_bits 位
                        int shift_amount = FP_TYPE::s_fractional_bits;
                        C[i * N + j] = FP_TYPE(FP_TYPE().apply_saturation(
                                                   FP_TYPE().apply_rounding_and_shift(acc.raw_value, shift_amount)
                                               ));
                        // 注意:这里调用 FP_TYPE 的 apply_saturation 和 apply_rounding_and_shift
                        // 因为最终结果要存储为 FP_TYPE
                    }
                }
            }
        }
    }
    return C;
}

/*
// 示例 main 函数 (与 naive 类似,只是调用不同的乘法函数)
int main() {
    using Q7_8 = FixedPoint<16, 8, true>; // Q7.8 格式
    int M = 64, K = 64, N = 64; // 更大的矩阵,以便观察性能差异

    // 填充随机数据
    std::vector<Q7_8> A_fp(M * K);
    std::vector<Q7_8> B_fp(K * N);
    for (size_t i = 0; i < A_fp.size(); ++i) A_fp[i] = Q7_8::from_float(static_cast<float>(rand() % 256 - 128) / 10.0f);
    for (size_t i = 0; i < B_fp.size(); ++i) B_fp[i] = Q7_8::from_float(static_cast<float>(rand() % 256 - 128) / 10.0f);

    std::cout << "Running naive matrix multiplication for " << M << "x" << K << " * " << K << "x" << N << std::endl;
    auto start_naive = std::chrono::high_resolution_clock::now();
    std::vector<Q7_8> C_naive = multiply_matrices_naive(A_fp, M, K, B_fp, N);
    auto end_naive = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff_naive = end_naive - start_naive;
    std::cout << "Naive time: " << diff_naive.count() << "s" << std::endl;

    std::cout << "Running optimized matrix multiplication (blocking) for " << M << "x" << K << " * " << K << "x" << N << std::endl;
    int BLOCK_M = 16, BLOCK_K = 16, BLOCK_N = 16; // 示例块大小
    auto start_opt = std::chrono::high_resolution_clock::now();
    std::vector<Q7_8> C_opt = multiply_matrices_optimized(A_fp, M, K, B_fp, N, BLOCK_M, BLOCK_K, BLOCK_N);
    auto end_opt = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> diff_opt = end_opt - start_opt;
    std::cout << "Optimized (blocking) time: " << diff_opt.count() << "s" << std::endl;

    // 简单验证结果是否一致
    bool results_match = true;
    for(size_t i=0; i < C_naive.size(); ++i) {
        if (C_naive[i].raw_value != C_opt[i].raw_value) {
            results_match = false;
            // std::cout << "Mismatch at " << i << ": Naive=" << C_naive[i] << ", Optimized=" << C_opt[i] << std::endl;
            break;
        }
    }
    std::cout << "Results match: " << (results_match ? "Yes" : "No") << std::endl;

    return 0;
}
*/

— Code Example 8: SIMD Intrinsics (conceptual sketch for fixed-point dot product) —

SIMD实现将是性能提升的关键。由于SIMD指令集是平台相关的,这里提供一个概念性的骨架,展示如何集成NEON指令。实际应用中需要根据具体芯片架构和编译器进行调整。

// 假设这是 NEON 版本的 FixedPoint 矩阵乘法的一个小片段
// 仅为示意,完整的 SIMD 优化需要考虑数据加载、存储、循环展开、对齐等
// 这是一个针对 16位定点数 (Q7.8) 的点积函数示例

#ifdef __ARM_NEON
#include <arm_neon.h>

// 假设 FP_TYPE 是 FixedPoint<16, 8, true>
// 并且矩阵数据已经对齐
void fixed_point_dot_product_neon_q7_8(
    const int16_t* A_row, // 指向 A 矩阵一行的原始数据
    const int16_t* B_col, // 指向 B 矩阵一列的原始数据
    int K,                // 点积长度
    int16_t* result_ptr)  // 存储最终 Q7.8 结果的指针
{
    // NEON 累加器,存储 4个 32位整数,对应 4个 Q15.16 乘积的累加
    int32x4_t sum_vec = vdupq_n_s32(0);

    // 假设 K 是 4 的倍数,以便一次处理 4 个元素
    for (int p = 0; p < K; p += 4) {
        // 加载 A 的 4 个 16 位元素
        int16x4_t a_vec = vld1_s16(A_row + p);
        // 加载 B 的 4 个 16 位元素
        int16x4_t b_vec = vld1_s16(B_col + p);

        // 执行 16位 x 16位 -> 32位 乘法,并累加到 32位累加器
        // vmull_s16 得到 32位乘积,vmlaq_s32 累加
        // 注意:vmlaq_s32 是 32位乘积累加,这里我们需要 16x16=32位的乘积
        // NEON通常提供 vmull_s16 (long multiply) 将 16位结果扩展为 32位
        // 然后再用 vaddq_s32 累加。或者直接用 vmlal_s16 (multiply and accumulate long)

        // 假设 A_row 和 B_col 是 Q7.8 (16位),乘积是 Q15.16 (32位)
        int32x4_t prod_low = vmull_s16(vget_low_s16(a_vec), vget_low_s16(b_vec)); // 乘法,结果是 32位
        int32x4_t prod_high = vmull_s16(vget_high_s16(a_vec), vget_high_s16(b_vec));

        sum_vec = vaddq_s32(sum_vec, prod_low);
        sum_vec = vaddq_s32(sum_vec, prod_high);
    }

    // 将 4个 32位累加结果水平求和,得到一个 32位总和
    int32_t final_sum_raw = vgetq_lane_s32(sum_vec, 0) +
                            vgetq_lane_s32(sum_vec, 1) +
                            vgetq_lane_s32(sum_vec, 2) +
                            vgetq_lane_s32(sum_vec, 3);

    // 定标与饱和:将 Q15.16 的结果转换为 Q7.8
    // 需要右移 8 位 (16 - 8)
    // vqrshrn_n_s32 是 NEON 的饱和舍入右移指令,将 32位结果右移并饱和到 16位
    // 这里的 final_sum_raw 是一个标量,需要手动处理
    int shift_amount = FixedPoint<16,8,true>::s_fractional_bits; // 8

    // 舍入 (round to nearest, half up)
    int64_t rounded_shifted_val = (static_cast<int64_t>(final_sum_raw) + (1LL << (shift_amount - 1))) >> shift_amount;

    // 饱和到 16位
    if (rounded_shifted_val > std::numeric_limits<int16_t>::max()) {
        *result_ptr = std::numeric_limits<int16_t>::max();
    } else if (rounded_shifted_val < std::numeric_limits<int16_t>::min()) {
        *result_ptr = std::numeric_limits<int16_t>::min();
    } else {
        *result_ptr = static_cast<int16_t>(rounded_shifted_val);
    }
}
#endif // __ARM_NEON

这样的SIMD优化能够显著提升矩阵乘法的性能,特别是在处理较大矩阵时。它将原本顺序执行的指令,并行化到多个数据通道上,充分利用了硬件资源。

VI. 精度与性能的平衡与评估

在定点化过程中,精度和性能往往是一对矛盾。我们需要仔细评估并找到最佳平衡点。

精度分析:量化误差来源,SNR (信噪比)

定点化引入的误差主要来源于:

  1. 量化误差: 浮点数转换为定点数时,由于精度有限而产生的误差。
  2. **截

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注