C++ `std::latch` 与 `std::barrier` 高级用法:复杂同步场景

好的,咱们今天来聊聊C++里两个挺有意思的同步工具:std::latchstd::barrier。这俩家伙,单看名字可能觉得挺高大上,但其实用好了,能让你的并发程序更优雅、更可控。

开场白:并发世界的坑和甜头

话说,并发编程就像是同时耍好几个盘子。耍好了,效率嗖嗖地往上涨;耍不好,盘子噼里啪啦碎一地,debug到怀疑人生。所以,我们需要一些“魔术道具”来保证盘子不掉,std::latchstd::barrier 就是其中两种。

第一幕:std::latch – “关门放狗”

std::latch,你可以把它想象成一个门闩。一开始门是打开的,你可以设置一个计数器,代表需要多少个“人”来把门闩上。每来一个人,计数器减一。当计数器归零时,门闩就彻底锁死,后面的“狗”(指那些等待的线程)就可以放出来了。

基本用法:

#include <iostream>
#include <thread>
#include <latch>

int main() {
  std::latch door(3); // 门闩需要3个人才能锁上

  auto worker = [&](int id) {
    std::cout << "Worker " << id << " arrived." << std::endl;
    door.count_down(); // 计数器减一
    door.wait(); // 等待门闩锁死
    std::cout << "Worker " << id << " is running!" << std::endl;
  };

  std::thread t1(worker, 1);
  std::thread t2(worker, 2);
  std::thread t3(worker, 3);

  t1.join();
  t2.join();
  t3.join();

  std::cout << "All workers finished." << std::endl;

  return 0;
}

代码解读:

  1. std::latch door(3);: 创建一个 latch 对象,初始计数器为3。
  2. door.count_down();:每个worker线程到达后,调用 count_down(),计数器减一。
  3. door.wait();:每个worker线程调用 wait(),如果计数器不为零,线程会阻塞在这里,直到计数器归零。

运行结果:

Worker 1 arrived.
Worker 2 arrived.
Worker 3 arrived.
Worker 1 is running!
Worker 2 is running!
Worker 3 is running!
All workers finished.

进阶用法:确保初始化完成

std::latch 最常见的用法之一就是确保所有线程都完成初始化,然后再开始执行主逻辑。

#include <iostream>
#include <thread>
#include <vector>
#include <latch>

std::vector<int> data;
std::latch init_latch(5); // 假设有5个线程需要初始化

void initialize_data(int thread_id) {
  // 模拟初始化过程
  std::this_thread::sleep_for(std::chrono::milliseconds(100 * thread_id));
  data.push_back(thread_id);
  std::cout << "Thread " << thread_id << " initialized." << std::endl;
  init_latch.count_down();
}

void process_data(int thread_id) {
  init_latch.wait(); // 等待所有线程初始化完成
  std::cout << "Thread " << thread_id << " is processing data." << std::endl;
  // ... 对 data 进行处理 ...
}

int main() {
  std::vector<std::thread> threads;
  for (int i = 0; i < 5; ++i) {
    threads.emplace_back(initialize_data, i);
  }

  for (int i = 0; i < 5; ++i) {
    threads.emplace_back(process_data, i);
  }

  for (auto& t : threads) {
    t.join();
  }

  std::cout << "All threads finished." << std::endl;
  return 0;
}

代码解读:

  1. initialize_data 函数模拟了初始化过程,每个线程初始化完成后,调用 init_latch.count_down()
  2. process_data 函数在执行数据处理之前,调用 init_latch.wait(),确保所有线程都完成了初始化。

std::latch 的特点:

  • 一次性使用: 计数器一旦归零,就不能再重置了。 也就是说门闩一旦锁死,就不能再打开了。
  • 简单易用: 使用起来非常简单,只需要 count_down()wait() 两个函数。

第二幕:std::barrier – “集体舞步”

std::barrier,你可以把它想象成一个舞台上的“集体舞”。所有舞者(线程)必须同时到达舞台上的某个特定位置(barrier),然后才能一起开始跳舞。如果有的舞者提前到了,就必须在原地等待,直到所有舞者都到齐。

基本用法:

#include <iostream>
#include <thread>
#include <barrier>

int main() {
  std::barrier sync_point(3, []() { // 3个线程,completion function
    std::cout << "All threads reached the barrier!" << std::endl;
  });

  auto worker = [&](int id) {
    std::cout << "Worker " << id << " is working..." << std::endl;
    std::this_thread::sleep_for(std::chrono::milliseconds(100 * id)); // 模拟不同线程的工作时间
    std::cout << "Worker " << id << " reached the barrier." << std::endl;
    sync_point.arrive_and_wait(); // 到达barrier并等待其他线程
    std::cout << "Worker " << id << " continues working." << std::endl;
  };

  std::thread t1(worker, 1);
  std::thread t2(worker, 2);
  std::thread t3(worker, 3);

  t1.join();
  t2.join();
  t3.join();

  std::cout << "All workers finished." << std::endl;

  return 0;
}

代码解读:

  1. std::barrier sync_point(3, []() { ... });:创建一个 barrier 对象,需要 3 个线程同步。 第二个参数是一个 completion function,当所有线程都到达 barrier 时,这个函数会被执行。
  2. sync_point.arrive_and_wait();:每个 worker 线程到达 barrier 后,调用 arrive_and_wait(),线程会阻塞在这里,直到所有线程都到达 barrier。

运行结果:

Worker 1 is working...
Worker 2 is working...
Worker 3 is working...
Worker 1 reached the barrier.
Worker 2 reached the barrier.
Worker 3 reached the barrier.
All threads reached the barrier!
Worker 1 continues working.
Worker 2 continues working.
Worker 3 continues working.
All workers finished.

进阶用法:多阶段计算

std::barrier 非常适合用于多阶段计算,例如图像处理、机器学习等。每个阶段都需要所有线程同步完成,才能进入下一个阶段。

#include <iostream>
#include <thread>
#include <vector>
#include <barrier>

const int NUM_THREADS = 4;
const int NUM_STAGES = 3;

std::vector<int> data(100); // 共享数据

std::barrier stage_barrier(NUM_THREADS, []() {
  std::cout << "Entering next stage..." << std::endl;
});

void worker(int thread_id) {
  for (int stage = 0; stage < NUM_STAGES; ++stage) {
    // 模拟每个阶段的计算
    std::cout << "Thread " << thread_id << " is working on stage " << stage << std::endl;
    for (int i = thread_id * 25; i < (thread_id + 1) * 25; ++i) {
      data[i] += stage; // 简单的数据处理
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(50 * thread_id)); // 模拟不同线程的工作时间

    stage_barrier.arrive_and_wait(); // 等待所有线程完成当前阶段
  }
}

int main() {
  std::vector<std::thread> threads;
  for (int i = 0; i < NUM_THREADS; ++i) {
    threads.emplace_back(worker, i);
  }

  for (auto& t : threads) {
    t.join();
  }

  std::cout << "All threads finished." << std::endl;
  return 0;
}

代码解读:

  1. stage_barrier 用于在每个阶段结束后同步所有线程。
  2. worker 函数模拟了每个线程在每个阶段的工作。
  3. stage_barrier.arrive_and_wait() 确保所有线程都完成了当前阶段的计算,才能进入下一个阶段。

Completion Function 的妙用

std::barrier 的构造函数可以接受一个 completion function。这个函数会在所有线程都到达 barrier 时被调用。 你可以在这个函数里做一些全局性的操作,比如更新共享数据、打印日志等等。

#include <iostream>
#include <thread>
#include <barrier>

int main() {
  int shared_value = 0;

  std::barrier sync_point(3, [&]() {
    shared_value++; // 所有线程到达 barrier 后,shared_value 加一
    std::cout << "Shared value is now: " << shared_value << std::endl;
  });

  auto worker = [&](int id) {
    std::cout << "Worker " << id << " is working..." << std::endl;
    sync_point.arrive_and_wait();
    std::cout << "Worker " << id << " continues working." << std::endl;
  };

  std::thread t1(worker, 1);
  std::thread t2(worker, 2);
  std::thread t3(worker, 3);

  t1.join();
  t2.join();
  t3.join();

  std::cout << "All workers finished." << std::endl;

  return 0;
}

std::barrier 的特点:

  • 可重复使用: 计数器可以重置,可以用于多个阶段的同步。
  • 灵活的 Completion Function: 可以在所有线程到达 barrier 时执行一些额外的操作。

第三幕:std::latch vs std::barrier – “门闩” vs “舞台”

现在,我们来对比一下 std::latchstd::barrier,看看它们各自的适用场景。

特性 std::latch std::barrier
用途 一次性同步,等待所有线程到达 多阶段同步,周期性同步所有线程
计数器 递减到零,不可重置 递减到零,自动重置
Completion Function 有,在所有线程到达 barrier 时执行
适用场景 初始化完成,单次同步 多阶段计算,迭代算法
类比 关门放狗 集体舞步

总结:

  • 如果你只需要一次性地等待所有线程完成某个任务,例如初始化,那么 std::latch 是一个不错的选择。
  • 如果你需要周期性地同步所有线程,例如多阶段计算,那么 std::barrier 更适合你。

代码示例:组合使用 std::latchstd::barrier

#include <iostream>
#include <thread>
#include <vector>
#include <latch>
#include <barrier>

const int NUM_THREADS = 4;
const int NUM_STAGES = 3;

std::vector<int> data(100);

std::latch init_latch(NUM_THREADS); // 用于初始化同步
std::barrier stage_barrier(NUM_THREADS, []() {
  std::cout << "Entering next stage..." << std::endl;
});

void worker(int thread_id) {
  // 初始化阶段
  std::cout << "Thread " << thread_id << " is initializing..." << std::endl;
  std::this_thread::sleep_for(std::chrono::milliseconds(50 * thread_id));
  init_latch.count_down();
  init_latch.wait(); // 等待所有线程初始化完成

  // 多阶段计算
  for (int stage = 0; stage < NUM_STAGES; ++stage) {
    std::cout << "Thread " << thread_id << " is working on stage " << stage << std::endl;
    for (int i = thread_id * 25; i < (thread_id + 1) * 25; ++i) {
      data[i] += stage + thread_id;
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(50 * thread_id));
    stage_barrier.arrive_and_wait(); // 等待所有线程完成当前阶段
  }
}

int main() {
  std::vector<std::thread> threads;
  for (int i = 0; i < NUM_THREADS; ++i) {
    threads.emplace_back(worker, i);
  }

  for (auto& t : threads) {
    t.join();
  }

  std::cout << "All threads finished." << std::endl;
  return 0;
}

代码解读:

  1. init_latch 用于确保所有线程都完成了初始化,然后才能开始多阶段计算。
  2. stage_barrier 用于在每个阶段结束后同步所有线程。

第四幕:高级用法 – 异常处理和超时

并发编程里,异常处理是个大问题。如果一个线程抛出了异常,可能会导致其他线程一直阻塞,或者数据不一致。 让我们看看怎么用 std::latchstd::barrier 来处理异常。

std::latch 和异常处理:

#include <iostream>
#include <thread>
#include <latch>
#include <exception>

std::latch my_latch(2);

void worker(int id) {
  try {
    if (id == 1) {
      throw std::runtime_error("Something went wrong in thread 1!");
    }
    std::cout << "Worker " << id << " is working..." << std::endl;
  } catch (const std::exception& e) {
    std::cerr << "Exception in worker " << id << ": " << e.what() << std::endl;
  }
  my_latch.count_down();
}

int main() {
  std::thread t1(worker, 1);
  std::thread t2(worker, 2);

  t1.join();
  t2.join();

  my_latch.wait(); // 等待所有线程完成 (即使有异常)

  std::cout << "Main thread continues." << std::endl;
  return 0;
}

在这个例子中,即使线程 1 抛出了异常,my_latch.count_down() 仍然会被调用,所以 my_latch.wait() 最终会返回,主线程可以继续执行。

std::barrier 和异常处理:

std::barrier 的异常处理稍微复杂一些,因为如果一个线程在 arrive_and_wait() 之前抛出了异常,可能会导致其他线程永远阻塞。

#include <iostream>
#include <thread>
#include <barrier>
#include <exception>

std::barrier my_barrier(2);

void worker(int id) {
  try {
    if (id == 1) {
      throw std::runtime_error("Something went wrong in thread 1!");
    }
    std::cout << "Worker " << id << " is working..." << std::endl;
    my_barrier.arrive_and_wait();
    std::cout << "Worker " << id << " continues..." << std::endl;
  } catch (const std::exception& e) {
    std::cerr << "Exception in worker " << id << ": " << e.what() << std::endl;
    // 关键:调用 arrive_and_drop,让 barrier 继续工作
    my_barrier.arrive_and_drop();
  }
}

int main() {
  std::thread t1(worker, 1);
  std::thread t2(worker, 2);

  t1.join();
  t2.join();

  std::cout << "Main thread continues." << std::endl;
  return 0;
}

关键点:

  • arrive_and_drop() 如果线程在 arrive_and_wait() 之前抛出了异常,应该调用 arrive_and_drop()。这个函数会让 barrier 忽略这个线程,并继续工作。

超时等待:

std::latchstd::barrier 都可以设置超时时间,避免线程永远阻塞。

std::latch 超时:

#include <iostream>
#include <thread>
#include <latch>
#include <chrono>

int main() {
  std::latch my_latch(1);
  std::cout << "Waiting for latch..." << std::endl;
  if (my_latch.wait_for(std::chrono::seconds(2))) {
    std::cout << "Latch released!" << std::endl;
  } else {
    std::cout << "Timeout waiting for latch!" << std::endl;
  }
  return 0;
}

std::barrier 超时:

std::barrier 没有直接的超时等待函数。但是,你可以使用 std::condition_variablestd::mutex 来实现类似的功能。

第五幕:真实案例 – 并行图像处理

假设我们要并行处理一张图像,可以使用 std::barrier 将处理过程分成多个阶段:

  1. 读取图像数据: 每个线程读取图像的一部分数据。
  2. 图像滤波: 每个线程对自己的数据进行滤波处理。
  3. 图像合成: 将所有线程处理后的数据合成一张完整的图像。
#include <iostream>
#include <thread>
#include <vector>
#include <barrier>

// 假设这是一个简单的图像数据结构
struct Image {
  int width;
  int height;
  std::vector<int> pixels; // 简化:每个像素用一个整数表示
};

// 模拟读取图像数据
void read_image_data(Image& image, int thread_id, int num_threads) {
  // ... (根据 thread_id 分配读取任务) ...
  std::cout << "Thread " << thread_id << " read image data." << std::endl;
}

// 模拟图像滤波
void apply_filter(Image& image, int thread_id, int num_threads) {
  // ... (根据 thread_id 分配滤波任务) ...
  std::cout << "Thread " << thread_id << " applied filter." << std::endl;
}

// 模拟图像合成
void compose_image(Image& image, int thread_id, int num_threads) {
  // ... (根据 thread_id 分配合成任务) ...
  std::cout << "Thread " << thread_id << " composed image data." << std::endl;
}

void process_image(Image& image, int thread_id, int num_threads, std::barrier& barrier) {
  read_image_data(image, thread_id, num_threads);
  barrier.arrive_and_wait();

  apply_filter(image, thread_id, num_threads);
  barrier.arrive_and_wait();

  compose_image(image, thread_id, num_threads);
  barrier.arrive_and_wait();
}

int main() {
  const int NUM_THREADS = 4;
  Image my_image{1024, 768, std::vector<int>(1024 * 768)}; // 初始化图像数据

  std::barrier image_barrier(NUM_THREADS, []() {
    std::cout << "Image processing stage completed." << std::endl;
  });

  std::vector<std::thread> threads;
  for (int i = 0; i < NUM_THREADS; ++i) {
    threads.emplace_back(process_image, std::ref(my_image), i, NUM_THREADS, std::ref(image_barrier));
  }

  for (auto& t : threads) {
    t.join();
  }

  std::cout << "Image processing complete." << std::endl;
  return 0;
}

总结:

std::latchstd::barrier 是 C++ 并发编程中非常有用的同步工具。 std::latch 适用于一次性同步,而 std::barrier 适用于多阶段同步。 通过合理地使用它们,你可以编写出更高效、更可靠的并发程序。 记住,并发编程是一门艺术,需要不断地学习和实践才能掌握。

希望今天的讲解对你有所帮助! 祝你编程愉快!

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注