Java併發包原始碼學習系列:同步元件CountDownLatch原始碼解析

天喬巴夏丶發表於2021-02-20

CountDownLatch概述

日常開發中,經常會遇到類似場景:主執行緒開啟多個子執行緒執行任務,需要等待所有子執行緒執行完畢後再進行彙總。

在同步元件CountDownLatch出現之前,我們可以使用join方法來完成,簡單實現如下:

public class JoinTest {
    public static void main(String[] args) throws InterruptedException {
        Thread A = new Thread(() -> {
            try {
                Thread.sleep(1000);
                System.out.println("A finish!");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        Thread B = new Thread(() -> {
            try {
                Thread.sleep(1000);
                System.out.println("B finish!");

            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        System.out.println("main thread wait ..");
        A.start();
        B.start();
        A.join(); // 等待A執行結束
        B.join(); // 等待B執行結束
        System.out.println("all thread finish !");
    }
}

但使用join方法並不是很靈活,並不能很好地滿足某些場景的需要,而CountDownLatch則能夠很好地代替它,並且相比之下,提供了更多靈活的特性:

CountDownLatch相比join方法對執行緒同步有更靈活的控制,原因如下:

  1. 呼叫子執行緒的join()方法後,該執行緒會一直被阻塞直到子執行緒執行完畢,而CountDownLatch使用計數器來允許子執行緒執行完畢或者執行中遞減計數,await方法返回不一定必須等待執行緒結束。
  2. 使用執行緒池管理執行緒時,新增Runnable到執行緒池,沒有辦法再呼叫執行緒的join方法了。

使用案例與基本思路

public class TestCountDownLatch {
    
    public static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
    
    public static void main (String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        executorService.submit(() -> {
            try {
                Thread.sleep(1000);
                System.out.println("A finish!");

            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                countDownLatch.countDown();
            }
        });
        executorService.submit(() -> {
            try {
                Thread.sleep(1000);
                System.out.println("B finish!");

            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                countDownLatch.countDown();
            }
        });
        System.out.println("main thread wait ..");
        countDownLatch.await();
        System.out.println("all thread finish !");
        executorService.shutdown();
    }
}
// 結果
main thread wait ..
B finish!
A finish!
all thread finish !
  • 構建CountDownLatch例項,構造引數傳參為2,內部計數初始值為2。
  • 主執行緒構建執行緒池,提交兩個任務,接著呼叫countDownLatch.await()陷入阻塞。
  • 子執行緒執行完畢之後呼叫countDownLatch.countDown(),內部計數器減1。
  • 所有子執行緒執行完畢之後,計數為0,此時主執行緒的await方法返回。

類圖與基本結構

public class CountDownLatch {
    /**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }
        //...
    }

    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public void countDown() {
        sync.releaseShared(1);
    }

    public long getCount() {
        return sync.getCount();
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

CountDownLatch基於AQS實現,內部維護一個Sync變數,繼承了AQS。

在AQS中,最重要的就是state狀態的表示,在CountDownLatch中使用state表示計數器的值,在初始化的時候,為state賦值。

幾個同步方法實現比較簡單,如果你不熟悉AQS,推薦你瞅一眼前置文章:

接下來我們簡單看一看實現,主要學習兩個方法:await()和countdown()。

void await()

當執行緒呼叫CountDownLatch的await方法後,執行緒會被阻塞,除非發生下面兩種情況:

  1. 內部計數器值為0,getState() == 0
  2. 被其他執行緒中斷,丟擲異常,也就是currThread.interrupt()
    // CountDownLatch.java
	public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
	// AQS.java
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        // 如果執行緒中斷, 則丟擲異常
        if (Thread.interrupted())
            throw new InterruptedException();
        // 由子類實現,這裡再Sync中實現,計數器為0就可以返回,否則進入AQS佇列等待
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }
	// Sync
	// 計數器為0 返回1, 否則返回-1
    private static final class Sync extends AbstractQueuedSynchronizer {
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
    }

boolean await(long timeout, TimeUnit unit)

當執行緒呼叫CountDownLatch的await方法後,執行緒會被阻塞,除非發生下面三種情況:

  1. 內部計數器值為0,getState() == 0,返回true。
  2. 被其他執行緒中斷,丟擲異常,也就是currThread.interrupt()
  3. 設定的timeout時間到了,超時返回false。
    // CountDownLatch.java
	public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
	// AQS.java
    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }

void countDown()

呼叫該方法,內部計數值減1,遞減後如果計數器值為0,喚醒所有因呼叫await方法而被阻塞的執行緒,否則跳過。

    // CountDownLatch.java
	public void countDown() {
        sync.releaseShared(1);
    }
	// AQS.java
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
	// Sync
    private static final class Sync extends AbstractQueuedSynchronizer {
        protected boolean tryReleaseShared(int releases) {
            // 迴圈進行CAS操作
            for (;;) {
                int c = getState();
                // 一旦為0,就返回false
                if (c == 0)
                    return false;
                int nextc = c-1;
                // CAS嘗試將state-1,只有這一步CAS成功且將state變成0的執行緒才會返回true
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

總結

  • CountDownLatch相比於join方法更加靈活且方便地實現執行緒間同步,體現在以下幾點:

    • 呼叫子執行緒的join()方法後,該執行緒會一直被阻塞直到子執行緒執行完畢,而CountDownLatch使用計數器來允許子執行緒執行完畢或者執行中遞減計數,await方法返回不一定必須等待執行緒結束。
    • 使用執行緒池管理執行緒時,新增Runnable到執行緒池,沒有辦法再呼叫執行緒的join方法了。
  • CountDownLatch使用state表示內部計數器的值,初始化傳入count。

  • 執行緒呼叫countdown方法將會原子性地遞減AQS的state值,執行緒呼叫await方法後將會置入AQS阻塞佇列中,直到計數器為0,或被打斷,或超時等才會返回,計數器為0時,當前執行緒還需要喚醒由於await()被阻塞的執行緒。

參考閱讀

  • 《Java併發程式設計之美》

相關文章