一、背景
项目中我们经常会处理这样一种业务场景。启动多个线程去计算一段业务逻辑,等待所有线程全部执行完毕之后再向下做业务逻辑处理。在java中为我们提供了ExecutorCompletionService可以轻松的实现这样的业务场景。当然,还有其他中办法可以实现,比如使用CountDownLatch也可以达到同样的目的。
二、代码实战
先说一下代码具体实现的思路。 定义一个类MyExecutorCompletionService继承ExecutorCompletionService。并定义submittedTasks表示已经提交的任务,completedTasks表示已经完成的任务数,因为是多线程执行,所以这两个变量定义为AtomicLong类型,以确保线程安全访问。利用Executors创建一个大小为5的固定线程池,模拟启动20个任务执行。每次提交任务都调用MyExecutorCompletionService的submitTask,在submitTask会调用ExecutorCompletionService的submit方法执行任务,并将submittedTasks加1。循环判断任务是否完成,若未完成则调用getEleByTake一直阻塞等待线程完成,并将completedTasks加1。当完成线程数等于完成线程数,则表示所有线程都已经执行完毕。
请看代码:
package concurrent; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; public class ExecutorCompletionServiceTest { public static void main(String[] args) throws ExecutionException, InterruptedException { ExecutorCompletionServiceTest executorCompletionServiceTest = new ExecutorCompletionServiceTest(); executorCompletionServiceTest.test(); } private void test() throws ExecutionException, InterruptedException { int numThread = 5; int taskNum = 20; ExecutorService executor = Executors.newFixedThreadPool(numThread); MyExecutorCompletionService myExecutorCompletionService = new MyExecutorCompletionService<String>(executor); for(int i = 0;i<taskNum;i++ ){ myExecutorCompletionService.submitTask(new ExecutorCompletionServiceTest.Task(i)); } while(myExecutorCompletionService.isTasksCompleted()) { System.out.println("blocking================"); /* if(myExecutorCompletionService.getEleByPoll()==null) { continue; }*/ myExecutorCompletionService.getEleByTake(); if(myExecutorCompletionService.completedTasks.get()==myExecutorCompletionService.submittedTasks.get()) { break; } } executor.shutdown(); System.out.println("end==============="); } static class Task implements Callable<String>{ private int i; public Task(int i){ this.i = i; } public String call() throws Exception { System.out.println(Thread.currentThread().getName() + "执行完任务:" + i); return Thread.currentThread().getName() + "执行完任务:" + i; } } class MyExecutorCompletionService<V> extends ExecutorCompletionService<V> { //提交的任务数量 private final AtomicLong submittedTasks = new AtomicLong(); //已经执行完成的任务数量 private final AtomicLong completedTasks = new AtomicLong(); public MyExecutorCompletionService(Executor executor) { super(executor); } public MyExecutorCompletionService(Executor executor, BlockingQueue<Future<V>> queue) { super(executor, queue); } public Future<V> submitTask(Callable<V> task) { Future<V> future = super.submit(task); submittedTasks.incrementAndGet(); System.out.println("submit()===================="); return future; } public Future<V> submitTask(Runnable task, V result) { Future<V> future = super.submit(task, result); submittedTasks.incrementAndGet(); return future; } /** * 阻塞方法,等待返回下一个执行完成任务的Future */ public Future<V> getEleByTake() throws InterruptedException { System.out.println("take()===================="); Future<V> future = super.take(); completedTasks.incrementAndGet(); return future; } /** * 非阻塞方法,如果有执行完成的任务,返回Future,如果无执行完成的任务,返回null; */ public Future<V> getEleByPoll() { Future<V> future = super.poll(); System.out.println("poll()================"); if (future != null) completedTasks.incrementAndGet(); return future; } public Future<V> getEleByPoll(long timeout, TimeUnit unit) throws InterruptedException { Future<V> future = super.poll(timeout, unit); if (future != null) completedTasks.incrementAndGet(); return future; } public long getNumberOfCompletedTasks() { return completedTasks.get(); } public long getNumberOfSubmittedTasks() { return submittedTasks.get(); } public boolean isTasksCompleted() { return completedTasks.get() < submittedTasks.get(); } } }