ForkJoinPool线程池

ForkJoinPool简介

在ExecutorService接口的实现类中有一个是ForkJoinPool,该线程池的思想是分而治之,可以将一个任务分解为多个小任务,然后计算每个小任务,最终再汇总。

工作窃取Work-stealing

当我们向线程池中提交了2个任务,任务1的执行时间较长,任务2的执行较短,这样就会导致一个线程比较忙,另一个线程比较空闲,无法更好的提高cpu的使用率。对于这种场景,可以通过工作窃取算法解决。

工作窃取核心思想:自己的任务执行结束后查看别人是否有未开启的任务,如果有的话,就帮着执行这个任务。

大多数实现机制是:为每个线程分配一个双向队列用于存放需要执行的任务,当自己的队列没有任务的时候会从其它线程队列中获得一个任务继续执行。由于是双向队列,一般是线程自己的本地队列采取LIFO(后进先出),别的线程偷取时采用FIFO(先进先出),一个从头开始执行,一个从尾部开始执行,这样极大的降低了并发安全问题。

ForkJoinPool的使用

创建ForkJoinPool对象的时候,默认创建的线程数量会取决于当前的cpu核数。我们可以向ForkJoinPool中提交Runnable或Callable任务,如果要想利用工作窃取算法的话,需要向ForkJoinPool中提交RecursiveAction或者RecursiveTask任务,这两个类都是ForkJoinTask的子类,在使用上还是有一些区别的。

  • RecursiveAction适合没有返回值的操作,在compute方法中没有返回值。
  • RecursiveTask适合有返回值的操作,在compute方法中有返回值。

下面通过ForkJoinPool来计算1~10之前的数字求和

/*
    计算任务
 */
public class CountTask extends RecursiveTask<Long> {

    //任务分解的阈值
    private int threshold = 5;
    //计算起始值
    private long begin;
    //计算结束值
    private long end;

    public CountTask(long begin, long end) {
        this.begin = begin;
        this.end = end;
    }

    @Override
    protected Long compute() {
        long sum = 0;
        //如果数据少于limit,则无需分解
        if (end - begin <= threshold) {
            for (long i = begin; i <= end; i++) {
                sum += i;
            }
        } else {
            //每个中间值
            long middle = (begin + end) / 2;

            //分解任务
            CountTask leftTask = new CountTask(begin, middle);
            CountTask rightTask = new CountTask(middle + 1, end);

            //执行任务
            invokeAll(leftTask,rightTask);

            //获得计算结果
            Long leftResult = leftTask.join();
            Long rightResult = rightTask.join();
            //将结果求和
            sum = leftResult + rightResult;

        }
        return sum;
    }
}

测试类

public class TestForkJoin {
    public static void main(String[] args) throws ExecutionException, InterruptedException {

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        CountTask task = new CountTask(1L,100L);
        //向线程池中提交任务,获取计算结果
        Long num = forkJoinPool.invoke(task);
        
        System.out.println(num);

        forkJoinPool.shutdown();

    }
}

需要注意的是,上面使用了invokeAll方法来执行任务,网上有些文章中使用的是fork方法,这种用法是有问题的,会造成线程的浪费。

比如向使用invokeAll来执行两个任务A,B,系统会使用当前执行invokeAll的线程来执行A,然后再使用一个新的线程来执行B,这样总共是使用了2个线程。

对A和B两个任务使用fork方法,会占用2个线程,再加上之前线程,总共使用了3个线程。