C++ 定点数运算库:在低功耗嵌入式 AI 芯片上的高效矩阵乘法实现
尊敬的各位同仁,女士们、先生们,大家好!
今天,我们齐聚一堂,共同探讨一个在当前人工智能浪潮中至关重要的话题:如何在低功耗嵌入式AI芯片上,利用C++定点数运算库,实现高效的矩阵乘法。随着AI技术从云端走向边缘,我们面临着前所未有的机遇与挑战。在资源受限的环境中,如何在保证模型性能的同时,最大限度地提升计算效率、降低功耗,是每一位工程师必须深思熟虑的问题。定点数运算,正是解决这一难题的关键利器。
引言:嵌入式AI与定点数运算的时代机遇与挑战
近年来,人工智能,特别是深度学习,取得了突破性进展,深刻改变了我们的生活。从智能语音助手到自动驾驶,从图像识别到自然语言处理,AI的应用场景无处不在。然而,随着模型规模的不断扩大和计算复杂度的急剧提升,将这些强大的AI能力部署到终端设备,如智能手机、物联网设备、可穿戴设备乃至微型传感器上,面临着严峻的挑战。这些“边缘侧”设备通常受限于:
- 低功耗要求: 电池供电或有限的电源供应,要求芯片在极低的功耗下运行。
- 实时性要求: 许多应用需要即时响应,如自动驾驶的决策,人脸识别的验证。
- 小尺寸与低成本: 嵌入式设备通常对物理尺寸和制造成本有严格限制。
- 有限的内存与带宽: 无法承载大型模型和高带宽的数据传输。
在深度学习模型中,矩阵乘法(Matrix Multiplication)是核心且计算量最大的操作之一,占据了高达80%甚至更多的计算资源。传统的AI模型训练和推理大多采用浮点数(如FP32、FP16)进行运算。浮点数提供了宽广的数值范围和较高的精度,但其代价是:
- 硬件复杂度高: 浮点运算单元(FPU)比整数运算单元(ALU)面积更大,功耗更高。
- 计算效率低: 浮点运算需要处理指数和尾数,指令周期更长。
- 内存带宽占用大: 存储和传输浮点数需要更多的位,增加了内存访问的开销。
为了克服这些局限,定点数运算应运而生。定点数运算的核心思想是用整数来近似表示实数,通过预先确定小数点的位置来管理精度和范围。这使得定点数运算能够:
- 降低硬件成本与功耗: 利用简单的整数ALU即可实现,无需复杂FPU。
- 提升计算效率: 整数运算速度快,指令周期短。
- 减少内存占用与带宽需求: 通常使用16位、8位甚至更低的位宽,极大压缩数据量。
因此,构建一个高效、灵活且易于使用的C++定点数运算库,并将其应用于矩阵乘法优化,对于推动嵌入式AI芯片的发展,具有不可估量的价值。本文将深入探讨定点数的基础理论、库设计原则、矩阵乘法的优化策略,以及精度与性能的平衡之道。
II. 定点数基础理论与表示
要设计一个定点数库,我们首先需要理解定点数的基本概念和表示方法。
什么是定点数?
定点数,顾名思义,是小数点位置固定的数。与浮点数通过指数部分动态调整小数点位置不同,定点数的小数点位置在设计时就已确定。它由一个整数部分和一个小数部分组成。例如,一个数123.456,如果我们规定小数点后有三位,那么它就可以被视为123456这个整数,只是我们知道它的实际值需要除以1000。
表示方法:Q格式与S.M.F格式
在数字信号处理和嵌入式领域,常用的定点数表示方法有两种:
-
Q格式 (Qm.n):
m表示整数部分的位数(不包括符号位)。n表示小数部分的位数。- 总位宽通常为
m + n + 1(如果是有符号数,加1是符号位)。 - 例如,Q1.15表示一个16位有符号定点数,其中1位是符号位,1位是整数位,15位是小数位。它的取值范围为
[-2^1, 2^1 - 2^-15],最小精度为2^-15。 - 实际存储的是
Value * 2^n的整数形式。
-
S.M.F格式:
S表示符号位 (1位)。M表示整数部分的位数。F表示小数部分的位数。- 总位宽为
S + M + F。 - 例如,S1.15表示一个16位有符号定点数,1位符号位,1位整数位,15位小数位,与Q1.15含义相同。
在本文中,我们将主要采用类似Q格式的思路,通过模板参数来定义总位宽和小数位宽。
符号位:有符号/无符号
定点数可以是无符号的(只表示非负数),也可以是有符号的(表示正负数)。有符号数通常采用二进制补码(Two’s Complement)表示,这使得加减法运算与无符号数保持一致,简化了硬件实现。
定点数选择的考量:精度与范围的权衡
选择合适的定点数格式是定点量化的核心挑战。我们需要在精度(由小数位数决定)和数值范围(由整数位数决定)之间进行权衡:
- 增加小数位数 (
n): 提升精度,但会缩小数值范围(在总位宽固定的情况下)。 - 增加整数位数 (
m): 扩大数值范围,但会降低精度。
在AI模型中,激活值(如ReLU输出)通常是非负的,可以使用无符号定点数以获取更大的正数范围。而权重和某些中间计算结果可能为负值,则必须使用有符号定点数。
溢出与饱和处理
定点数在运算过程中极易发生溢出。例如,两个Q1.15的数相乘,结果可能需要Q2.30来表示。如果直接存储回Q1.15,就会丢失信息。
- 溢出 (Overflow): 运算结果超出了当前定点数格式能表示的最大或最小范围。
- 饱和 (Saturation): 当发生溢出时,将结果限制在当前定点数格式所能表示的最大值或最小值。这是嵌入式系统中常用的处理方式,因为它能避免结果突变,保持一定的信号特性。
- 环绕 (Wrap-around): 溢出时结果“回卷”,例如
MAX_INT + 1 = MIN_INT。这通常会导致结果的显著失真,不适用于大多数AI场景。
C++中的定点数模拟:基于整数类型
C++标准库并没有内置的定点数类型。我们通常通过封装C++的整数类型(int8_t, int16_t, int32_t, int64_t等)来模拟定点数。其核心思想是,将一个实数 X 乘以一个比例因子 2^n(n 为小数位数),然后存储这个整数结果。
例如,如果我们的定点数格式是Q1.15,总共16位,其中1位符号位,1位整数位,15位小数位。那么一个浮点数f_val在转换为定点数时,其内部存储的整数值 raw_val 就是 round(f_val * 2^15)。
#include <cstdint>
#include <cmath> // For round, pow
#include <limits> // For numeric_limits
#include <iostream>
#include <type_traits> // For std::is_signed
// --- Code Example 1: Basic Fixed-Point Representation (struct/class design) ---
template <int TOTAL_BITS, int FRACTIONAL_BITS, bool IS_SIGNED = true>
class FixedPoint {
public:
// 内部存储类型,根据总位宽选择合适的整数类型
// 使用 std::conditional 来选择 int16_t, int32_t, int64_t
using internal_type = typename std::conditional<
TOTAL_BITS <= 16, int16_t,
typename std::conditional<
TOTAL_BITS <= 32, int32_t,
int64_t
>::type
>::type;
internal_type raw_value;
static constexpr int s_total_bits = TOTAL_BITS;
static constexpr int s_fractional_bits = FRACTIONAL_BITS;
static constexpr int s_integer_bits = TOTAL_BITS - FRACTIONAL_BITS - (IS_SIGNED ? 1 : 0);
static constexpr bool s_is_signed = IS_SIGNED;
static constexpr internal_type s_scale = internal_type(1) << FRACTIONAL_BITS;
// 构造函数
FixedPoint() : raw_value(0) {}
explicit FixedPoint(internal_type raw_val) : raw_value(raw_val) {}
// 从浮点数转换
static FixedPoint from_float(float f_val) {
// 防止溢出,这里只是简单实现,后续会加入饱和处理
float scaled_val = f_val * s_scale;
if (s_is_signed) {
if (scaled_val > std::numeric_limits<internal_type>::max()) {
scaled_val = std::numeric_limits<internal_type>::max();
} else if (scaled_val < std::numeric_limits<internal_type>::min()) {
scaled_val = std::numeric_limits<internal_type>::min();
}
} else { // Unsigned
if (scaled_val > std::numeric_limits<internal_type>::max()) {
scaled_val = std::numeric_limits<internal_type>::max();
} else if (scaled_val < 0) { // Unsigned can't be negative
scaled_val = 0;
}
}
return FixedPoint(static_cast<internal_type>(std::round(scaled_val)));
}
// 转换为浮点数
float to_float() const {
return static_cast<float>(raw_value) / s_scale;
}
// 获取最大值和最小值 (作为浮点数)
static float max_float() {
if (s_is_signed) {
return static_cast<float>(std::numeric_limits<internal_type>::max()) / s_scale;
} else {
return static_cast<float>(std::numeric_limits<internal_type>::max()) / s_scale;
}
}
static float min_float() {
if (s_is_signed) {
return static_cast<float>(std::numeric_limits<internal_type>::min()) / s_scale;
} else { // Unsigned min is always 0
return 0.0f;
}
}
// 打印
friend std::ostream& operator<<(std::ostream& os, const FixedPoint& fp) {
os << fp.to_float() << " (raw: " << fp.raw_value << ")";
return os;
}
// 为了示例,定义一个简单的相等运算符
bool operator==(const FixedPoint& other) const {
return raw_value == other.raw_value;
}
// 更多运算符将在下一节中实现
};
// 示例用法
/*
int main() {
// 定义一个Q1.15格式的定点数 (16位总位宽, 15位小数位, 有符号)
using Q1_15 = FixedPoint<16, 15, true>;
float f1 = 0.5f;
float f2 = -0.125f;
float f3 = 1.99f;
float f4 = -1.99f; // 会饱和到最小值
Q1_15 q1 = Q1_15::from_float(f1);
Q1_15 q2 = Q1_15::from_float(f2);
Q1_15 q3 = Q1_15::from_float(f3);
Q1_15 q4 = Q1_15::from_float(f4);
std::cout << "Q1.15 Max Float: " << Q1_15::max_float() << std::endl;
std::cout << "Q1.15 Min Float: " << Q1_15::min_float() << std::endl;
std::cout << "Float " << f1 << " -> Q1.15 " << q1 << std::endl;
std::cout << "Float " << f2 << " -> Q1.15 " << q2 << std::endl;
std::cout << "Float " << f3 << " -> Q1.15 " << q3 << std::endl; // 接近最大值
std::cout << "Float " << f4 << " -> Q1.15 " << q4 << std::endl; // 接近最小值
// 定义一个Q7.8格式的定点数 (16位总位宽, 8位小数位, 有符号)
using Q7_8 = FixedPoint<16, 8, true>;
float f_large = 123.45f;
Q7_8 q_large = Q7_8::from_float(f_large);
std::cout << "Q7.8 Max Float: " << Q7_8::max_float() << std::endl;
std::cout << "Q7.8 Min Float: " << Q7_8::min_float() << std::endl;
std::cout << "Float " << f_large << " -> Q7.8 " << q_large << std::endl;
// 无符号定点数 (例如 Q8.8, 16位总位宽, 8位小数位, 无符号)
using UQ8_8 = FixedPoint<16, 8, false>;
float uf1 = 123.45f;
float uf2 = -5.0f; // 无符号会饱和到0
UQ8_8 uq1 = UQ8_8::from_float(uf1);
UQ8_8 uq2 = UQ8_8::from_float(uf2);
std::cout << "UQ8.8 Max Float: " << UQ8_8::max_float() << std::endl;
std::cout << "UQ8.8 Min Float: " << UQ8_8::min_float() << std::endl;
std::cout << "Float " << uf1 << " -> UQ8.8 " << uq1 << std::endl;
std::cout << "Float " << uf2 << " -> UQ8.8 " << uq2 << std::endl;
return 0;
}
*/
在上述代码中,我们定义了一个FixedPoint模板类,它通过模板参数TOTAL_BITS、FRACTIONAL_BITS和IS_SIGNED来灵活配置定点数的格式。internal_type根据总位宽自动选择合适的C++整数类型,确保了存储效率。s_scale是2的FRACTIONAL_BITS次方,用于浮点数和定点数之间的转换。
III. 定点数运算原理与实现
实现定点数库的核心在于正确地实现其基本算术运算。
加法与减法:对齐小数点
定点数的加减法非常直观,只需像整数一样直接对内部的raw_value进行加减即可,因为它们的小数点位置是隐式对齐的。
Q(A) + Q(B) = (A * 2^n) + (B * 2^n) = (A+B) * 2^n = Q(A+B)
但需要注意的是,加减法可能会导致溢出,特别是当两个大数值相加时。
乘法:精度扩张与截断
定点数乘法是所有运算中最需要细致处理的。两个定点数相乘,其结果的小数位数会增加。
假设我们有两个定点数 $FP_1$ (Q$m_1.n_1$) 和 $FP_2$ (Q$m_2.n_2$)。
它们的内部存储值分别为 $R_1 = FP_1 times 2^{n_1}$ 和 $R_2 = FP_2 times 2^{n2}$。
它们的乘积 $FP{res} = FP_1 times FP2$。
那么结果的内部存储值 $R{res}$ 应该是 $FP{res} times 2^{n{res}}$。
如果直接对内部整数值进行乘法:
$R_1 times R_2 = (FP_1 times 2^{n_1}) times (FP_2 times 2^{n_2}) = (FP_1 times FP_2) times 2^{n_1 + n_2}$
这意味着,直接相乘后的整数结果,其隐含的小数位数是 $n_1 + n2$。如果我们希望将结果存储回一个具有 $n{res}$ 小数位数的定点数格式,就需要进行定标(scaling),即将结果右移 (n_1 + n_2) - n_{res} 位。
通常,为了避免中间结果溢出,乘法操作会先将两个定点数的内部整数值相乘,这会得到一个位宽更长的结果(例如,两个16位整数相乘可能得到32位结果)。然后,再根据目标定点数格式进行右移和舍入。
除法:精度提升与舍入
定点数除法比乘法更复杂一些,因为它可能导致精度损失或需要额外的精度提升。
$FP_1 / FP_2 = (R_1 / 2^{n_1}) / (R_2 / 2^{n_2}) = (R_1 / R_2) times 2^{n_2 – n_1}$
如果直接进行整数除法 $R_1 / R_2$,得到的结果隐含的小数位数是 $n_1 – n_2$。
为了保持精度或将其转换为目标小数位数,我们通常会在进行整数除法之前,先将被除数 $R_1$ 左移一定的位数,以提升中间结果的精度。
例如,如果 (n_1 - n_2) 是负数,意味着结果的小数位数会减少,精度下降。为了得到一个 n_{res} 小数位的结果,我们需要将被除数 $R1$ 左移 `n{res} – (n_1 – n_2)` 位,然后再进行整数除法。
类型转换与定标 (Scaling)
float到FixedPoint:raw_value = round(f_val * 2^n)FixedPoint到float:f_val = raw_value / 2^n- 不同
FixedPoint格式之间转换:
例如,从Q$m_1.n_1$ 转换为 Q$m_2.n_2$。
如果n_1 > n_2,需要右移n_1 - n_2位(舍弃精度)。
如果n_1 < n_2,需要左移n_2 - n_1位(提升精度,但不会增加实际有效位)。
这其中同样需要考虑溢出和饱和。
舍入策略
在定点数运算中,特别是在乘法、除法和类型转换中,当结果需要截断时,如何处理被舍弃的位决定了最终的精度。常见的舍入策略包括:
- 截断 (Truncation / Round Towards Zero): 直接丢弃小数部分,向零方向取整。最简单,但可能引入较大偏差。
- 最近偶数 (Round Half To Even / Banker’s Rounding): 当小数部分恰好为0.5时,向最近的偶数取整。这是IEEE 754浮点标准默认的舍入方式,能有效减少累积误差。
- 向上取整 (Round Up / Ceiling): 向正无穷方向取整。
- 向下取整 (Round Down / Floor): 向负无穷方向取整。
对于嵌入式系统,通常会选择截断或最近偶数,因为它们在硬件实现上相对简单或精度累积误差较小。
溢出策略:饱和、环绕
如前所述,当运算结果超出目标定点数格式的表示范围时:
- 饱和 (Saturation): 将结果限制在最大值或最小值。这是AI推理中最常用的策略,因为它能保持数值的合理性,避免模型崩溃。
- 环绕 (Wrap-around): 结果回卷。在大多数AI场景中应避免。
*— Code Example 2: Fixed-Point Arithmetic Operators (+, -, , /) —
— Code Example 3: Saturation and Rounding Implementations —**
我们将这些策略集成到FixedPoint类中。为了简化,这里先实现一个默认的饱和和截断/近似舍入版本。
// 假设这是 FixedPoint 类定义内部
template <int TOTAL_BITS, int FRACTIONAL_BITS, bool IS_SIGNED = true>
class FixedPoint {
public:
// ... (previous definitions: internal_type, raw_value, static constexpr members) ...
// Saturated helper function
static internal_type saturate(int64_t value) {
if (s_is_signed) {
if (value > std::numeric_limits<internal_type>::max()) {
return std::numeric_limits<internal_type>::max();
}
if (value < std::numeric_limits<internal_type>::min()) {
return std::numeric_limits<internal_type>::min();
}
} else { // Unsigned
if (value < 0) return 0; // Saturate to 0 for unsigned
if (value > std::numeric_limits<internal_type>::max()) {
return std::numeric_limits<internal_type>::max();
}
}
return static_cast<internal_type>(value);
}
// Rounding helper function (round to nearest, half up)
// For fixed-point, right shifting usually implies truncation.
// To implement round-to-nearest, we add 0.5 (represented as 1 << (shift - 1))
// before the shift.
static internal_type round_and_shift(int64_t value, int shift_amount) {
if (shift_amount <= 0) return static_cast<internal_type>(value); // No shift or left shift
// Round to nearest: add half of the LSB of the part being shifted out
// For positive numbers, add (1 << (shift_amount - 1))
// For negative numbers, subtract (1 << (shift_amount - 1))
// A simpler way for symmetric rounding is (value + (1 << (shift_amount - 1))) >> shift_amount for positive
// and (value - (1 << (shift_amount - 1))) >> shift_amount for negative.
// A common technique for signed values:
if (value >= 0) {
value = (value + (internal_type(1) << (shift_amount - 1)));
} else {
value = (value - (internal_type(1) << (shift_amount - 1)));
}
return static_cast<internal_type>(value >> shift_amount);
}
// 加法运算符
FixedPoint operator+(const FixedPoint& other) const {
// 中间结果需要更宽的位宽以避免溢出,然后进行饱和
return FixedPoint(saturate(static_cast<int64_t>(raw_value) + other.raw_value));
}
// 减法运算符
FixedPoint operator-(const FixedPoint& other) const {
return FixedPoint(saturate(static_cast<int64_t>(raw_value) - other.raw_value));
}
// 乘法运算符
FixedPoint operator*(const FixedPoint& other) const {
// 两个Q(m.n)相乘,结果是Q(m+m'+1.n+n'),这里假设m=m', n=n'
// 乘积的原始小数位是 s_fractional_bits + other.s_fractional_bits
// 但我们想将结果存回 s_fractional_bits
// 因此需要右移 s_fractional_bits 位
// 使用更宽的类型进行乘法,防止中间结果溢出
int64_t prod = static_cast<int64_t>(raw_value) * other.raw_value;
// 舍入并右移,将小数位调整回 s_fractional_bits
// 这里的 round_and_shift 已经包含了饱和逻辑
return FixedPoint(saturate(round_and_shift(prod, s_fractional_bits)));
}
// 除法运算符 (简化版,可能精度损失较大)
FixedPoint operator/(const FixedPoint& other) const {
// 为了提高除法精度,可以在除法前将被除数左移 s_fractional_bits 位
// 然后再进行整数除法
if (other.raw_value == 0) {
// 处理除零错误,这里简化为返回最大值或最小值
if (raw_value > 0) return FixedPoint(std::numeric_limits<internal_type>::max());
else if (raw_value < 0) return FixedPoint(std::numeric_limits<internal_type>::min());
else return FixedPoint(0); // 0 / 0
}
int64_t numerator = static_cast<int64_t>(raw_value) << s_fractional_bits;
int64_t result_raw = numerator / other.raw_value;
return FixedPoint(saturate(result_raw));
}
// 赋值运算符
FixedPoint& operator+=(const FixedPoint& other) { *this = *this + other; return *this; }
FixedPoint& operator-=(const FixedPoint& other) { *this = *this - other; return *this; }
FixedPoint& operator*=(const FixedPoint& other) { *this = *this * other; return *this; }
FixedPoint& operator/=(const FixedPoint& other) { *this = *this / other; return *this; }
// 比较运算符 (仅示例)
bool operator<(const FixedPoint& other) const { return raw_value < other.raw_value; }
bool operator>(const FixedPoint& other) const { return raw_value > other.raw_value; }
bool operator<=(const FixedPoint& other) const { return raw_value <= other.raw_value; }
bool operator>=(const FixedPoint& other) const { return raw_value >= other.raw_value; }
bool operator!=(const FixedPoint& other) const { return raw_value != other.raw_value; }
// 显式类型转换到其他 FixedPoint 类型 (需要处理位宽和小数位调整)
template <int OTHER_TOTAL_BITS, int OTHER_FRACTIONAL_BITS, bool OTHER_IS_SIGNED>
explicit operator FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED>() const {
using TargetFixedPoint = FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED>;
int64_t temp_raw = raw_value;
int shift_diff = TargetFixedPoint::s_fractional_bits - s_fractional_bits;
if (shift_diff > 0) { // Target has more fractional bits, left shift
temp_raw <<= shift_diff;
} else if (shift_diff < 0) { // Target has fewer fractional bits, right shift (round)
temp_raw = round_and_shift(temp_raw, -shift_diff);
}
// 最后进行饱和处理,因为位宽可能也变了
return TargetFixedPoint(TargetFixedPoint::saturate(temp_raw));
}
};
/*
int main() {
using Q1_15 = FixedPoint<16, 15, true>;
Q1_15 a = Q1_15::from_float(0.75f); // raw: 0.75 * 2^15 = 24576
Q1_15 b = Q1_15::from_float(0.5f); // raw: 0.5 * 2^15 = 16384
std::cout << "a: " << a << std::endl;
std::cout << "b: " << b << std::endl;
Q1_15 sum = a + b; // raw: 24576 + 16384 = 40960. Should be 1.25
std::cout << "a + b = " << sum << std::endl; // Expect ~1.25
Q1_15 diff = a - b; // raw: 24576 - 16384 = 8192. Should be 0.25
std::cout << "a - b = " << diff << std::endl; // Expect ~0.25
Q1_15 prod = a * b; // raw: (24576 * 16384) >> 15 = (402653184) >> 15 = 12288. Should be 0.375
std::cout << "a * b = " << prod << std::endl; // Expect ~0.375
Q1_15 div_res = a / b; // raw: (24576 << 15) / 16384 = (805306368) / 16384 = 49152. Should be 1.5
std::cout << "a / b = " << div_res << std::endl; // Expect ~1.5
// 溢出测试
using Q0_7 = FixedPoint<8, 7, true>; // 范围约 [-1.0, 0.99...]
Q0_7 x = Q0_7::from_float(0.75f);
Q0_7 y = Q0_7::from_float(0.5f);
Q0_7 ov_sum = x + y; // 0.75 + 0.5 = 1.25, 会饱和到最大值 (0.99...)
std::cout << "Q0_7 x: " << x << ", y: " << y << std::endl;
std::cout << "Q0_7 x + y (overflow): " << ov_sum << std::endl;
std::cout << "Q0_7 max float: " << Q0_7::max_float() << std::endl;
// 类型转换
using Q7_8 = FixedPoint<16, 8, true>;
Q7_8 converted_a = static_cast<Q7_8>(a); // 从 Q1.15 转 Q7.8
std::cout << "Q1.15 a: " << a << " converted to Q7.8: " << converted_a << std::endl;
return 0;
}
*/
IV. 定点数库设计与架构
一个健壮且高效的定点数库需要精心设计的架构。我们将利用C++的模板元编程和策略模式来构建这个库。
模板元编程的应用:泛化类型、位宽、小数位
模板元编程允许我们在编译期根据模板参数生成不同的代码。这对于定点数库非常有用,因为它能让我们:
- 灵活定义定点数格式: 通过模板参数
TOTAL_BITS,FRACTIONAL_BITS,IS_SIGNED,用户可以根据需求创建任意位宽、小数位的定点数类型,如FixedPoint<16, 8, true>(Q7.8),FixedPoint<32, 24, true>(Q7.24) 等。 - 编译期优化: 许多计算,如
s_scale的计算,可以在编译期完成,避免运行时开销。 - 类型安全: 编译器会检查不同定点数类型之间的操作是否合法,例如,不允许将一个
FixedPoint<16,8>直接赋值给FixedPoint<8,4>而不进行显式转换。
类设计:FixedPoint 类
我们已经初步构建了FixedPoint类,现在进一步完善其设计:
- 内部存储类型 (
internal_type): 必须足够大以容纳指定位宽的整数。使用std::conditional来选择最小合适的类型(int8_t,int16_t,int32_t,int64_t),从而节省内存。 - 模板参数:
TOTAL_BITS:总位数(包括符号位)。FRACTIONAL_BITS:小数部分的位数。IS_SIGNED:布尔值,表示是否有符号。
- 构造函数、转换函数: 提供从整数、浮点数构造,以及转换为浮点数的接口。也应提供不同
FixedPoint类型之间的显式转换操作符。 - 运算符重载:
+,-,*,/,+=,-=,*=,/=,==,!=,<,>,<=,>=等,使得定点数操作像内置类型一样自然。 - 静态辅助函数:
from_float,to_float,max_value,min_value等。
策略模式的应用:舍入策略、溢出策略可插拔
在嵌入式系统中,对舍入和溢出处理有不同的需求。策略模式允许我们将这些行为封装到独立的策略类中,并在FixedPoint类中作为模板参数引入,从而实现行为的动态切换(在编译期)。
// --- Code Example 4: FixedPoint Class Template Structure (Refined with Policies) ---
// --- Code Example 5: Policy-Based Rounding/Saturation ---
// 1. 溢出策略基类 (或概念)
template <typename InternalType>
struct SaturationPolicy {
static InternalType saturate(int64_t value) {
// 默认实现:饱和到目标类型最大最小值
if (value > std::numeric_limits<InternalType>::max()) {
return std::numeric_limits<InternalType>::max();
}
if (value < std::numeric_limits<InternalType>::min()) {
return std::numeric_limits<InternalType>::min();
}
return static_cast<InternalType>(value);
}
};
// 无符号数的饱和策略
template <typename InternalType>
struct UnsignedSaturationPolicy {
static InternalType saturate(int64_t value) {
if (value < 0) return 0;
if (value > std::numeric_limits<InternalType>::max()) {
return std::numeric_limits<InternalType>::max();
}
return static_cast<InternalType>(value);
}
};
// 环绕溢出策略 (作为对比,实际AI中不常用)
template <typename InternalType>
struct WrapAroundPolicy {
static InternalType saturate(int64_t value) {
return static_cast<InternalType>(value); // 直接截断
}
};
// 2. 舍入策略基类 (或概念)
template <typename InternalType>
struct RoundingPolicy {
static InternalType round_and_shift(int64_t value, int shift_amount) {
if (shift_amount <= 0) return static_cast<InternalType>(value);
// 默认实现:Round Half Up (对于正数,加0.5后截断)
// 对于负数,round(x) = floor(x + 0.5)
if (value >= 0) {
value = (value + (InternalType(1) << (shift_amount - 1)));
} else {
value = (value - (InternalType(1) << (shift_amount - 1))); // 向负无穷方向舍入,需要调整
}
return static_cast<InternalType>(value >> shift_amount);
}
};
// 截断策略
template <typename InternalType>
struct TruncationPolicy {
static InternalType round_and_shift(int64_t value, int shift_amount) {
if (shift_amount <= 0) return static_cast<InternalType>(value);
return static_cast<InternalType>(value >> shift_amount); // 直接截断
}
};
// 3. 改进的 FixedPoint 类,接受策略作为模板参数
template <int TOTAL_BITS, int FRACTIONAL_BITS,
bool IS_SIGNED = true,
template <typename> class Saturation = std::conditional_t<IS_SIGNED, SaturationPolicy, UnsignedSaturationPolicy>,
template <typename> class Rounding = RoundingPolicy>
class FixedPoint {
public:
using internal_type = typename std::conditional<
TOTAL_BITS <= 16, int16_t,
typename std::conditional<
TOTAL_BITS <= 32, int32_t,
int64_t
>::type
>::type;
internal_type raw_value;
static constexpr int s_total_bits = TOTAL_BITS;
static constexpr int s_fractional_bits = FRACTIONAL_BITS;
static constexpr bool s_is_signed = IS_SIGNED;
static constexpr internal_type s_scale = internal_type(1) << FRACTIONAL_BITS;
// 策略的应用
internal_type apply_saturation(int64_t val) const {
return Saturation<internal_type>::saturate(val);
}
internal_type apply_rounding_and_shift(int64_t val, int shift_amount) const {
return Rounding<internal_type>::round_and_shift(val, shift_amount);
}
// 构造函数、to_float, from_float等保持不变,但要调用 apply_saturation
FixedPoint() : raw_value(0) {}
explicit FixedPoint(internal_type raw_val) : raw_value(raw_val) {}
static FixedPoint from_float(float f_val) {
float scaled_val = f_val * s_scale;
// 这里的饱和逻辑需要根据 IS_SIGNED 和 Saturation Policy 调整
// 简化起见,这里直接调用默认的 SaturationPolicy::saturate
return FixedPoint(Saturation<internal_type>::saturate(static_cast<int64_t>(std::round(scaled_val))));
}
float to_float() const {
return static_cast<float>(raw_value) / s_scale;
}
// 重载运算符,现在调用策略函数
FixedPoint operator+(const FixedPoint& other) const {
return FixedPoint(apply_saturation(static_cast<int64_t>(raw_value) + other.raw_value));
}
FixedPoint operator-(const FixedPoint& other) const {
return FixedPoint(apply_saturation(static_cast<int64_t>(raw_value) - other.raw_value));
}
FixedPoint operator*(const FixedPoint& other) const {
int64_t prod = static_cast<int64_t>(raw_value) * other.raw_value;
return FixedPoint(apply_saturation(apply_rounding_and_shift(prod, s_fractional_bits)));
}
FixedPoint operator/(const FixedPoint& other) const {
if (other.raw_value == 0) { /* handle division by zero */ return FixedPoint(apply_saturation(std::numeric_limits<int64_t>::max())); }
int64_t numerator = static_cast<int64_t>(raw_value) << s_fractional_bits;
int64_t result_raw = numerator / other.raw_value;
return FixedPoint(apply_saturation(result_raw));
}
// 类型转换操作符也使用策略
template <int OTHER_TOTAL_BITS, int OTHER_FRACTIONAL_BITS, bool OTHER_IS_SIGNED,
template <typename> class OTHER_SAT, template <typename> class OTHER_ROUND>
explicit operator FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED, OTHER_SAT, OTHER_ROUND>() const {
using TargetFixedPoint = FixedPoint<OTHER_TOTAL_BITS, OTHER_FRACTIONAL_BITS, OTHER_IS_SIGNED, OTHER_SAT, OTHER_ROUND>;
int64_t temp_raw = raw_value;
int shift_diff = TargetFixedPoint::s_fractional_bits - s_fractional_bits;
if (shift_diff > 0) { // Target has more fractional bits, left shift
temp_raw <<= shift_diff;
} else if (shift_diff < 0) { // Target has fewer fractional bits, right shift (round)
temp_raw = apply_rounding_and_shift(temp_raw, -shift_diff); // 使用当前对象的舍入策略
}
return TargetFixedPoint(TargetFixedPoint().apply_saturation(temp_raw)); // 使用目标类型的饱和策略
}
// 打印等其他成员保持不变
friend std::ostream& operator<<(std::ostream& os, const FixedPoint& fp) {
os << fp.to_float() << " (raw: " << fp.raw_value << ")";
return os;
}
};
/*
int main() {
// 使用默认策略的Q1.15 (有符号,饱和,Round Half Up)
using Q1_15_Default = FixedPoint<16, 15, true>;
Q1_15_Default a_def = Q1_15_Default::from_float(0.75f);
Q1_15_Default b_def = Q1_15_Default::from_float(0.5f);
std::cout << "Default Q1.15 a*b: " << (a_def * b_def) << std::endl;
// 使用截断舍入策略的Q1.15
using Q1_15_Trunc = FixedPoint<16, 15, true, SaturationPolicy, TruncationPolicy>;
Q1_15_Trunc a_trunc = Q1_15_Trunc::from_float(0.75f);
Q1_15_Trunc b_trunc = Q1_15_Trunc::from_float(0.5f);
std::cout << "Truncation Q1.15 a*b: " << (a_trunc * b_trunc) << std::endl;
// 无符号Q8.8,使用无符号饱和策略
using UQ8_8_Default = FixedPoint<16, 8, false>;
UQ8_8_Default ua = UQ8_8_Default::from_float(100.0f);
UQ8_8_Default ub = UQ8_8_Default::from_float(-10.0f); // 会被饱和到0
std::cout << "Unsigned UQ8_8 a: " << ua << ", b: " << ub << std::endl;
std::cout << "Unsigned UQ8_8 a+b: " << (ua + ub) << std::endl;
return 0;
}
*/
性能优化考量
在嵌入式AI芯片上,性能是第一位的。我们的库设计需要充分考虑以下优化:
- 避免动态内存分配:
FixedPoint对象应完全存储在栈上或静态存储区,避免堆内存分配,因为堆操作开销大且可能导致内存碎片。 - 内联函数: 关键的运算符和辅助函数应声明为
inline,鼓励编译器将其展开,消除函数调用开销。 - 位运算: 充分利用位移(
<<,>>)进行定标,比浮点乘除法快得多。 constexpr: 尽可能将可在编译期确定的值(如s_scale)声明为constexpr,进一步优化。- 避免不必要的类型转换: 尽量在同一种
FixedPoint类型之间进行运算,减少定标和饱和的开销。 - 编译器优化标志: 配合
O2/O3等编译器优化选项。
V. 矩阵乘法在定点数域的优化实现
现在,我们将定点数库应用于核心任务:矩阵乘法。
矩阵乘法基础:$C = A times B$
给定一个 $M times K$ 矩阵 $A$ 和一个 $K times N$ 矩阵 $B$,它们的乘积 $C$ 是一个 $M times N$ 矩阵,其中 $C{ij} = sum{p=0}^{K-1} A{ip} times B{pj}$。
定点数矩阵乘法的挑战
- 中间结果的精度累积 (乘积和): 矩阵乘法涉及大量的乘加操作。每个乘法结果都需要定标,然后累加。如果累加器位宽不足,连续的加法很容易导致溢出。
- 溢出风险管理: 如何选择合适的累加器位宽,以及何时进行饱和处理,是关键。
- 效率: 即使是定点数,大量的内存访问和计算仍然需要高效的实现。
累加器位宽选择:防止中间结果溢出
这是定点数矩阵乘法中最关键的设计点。
假设 $A$ 的元素是 Q$m_A.n_A$, $B$ 的元素是 Q$m_B.nB$。
一个乘积 $A{ip} times B_{pj}$ 的原始小数位是 $n_A + n_B$。
累加 $K$ 个这样的乘积,最大值可能达到 $K times (text{MaxA} times text{MaxB})$。
为了避免溢出,累加器的位宽必须足够大。
通常,累加器的位宽至少要能容纳 Max(A) * Max(B) * K 的值。
如果 $A, B$ 是16位定点数,它们的乘积是32位。如果 $K=256$,那么累加器至少需要 32 + log2(256) = 32 + 8 = 40 位。因此,使用 int64_t 作为累加器是常见的做法。
分块乘法 (Tiling):缓存优化,减少内存访问
在现代处理器中,内存访问速度远低于CPU计算速度。分块乘法是一种有效的缓存优化技术。它将大矩阵分成小块,每次只处理一小块数据,确保这些数据能够完全放入高速缓存(L1/L2 Cache),从而减少对主内存的访问,提高数据局部性。
例如,对于 $C{ij} = sum{p=0}^{K-1} A{ip} times B{pj}$,我们可以将 $A, B, C$ 矩阵都分割成若干个小块。在一个块内完成乘加操作,再处理下一个块。
SIMD指令集利用 (SSE/NEON):针对嵌入式AI芯片的硬件加速
大多数低功耗嵌入式AI芯片都集成了SIMD(Single Instruction, Multiple Data)指令集,如ARM的NEON或x86的SSE/AVX。SIMD指令允许处理器同时对多个数据元素执行相同的操作,极大提升并行计算能力。
将定点数运算映射到SIMD指令是实现高效矩阵乘法的关键。
-
ARM NEON:
vmla_s16:16位有符号整数乘积累加。vmull_s16:16位有符号整数乘法,结果扩展到32位。vadd_s32:32位有符号整数加法。vqrdmulh_s16:16位有符号整数乘法,结果取高16位并舍入。vqadd_s16:16位有符号整数饱和加法。vqmovn_s32:32位饱和截断到16位。
-
x86 SSE/AVX:
_mm_madd_epi16:对16位有符号整数进行成对乘法并累加到32位。_mm_mullo_epi16:16位有符号整数乘法,结果截断到16位。_mm_adds_epi16:16位有符号整数饱和加法。
如何将定点数操作映射到SIMD指令?
- 数据打包: 将多个
FixedPoint<16, N>(或FixedPoint<8, N>)的raw_value打包到SIMD寄存器中(例如,NEON的int16x8_t或int8x16_t)。 - 并行乘法: 使用SIMD乘法指令对打包的数据进行并行乘法。这通常会产生位宽更大的中间结果(例如,两个16位乘法得到32位结果)。
- 并行累加: 使用SIMD累加指令将这些中间结果累加到位宽更大的累加器SIMD寄存器中。
- 定标与饱和: 在累加完成后,对累加器结果进行并行右移(定标)和饱和操作,将其转换回目标定点数格式。
循环展开 (Loop Unrolling):减少循环开销
手动或让编译器进行循环展开,可以减少循环控制指令的开销,并为编译器提供更多的指令调度机会,进一步提升性能。
内存布局优化:行主序/列主序,对齐
- 行主序(Row-major)与列主序(Column-major): C++默认是行主序存储。确保矩阵访问模式与内存布局一致,可以提高缓存命中率。
- 内存对齐: SIMD指令通常要求数据按特定边界(如16字节、32字节)对齐。对齐内存可以避免性能下降,甚至避免某些SIMD指令无法使用。可以使用
alignas关键字或自定义内存分配器。
— Code Example 6: Basic Fixed-Point Matrix Multiplication (naive) —
首先,我们来看一个不带任何优化的定点数矩阵乘法基础实现。
#include <vector>
#include <chrono> // For performance measurement
// 假设 FixedPoint 类以及其运算符已经定义好
// using Q_TYPE = FixedPoint<16, 8, true>; // 示例定点数类型
template <typename FP_TYPE>
std::vector<FP_TYPE> multiply_matrices_naive(
const std::vector<FP_TYPE>& A, int M, int K,
const std::vector<FP_TYPE>& B, int N)
{
// A: M x K, B: K x N, C: M x N
if (A.size() != M * K || B.size() != K * N) {
throw std::runtime_error("Matrix dimensions mismatch or invalid sizes.");
}
std::vector<FP_TYPE> C(M * N);
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
// 使用 Q_ACCUMULATOR_TYPE 作为累加器,防止中间结果溢出
// 假设乘法结果是 2 * FP_TYPE::s_total_bits 位,且小数位是 2 * FP_TYPE::s_fractional_bits
// 累加器需要能容纳 K 个这样的结果,并最终转换回 FP_TYPE
// 这里我们创建一个临时的高精度累加器,确保不溢出
// 假设 FP_TYPE 是 FixedPoint<16, 8, true>,那么其 raw_value 是 int16_t
// 乘法结果的 raw_value 是 int32_t,小数位是 16
// 累加 K 次 int32_t,如果 K 很大,需要 int64_t
// 最终的累加器需要在内部处理好精度和饱和
// 为了简化,我们暂时假设 FP_TYPE 内部的乘法运算符已经处理了精度和饱和
// 但更严谨的做法是,这里使用一个更高精度的定点数类型作为累加器
// Option 1: 使用更高精度的定点数作为累加器
// 假设 FP_TYPE 是 FixedPoint<16, 8, true>
// 乘法结果 Q(15.16)
// 累加器需要 Q(15 + log2(K) . 16)
// 我们可以定义一个临时的累加器类型,例如 FixedPoint<32, 16, true>
// 但为了简化,这里暂时直接使用 FP_TYPE::internal_type 作为原始累加器
// 并手动管理其定点数含义
// 更优的累加器类型选择,例如:
// A是 Qm.n, B是 Qm.n, 乘积是 Q(2m+1).2n
// 累加 K 次,需要 (2m+1) + log2(K) 位整数部分
// 小数部分保持 2n
// 假设我们的 FP_TYPE 是 FixedPoint<16, 8, true> (Q7.8)
// 乘积是 Q(15).16
// 如果 K=64 (log2(K)=6),整数位需要 15+6=21位
// 总位宽 21 + 16 + 1 (符号位) = 38位
// 所以累加器需要至少 FixedPoint<40, 16, true> (使用 int64_t 内部存储)
// 为了演示,这里直接使用 int64_t 作为裸整数累加器,并手动管理其定点数含义
// 内部存储的原始值为 int64_t
int64_t acc_raw = 0;
// 每个乘法的结果,其小数位数是 FP_TYPE::s_fractional_bits * 2
// 假设 FP_TYPE 是 Qm.n
// A_raw * B_raw 结果的小数位是 2n
// 如果我们希望累加器的小数位也是 2n,那么直接累加即可
// 最终将累加器结果右移 n 位,转换回 Qm.n
// 这里我们假设 FP_TYPE 的乘法运算符返回的结果已经定标到 FP_TYPE::s_fractional_bits
// 那么累加器的小数位数就和 FP_TYPE 相同
// 这是一个简化的累加器,假设 FP_TYPE 的乘法结果已经定标到 FP_TYPE::s_fractional_bits
// 并且我们希望累加器的小数位数也保持 FP_TYPE::s_fractional_bits
// 这种方式会导致精度损失,但简化了累加器类型选择
// 更优的方法是使用一个具有更长小数位的中间累加器类型
// 重新思考累加器:
// C_ij = SUM( A_ip * B_pj )
// A_ip 的 raw_value 是 val_A * 2^N_f
// B_pj 的 raw_value 是 val_B * 2^N_f
// A_ip * B_pj 的 raw_value 是 (val_A * val_B) * 2^(2*N_f)
// 所以每个乘积的原始整数值是 `A[A_idx].raw_value * B[B_idx].raw_value`
// 其隐含的小数位是 `2 * FP_TYPE::s_fractional_bits`
// 我们用 `int64_t` 来存储这个原始乘积,然后累加
for (int p = 0; p < K; ++p) {
int A_idx = i * K + p;
int B_idx = p * N + j;
// 执行原始整数乘法,结果是 2 * TOTAL_BITS 位,小数位是 2 * FRACTIONAL_BITS
acc_raw += static_cast<int64_t>(A[A_idx].raw_value) * B[B_idx].raw_value;
}
// 累加完成后,将累加器的原始值定标回 FP_TYPE 格式
// 原始小数位是 2 * FP_TYPE::s_fractional_bits
// 目标小数位是 FP_TYPE::s_fractional_bits
// 需要右移 FP_TYPE::s_fractional_bits 位
C[i * N + j] = FP_TYPE(FP_TYPE().apply_saturation(
FP_TYPE().apply_rounding_and_shift(acc_raw, FP_TYPE::s_fractional_bits)
));
}
}
return C;
}
/*
// 示例 main 函数
int main() {
using Q7_8 = FixedPoint<16, 8, true>; // Q7.8 格式
int M = 2, K = 3, N = 2;
std::vector<Q7_8> A_fp = {
Q7_8::from_float(1.0f), Q7_8::from_float(2.0f), Q7_8::from_float(3.0f),
Q7_8::from_float(4.0f), Q7_8::from_float(5.0f), Q7_8::from_float(6.0f)
}; // 2x3 matrix
std::vector<Q7_8> B_fp = {
Q7_8::from_float(7.0f), Q7_8::from_float(8.0f),
Q7_8::from_float(9.0f), Q7_8::from_float(10.0f),
Q7_8::from_float(11.0f), Q7_8::from_float(12.0f)
}; // 3x2 matrix
std::cout << "Matrix A (Q7.8):" << std::endl;
for (int i = 0; i < M; ++i) {
for (int j = 0; j < K; ++j) {
std::cout << A_fp[i * K + j].to_float() << "t";
}
std::cout << std::endl;
}
std::cout << "nMatrix B (Q7.8):" << std::endl;
for (int i = 0; i < K; ++i) {
for (int j = 0; j < N; ++j) {
std::cout << B_fp[i * N + j].to_float() << "t";
}
std::cout << std::endl;
}
auto start = std::chrono::high_resolution_clock::now();
std::vector<Q7_8> C_fp = multiply_matrices_naive(A_fp, M, K, B_fp, N);
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
std::cout << "nMatrix C (Q7.8) - Naive (Time: " << diff.count() << "s):" << std::endl;
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
std::cout << C_fp[i * N + j].to_float() << "t";
}
std::cout << std::endl;
}
// Expected result (float for comparison):
// C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
// C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
// C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
// C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
// Output should be close to:
// 58.0 64.0
// 139.0 154.0
return 0;
}
*/
上述multiply_matrices_naive函数展示了定点数矩阵乘法的基本逻辑,其中累加器acc_raw被声明为int64_t,以确保在累加过程中不会发生溢出。在所有乘加操作完成后,再将最终的累加结果进行一次定标和饱和,转换回目标定点数格式。
— Code Example 7: Optimized Fixed-Point Matrix Multiplication (Accumulator management, basic tiling idea) —
针对嵌入式AI芯片的优化,我们可以在上述基础上引入分块和更精细的累加器管理,并为SIMD集成预留接口。
#include <vector>
#include <chrono>
#include <stdexcept>
// #include <arm_neon.h> // For ARM NEON intrinsics, uncomment for ARM targets
// #include <immintrin.h> // For x86 SSE/AVX intrinsics, uncomment for x86 targets
// 定义一个专门用于矩阵乘法累加的FixedPoint类型
// 假设输入是 FixedPoint<16, 8, true> (Q7.8)
// 乘积的原始小数位是 8+8=16位。
// 如果 K 最大为 256,则整数位需要 7+7+1(符号位) + log2(256) = 15 + 8 = 23位
// 所以累加器总位宽至少 23 + 16 (小数位) + 1 (符号位) = 40位
// 我们可以使用 FixedPoint<40, 16, true>,但为了方便,直接用 FixedPoint<64, 16, true>
// 这样内部存储就是 int64_t,且小数位为 16
template <typename FP_TYPE>
using AccumulatorFP = FixedPoint<64, 2 * FP_TYPE::s_fractional_bits, true>;
template <typename FP_TYPE>
std::vector<FP_TYPE> multiply_matrices_optimized(
const std::vector<FP_TYPE>& A, int M, int K,
const std::vector<FP_TYPE>& B, int N,
int BLOCK_SIZE_M, int BLOCK_SIZE_K, int BLOCK_SIZE_N)
{
if (A.size() != M * K || B.size() != K * N) {
throw std::runtime_error("Matrix dimensions mismatch or invalid sizes.");
}
std::vector<FP_TYPE> C(M * N);
// 假设矩阵是行主序存储
// 循环分块
for (int i_block = 0; i_block < M; i_block += BLOCK_SIZE_M) {
for (int j_block = 0; j_block < N; j_block += BLOCK_SIZE_N) {
for (int p_block = 0; p_block < K; p_block += BLOCK_SIZE_K) {
// 处理当前块内的元素
for (int i = i_block; i < std::min(i_block + BLOCK_SIZE_M, M); ++i) {
for (int j = j_block; j < std::min(j_block + BLOCK_SIZE_N, N); ++j) {
// 初始化高精度累加器
AccumulatorFP<FP_TYPE> acc(0); // 内部 raw_value = 0
for (int p = p_block; p < std::min(p_block + BLOCK_SIZE_K, K); ++p) {
int A_idx = i * K + p;
int B_idx = p * N + j;
// 这里直接使用 FP_TYPE 之间的乘法运算符
// 它的结果已经是定标到 FP_TYPE::s_fractional_bits
// 再次强调,如果 FP_TYPE 的乘法结果不直接定标,这里需要手动管理
// 为了简化,我们假设 FP_TYPE 的乘法结果内部是 FP_TYPE::raw_value
// 但我们希望累加器的内部 raw_value 是 `original_product_raw_value`
// 实际的优化应该是:
// 1. 获取 A 和 B 的原始整数值
// 2. 将它们相乘 (结果是 int64_t, 小数位是 2 * FP_TYPE::s_fractional_bits)
// 3. 将这个结果累加到 AccumulatorFP 的 raw_value 中
// 4. 最终将 AccumulatorFP 的 raw_value 定标回 FP_TYPE
// 改进的乘加逻辑:直接操作原始整数值,并使用 AccumulatorFP 累加
int64_t prod_raw = static_cast<int64_t>(A[A_idx].raw_value) * B[B_idx].raw_value;
acc.raw_value += prod_raw;
}
// 累加完成后,将 AccumulatorFP 的结果转换回 FP_TYPE
// AccumulatorFP 的小数位是 2 * FP_TYPE::s_fractional_bits
// FP_TYPE 的小数位是 FP_TYPE::s_fractional_bits
// 因此需要右移 FP_TYPE::s_fractional_bits 位
int shift_amount = FP_TYPE::s_fractional_bits;
C[i * N + j] = FP_TYPE(FP_TYPE().apply_saturation(
FP_TYPE().apply_rounding_and_shift(acc.raw_value, shift_amount)
));
// 注意:这里调用 FP_TYPE 的 apply_saturation 和 apply_rounding_and_shift
// 因为最终结果要存储为 FP_TYPE
}
}
}
}
}
return C;
}
/*
// 示例 main 函数 (与 naive 类似,只是调用不同的乘法函数)
int main() {
using Q7_8 = FixedPoint<16, 8, true>; // Q7.8 格式
int M = 64, K = 64, N = 64; // 更大的矩阵,以便观察性能差异
// 填充随机数据
std::vector<Q7_8> A_fp(M * K);
std::vector<Q7_8> B_fp(K * N);
for (size_t i = 0; i < A_fp.size(); ++i) A_fp[i] = Q7_8::from_float(static_cast<float>(rand() % 256 - 128) / 10.0f);
for (size_t i = 0; i < B_fp.size(); ++i) B_fp[i] = Q7_8::from_float(static_cast<float>(rand() % 256 - 128) / 10.0f);
std::cout << "Running naive matrix multiplication for " << M << "x" << K << " * " << K << "x" << N << std::endl;
auto start_naive = std::chrono::high_resolution_clock::now();
std::vector<Q7_8> C_naive = multiply_matrices_naive(A_fp, M, K, B_fp, N);
auto end_naive = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff_naive = end_naive - start_naive;
std::cout << "Naive time: " << diff_naive.count() << "s" << std::endl;
std::cout << "Running optimized matrix multiplication (blocking) for " << M << "x" << K << " * " << K << "x" << N << std::endl;
int BLOCK_M = 16, BLOCK_K = 16, BLOCK_N = 16; // 示例块大小
auto start_opt = std::chrono::high_resolution_clock::now();
std::vector<Q7_8> C_opt = multiply_matrices_optimized(A_fp, M, K, B_fp, N, BLOCK_M, BLOCK_K, BLOCK_N);
auto end_opt = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff_opt = end_opt - start_opt;
std::cout << "Optimized (blocking) time: " << diff_opt.count() << "s" << std::endl;
// 简单验证结果是否一致
bool results_match = true;
for(size_t i=0; i < C_naive.size(); ++i) {
if (C_naive[i].raw_value != C_opt[i].raw_value) {
results_match = false;
// std::cout << "Mismatch at " << i << ": Naive=" << C_naive[i] << ", Optimized=" << C_opt[i] << std::endl;
break;
}
}
std::cout << "Results match: " << (results_match ? "Yes" : "No") << std::endl;
return 0;
}
*/
— Code Example 8: SIMD Intrinsics (conceptual sketch for fixed-point dot product) —
SIMD实现将是性能提升的关键。由于SIMD指令集是平台相关的,这里提供一个概念性的骨架,展示如何集成NEON指令。实际应用中需要根据具体芯片架构和编译器进行调整。
// 假设这是 NEON 版本的 FixedPoint 矩阵乘法的一个小片段
// 仅为示意,完整的 SIMD 优化需要考虑数据加载、存储、循环展开、对齐等
// 这是一个针对 16位定点数 (Q7.8) 的点积函数示例
#ifdef __ARM_NEON
#include <arm_neon.h>
// 假设 FP_TYPE 是 FixedPoint<16, 8, true>
// 并且矩阵数据已经对齐
void fixed_point_dot_product_neon_q7_8(
const int16_t* A_row, // 指向 A 矩阵一行的原始数据
const int16_t* B_col, // 指向 B 矩阵一列的原始数据
int K, // 点积长度
int16_t* result_ptr) // 存储最终 Q7.8 结果的指针
{
// NEON 累加器,存储 4个 32位整数,对应 4个 Q15.16 乘积的累加
int32x4_t sum_vec = vdupq_n_s32(0);
// 假设 K 是 4 的倍数,以便一次处理 4 个元素
for (int p = 0; p < K; p += 4) {
// 加载 A 的 4 个 16 位元素
int16x4_t a_vec = vld1_s16(A_row + p);
// 加载 B 的 4 个 16 位元素
int16x4_t b_vec = vld1_s16(B_col + p);
// 执行 16位 x 16位 -> 32位 乘法,并累加到 32位累加器
// vmull_s16 得到 32位乘积,vmlaq_s32 累加
// 注意:vmlaq_s32 是 32位乘积累加,这里我们需要 16x16=32位的乘积
// NEON通常提供 vmull_s16 (long multiply) 将 16位结果扩展为 32位
// 然后再用 vaddq_s32 累加。或者直接用 vmlal_s16 (multiply and accumulate long)
// 假设 A_row 和 B_col 是 Q7.8 (16位),乘积是 Q15.16 (32位)
int32x4_t prod_low = vmull_s16(vget_low_s16(a_vec), vget_low_s16(b_vec)); // 乘法,结果是 32位
int32x4_t prod_high = vmull_s16(vget_high_s16(a_vec), vget_high_s16(b_vec));
sum_vec = vaddq_s32(sum_vec, prod_low);
sum_vec = vaddq_s32(sum_vec, prod_high);
}
// 将 4个 32位累加结果水平求和,得到一个 32位总和
int32_t final_sum_raw = vgetq_lane_s32(sum_vec, 0) +
vgetq_lane_s32(sum_vec, 1) +
vgetq_lane_s32(sum_vec, 2) +
vgetq_lane_s32(sum_vec, 3);
// 定标与饱和:将 Q15.16 的结果转换为 Q7.8
// 需要右移 8 位 (16 - 8)
// vqrshrn_n_s32 是 NEON 的饱和舍入右移指令,将 32位结果右移并饱和到 16位
// 这里的 final_sum_raw 是一个标量,需要手动处理
int shift_amount = FixedPoint<16,8,true>::s_fractional_bits; // 8
// 舍入 (round to nearest, half up)
int64_t rounded_shifted_val = (static_cast<int64_t>(final_sum_raw) + (1LL << (shift_amount - 1))) >> shift_amount;
// 饱和到 16位
if (rounded_shifted_val > std::numeric_limits<int16_t>::max()) {
*result_ptr = std::numeric_limits<int16_t>::max();
} else if (rounded_shifted_val < std::numeric_limits<int16_t>::min()) {
*result_ptr = std::numeric_limits<int16_t>::min();
} else {
*result_ptr = static_cast<int16_t>(rounded_shifted_val);
}
}
#endif // __ARM_NEON
这样的SIMD优化能够显著提升矩阵乘法的性能,特别是在处理较大矩阵时。它将原本顺序执行的指令,并行化到多个数据通道上,充分利用了硬件资源。
VI. 精度与性能的平衡与评估
在定点化过程中,精度和性能往往是一对矛盾。我们需要仔细评估并找到最佳平衡点。
精度分析:量化误差来源,SNR (信噪比)
定点化引入的误差主要来源于:
- 量化误差: 浮点数转换为定点数时,由于精度有限而产生的误差。
- **截