一文搞懂 CountDownLatch 用法和原始碼!

程式設計師cxuan發表於2020-12-21

CountDownLatch 是多執行緒控制的一種工具,它被稱為 門閥計數器或者 閉鎖。這個工具經常用來用來協調多個執行緒之間的同步,或者說起到執行緒之間的通訊(而不是用作互斥的作用)。下面我們就來一起認識一下 CountDownLatch

我把自己以往的文章彙總成為了 Github ,歡迎各位大佬 star
https://github.com/crisxuan/bestJavaer

認識 CountDownLatch

CountDownLatch 能夠使一個執行緒在等待另外一些執行緒完成各自工作之後,再繼續執行。它相當於是一個計數器,這個計數器的初始值就是執行緒的數量,每當一個任務完成後,計數器的值就會減一,當計數器的值為 0 時,表示所有的執行緒都已經任務了,然後在 CountDownLatch 上等待的執行緒就可以恢復執行接下來的任務。

CountDownLatch 的使用

CountDownLatch 提供了一個構造方法,你必須指定其初始值,還指定了 countDown 方法,這個方法的作用主要用來減小計數器的值,當計數器變為 0 時,在 CountDownLatch 上 await 的執行緒就會被喚醒,繼續執行其他任務。當然也可以延遲喚醒,給 CountDownLatch 加一個延遲時間就可以實現。

其主要方法如下

CountDownLatch 主要有下面這幾個應用場景

CountDownLatch 應用場景

典型的應用場景就是當一個服務啟動時,同時會載入很多元件和服務,這時候主執行緒會等待元件和服務的載入。當所有的元件和服務都載入完畢後,主執行緒和其他執行緒在一起完成某個任務。

CountDownLatch 還可以實現學生一起比賽跑步的程式,CountDownLatch 初始化為學生數量的執行緒,鳴槍後,每個學生就是一條執行緒,來完成各自的任務,當第一個學生跑完全程後,CountDownLatch 就會減一,直到所有的學生完成後,CountDownLatch 會變為 0 ,接下來再一起宣佈跑步成績。

順著這個場景,你自己就可以延伸、擴充出來很多其他任務場景。

CountDownLatch 用法

下面我們通過一個簡單的計數器來演示一下 CountDownLatch 的用法

public class TCountDownLatch {

    public static void main(String[] args) {
        CountDownLatch latch = new CountDownLatch(5);
        Increment increment = new Increment(latch);
        Decrement decrement = new Decrement(latch);

        new Thread(increment).start();
        new Thread(decrement).start();

        try {
            Thread.sleep(6000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

class Decrement implements Runnable {

    CountDownLatch countDownLatch;

    public Decrement(CountDownLatch countDownLatch){
        this.countDownLatch = countDownLatch;
    }

    @Override
    public void run() {
        try {

            for(long i = countDownLatch.getCount();i > 0;i--){
                Thread.sleep(1000);
                System.out.println("countdown");
                this.countDownLatch.countDown();
            }

        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}


class Increment implements Runnable {

    CountDownLatch countDownLatch;

    public Increment(CountDownLatch countDownLatch){
        this.countDownLatch = countDownLatch;
    }

    @Override
    public void run() {
        try {
            System.out.println("await");
            countDownLatch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("Waiter Released");
    }
}

在 main 方法中我們初始化了一個計數器為 5 的 CountDownLatch,在 Decrement 方法中我們使用 countDown 執行減一操作,然後睡眠一段時間,同時在 Increment 類中進行等待,直到 Decrement 中的執行緒完成計數減一的操作後,喚醒 Increment 類中的 run 方法,使其繼續執行。

下面我們再來通過學生賽跑這個例子來演示一下 CountDownLatch 的具體用法

public class StudentRunRace {

    CountDownLatch stopLatch = new CountDownLatch(1);
    CountDownLatch runLatch = new CountDownLatch(10);

    public void waitSignal() throws Exception{
        System.out.println("選手" + Thread.currentThread().getName() + "正在等待裁判釋出口令");
        stopLatch.await();
        System.out.println("選手" + Thread.currentThread().getName() + "已接受裁判口令");
        Thread.sleep((long) (Math.random() * 10000));
        System.out.println("選手" + Thread.currentThread().getName() + "到達終點");
        runLatch.countDown();
    }

    public void waitStop() throws Exception{
        Thread.sleep((long) (Math.random() * 10000));
        System.out.println("裁判"+Thread.currentThread().getName()+"即將釋出口令");
        stopLatch.countDown();
        System.out.println("裁判"+Thread.currentThread().getName()+"已傳送口令,正在等待所有選手到達終點");
        runLatch.await();
        System.out.println("所有選手都到達終點");
        System.out.println("裁判"+Thread.currentThread().getName()+"彙總成績排名");
    }

    public static void main(String[] args) {
        ExecutorService service = Executors.newCachedThreadPool();
        StudentRunRace studentRunRace = new StudentRunRace();
        for (int i = 0; i < 10; i++) {
            Runnable runnable = () -> {
                try {
                    studentRunRace.waitSignal();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            };
            service.execute(runnable);
        }
        try {
            studentRunRace.waitStop();
        } catch (Exception e) {
            e.printStackTrace();
        }
        service.shutdown();
    }
}

下面我們就來一起分析一下 CountDownLatch 的原始碼

CountDownLatch 原始碼分析

CountDownLatch 使用起來比較簡單,但是卻非常有用,現在你可以在你的工具箱中加上 CountDownLatch 這個工具類了。下面我們就來深入認識一下 CountDownLatch。

CountDownLatch 的底層是由 AbstractQueuedSynchronizer 支援,而 AQS 的資料結構的核心就是兩個佇列,一個是 同步佇列(sync queue),一個是條件佇列(condition queue)

Sync 內部類

CountDownLatch 在其內部是一個 Sync ,它繼承了 AQS 抽象類。

private static final class Sync extends AbstractQueuedSynchronizer {...}

CountDownLatch 其實其內部只有一個 sync 屬性,並且是 final 的

private final Sync sync;

CountDownLatch 只有一個帶引數的構造方法

public CountDownLatch(int count) {
  if (count < 0) throw new IllegalArgumentException("count < 0");
  this.sync = new Sync(count);
}

也就是說,初始化的時候必須指定計數器的數量,如果數量為負會直接丟擲異常。

然後把 count 初始化為 Sync 內部的 count,也就是

Sync(int count) {
  setState(count);
}

注意這裡有一個 setState(count),這是什麼意思呢?見聞知意這只是一個設定狀態的操作,但是實際上不單單是,還有一層意思是 state 的值代表著待達到條件的執行緒數。這個我們在聊 countDown 方法的時候再討論。

getCount() 方法的返回值是 getState() 方法,它是 AbstractQueuedSynchronizer 中的方法,這個方法會返回當前執行緒計數,具有 volatile 讀取的記憶體語義。

// ---- CountDownLatch ----

int getCount() {
  return getState();
}

// ---- AbstractQueuedSynchronizer ----

protected final int getState() {
  return state;
}

tryAcquireShared() 方法用於獲取·共享狀態下物件的狀態,判斷物件是否為 0 ,如果為 0 返回 1 ,表示能夠嘗試獲取,如果不為 0,那麼返回 -1,表示無法獲取。

protected int tryAcquireShared(int acquires) {
  return (getState() == 0) ? 1 : -1;
}

// ----  getState() 方法和上面的方法相同 ----

這個 共享狀態 屬於 AQS 中的概念,在 AQS 中分為兩種模式,一種是 獨佔模式,一種是 共享模式

  • tryAcquire 獨佔模式,嘗試獲取資源,成功則返回 true,失敗則返回 false。
  • tryAcquireShared 共享方式,嘗試獲取資源。負數表示失敗;0 表示成功,但沒有剩餘可用資源;正數表示成功,且有剩餘資源。

tryReleaseShared() 方法用於共享模式下的釋放

protected boolean tryReleaseShared(int releases) {
  // 減小數量,變為 0 的時候進行通知。
  for (;;) {
    int c = getState();
    if (c == 0)
      return false;
    int nextc = c-1;
    if (compareAndSetState(c, nextc))
      return nextc == 0;
  }
}

這個方法是一個無限迴圈,獲取執行緒狀態,如果執行緒狀態是 0 則表示沒有被執行緒佔有,沒有佔有的話那麼直接返回 false ,表示已經釋放;然後下一個狀態進行 - 1 ,使用 compareAndSetState CAS 方法進行和記憶體值的比較,如果記憶體值也是 1 的話,就會更新記憶體值為 0 ,判斷 nextc 是否為 0 ,如果 CAS 比較不成功的話,會再次進行迴圈判斷。

如果 CAS 用法不清楚的話,讀者朋友們可以參考這篇文章 告訴你一個 AtomicInteger 的驚天大祕密!

await 方法

await() 方法是 CountDownLatch 一個非常重要的方法,基本上可以說只有 countDown 和 await 方法才是 CountDownLatch 的精髓所在,這個方法將會使當前執行緒在 CountDownLatch 計數減至零之前一直等待,除非執行緒被中斷。

CountDownLatch 中的 await 方法有兩種,一種是不帶任何引數的 await(),一種是可以等待一段時間的await(long timeout, TimeUnit unit)。下面我們先來看一下 await() 方法。

public void await() throws InterruptedException {
  sync.acquireSharedInterruptibly(1);
}

await 方法內部會呼叫 acquireSharedInterruptibly 方法,這個 acquireSharedInterruptibly 是 AQS 中的方法,以共享模式進行中斷。

public final void acquireSharedInterruptibly(int arg)
  throws InterruptedException {
  if (Thread.interrupted())
    throw new InterruptedException();
  if (tryAcquireShared(arg) < 0)
    doAcquireSharedInterruptibly(arg);
}

可以看到,acquireSharedInterruptibly 方法的內部會首先判斷執行緒是否中斷,如果執行緒中斷,則直接丟擲執行緒中斷異常。如果沒有中斷,那麼會以共享的方式獲取。如果能夠在共享的方式下不能獲取鎖,那麼就會以共享的方式斷開連結。

private void doAcquireSharedInterruptibly(int arg)
  throws InterruptedException {
  final Node node = addWaiter(Node.SHARED);
  boolean failed = true;
  try {
    for (;;) {
      final Node p = node.predecessor();
      if (p == head) {
        int r = tryAcquireShared(arg);
        if (r >= 0) {
          setHeadAndPropagate(node, r);
          p.next = null; // help GC
          failed = false;
          return;
        }
      }
      if (shouldParkAfterFailedAcquire(p, node) &&
          parkAndCheckInterrupt())
        throw new InterruptedException();
    }
  } finally {
    if (failed)
      cancelAcquire(node);
  }
}

這個方法有些長,我們分開來看

  • 首先,會先構造一個共享模式的 Node 入隊
  • 然後使用無限迴圈判斷新構造 node 的前驅節點,如果 node 節點的前驅節點是頭節點,那麼就會判斷執行緒的狀態,這裡呼叫了一個 setHeadAndPropagate ,其原始碼如下
private void setHeadAndPropagate(Node node, int propagate) {
  Node h = head; 
  setHead(node);
  if (propagate > 0 || h == null || h.waitStatus < 0 ||
      (h = head) == null || h.waitStatus < 0) {
    Node s = node.next;
    if (s == null || s.isShared())
      doReleaseShared();
  }
}

首先會設定頭節點,然後進行一系列的判斷,獲取節點的獲取節點的後繼,以共享模式進行釋放,就會呼叫 doReleaseShared 方法,我們再來看一下 doReleaseShared 方法

private void doReleaseShared() {

  for (;;) {
    Node h = head;
    if (h != null && h != tail) {
      int ws = h.waitStatus;
      if (ws == Node.SIGNAL) {
        if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
          continue;            // loop to recheck cases
        unparkSuccessor(h);
      }
      else if (ws == 0 &&
               !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
        continue;                // loop on failed CAS
    }
    if (h == head)                   // loop if head changed
      break;
  }
}

這個方法會以無限迴圈的方式首先判斷頭節點是否等於尾節點,如果頭節點等於尾節點的話,就會直接退出。如果頭節點不等於尾節點,會判斷狀態是否為 SIGNAL,不是的話就繼續迴圈 compareAndSetWaitStatus,然後斷開後繼節點。如果狀態不是 SIGNAL,也會呼叫 compareAndSetWaitStatus 設定狀態為 PROPAGATE,狀態為 0 並且不成功,就會繼續迴圈。

也就是說 setHeadAndPropagate 就是設定頭節點並且釋放後繼節點的一系列過程。

  • 我們來看下面的 if 判斷,也就是 shouldParkAfterFailedAcquire(p, node) 這裡
if (shouldParkAfterFailedAcquire(p, node) &&
    parkAndCheckInterrupt())
  throw new InterruptedException();

如果上面 Node p = node.predecessor() 獲取前驅節點不是頭節點,就會進行 park 斷開操作,判斷此時是否能夠斷開,判斷的標準如下

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
  int ws = pred.waitStatus;
  if (ws == Node.SIGNAL)
    return true;
  if (ws > 0) {
    do {
      node.prev = pred = pred.prev;
    } while (pred.waitStatus > 0);
    pred.next = node;
  } else {
    compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
  }
  return false;
}

這個方法會判斷 Node p 的前驅節點的結點狀態(waitStatus),節點狀態一共有五種,分別是

  1. CANCELLED(1):表示當前結點已取消排程。當超時或被中斷(響應中斷的情況下),會觸發變更為此狀態,進入該狀態後的結點將不會再變化。

  2. SIGNAL(-1):表示後繼結點在等待當前結點喚醒。後繼結點入隊時,會將前繼結點的狀態更新為 SIGNAL。

  3. CONDITION(-2):表示結點等待在 Condition 上,當其他執行緒呼叫了 Condition 的 signal() 方法後,CONDITION狀態的結點將從等待佇列轉移到同步佇列中,等待獲取同步鎖。

  4. PROPAGATE(-3):共享模式下,前繼結點不僅會喚醒其後繼結點,同時也可能會喚醒後繼的後繼結點。

  5. 0:新結點入隊時的預設狀態。

如果前驅節點是 SIGNAL 就會返回 true 表示可以斷開,如果前驅節點的狀態大於 0 (此時為什麼不用 ws == Node.CANCELLED ) 呢?因為 ws 大於 0 的條件只有 CANCELLED 狀態了。然後就是一系列的查詢遍歷操作直到前驅節點的 waitStatus > 0。如果 ws <= 0 ,而且還不是 SIGNAL 狀態的話,就會使用 CAS 替換前驅節點的 ws 為 SIGNAL 狀態。

如果檢查判斷是中斷狀態的話,就會返回 false。

private final boolean parkAndCheckInterrupt() {
  LockSupport.park(this);
  return Thread.interrupted();
}

這個方法使用 LockSupport.park 斷開連線,然後返回執行緒是否中斷的標誌。

  • cancelAcquire() 用於取消等待佇列,如果等待過程中沒有成功獲取資源(如timeout,或者可中斷的情況下被中斷了),那麼取消結點在佇列中的等待。
private void cancelAcquire(Node node) {
  if (node == null)
    return;

  node.thread = null;
  
  Node pred = node.prev;
  while (pred.waitStatus > 0)
    node.prev = pred = pred.prev;

  Node predNext = pred.next;

  node.waitStatus = Node.CANCELLED;

  if (node == tail && compareAndSetTail(node, pred)) {
    compareAndSetNext(pred, predNext, null);
  } else {
    int ws;
    if (pred != head &&
        ((ws = pred.waitStatus) == Node.SIGNAL ||
         (ws <= 0 && compareAndSetWaitStatus(pred, ws, Node.SIGNAL))) &&
        pred.thread != null) {
      Node next = node.next;
      if (next != null && next.waitStatus <= 0)
        compareAndSetNext(pred, predNext, next);
    } else {
      unparkSuccessor(node);
    }
    node.next = node; // help GC
  }
}

所以,對 CountDownLatch 的 await 呼叫大致會有如下的呼叫過程。

一個和 await 過載的方法是 await(long timeout, TimeUnit unit),這個方法和 await 最主要的區別就是這個方法能夠可以等待計數器一段時間再執行後續操作。

countDown 方法

countDown 是和 await 同等重要的方法,countDown 用於減少計數器的數量,如果計數減為 0 的話,就會釋放所有的執行緒。

public void countDown() {
  sync.releaseShared(1);
}

這個方法會呼叫 releaseShared 方法,此方法用於共享模式下的釋放操作,首先會判斷是否能夠進行釋放,判斷的方法就是 CountDownLatch 內部類 Sync 的 tryReleaseShared 方法

public final boolean releaseShared(int arg) {
  if (tryReleaseShared(arg)) {
    doReleaseShared();
    return true;
  }
  return false;
}

// ---- CountDownLatch ----

protected boolean tryReleaseShared(int releases) {
  for (;;) {
    int c = getState();
    if (c == 0)
      return false;
    int nextc = c-1;
    if (compareAndSetState(c, nextc))
      return nextc == 0;
  }
}

tryReleaseShared 會進行 for 迴圈判斷執行緒狀態值,使用 CAS 不斷嘗試進行替換。

如果能夠釋放,就會呼叫 doReleaseShared 方法

private void doReleaseShared() {
  for (;;) {
    Node h = head;
    if (h != null && h != tail) {
      int ws = h.waitStatus;
      if (ws == Node.SIGNAL) {
        if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
          continue;            // loop to recheck cases
        unparkSuccessor(h);
      }
      else if (ws == 0 &&
               !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
        continue;                // loop on failed CAS
    }
    if (h == head)                   // loop if head changed
      break;
  }
}

可以看到,doReleaseShared 其實也是一個無限迴圈不斷使用 CAS 嘗試替換的操作。

總結

本文是 CountDownLatch 的基本使用和原始碼分析,CountDownLatch 就是一個基於 AQS 的計數器,它內部的方法都是圍繞 AQS 框架來談的,除此之外還有其他比如 ReentrantLock、Semaphore 等都是 AQS 的實現,所以要研究併發的話,離不開對 AQS 的探討。CountDownLatch 的原始碼看起來很少,比較簡單,但是其內部比如 await 方法的呼叫鏈路卻很長,也值得花費時間深入研究。

我是 cxuan,一枚技術創作的程式設計師。如果本文你覺得不錯的話,跪求讀者點贊、在看、分享!

另外,我自己肝了六本 PDF,微信搜尋「程式設計師cxuan」關注公眾號後,在後臺回覆 cxuan ,領取全部 PDF,這些 PDF 如下

六本 PDF 連結

相關文章