手写一个高性能的Java线程池:参数调优、任务窃取(Work Stealing)实现
大家好,今天我们来深入探讨如何手写一个高性能的Java线程池,并重点关注参数调优和任务窃取(Work Stealing)的实现。线程池是并发编程中至关重要的组件,它可以有效地管理线程资源,提高程序的性能和稳定性。虽然Java提供了ExecutorService
接口和ThreadPoolExecutor
类,但了解其内部机制并能够自定义线程池,可以让我们更好地掌控并发行为,针对特定场景进行优化。
1. 线程池的基本原理
线程池的核心思想是复用线程。它维护一个线程集合,当有任务需要执行时,从线程池中取出一个空闲线程来执行任务,而不是每次都创建新的线程。任务执行完毕后,线程并不销毁,而是返回到线程池中,等待执行下一个任务。
线程池通常包含以下几个关键组件:
- 任务队列 (Task Queue): 用于存放等待执行的任务。常见的任务队列有
ArrayBlockingQueue
(有界阻塞队列)、LinkedBlockingQueue
(无界阻塞队列)、PriorityBlockingQueue
(优先级队列)等。 - 线程管理器 (Thread Manager): 负责创建、销毁和管理线程。
- 工作线程 (Worker Threads): 执行任务的线程。
- 拒绝策略 (Rejected Execution Handler): 当任务队列已满且线程池中的线程都在忙碌时,用于处理新提交的任务。
2. 手写线程池的基本框架
首先,我们来构建一个基本的线程池框架。这个框架包含核心的线程池管理逻辑,但不包含任务窃取功能,后续我们会逐步完善。
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
public class SimpleThreadPool {
private final int corePoolSize;
private final int maximumPoolSize;
private final long keepAliveTime;
private final BlockingQueue<Runnable> workQueue;
private final ThreadFactory threadFactory;
private final RejectedExecutionHandler rejectedExecutionHandler;
private final Worker[] workers;
private final AtomicInteger workerCount = new AtomicInteger(0);
public SimpleThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime,
BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory,
RejectedExecutionHandler rejectedExecutionHandler) {
this.corePoolSize = corePoolSize;
this.maximumPoolSize = maximumPoolSize;
this.keepAliveTime = keepAliveTime;
this.workQueue = workQueue;
this.threadFactory = threadFactory;
this.rejectedExecutionHandler = rejectedExecutionHandler;
this.workers = new Worker[maximumPoolSize]; //初始化worker数组
initializeWorkers();
}
private void initializeWorkers() {
for (int i = 0; i < corePoolSize; i++) {
addWorker(null, false); // 创建核心线程
}
}
public void execute(Runnable task) {
if (task == null) {
throw new NullPointerException();
}
if (workerCount.get() < corePoolSize) {
if (addWorker(task, true)) {
return;
}
}
if (!workQueue.offer(task)) {
if (!addWorker(task, false)) {
rejectedExecutionHandler.rejectedExecution(task, this);
}
}
}
private boolean addWorker(Runnable firstTask, boolean core) {
int wc = workerCount.get();
if (wc >= maximumPoolSize || (core && wc >= corePoolSize)) {
return false;
}
Worker worker = new Worker(firstTask);
Thread thread = worker.thread;
if (thread == null) {
return false; // 创建线程失败
}
workers[wc] = worker;
workerCount.incrementAndGet();
thread.start();
return true;
}
private final class Worker implements Runnable {
private final Thread thread;
private Runnable firstTask;
Worker(Runnable firstTask) {
this.firstTask = firstTask;
this.thread = threadFactory.newThread(this);
}
@Override
public void run() {
Runnable task = firstTask;
try {
while (task != null || (task = getTask()) != null) {
try {
task.run();
} catch (Throwable t) {
// 处理任务执行期间的异常
t.printStackTrace();
} finally {
task = null;
}
}
} finally {
workerDone(this);
}
}
private Runnable getTask() {
try {
return workQueue.take(); // 阻塞式获取任务
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return null;
}
}
}
private void workerDone(Worker w) {
workerCount.decrementAndGet();
removeWorker(w);
if(workerCount.get() < corePoolSize) {
addWorker(null, false);
}
}
private void removeWorker(Worker w) {
for (int i = 0; i < workers.length; i++) {
if (workers[i] == w) {
workers[i] = null;
break;
}
}
}
// 默认的线程工厂
public static class DefaultThreadFactory implements ThreadFactory {
private static final AtomicInteger poolNumber = new AtomicInteger(1);
private final ThreadGroup group;
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final String namePrefix;
DefaultThreadFactory() {
SecurityManager s = System.getSecurityManager();
group = (s != null) ? s.getThreadGroup() :
Thread.currentThread().getThreadGroup();
namePrefix = "pool-" + poolNumber.getAndIncrement() + "-thread-";
}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(group, r,
namePrefix + threadNumber.getAndIncrement(),
0);
if (t.isDaemon())
t.setDaemon(false);
if (t.getPriority() != Thread.NORM_PRIORITY)
t.setPriority(Thread.NORM_PRIORITY);
return t;
}
}
// 默认的拒绝策略
public static class DefaultRejectedExecutionHandler implements RejectedExecutionHandler {
@Override
public void rejectedExecution(Runnable r, SimpleThreadPool executor) {
throw new RuntimeException("Task " + r.toString() +
" rejected from " +
executor.toString());
}
}
@Override
public String toString() {
return "SimpleThreadPool{" +
"corePoolSize=" + corePoolSize +
", maximumPoolSize=" + maximumPoolSize +
", keepAliveTime=" + keepAliveTime +
", workQueue=" + workQueue +
", workerCount=" + workerCount +
'}';
}
public static void main(String[] args) {
SimpleThreadPool threadPool = new SimpleThreadPool(
5, // corePoolSize
10, // maximumPoolSize
60, // keepAliveTime (seconds)
new LinkedBlockingQueue<>(100), // workQueue
new DefaultThreadFactory(), // threadFactory
new DefaultRejectedExecutionHandler() // rejectedExecutionHandler
);
for (int i = 0; i < 20; i++) {
final int taskNumber = i;
threadPool.execute(() -> {
System.out.println("Task " + taskNumber + " is running in thread: " + Thread.currentThread().getName());
try {
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
}
}
这个例子实现了一个简单的线程池,包括了核心线程数,最大线程数,任务队列,线程工厂和拒绝策略。
3. 线程池参数调优
线程池的性能很大程度上取决于其参数的配置。合理的参数设置可以提高线程池的吞吐量、降低延迟,并避免资源浪费。
corePoolSize
(核心线程数): 线程池中保持存活的最小线程数。即使线程处于空闲状态,也不会被回收,除非设置了allowCoreThreadTimeOut
。 如果任务提交速度很快,且任务执行时间很短,可以适当增加corePoolSize
,减少线程创建和销毁的开销。maximumPoolSize
(最大线程数): 线程池中允许存在的最大线程数。当任务队列已满,且当前线程数小于maximumPoolSize
时,线程池会创建新的线程来执行任务。 如果任务是CPU密集型的,maximumPoolSize
可以设置得与CPU核心数相等,甚至略微大于CPU核心数。 如果任务是IO密集型的,maximumPoolSize
可以设置得更大,因为线程在等待IO时不会占用CPU。keepAliveTime
(线程空闲存活时间): 当线程池中的线程数量大于corePoolSize
时,多余的空闲线程在超过keepAliveTime
时间后会被终止。 长时间运行的应用程序可以适当增加keepAliveTime
,避免频繁创建和销毁线程。 对于需要快速响应的应用程序,可以适当缩短keepAliveTime
,及时释放资源。workQueue
(任务队列): 用于存放等待执行的任务。ArrayBlockingQueue
:有界队列,可以防止任务过多导致内存溢出。适用于任务数量可控的场景。LinkedBlockingQueue
:无界队列,如果任务提交速度远大于处理速度,可能导致OOM。PriorityBlockingQueue
:优先级队列,可以根据任务的优先级来执行任务。适用于需要保证高优先级任务优先执行的场景。 选择合适的队列类型需要根据实际的应用场景和任务特性进行权衡。ThreadFactory
(线程工厂): 用于创建新的线程。 可以自定义ThreadFactory
来设置线程的名称、优先级、是否为守护线程等。 默认的DefaultThreadFactory
会创建一个非守护线程,并设置线程的优先级为Thread.NORM_PRIORITY
。RejectedExecutionHandler
(拒绝策略): 当任务队列已满且线程池中的线程都在忙碌时,用于处理新提交的任务。AbortPolicy
:直接抛出RejectedExecutionException
异常。CallerRunsPolicy
:由提交任务的线程来执行该任务。DiscardPolicy
:直接丢弃该任务,不抛出异常。DiscardOldestPolicy
:丢弃队列中最老的任务,然后尝试将新任务加入队列。 可以自定义RejectedExecutionHandler
来处理被拒绝的任务,例如记录日志、持久化任务等。
参数调优示例:
假设我们有一个Web服务器,需要处理大量的HTTP请求。每个请求的处理包括读取数据库、处理业务逻辑和返回响应。
- CPU核心数: 8
- 平均请求处理时间: 100ms
- IO等待时间: 80ms (假设大部分时间都在等待数据库响应)
在这种情况下,我们可以考虑以下参数配置:
参数 | 值 | 说明 |
---|---|---|
corePoolSize |
8 | 与CPU核心数相等,保证CPU的利用率。 |
maximumPoolSize |
32 | 由于存在大量的IO等待时间,可以适当增加maximumPoolSize ,提高并发处理能力。 经验公式: 最佳线程数 = CPU核心数 * (1 + IO耗时 / CPU耗时) , 这里是 8 * (1 + 80/20) = 40 考虑到实际情况,设置为32。 |
keepAliveTime |
60 秒 | 保持线程存活一段时间,避免频繁创建和销毁线程。 |
workQueue |
LinkedBlockingQueue(1000) |
使用有界队列,防止任务过多导致内存溢出。队列的大小需要根据实际的请求量和处理速度进行调整。 |
ThreadFactory |
DefaultThreadFactory |
使用默认的线程工厂。 |
RejectedExecutionHandler |
CallerRunsPolicy |
当任务队列已满时,由提交任务的线程来执行该任务,避免丢失请求。 |
调优注意事项:
- 监控: 在调整线程池参数之前,需要对线程池的运行状态进行监控,例如线程池的活跃线程数、任务队列的长度、拒绝任务的数量等。
- 测试: 在调整线程池参数之后,需要进行压力测试,验证参数的调整是否达到了预期的效果。
- 逐步调整: 避免一次性调整过多的参数,而是应该逐步调整,每次调整一个参数,并观察其对性能的影响。
4. 任务窃取(Work Stealing)的实现
任务窃取是一种用于优化线程池性能的技术,尤其是在工作负载不均衡的情况下。其核心思想是:当一个线程的任务队列为空时,它可以从其他线程的任务队列中“窃取”任务来执行,从而提高整体的资源利用率。
任务窃取的原理:
- 每个线程拥有自己的任务队列(通常是双端队列
Deque
)。 - 线程优先从自己的队列头部获取任务执行(LIFO,Last-In-First-Out)。 这样可以提高缓存命中率。
- 当线程自己的队列为空时,从其他线程的队列尾部窃取任务执行(FIFO,First-In-First-Out)。 这样可以减少任务饥饿的概率。
实现步骤:
- 将任务队列改为
Deque
。 我们需要使用双端队列,以便从队列头部获取任务,从队列尾部窃取任务。 - 修改
Worker
类。Worker
类需要增加窃取任务的逻辑。 - 增加窃取任务的方法。 实现一个
stealTask()
方法,用于从其他线程的队列中窃取任务。
代码实现:
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class WorkStealingThreadPool {
private final int corePoolSize;
private final int maximumPoolSize;
private final long keepAliveTime;
private final ThreadFactory threadFactory;
private final RejectedExecutionHandler rejectedExecutionHandler;
private final Worker[] workers;
private final AtomicInteger workerCount = new AtomicInteger(0);
private final Random random = new Random();
public WorkStealingThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime,
ThreadFactory threadFactory,
RejectedExecutionHandler rejectedExecutionHandler) {
this.corePoolSize = corePoolSize;
this.maximumPoolSize = maximumPoolSize;
this.keepAliveTime = keepAliveTime;
this.threadFactory = threadFactory;
this.rejectedExecutionHandler = rejectedExecutionHandler;
this.workers = new Worker[maximumPoolSize];
initializeWorkers();
}
private void initializeWorkers() {
for (int i = 0; i < corePoolSize; i++) {
addWorker(null, false);
}
}
public void execute(Runnable task) {
if (task == null) {
throw new NullPointerException();
}
int workerIndex = random.nextInt(workerCount.get()); // 随机选择一个worker
Worker worker = workers[workerIndex];
if (worker != null && worker.offer(task)) {
return; // 优先放入随机worker的队列
}
if (workerCount.get() < maximumPoolSize) {
if (addWorker(task, true)) {
return;
}
}
rejectedExecutionHandler.rejectedExecution(task, this); // 实在不行,拒绝任务
}
private boolean addWorker(Runnable firstTask, boolean core) {
int wc = workerCount.get();
if (wc >= maximumPoolSize || (core && wc >= corePoolSize)) {
return false;
}
Worker worker = new Worker(firstTask);
Thread thread = worker.thread;
if (thread == null) {
return false;
}
workers[wc] = worker;
workerCount.incrementAndGet();
thread.start();
return true;
}
private final class Worker implements Runnable {
private final Thread thread;
private final Deque<Runnable> taskQueue = new ArrayDeque<>(); // 使用双端队列
Worker(Runnable firstTask) {
if (firstTask != null) {
taskQueue.offer(firstTask);
}
this.thread = threadFactory.newThread(this);
}
public boolean offer(Runnable task) {
return taskQueue.offer(task);
}
@Override
public void run() {
try {
while (true) {
Runnable task = pollTask(); // 优先从自己的队列获取任务
if (task != null) {
try {
task.run();
} catch (Throwable t) {
t.printStackTrace();
}
continue;
}
task = stealTask(); // 尝试窃取任务
if (task != null) {
try {
task.run();
} catch (Throwable t) {
t.printStackTrace();
}
continue;
}
if (workerCount.get() > corePoolSize) { // 如果线程数量大于核心线程数,且没有任务可做,则退出
if (taskQueue.isEmpty()) {
workerCount.decrementAndGet();
removeWorker(this);
return;
}
}
try {
Thread.sleep(1); // 避免空循环消耗CPU
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
} finally {
removeWorker(this);
}
}
private Runnable pollTask() {
return taskQueue.pollFirst(); // 从队列头部获取任务
}
private Runnable stealTask() {
List<Worker> otherWorkers = IntStream.range(0, workers.length)
.mapToObj(i -> workers[i])
.filter(w -> w != this && w != null)
.collect(Collectors.toList());
if (otherWorkers.isEmpty()) {
return null;
}
int index = random.nextInt(otherWorkers.size());
Worker victim = otherWorkers.get(index);
return victim.taskQueue.pollLast(); // 从队列尾部窃取任务
}
}
private void removeWorker(Worker w) {
for (int i = 0; i < workers.length; i++) {
if (workers[i] == w) {
workers[i] = null;
break;
}
}
}
// 默认的线程工厂
public static class DefaultThreadFactory implements ThreadFactory {
private static final AtomicInteger poolNumber = new AtomicInteger(1);
private final ThreadGroup group;
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final String namePrefix;
DefaultThreadFactory() {
SecurityManager s = System.getSecurityManager();
group = (s != null) ? s.getThreadGroup() :
Thread.currentThread().getThreadGroup();
namePrefix = "pool-" + poolNumber.getAndIncrement() + "-thread-";
}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(group, r,
namePrefix + threadNumber.getAndIncrement(),
0);
if (t.isDaemon())
t.setDaemon(false);
if (t.getPriority() != Thread.NORM_PRIORITY)
t.setPriority(Thread.NORM_PRIORITY);
return t;
}
}
// 默认的拒绝策略
public static class DefaultRejectedExecutionHandler implements RejectedExecutionHandler {
@Override
public void rejectedExecution(Runnable r, WorkStealingThreadPool executor) {
throw new RuntimeException("Task " + r.toString() +
" rejected from " +
executor.toString());
}
}
public static void main(String[] args) {
WorkStealingThreadPool threadPool = new WorkStealingThreadPool(
2, // corePoolSize
4, // maximumPoolSize
60, // keepAliveTime (seconds)
new DefaultThreadFactory(), // threadFactory
new DefaultRejectedExecutionHandler() // rejectedExecutionHandler
);
for (int i = 0; i < 10; i++) {
final int taskNumber = i;
threadPool.execute(() -> {
System.out.println("Task " + taskNumber + " is running in thread: " + Thread.currentThread().getName());
try {
Thread.sleep(100 * (taskNumber % 3 + 1)); // 模拟不同的任务执行时间
} catch (InterruptedException e) {
e.printStackTrace();
}
});
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
任务窃取的优点:
- 提高资源利用率: 能够充分利用空闲线程,避免资源浪费。
- 减少任务等待时间: 能够更快地执行任务,降低延迟。
- 适应性强: 能够适应工作负载不均衡的情况。
任务窃取的缺点:
- 增加复杂性: 实现任务窃取需要更多的代码和更复杂的逻辑。
- 可能存在竞争: 多个线程可能同时尝试从同一个队列中窃取任务,导致竞争。
- 窃取开销: 窃取任务本身也需要一定的开销,例如获取锁、复制任务等。
任务窃取的适用场景:
- 工作负载不均衡: 某些线程的任务队列为空,而其他线程的任务队列很长。
- 任务执行时间差异大: 某些任务执行时间很短,而其他任务执行时间很长。
- CPU密集型任务: 任务主要消耗CPU资源,线程空闲时间较少。
5. 高性能线程池的关键点
构建高性能线程池需要考虑以下几个关键点:
- 选择合适的任务队列: 根据任务的特性和应用场景选择合适的任务队列,例如
ArrayBlockingQueue
、LinkedBlockingQueue
、PriorityBlockingQueue
等。 - 合理设置线程池参数: 根据系统的负载和资源情况,合理设置
corePoolSize
、maximumPoolSize
、keepAliveTime
等参数。 - 使用任务窃取: 在工作负载不均衡的情况下,使用任务窃取可以提高资源利用率。
- 避免线程饥饿: 确保每个线程都有机会执行任务,避免某些线程长时间处于空闲状态。
- 处理异常: 捕获并处理任务执行期间的异常,避免线程崩溃。
- 监控和调优: 对线程池的运行状态进行监控,并根据监控结果进行调优。
- 使用非阻塞算法: 尽可能使用非阻塞算法来减少线程之间的竞争。例如使用
ConcurrentLinkedDeque
来实现任务队列。
实现WorkStealingThreadPool需要注意
- 使用随机数来选择窃取对象的worker
- 需要处理worker为null的情况
- worker队列的访问要做好同步。例子中因为是单线程访问,所以省略了同步的步骤
如何理解线程池的工作过程
线程池的工作流程可以总结为以下几个步骤:
- 任务提交: 当有新的任务需要执行时,将其提交给线程池。
- 核心线程处理: 如果线程池中的线程数量小于
corePoolSize
,则创建一个新的线程来执行该任务。 - 任务队列缓存: 如果线程池中的线程数量大于等于
corePoolSize
,则将该任务放入任务队列中。 - 扩展线程池: 如果任务队列已满,且线程池中的线程数量小于
maximumPoolSize
,则创建一个新的线程来执行该任务。 - 拒绝任务: 如果任务队列已满,且线程池中的线程数量大于等于
maximumPoolSize
,则执行拒绝策略。 - 线程复用: 当线程执行完任务后,不会被销毁,而是返回到线程池中,等待执行下一个任务。
- 任务窃取: 如果某个线程的任务队列为空,则尝试从其他线程的任务队列中窃取任务来执行。
- 线程回收: 当线程池中的线程数量大于
corePoolSize
时,多余的空闲线程在超过keepAliveTime
时间后会被终止。
总结要点
我们学习了线程池的基本原理、参数调优和任务窃取。
自定义线程池能够更好地控制并发行为,并针对特定场景进行优化,提高程序性能和稳定性。
选择合适的参数和使用任务窃取,可以构建出高性能的Java线程池。