【Java】多线程 - 线程间通信工具 CountDownLatch、CyclicBarrier 和 Phaser 类

Posted by 西维蜀黍 on 2019-03-01, Last Modified on 2021-09-21

本文将介绍常用的线程间通信工具 CountDownLatch、CyclicBarrier 和Phaser 的用法,并结合实例介绍它们各自的适用场景及相同点和不同点。

CountDownLatch 类

CountDownLatch适用场景

Java 多线程编程中经常会碰到这样一种场景——某个线程需要等待一个或多个线程操作结束(或达到某种状态)才继续执行

比如我们需要同时 parallel call 多个 dependency services(这几个dependency call没有依赖关系,也就是可以同时send requests),直到所有的 services都返回之后,才继续执行下面的操作,此时可以通过 CountDownLatch 轻松实现。

这其实类似于 Golang中的 WaitGroup

CountDownLatch实例

import java.util.concurrent.CountDownLatch;

public class Main {
    public static void main(String[] args) throws InterruptedException {
        int totalThread = 3;
        long start = System.currentTimeMillis();
        CountDownLatch countDown = new CountDownLatch(totalThread);
        for(int i = 0; i < totalThread; i++) {
            final String threadName = "Thread " + i;
            new Thread(() -> {
                System.out.println(String.format("%s\t%s %s", System.currentTimeMillis(), threadName, "started"));
                try {
                    if (threadName.equals("Thread 0"))
                        Thread.sleep(1000);
                    else if (threadName.equals("Thread 1"))
                        Thread.sleep(5000);
                    else if (threadName.equals("Thread 2"))
                        Thread.sleep(3000);
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
                countDown.countDown();
                System.out.println(String.format("%s\t%s %s", System.currentTimeMillis(), threadName, "ended"));
            }).start();;
        }
        countDown.await();
        long stop = System.currentTimeMillis();
        System.out.println(String.format("Total time : %sms", (stop - start)));
    }
}

执行结果

1551413475740	Thread 2 started
1551413475740	Thread 0 started
1551413475740	Thread 1 started
1551413476761	Thread 0 ended
1551413478761	Thread 2 ended
Total time : 5076ms
1551413480762	Thread 1 ended

分析

可以看到,三个并发测试线程同时开始工作,当各线程完成测试时(分别耗时 1s、5s 和 3s ),分别调用 countDown.countDown(); 以通知耗时计算线程(这里对应主线程)。当 countDown 对象的计数值为 0 时,意味着全部的并发测试线程均完成工作,耗时计算线程收到通知,并计算耗时(忽略由于线程调度消耗而导致的误差,总测测试时长为 5s )。

CountDownLatch主要接口分析

CountDownLatch 工作原理相对简单,可以简单看成一个倒计数器,在构造方法中指定初始值,每次调用 countDown() 方法时将计数器减1,而调用 await() 会阻塞当前线程直到计数器变为 0 。CountDownLatch 关键接口如下

  • countDown() 如果当前计数器的值大于1,则将其减1;若当前值为1,则将其置为0并唤醒所有通过await等待的线程;若当前值为0,则什么也不做直接返回。
  • await() 等待计数器的值为0,若计数器的值为0则该方法返回;若等待期间该线程被中断,则抛出InterruptedException并清除该线程的中断状态。
  • await(long timeout, TimeUnit unit) 在指定的时间内等待计数器的值为0,若在指定时间内计数器的值变为0,则该方法返回true;若指定时间内计数器的值仍未变为0,则返回false;若指定时间内计数器的值变为0之前当前线程被中断,则抛出InterruptedException并清除该线程的中断状态。
  • getCount() 读取当前计数器的值,一般用于调试或者测试。

CyclicBarrier

CyclicBarrier 适用场景

**内存屏障(Memory Barrier)**能保证屏障之前的代码一定在屏障之后的代码之前被执行。

CyclicBarrier 可以译为循环屏障,也有类似的功能。

CyclicBarrier 可以在构造时指定需要在屏障前执行 await 的个数,所有线程对 await 的调用都导致自己被阻塞,直到调用await的次数达到预定值,此后所有被阻塞的线程都会立即被唤醒。

从使用场景上来说,CyclicBarrier 是让多个线程互相等待某一事件的发生,然后同时被唤醒。而上文讲的 CountDownLatch 是让某一线程等待多个线程的状态,然后该线程被唤醒。

CyclicBarrier 实例

import java.util.concurrent.CyclicBarrier;
import java.util.Date;
import java.util.concurrent.CyclicBarrier;

public class Main {
    public static void main(String[] args) {
        int totalThread = 5;
        CyclicBarrier barrier = new CyclicBarrier(totalThread);

        for(int i = 0; i < totalThread; i++) {
            String threadName = "Thread " + i;
            long start = System.currentTimeMillis();
            new Thread(() -> {
                System.out.println(String.format("%s\t%s %s",System.currentTimeMillis(), threadName, " is waiting"));
                try {
                    if (threadName.equals("Thread 0"))
                        Thread.sleep(1000);
                    else if (threadName.equals("Thread 1"))
                        Thread.sleep(5000);
                    else if (threadName.equals("Thread 2"))
                        Thread.sleep(3000);
                    barrier.await();
                } catch (Exception ex) {
                    ex.printStackTrace();
                }
                System.out.println(String.format("%s\t%s %s", System.currentTimeMillis(), threadName, "ended"));

                if (threadName.equals("Thread 0")) {
                    long stop = System.currentTimeMillis();
                    System.out.println(String.format("Total time : %sms", (stop - start)));
                }
            }).start();

        }
    }
}

执行结果

1551413441177	Thread 1  is waiting
1551413441177	Thread 3  is waiting
1551413441177	Thread 2  is waiting
1551413441177	Thread 0  is waiting
1551413441177	Thread 4  is waiting
Total time : 5076ms
1551413446196	Thread 1 ended
1551413446196	Thread 3 ended
1551413446196	Thread 0 ended
1551413446196	Thread 2 ended
1551413446196	Thread 4 ended

分析

可以看到,五个并发线程同时开始工作,(分别耗时 1s、5s、 3s、0s 和 0s )后各线程到达同步节点,这时均调用 countDown.countDown() 以设置内存屏障。此后被阻塞,阻塞的时长由达到屏障(barrier)最晚的线程决定。直到所有线程均到达屏障(barrier),所有线程恢复执行。

Phaser

Phaser适用场景

CountDownLatch和CyclicBarrier都是JDK 1.5引入的,而Phaser是JDK 1.7引入的。Phaser的功能与CountDownLatch和CyclicBarrier有部分重叠,同时也提供了更丰富的语义和更灵活的用法。

Phaser顾名思义,与阶段相关。Phaser比较适合这样一种场景,一种任务可以分为多个阶段,现希望多个线程去处理该批任务,对于每个阶段,多个线程可以并发进行,但是希望保证只有前面一个阶段的任务完成之后才能开始后面的任务。这种场景可以使用多个CyclicBarrier来实现,每个CyclicBarrier负责等待一个阶段的任务全部完成。但是使用CyclicBarrier的缺点在于,需要明确知道总共有多少个阶段,同时并行的任务数需要提前预定义好,且无法动态修改。而Phaser可同时解决这两个问题。

Phaser实例

public class PhaserDemo {

  public static void main(String[] args) throws IOException {
    int parties = 3;
    int phases = 4;
    final Phaser phaser = new Phaser(parties) {
      @Override  
      protected boolean onAdvance(int phase, int registeredParties) {  
          System.out.println("====== Phase : " + phase + " ======");  
          return registeredParties == 0;  
      }  
    };
    
    for(int i = 0; i < parties; i++) {
      int threadId = i;
      Thread thread = new Thread(() -> {
        for(int phase = 0; phase < phases; phase++) {
          System.out.println(String.format("Thread %s, phase %s", threadId, phase));
          phaser.arriveAndAwaitAdvance();
        }
      });
      thread.start();
    }
  }
}

执行结果

Thread 0, phase 0
Thread 1, phase 0
Thread 2, phase 0
====== Phase : 0 ======
Thread 2, phase 1
Thread 0, phase 1
Thread 1, phase 1
====== Phase : 1 ======
Thread 1, phase 2
Thread 2, phase 2
Thread 0, phase 2
====== Phase : 2 ======
Thread 0, phase 3
Thread 1, phase 3
Thread 2, phase 3
====== Phase : 3 ======

分析

从上面的结果可以看到,多个线程必须等到其它线程的同一阶段的任务全部完成才能进行到下一个阶段,并且每当完成某一阶段任务时,Phaser都会执行其onAdvance方法。

Phaser主要接口分析

Phaser主要接口如下:

  • arriveAndAwaitAdvance() 当前线程当前阶段执行完毕,等待其它线程完成当前阶段。如果当前线程是该阶段最后一个未到达的,则该方法直接返回下一个阶段的序号(阶段序号从0开始),同时其它线程的该方法也返回下一个阶段的序号。
  • arriveAndDeregister() 该方法立即返回下一阶段的序号,并且其它线程需要等待的个数减一,并且把当前线程从之后需要等待的成员中移除。如果该Phaser是另外一个Phaser的子Phaser(层次化Phaser会在后文中讲到),并且该操作导致当前Phaser的成员数为0,则该操作也会将当前Phaser从其父Phaser中移除。
  • arrive() 该方法不作任何等待,直接返回下一阶段的序号。
  • awaitAdvance(int phase) 该方法等待某一阶段执行完毕。如果当前阶段不等于指定的阶段或者该Phaser已经被终止,则立即返回。该阶段数一般由*arrive()方法或者arriveAndDeregister()*方法返回。返回下一阶段的序号,或者返回参数指定的值(如果该参数为负数),或者直接返回当前阶段序号(如果当前Phaser已经被终止)。
  • awaitAdvanceInterruptibly(int phase) 效果与*awaitAdvance(int phase)*相当,唯一的不同在于若该线程在该方法等待时被中断,则该方法抛出InterruptedException
  • awaitAdvanceInterruptibly(int phase, long timeout, TimeUnit unit) 效果与*awaitAdvanceInterruptibly(int phase)*相当,区别在于如果超时则抛出TimeoutException
  • bulkRegister(int parties) 注册多个party。如果当前phaser已经被终止,则该方法无效,并返回负数。如果调用该方法时,onAdvance方法正在执行,则该方法等待其执行完毕。如果该Phaser有父Phaser则指定的party数大于0,且之前该Phaser的party数为0,那么该Phaser会被注册到其父Phaser中。
  • forceTermination() 强制让该Phaser进入终止状态。已经注册的party数不受影响。如果该Phaser有子Phaser,则其所有的子Phaser均进入终止状态。如果该Phaser已经处于终止状态,该方法调用不造成任何影响。

Reference