java原始碼-CountDownLatch

晴天哥發表於2018-09-02

開篇

  • CountDownLatch是一個同步工具類,用來協調多個執行緒之間的同步,或者說起到執行緒之間的通訊(而不是用作互斥的作用)。

  • CountDownLatch能夠使一個執行緒在等待另外一些執行緒完成各自工作之後,再繼續執行。使用一個計數器進行實現。計數器初始值為執行緒的數量。當每一個執行緒完成自己任務後,計數器的值就會減一。當計數器的值為0時,表示所有的執行緒都已經完成了任務,然後在CountDownLatch上等待的執行緒就可以恢復執行任務。

  • CountDownLatch是一次性的,計數器的值只能在構造方法中初始化一次,之後沒有任何機制再次對其設定值,當CountDownLatch使用完畢後,它不能再次被使用。

CountDownLatch的用法

  • CountDownLatch典型用法1:某一執行緒在開始執行前等待n個執行緒執行完畢。將CountDownLatch的計數器初始化為n new CountDownLatch(n) ,每當一個任務執行緒執行完畢,就將計數器減1 countdownlatch.countDown(),當計數器的值變為0時,在CountDownLatch上 await() 的執行緒就會被喚醒。一個典型應用場景就是啟動一個服務時,主執行緒需要等待多個元件載入完畢,之後再繼續執行。

  • CountDownLatch典型用法2:實現多個執行緒開始執行任務的最大並行性。注意是並行性,不是併發,強調的是多個執行緒在某一時刻同時開始執行。類似於賽跑,將多個執行緒放到起點,等待發令槍響,然後同時開跑。做法是初始化一個共享的CountDownLatch(1),將其計數器初始化為1,多個執行緒在開始執行任務前首先 coundownlatch.await(),當主執行緒呼叫 countDown() 時,計數器變為0,多個執行緒同時被喚醒。

CountDownLatch的demo

public class CountDownLatchDemo {
    public static void main(String[] args) throws InterruptedException{
        CountDownLatch countDownLatch = new CountDownLatch(2){
            @Override
            public void await() throws InterruptedException {
                super.await();
                System.out.println(Thread.currentThread().getName() +  " count down is ok");
            }
        };
        
        Thread thread1 = new Thread(new Runnable() {
            @Override
            public void run() {
                //do something
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + " is done");
                countDownLatch.countDown();
            }
        }, "thread1");
        
        Thread thread2 = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    Thread.sleep(2000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + " is done");
                countDownLatch.countDown();
            }
        }, "thread2");
        
        thread1.start();
        thread2.start();
        
        countDownLatch.await();
    }

CountDownLatch的類定義

  • CountDownLatch內部包含Sync類。
  • CountDownLatch內部包含Sync類的物件sync。
  • Sync類繼承自AQS(神奇的AQS),建構函式設定AQS的state值為等待值。
public class CountDownLatch {

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
}

CountDownLatch的等待過程

  • CountDownLatch通過await()進入等待。
  • CountDownLatch通過await(long timeout, TimeUnit unit)進入超時等待。
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

CountDownLatch的await()過程

  • await()通過sync.acquireSharedInterruptibly()獲鎖。
  • acquireSharedInterruptibly通過tryAcquireShared()嘗試獲鎖。
  • tryAcquireShared()判斷獲鎖成功與否的依據是AQS的state的值是否為零。
  • 獲鎖失敗後通過doAcquireSharedInterruptibly()進入鎖等待佇列CLH。
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }


    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        // 嘗試獲鎖失敗
        if (tryAcquireShared(arg) < 0)
            // 
            doAcquireSharedInterruptibly(arg);
    }


    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }


    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

CountDownLatch的await(long timeout, TimeUnit unit)過程

  • await(long timeout, TimeUnit unit)通過sync.tryAcquireSharedNanos()獲鎖。
  • tryAcquireSharedNanos()通過doAcquireSharedNanos()嘗試獲鎖。
  • tryAcquireShared()判斷獲鎖成功與否的依據是AQS的state的值是否為零。
  • 獲鎖失敗後通過doAcquireSharedNanos()進入鎖等待佇列CLH,和doAcquireSharedInterruptibly()方法相比增加了超時檢測機制,通過LockSupport.parkNanos()實現超時。
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }



    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }



    private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (nanosTimeout <= 0L)
            return false;
        final long deadline = System.nanoTime() + nanosTimeout;
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return true;
                    }
                }
                nanosTimeout = deadline - System.nanoTime();
                if (nanosTimeout <= 0L)
                    return false;
                if (shouldParkAfterFailedAcquire(p, node) &&
                    nanosTimeout > spinForTimeoutThreshold)
                    LockSupport.parkNanos(this, nanosTimeout);
                if (Thread.interrupted())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

CountDownLatch的喚醒過程

  • CountDownLatch通過sync.releaseShared(1)釋放鎖實現state的遞減
  • tryReleaseShared()方法判斷鎖狀態state==0,遞減後值為0說明鎖已經被釋放。
  • releaseShared()釋放鎖成功後通過doReleaseShared()方法喚醒所有等待執行緒。
  • doReleaseShared()喚醒鎖的過程是一個傳播性的喚醒,通過執行緒A喚醒執行緒B,然後由執行緒B喚醒執行緒C的傳播性依次喚醒所有等待執行緒。
    public void countDown() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

    protected boolean tryReleaseShared(int releases) {
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

總結

CountDownLatch的工作原理,總結起來就兩點(基於AQS實現):

  • 初始化鎖狀態的值為需要等待的執行緒數。
  • 判斷鎖狀態是否已經釋放,如果鎖未釋放所有等待鎖的執行緒就會進入等待的CLH佇列。
  • 如果鎖狀態已經釋放,那麼就會通過傳播性喚醒所有的等待執行緒。


相關文章