JS `WebGPU Compute Shaders` `Workgroup Memory` 与 `Global Memory` 优化

咳咳,各位听众朋友们,大家好!今天咱们来聊点硬核的,关于WebGPU里Compute Shaders的优化,特别是Workgroup Memory和Global Memory这俩兄弟。这俩货用好了,能让你的Compute Shader跑得飞起,用不好,那就是蜗牛爬,甚至直接原地爆炸。

咱们先来明确下概念,免得有人迷路。

什么是Compute Shader?

简单来说,Compute Shader就是WebGPU里用来做通用计算的,它能利用GPU的并行能力,处理各种各样的计算任务,比如图像处理、物理模拟、机器学习等等。你可以把它想象成一个超级强大的计算器,只不过这个计算器有很多很多个小计算器同时工作。

什么是Workgroup Memory?

Workgroup Memory,也叫Local Memory,是每个Workgroup里的线程共享的内存。它的特点是速度非常快,但是容量很小。你可以把它想象成一个每个小组内部的草稿纸,小组里的每个人都可以往上面写写画画,速度很快,但是地方不大。

什么是Global Memory?

Global Memory,也叫Device Memory,是所有线程都可以访问的内存。它的特点是容量很大,但是速度相对较慢。你可以把它想象成一块巨大的黑板,所有小组都可以往上面写写画画,地方很大,但是要排队,速度就慢了。

OK,概念清楚了,咱们就开始今天的正题:如何利用Workgroup Memory和Global Memory来优化Compute Shader。

优化原则:能用Workgroup Memory的,坚决不用Global Memory!

记住这个黄金法则!Workgroup Memory比Global Memory快得多,所以我们要尽可能地把数据放在Workgroup Memory里,减少对Global Memory的访问。

下面我们通过几个例子来说明如何进行优化。

例子1:矩阵乘法

矩阵乘法是Compute Shader里一个非常常见的操作,也是一个非常适合用Workgroup Memory优化的场景。

先来看看一个简单的矩阵乘法的Compute Shader代码(未优化版本):

struct Matrix {
  values: array<f32>,
};

@group(0) @binding(0) var<storage, read> a : Matrix;
@group(0) @binding(1) var<storage, read> b : Matrix;
@group(0) @binding(2) var<storage, write> c : Matrix;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(local_invocation_id) local_id : vec3<u32>) {
  let row = global_id.x;
  let col = global_id.y;
  let a_width = 256u; // 假设矩阵A的宽度是256
  let b_width = 256u; // 假设矩阵B的宽度是256

  var sum = 0.0;
  for (var k = 0u; k < a_width; k++) {
    sum = sum + a.values[row * a_width + k] * b.values[k * b_width + col];
  }

  c.values[row * b_width + col] = sum;
}

这段代码看起来没啥问题,但是效率非常低。为什么?因为每个线程都要从Global Memory里读取矩阵A和矩阵B的元素,而Global Memory的速度又很慢,这就造成了瓶颈。

现在我们来用Workgroup Memory优化一下:

struct Matrix {
  values: array<f32>,
};

@group(0) @binding(0) var<storage, read> a : Matrix;
@group(0) @binding(1) var<storage, read> b : Matrix;
@group(0) @binding(2) var<storage, write> c : Matrix;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(local_invocation_id) local_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>) {
  let row = global_id.x;
  let col = global_id.y;
  let a_width = 256u; // 假设矩阵A的宽度是256
  let b_width = 256u; // 假设矩阵B的宽度是256
  let tile_size = 16u;

  var a_tile : array<array<f32, 16>, 16>;
  var b_tile : array<array<f32, 16>, 16>;

  // Load data into workgroup memory
  let a_tile_row = local_id.x;
  let a_tile_col = local_id.y;
  let b_tile_row = local_id.x;
  let b_tile_col = local_id.y;

  a_tile[a_tile_row][a_tile_col] = a.values[(workgroup_id.x * tile_size + a_tile_row) * a_width + (workgroup_id.y * tile_size + a_tile_col)];
  b_tile[b_tile_row][b_tile_col] = b.values[(workgroup_id.x * tile_size + b_tile_row) * b_width + (workgroup_id.y * tile_size + b_tile_col)];

  workgroupBarrier(); // Make sure all threads have loaded their data

  var sum = 0.0;
  for (var k = 0u; k < tile_size; k++) {
    sum = sum + a_tile[local_id.x][k] * b_tile[k][local_id.y];
  }

  c.values[row * b_width + col] = sum;
}

这段代码做了什么?

  1. 定义了a_tileb_tile两个Workgroup Memory里的二维数组,用来存储矩阵A和矩阵B的一个小块(tile)。
  2. 每个线程从Global Memory里读取自己负责的那个元素,放到a_tileb_tile里。
  3. 使用workgroupBarrier()函数,确保所有线程都完成了数据的加载。这个函数非常重要,它可以保证在进行下一步计算之前,Workgroup Memory里的数据是完整的。
  4. 在Workgroup Memory里进行矩阵乘法计算。

这样一来,每个线程只需要从Global Memory里读取一次数据,然后就可以在Workgroup Memory里进行多次计算,大大减少了对Global Memory的访问,提高了效率。

代码解释:

  • @builtin(workgroup_id) workgroup_id : vec3<u32>: 获取当前Workgroup的ID。
  • var a_tile : array<array<f32, 16>, 16>;: 在Workgroup Memory里定义一个16×16的二维数组。
  • workgroupBarrier();: 一个同步点,确保Workgroup里的所有线程都执行到这里才能继续往下执行。
  • a.values[(workgroup_id.x * tile_size + a_tile_row) * a_width + (workgroup_id.y * tile_size + a_tile_col)];: 计算线程需要读取的Global Memory的索引。

表格对比:

特性 未优化版本 (Global Memory) 优化版本 (Workgroup Memory)
内存访问频率
性能
代码复杂度

例子2:图像模糊

图像模糊也是一个常见的图像处理操作,同样可以使用Workgroup Memory进行优化。

先来看看一个简单的图像模糊的Compute Shader代码(未优化版本):

struct Image {
  width: u32,
  height: u32,
  pixels: array<f32>,
};

@group(0) @binding(0) var<storage, read> input_image : Image;
@group(0) @binding(1) var<storage, write> output_image : Image;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
  let x = global_id.x;
  let y = global_id.y;
  let width = input_image.width;
  let height = input_image.height;

  if (x >= width || y >= height) {
    return;
  }

  var sum = 0.0;
  var count = 0.0;
  for (var i = -1; i <= 1; i++) {
    for (var j = -1; j <= 1; j++) {
      let nx = i32(x) + i;
      let ny = i32(y) + j;

      if (nx >= 0 && nx < i32(width) && ny >= 0 && ny < i32(height)) {
        sum = sum + input_image.pixels[u32(nx) + u32(ny) * width];
        count = count + 1.0;
      }
    }
  }

  output_image.pixels[x + y * width] = sum / count;
}

这段代码的问题在于,每个线程都要多次访问Global Memory来读取周围像素的值,而这些周围像素的值很可能被相邻的线程重复读取,造成了浪费。

现在我们来用Workgroup Memory优化一下:

struct Image {
  width: u32,
  height: u32,
  pixels: array<f32>,
};

@group(0) @binding(0) var<storage, read> input_image : Image;
@group(0) @binding(1) var<storage, write> output_image : Image;

@compute @workgroup_size(16, 16)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
        @builtin(local_invocation_id) local_id : vec3<u32>,
        @builtin(workgroup_id) workgroup_id : vec3<u32>) {
  let x = global_id.x;
  let y = global_id.y;
  let width = input_image.width;
  let height = input_image.height;
  let tile_size = 16u;

  if (x >= width || y >= height) {
    return;
  }

  var tile : array<array<f32, 18>, 18>; // 16x16 + 2 pixels border

  // Load data into workgroup memory
  let tile_x = local_id.x + 1;
  let tile_y = local_id.y + 1;

  // Handle boundary conditions
  let global_x = workgroup_id.x * tile_size + local_id.x;
  let global_y = workgroup_id.y * tile_size + local_id.y;

  // Read with boundary check
  var pixel_value = 0.0;
  if (global_x < width && global_y < height) {
    pixel_value = input_image.pixels[global_x + global_y * width];
  }

  tile[tile_x][tile_y] = pixel_value;

  // Load the border pixels
  if (local_id.x == 0) {
    let left_global_x = i32(global_x) - 1;
    let left_global_y = global_y;
    var left_pixel_value = 0.0;

    if (left_global_x >= 0 && left_global_y < height) {
      left_pixel_value = input_image.pixels[u32(left_global_x) + left_global_y * width];
    }
    tile[0][tile_y] = left_pixel_value;
  }

  if (local_id.x == 15) {
    let right_global_x = global_x + 1;
    let right_global_y = global_y;
    var right_pixel_value = 0.0;
    if (right_global_x < width && right_global_y < height) {
      right_pixel_value = input_image.pixels[right_global_x + right_global_y * width];
    }
    tile[17][tile_y] = right_pixel_value;
  }

  if (local_id.y == 0) {
    let top_global_x = global_x;
    let top_global_y = i32(global_y) - 1;
    var top_pixel_value = 0.0;
    if (top_global_x < width && top_global_y >= 0) {
      top_pixel_value = input_image.pixels[top_global_x + u32(top_global_y) * width];
    }
    tile[tile_x][0] = top_pixel_value;
  }

  if (local_id.y == 15) {
    let bottom_global_x = global_x;
    let bottom_global_y = global_y + 1;
    var bottom_pixel_value = 0.0;
    if (bottom_global_x < width && bottom_global_y < height) {
      bottom_pixel_value = input_image.pixels[bottom_global_x + bottom_global_y * width];
    }
    tile[tile_x][17] = bottom_pixel_value;
  }

  workgroupBarrier(); // Make sure all threads have loaded their data

  var sum = 0.0;
  var count = 0.0;
  for (var i = -1; i <= 1; i++) {
    for (var j = -1; j <= 1; j++) {
      sum = sum + tile[tile_x + i][tile_y + j];
      count = count + 1.0;
    }
  }

  output_image.pixels[x + y * width] = sum / count;
}

这段代码做了什么?

  1. 定义了一个18×18的Workgroup Memory里的二维数组tile,用来存储当前Workgroup负责的16×16的像素以及周围一圈像素(为了计算模糊)。
  2. 每个线程从Global Memory里读取自己负责的像素以及周围的像素,放到tile里。这里需要处理边界情况,确保不会越界访问。
  3. 使用workgroupBarrier()函数,确保所有线程都完成了数据的加载。
  4. 在Workgroup Memory里进行模糊计算。

这样一来,每个线程只需要从Global Memory里读取有限的几个像素,然后就可以在Workgroup Memory里进行多次计算,大大减少了对Global Memory的访问,提高了效率。

代码解释:

  • var tile : array<array<f32, 18>, 18>;: 在Workgroup Memory里定义一个18×18的二维数组。
  • tile[tile_x][tile_y] = input_image.pixels[global_x + global_y * width];: 将Global Memory里的像素值加载到Workgroup Memory里。
  • 处理边界像素的代码是用来读取边缘像素,保证模糊计算的正确性。

表格对比:

特性 未优化版本 (Global Memory) 优化版本 (Workgroup Memory)
内存访问频率
性能
代码复杂度

Workgroup Memory的局限性

虽然Workgroup Memory很强大,但是它也有一些局限性:

  1. 容量限制: Workgroup Memory的容量非常有限,通常只有几十KB,所以不能存储太多的数据。
  2. 同步问题: 在使用Workgroup Memory时,需要特别注意同步问题,确保所有线程都完成了数据的加载和计算,才能进行下一步操作。workgroupBarrier()函数是解决同步问题的关键。

Global Memory的优化

虽然我们一直强调要减少对Global Memory的访问,但是有些情况下,我们不得不使用Global Memory。这时候,我们也可以通过一些技巧来优化Global Memory的访问。

  1. 合并访问: 尽量让相邻的线程访问相邻的内存地址,这样可以提高Global Memory的访问效率。
  2. 使用缓存: WebGPU会自动缓存Global Memory里的数据,所以我们可以尽量重复使用Global Memory里的数据,减少不必要的访问。

总结

Workgroup Memory和Global Memory是Compute Shader里非常重要的两个概念,合理地利用它们可以大大提高Compute Shader的性能。记住,能用Workgroup Memory的,坚决不用Global Memory!同时,也要注意Workgroup Memory的局限性,以及Global Memory的优化技巧。

最后,希望今天的讲座对大家有所帮助。记住,优化是一个持续的过程,需要不断地尝试和改进。祝大家写出高性能的WebGPU Compute Shader! 散会!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注