手写一个高性能的Java线程池:参数调优、任务窃取(Work Stealing)实现

手写一个高性能的Java线程池:参数调优、任务窃取(Work Stealing)实现

大家好,今天我们来深入探讨如何手写一个高性能的Java线程池,并重点关注参数调优和任务窃取(Work Stealing)的实现。线程池是并发编程中至关重要的组件,它可以有效地管理线程资源,提高程序的性能和稳定性。虽然Java提供了ExecutorService接口和ThreadPoolExecutor类,但了解其内部机制并能够自定义线程池,可以让我们更好地掌控并发行为,针对特定场景进行优化。

1. 线程池的基本原理

线程池的核心思想是复用线程。它维护一个线程集合,当有任务需要执行时,从线程池中取出一个空闲线程来执行任务,而不是每次都创建新的线程。任务执行完毕后,线程并不销毁,而是返回到线程池中,等待执行下一个任务。

线程池通常包含以下几个关键组件:

  • 任务队列 (Task Queue): 用于存放等待执行的任务。常见的任务队列有ArrayBlockingQueue(有界阻塞队列)、LinkedBlockingQueue(无界阻塞队列)、PriorityBlockingQueue(优先级队列)等。
  • 线程管理器 (Thread Manager): 负责创建、销毁和管理线程。
  • 工作线程 (Worker Threads): 执行任务的线程。
  • 拒绝策略 (Rejected Execution Handler): 当任务队列已满且线程池中的线程都在忙碌时,用于处理新提交的任务。

2. 手写线程池的基本框架

首先,我们来构建一个基本的线程池框架。这个框架包含核心的线程池管理逻辑,但不包含任务窃取功能,后续我们会逐步完善。

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

public class SimpleThreadPool {

    private final int corePoolSize;
    private final int maximumPoolSize;
    private final long keepAliveTime;
    private final BlockingQueue<Runnable> workQueue;
    private final ThreadFactory threadFactory;
    private final RejectedExecutionHandler rejectedExecutionHandler;
    private final Worker[] workers;
    private final AtomicInteger workerCount = new AtomicInteger(0);

    public SimpleThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime,
                           BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory,
                           RejectedExecutionHandler rejectedExecutionHandler) {
        this.corePoolSize = corePoolSize;
        this.maximumPoolSize = maximumPoolSize;
        this.keepAliveTime = keepAliveTime;
        this.workQueue = workQueue;
        this.threadFactory = threadFactory;
        this.rejectedExecutionHandler = rejectedExecutionHandler;
        this.workers = new Worker[maximumPoolSize]; //初始化worker数组
        initializeWorkers();
    }

    private void initializeWorkers() {
        for (int i = 0; i < corePoolSize; i++) {
            addWorker(null, false); // 创建核心线程
        }
    }

    public void execute(Runnable task) {
        if (task == null) {
            throw new NullPointerException();
        }

        if (workerCount.get() < corePoolSize) {
            if (addWorker(task, true)) {
                return;
            }
        }

        if (!workQueue.offer(task)) {
            if (!addWorker(task, false)) {
                rejectedExecutionHandler.rejectedExecution(task, this);
            }
        }
    }

    private boolean addWorker(Runnable firstTask, boolean core) {
        int wc = workerCount.get();
        if (wc >= maximumPoolSize || (core && wc >= corePoolSize)) {
            return false;
        }

        Worker worker = new Worker(firstTask);
        Thread thread = worker.thread;
        if (thread == null) {
            return false; // 创建线程失败
        }

        workers[wc] = worker;
        workerCount.incrementAndGet();
        thread.start();
        return true;
    }

    private final class Worker implements Runnable {
        private final Thread thread;
        private Runnable firstTask;

        Worker(Runnable firstTask) {
            this.firstTask = firstTask;
            this.thread = threadFactory.newThread(this);
        }

        @Override
        public void run() {
            Runnable task = firstTask;
            try {
                while (task != null || (task = getTask()) != null) {
                    try {
                        task.run();
                    } catch (Throwable t) {
                        // 处理任务执行期间的异常
                        t.printStackTrace();
                    } finally {
                        task = null;
                    }
                }
            } finally {
                workerDone(this);
            }
        }

        private Runnable getTask() {
            try {
                return workQueue.take(); // 阻塞式获取任务
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                return null;
            }
        }
    }

    private void workerDone(Worker w) {
        workerCount.decrementAndGet();
        removeWorker(w);
        if(workerCount.get() < corePoolSize) {
             addWorker(null, false);
        }
    }

    private void removeWorker(Worker w) {
        for (int i = 0; i < workers.length; i++) {
            if (workers[i] == w) {
                workers[i] = null;
                break;
            }
        }
    }

    // 默认的线程工厂
    public static class DefaultThreadFactory implements ThreadFactory {
        private static final AtomicInteger poolNumber = new AtomicInteger(1);
        private final ThreadGroup group;
        private final AtomicInteger threadNumber = new AtomicInteger(1);
        private final String namePrefix;

        DefaultThreadFactory() {
            SecurityManager s = System.getSecurityManager();
            group = (s != null) ? s.getThreadGroup() :
                                  Thread.currentThread().getThreadGroup();
            namePrefix = "pool-" + poolNumber.getAndIncrement() + "-thread-";
        }

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(group, r,
                                  namePrefix + threadNumber.getAndIncrement(),
                                  0);
            if (t.isDaemon())
                t.setDaemon(false);
            if (t.getPriority() != Thread.NORM_PRIORITY)
                t.setPriority(Thread.NORM_PRIORITY);
            return t;
        }
    }

    // 默认的拒绝策略
    public static class DefaultRejectedExecutionHandler implements RejectedExecutionHandler {
        @Override
        public void rejectedExecution(Runnable r, SimpleThreadPool executor) {
            throw new RuntimeException("Task " + r.toString() +
                                       " rejected from " +
                                       executor.toString());
        }
    }

    @Override
    public String toString() {
        return "SimpleThreadPool{" +
               "corePoolSize=" + corePoolSize +
               ", maximumPoolSize=" + maximumPoolSize +
               ", keepAliveTime=" + keepAliveTime +
               ", workQueue=" + workQueue +
               ", workerCount=" + workerCount +
               '}';
    }

    public static void main(String[] args) {
        SimpleThreadPool threadPool = new SimpleThreadPool(
                5, // corePoolSize
                10, // maximumPoolSize
                60, // keepAliveTime (seconds)
                new LinkedBlockingQueue<>(100), // workQueue
                new DefaultThreadFactory(), // threadFactory
                new DefaultRejectedExecutionHandler()  // rejectedExecutionHandler
        );

        for (int i = 0; i < 20; i++) {
            final int taskNumber = i;
            threadPool.execute(() -> {
                System.out.println("Task " + taskNumber + " is running in thread: " + Thread.currentThread().getName());
                try {
                    Thread.sleep(100);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });
        }
    }
}

这个例子实现了一个简单的线程池,包括了核心线程数,最大线程数,任务队列,线程工厂和拒绝策略。

3. 线程池参数调优

线程池的性能很大程度上取决于其参数的配置。合理的参数设置可以提高线程池的吞吐量、降低延迟,并避免资源浪费。

  • corePoolSize (核心线程数): 线程池中保持存活的最小线程数。即使线程处于空闲状态,也不会被回收,除非设置了allowCoreThreadTimeOut。 如果任务提交速度很快,且任务执行时间很短,可以适当增加corePoolSize,减少线程创建和销毁的开销。
  • maximumPoolSize (最大线程数): 线程池中允许存在的最大线程数。当任务队列已满,且当前线程数小于maximumPoolSize时,线程池会创建新的线程来执行任务。 如果任务是CPU密集型的,maximumPoolSize可以设置得与CPU核心数相等,甚至略微大于CPU核心数。 如果任务是IO密集型的,maximumPoolSize可以设置得更大,因为线程在等待IO时不会占用CPU。
  • keepAliveTime (线程空闲存活时间): 当线程池中的线程数量大于corePoolSize时,多余的空闲线程在超过keepAliveTime时间后会被终止。 长时间运行的应用程序可以适当增加keepAliveTime,避免频繁创建和销毁线程。 对于需要快速响应的应用程序,可以适当缩短keepAliveTime,及时释放资源。
  • workQueue (任务队列): 用于存放等待执行的任务。 ArrayBlockingQueue:有界队列,可以防止任务过多导致内存溢出。适用于任务数量可控的场景。 LinkedBlockingQueue:无界队列,如果任务提交速度远大于处理速度,可能导致OOM。 PriorityBlockingQueue:优先级队列,可以根据任务的优先级来执行任务。适用于需要保证高优先级任务优先执行的场景。 选择合适的队列类型需要根据实际的应用场景和任务特性进行权衡。
  • ThreadFactory (线程工厂): 用于创建新的线程。 可以自定义ThreadFactory来设置线程的名称、优先级、是否为守护线程等。 默认的DefaultThreadFactory会创建一个非守护线程,并设置线程的优先级为Thread.NORM_PRIORITY
  • RejectedExecutionHandler (拒绝策略): 当任务队列已满且线程池中的线程都在忙碌时,用于处理新提交的任务。 AbortPolicy:直接抛出RejectedExecutionException异常。 CallerRunsPolicy:由提交任务的线程来执行该任务。 DiscardPolicy:直接丢弃该任务,不抛出异常。 DiscardOldestPolicy:丢弃队列中最老的任务,然后尝试将新任务加入队列。 可以自定义RejectedExecutionHandler来处理被拒绝的任务,例如记录日志、持久化任务等。

参数调优示例:

假设我们有一个Web服务器,需要处理大量的HTTP请求。每个请求的处理包括读取数据库、处理业务逻辑和返回响应。

  • CPU核心数: 8
  • 平均请求处理时间: 100ms
  • IO等待时间: 80ms (假设大部分时间都在等待数据库响应)

在这种情况下,我们可以考虑以下参数配置:

参数 说明
corePoolSize 8 与CPU核心数相等,保证CPU的利用率。
maximumPoolSize 32 由于存在大量的IO等待时间,可以适当增加maximumPoolSize,提高并发处理能力。 经验公式: 最佳线程数 = CPU核心数 * (1 + IO耗时 / CPU耗时), 这里是 8 * (1 + 80/20) = 40 考虑到实际情况,设置为32。
keepAliveTime 60 秒 保持线程存活一段时间,避免频繁创建和销毁线程。
workQueue LinkedBlockingQueue(1000) 使用有界队列,防止任务过多导致内存溢出。队列的大小需要根据实际的请求量和处理速度进行调整。
ThreadFactory DefaultThreadFactory 使用默认的线程工厂。
RejectedExecutionHandler CallerRunsPolicy 当任务队列已满时,由提交任务的线程来执行该任务,避免丢失请求。

调优注意事项:

  • 监控: 在调整线程池参数之前,需要对线程池的运行状态进行监控,例如线程池的活跃线程数、任务队列的长度、拒绝任务的数量等。
  • 测试: 在调整线程池参数之后,需要进行压力测试,验证参数的调整是否达到了预期的效果。
  • 逐步调整: 避免一次性调整过多的参数,而是应该逐步调整,每次调整一个参数,并观察其对性能的影响。

4. 任务窃取(Work Stealing)的实现

任务窃取是一种用于优化线程池性能的技术,尤其是在工作负载不均衡的情况下。其核心思想是:当一个线程的任务队列为空时,它可以从其他线程的任务队列中“窃取”任务来执行,从而提高整体的资源利用率。

任务窃取的原理:

  1. 每个线程拥有自己的任务队列(通常是双端队列Deque)。
  2. 线程优先从自己的队列头部获取任务执行(LIFO,Last-In-First-Out)。 这样可以提高缓存命中率。
  3. 当线程自己的队列为空时,从其他线程的队列尾部窃取任务执行(FIFO,First-In-First-Out)。 这样可以减少任务饥饿的概率。

实现步骤:

  1. 将任务队列改为Deque 我们需要使用双端队列,以便从队列头部获取任务,从队列尾部窃取任务。
  2. 修改Worker类。 Worker类需要增加窃取任务的逻辑。
  3. 增加窃取任务的方法。 实现一个stealTask()方法,用于从其他线程的队列中窃取任务。

代码实现:

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class WorkStealingThreadPool {

    private final int corePoolSize;
    private final int maximumPoolSize;
    private final long keepAliveTime;
    private final ThreadFactory threadFactory;
    private final RejectedExecutionHandler rejectedExecutionHandler;
    private final Worker[] workers;
    private final AtomicInteger workerCount = new AtomicInteger(0);
    private final Random random = new Random();

    public WorkStealingThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime,
                                ThreadFactory threadFactory,
                                RejectedExecutionHandler rejectedExecutionHandler) {
        this.corePoolSize = corePoolSize;
        this.maximumPoolSize = maximumPoolSize;
        this.keepAliveTime = keepAliveTime;
        this.threadFactory = threadFactory;
        this.rejectedExecutionHandler = rejectedExecutionHandler;
        this.workers = new Worker[maximumPoolSize];
        initializeWorkers();
    }

    private void initializeWorkers() {
        for (int i = 0; i < corePoolSize; i++) {
            addWorker(null, false);
        }
    }

    public void execute(Runnable task) {
        if (task == null) {
            throw new NullPointerException();
        }

        int workerIndex = random.nextInt(workerCount.get()); // 随机选择一个worker
        Worker worker = workers[workerIndex];
        if (worker != null && worker.offer(task)) {
            return; // 优先放入随机worker的队列
        }

        if (workerCount.get() < maximumPoolSize) {
            if (addWorker(task, true)) {
                return;
            }
        }

        rejectedExecutionHandler.rejectedExecution(task, this); // 实在不行,拒绝任务
    }

    private boolean addWorker(Runnable firstTask, boolean core) {
        int wc = workerCount.get();
        if (wc >= maximumPoolSize || (core && wc >= corePoolSize)) {
            return false;
        }

        Worker worker = new Worker(firstTask);
        Thread thread = worker.thread;
        if (thread == null) {
            return false;
        }

        workers[wc] = worker;
        workerCount.incrementAndGet();
        thread.start();
        return true;
    }

    private final class Worker implements Runnable {
        private final Thread thread;
        private final Deque<Runnable> taskQueue = new ArrayDeque<>(); // 使用双端队列

        Worker(Runnable firstTask) {
            if (firstTask != null) {
                taskQueue.offer(firstTask);
            }
            this.thread = threadFactory.newThread(this);
        }

        public boolean offer(Runnable task) {
            return taskQueue.offer(task);
        }

        @Override
        public void run() {
            try {
                while (true) {
                    Runnable task = pollTask(); // 优先从自己的队列获取任务
                    if (task != null) {
                        try {
                            task.run();
                        } catch (Throwable t) {
                            t.printStackTrace();
                        }
                        continue;
                    }

                    task = stealTask(); // 尝试窃取任务
                    if (task != null) {
                        try {
                            task.run();
                        } catch (Throwable t) {
                            t.printStackTrace();
                        }
                        continue;
                    }

                    if (workerCount.get() > corePoolSize) { // 如果线程数量大于核心线程数,且没有任务可做,则退出
                        if (taskQueue.isEmpty()) {
                            workerCount.decrementAndGet();
                            removeWorker(this);
                            return;
                        }
                    }

                    try {
                        Thread.sleep(1); // 避免空循环消耗CPU
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
            } finally {
                removeWorker(this);
            }
        }

        private Runnable pollTask() {
            return taskQueue.pollFirst(); // 从队列头部获取任务
        }

        private Runnable stealTask() {
            List<Worker> otherWorkers = IntStream.range(0, workers.length)
                    .mapToObj(i -> workers[i])
                    .filter(w -> w != this && w != null)
                    .collect(Collectors.toList());

            if (otherWorkers.isEmpty()) {
                return null;
            }

            int index = random.nextInt(otherWorkers.size());
            Worker victim = otherWorkers.get(index);
            return victim.taskQueue.pollLast(); // 从队列尾部窃取任务
        }
    }

     private void removeWorker(Worker w) {
        for (int i = 0; i < workers.length; i++) {
            if (workers[i] == w) {
                workers[i] = null;
                break;
            }
        }
    }

    // 默认的线程工厂
    public static class DefaultThreadFactory implements ThreadFactory {
        private static final AtomicInteger poolNumber = new AtomicInteger(1);
        private final ThreadGroup group;
        private final AtomicInteger threadNumber = new AtomicInteger(1);
        private final String namePrefix;

        DefaultThreadFactory() {
            SecurityManager s = System.getSecurityManager();
            group = (s != null) ? s.getThreadGroup() :
                                  Thread.currentThread().getThreadGroup();
            namePrefix = "pool-" + poolNumber.getAndIncrement() + "-thread-";
        }

        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(group, r,
                                  namePrefix + threadNumber.getAndIncrement(),
                                  0);
            if (t.isDaemon())
                t.setDaemon(false);
            if (t.getPriority() != Thread.NORM_PRIORITY)
                t.setPriority(Thread.NORM_PRIORITY);
            return t;
        }
    }

    // 默认的拒绝策略
    public static class DefaultRejectedExecutionHandler implements RejectedExecutionHandler {
        @Override
        public void rejectedExecution(Runnable r, WorkStealingThreadPool executor) {
            throw new RuntimeException("Task " + r.toString() +
                                       " rejected from " +
                                       executor.toString());
        }
    }

    public static void main(String[] args) {
        WorkStealingThreadPool threadPool = new WorkStealingThreadPool(
                2, // corePoolSize
                4, // maximumPoolSize
                60, // keepAliveTime (seconds)
                new DefaultThreadFactory(), // threadFactory
                new DefaultRejectedExecutionHandler()  // rejectedExecutionHandler
        );

        for (int i = 0; i < 10; i++) {
            final int taskNumber = i;
            threadPool.execute(() -> {
                System.out.println("Task " + taskNumber + " is running in thread: " + Thread.currentThread().getName());
                try {
                    Thread.sleep(100 * (taskNumber % 3 + 1)); // 模拟不同的任务执行时间
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });
        }

        try {
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

任务窃取的优点:

  • 提高资源利用率: 能够充分利用空闲线程,避免资源浪费。
  • 减少任务等待时间: 能够更快地执行任务,降低延迟。
  • 适应性强: 能够适应工作负载不均衡的情况。

任务窃取的缺点:

  • 增加复杂性: 实现任务窃取需要更多的代码和更复杂的逻辑。
  • 可能存在竞争: 多个线程可能同时尝试从同一个队列中窃取任务,导致竞争。
  • 窃取开销: 窃取任务本身也需要一定的开销,例如获取锁、复制任务等。

任务窃取的适用场景:

  • 工作负载不均衡: 某些线程的任务队列为空,而其他线程的任务队列很长。
  • 任务执行时间差异大: 某些任务执行时间很短,而其他任务执行时间很长。
  • CPU密集型任务: 任务主要消耗CPU资源,线程空闲时间较少。

5. 高性能线程池的关键点

构建高性能线程池需要考虑以下几个关键点:

  • 选择合适的任务队列: 根据任务的特性和应用场景选择合适的任务队列,例如ArrayBlockingQueueLinkedBlockingQueuePriorityBlockingQueue等。
  • 合理设置线程池参数: 根据系统的负载和资源情况,合理设置corePoolSizemaximumPoolSizekeepAliveTime等参数。
  • 使用任务窃取: 在工作负载不均衡的情况下,使用任务窃取可以提高资源利用率。
  • 避免线程饥饿: 确保每个线程都有机会执行任务,避免某些线程长时间处于空闲状态。
  • 处理异常: 捕获并处理任务执行期间的异常,避免线程崩溃。
  • 监控和调优: 对线程池的运行状态进行监控,并根据监控结果进行调优。
  • 使用非阻塞算法: 尽可能使用非阻塞算法来减少线程之间的竞争。例如使用ConcurrentLinkedDeque来实现任务队列。

实现WorkStealingThreadPool需要注意

  • 使用随机数来选择窃取对象的worker
  • 需要处理worker为null的情况
  • worker队列的访问要做好同步。例子中因为是单线程访问,所以省略了同步的步骤

如何理解线程池的工作过程

线程池的工作流程可以总结为以下几个步骤:

  1. 任务提交: 当有新的任务需要执行时,将其提交给线程池。
  2. 核心线程处理: 如果线程池中的线程数量小于corePoolSize,则创建一个新的线程来执行该任务。
  3. 任务队列缓存: 如果线程池中的线程数量大于等于corePoolSize,则将该任务放入任务队列中。
  4. 扩展线程池: 如果任务队列已满,且线程池中的线程数量小于maximumPoolSize,则创建一个新的线程来执行该任务。
  5. 拒绝任务: 如果任务队列已满,且线程池中的线程数量大于等于maximumPoolSize,则执行拒绝策略。
  6. 线程复用: 当线程执行完任务后,不会被销毁,而是返回到线程池中,等待执行下一个任务。
  7. 任务窃取: 如果某个线程的任务队列为空,则尝试从其他线程的任务队列中窃取任务来执行。
  8. 线程回收: 当线程池中的线程数量大于corePoolSize时,多余的空闲线程在超过keepAliveTime时间后会被终止。

总结要点

我们学习了线程池的基本原理、参数调优和任务窃取。
自定义线程池能够更好地控制并发行为,并针对特定场景进行优化,提高程序性能和稳定性。
选择合适的参数和使用任务窃取,可以构建出高性能的Java线程池。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注