JDK 7 中的 Fork/Join 模式

青色刀客發表於2013-10-17

介紹

隨著多核晶片逐漸成為主流,大多數軟體開發人員不可避免地需要了解並行程式設計的知識。而同時,主流程式語言正在將越來越多的並行特性合併到標準庫或者語言本身之中。我們可以看到,JDK 在這方面同樣走在潮流的前方。在 JDK 標準版 5 中,由 Doug Lea 提供的並行框架成為了標準庫的一部分(JSR-166)。隨後,在 JDK 6 中,一些新的並行特性,例如並行 collection 框架,合併到了標準庫中(JSR-166x)。直到今天,儘管 Java SE 7 還沒有正式釋出,一些並行相關的新特性已經出現在 JSR-166y 中:

  1. Fork/Join 模式;
  2. TransferQueue,它繼承自 BlockingQueue 並能在佇列滿時阻塞“生產者”;
  3. ArrayTasks/ListTasks,用於並行執行某些陣列/列表相關任務的類;
  4. IntTasks/LongTasks/DoubleTasks,用於並行處理數字型別陣列的工具類,提供了排序、查詢、求和、求最小值、求最大值等功能;

其中,對 Fork/Join 模式的支援可能是對開發並行軟體來說最通用的新特性。在 JSR-166y 中,Doug Lea 實現 ArrayTasks/ListTasks/IntTasks/LongTasks/DoubleTasks 時就大量的用到了 Fork/Join 模式。讀者還需要注意一點,因為 JDK 7 還沒有正式釋出,因此本文涉及到的功能和釋出版本有可能不一樣。

Fork/Join 模式有自己的適用範圍。如果一個應用能被分解成多個子任務,並且組合多個子任務的結果就能夠獲得最終的答案,那麼這個應用就適合用 Fork/Join 模式來解決。圖 1 給出了一個 Fork/Join 模式的示意圖,位於圖上部的 Task 依賴於位於其下的 Task 的執行,只有當所有的子任務都完成之後,呼叫者才能獲得 Task 0 的返回結果。

圖 1. Fork/Join 模式示意圖
圖 1. Fork/Join 模式示意圖

可以說,Fork/Join 模式能夠解決很多種類的並行問題。通過使用 Doug Lea 提供的 Fork/Join 框架,軟體開發人員只需要關注任務的劃分和中間結果的組合就能充分利用並行平臺的優良效能。其他和並行相關的諸多難於處理的問題,例如負載平衡、同步等,都可以由框架採用統一的方式解決。這樣,我們就能夠輕鬆地獲得並行的好處而避免了並行程式設計的困難且容易出錯的缺點。

使用 Fork/Join 模式

在開始嘗試 Fork/Join 模式之前,我們需要從 Doug Lea 主持的 Concurrency JSR-166 Interest Site 上下載 JSR-166y 的原始碼,並且我們還需要安裝最新版本的 JDK 6(下載網址請參閱 參考資源)。Fork/Join 模式的使用方式非常直觀。首先,我們需要編寫一個 ForkJoinTask 來完成子任務的分割、中間結果的合併等工作。隨後,我們將這個 ForkJoinTask 交給 ForkJoinPool 來完成應用的執行。

通常我們並不直接繼承 ForkJoinTask,它包含了太多的抽象方法。針對特定的問題,我們可以選擇 ForkJoinTask 的不同子類來完成任務。RecursiveAction 是 ForkJoinTask 的一個子類,它代表了一類最簡單的 ForkJoinTask:不需要返回值,當子任務都執行完畢之後,不需要進行中間結果的組合。如果我們從 RecursiveAction 開始繼承,那麼我們只需要過載 protected void compute() 方法。下面,我們來看看怎麼為快速排序演算法建立一個 ForkJoinTask 的子類:

清單 1. ForkJoinTask 的子類
class SortTask extends RecursiveAction {
    final long[] array;
    final int lo;
    final int hi;
    private int THRESHOLD = 30;

    public SortTask(long[] array) {
        this.array = array;
        this.lo = 0;
        this.hi = array.length - 1;
    }

    public SortTask(long[] array, int lo, int hi) {
        this.array = array;
        this.lo = lo;
        this.hi = hi;
    }

    protected void compute() {
        if (hi - lo < THRESHOLD)
            sequentiallySort(array, lo, hi);
        else {
            int pivot = partition(array, lo, hi);
            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
                pivot + 1, hi));
        }
    }

    private int partition(long[] array, int lo, int hi) {
        long x = array[hi];
        int i = lo - 1;
        for (int j = lo; j < hi; j++) {
            if (array[j] <= x) {
                i++;
                swap(array, i, j);
            }
        }
        swap(array, i + 1, hi);
        return i + 1;
    }

    private void swap(long[] array, int i, int j) {
        if (i != j) {
            long temp = array[i];
            array[i] = array[j];
            array[j] = temp;
        }
    }

    private void sequentiallySort(long[] array, int lo, int hi) {
        Arrays.sort(array, lo, hi + 1);
    }
}

在 清單 1 中,SortTask 首先通過 partition() 方法將陣列分成兩個部分。隨後,兩個子任務將被生成並分別排序陣列的兩個部分。當子任務足夠小時,再將其分割為更小的任務反而引起效能的降低。因此,這裡我們使用一個THRESHOLD,限定在子任務規模較小時,使用直接排序,而不是再將其分割成為更小的任務。其中,我們用到了 RecursiveAction 提供的方法 coInvoke()。它表示:啟動所有的任務,並在所有任務都正常結束後返回。如果其中一個任務出現異常,則其它所有的任務都取消。coInvoke() 的引數還可以是任務的陣列。

現在剩下的工作就是將 SortTask 提交到 ForkJoinPool 了。ForkJoinPool() 預設建立具有與 CPU 可使用執行緒數相等執行緒個數的執行緒池。我們在一個 JUnit 的 test 方法中將 SortTask 提交給一個新建的 ForkJoinPool:

清單 2. 新建的 ForkJoinPool
@Test
public void testSort() throws Exception {
    ForkJoinTask sort = new SortTask(array);
    ForkJoinPool fjpool = new ForkJoinPool();
    fjpool.submit(sort);
    fjpool.shutdown();

    fjpool.awaitTermination(30, TimeUnit.SECONDS);

    assertTrue(checkSorted(array));
}

在上面的程式碼中,我們用到了 ForkJoinPool 提供的如下函式:

  1. submit():將 ForkJoinTask 類的物件提交給 ForkJoinPool,ForkJoinPool 將立刻開始執行 ForkJoinTask。
  2. shutdown():執行此方法之後,ForkJoinPool 不再接受新的任務,但是已經提交的任務可以繼續執行。如果希望立刻停止所有的任務,可以嘗試 shutdownNow() 方法。
  3. awaitTermination():阻塞當前執行緒直到 ForkJoinPool 中所有的任務都執行結束。

並行快速排序的完整程式碼如下所示:

清單 3. 並行快速排序的完整程式碼
package tests;

import static org.junit.Assert.*;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import jsr166y.forkjoin.ForkJoinPool;
import jsr166y.forkjoin.ForkJoinTask;
import jsr166y.forkjoin.RecursiveAction;

import org.junit.Before;
import org.junit.Test;

class SortTask extends RecursiveAction {
    final long[] array;
    final int lo;
    final int hi;
    private int THRESHOLD = 0; //For demo only

    public SortTask(long[] array) {
        this.array = array;
        this.lo = 0;
        this.hi = array.length - 1;
    }

    public SortTask(long[] array, int lo, int hi) {
        this.array = array;
        this.lo = lo;
        this.hi = hi;
    }

    protected void compute() {
        if (hi - lo < THRESHOLD)
            sequentiallySort(array, lo, hi);
        else {
            int pivot = partition(array, lo, hi);
            System.out.println("\npivot = " + pivot + ", low = " + lo + ", high = " + hi);
			System.out.println("array" + Arrays.toString(array));
            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,
                    pivot + 1, hi));
        }
    }

    private int partition(long[] array, int lo, int hi) {
        long x = array[hi];
        int i = lo - 1;
        for (int j = lo; j < hi; j++) {
            if (array[j] <= x) {
                i++;
                swap(array, i, j);
            }
        }
        swap(array, i + 1, hi);
        return i + 1;
    }

    private void swap(long[] array, int i, int j) {
        if (i != j) {
            long temp = array[i];
            array[i] = array[j];
            array[j] = temp;
        }
    }

    private void sequentiallySort(long[] array, int lo, int hi) {
        Arrays.sort(array, lo, hi + 1);
    }
}

public class TestForkJoinSimple {
    private static final int NARRAY = 16; //For demo only
    long[] array = new long[NARRAY];
    Random rand = new Random();

    @Before
    public void setUp() {
        for (int i = 0; i < array.length; i++) {
            array[i] = rand.nextLong()%100; //For demo only
        }
        System.out.println("Initial Array: " + Arrays.toString(array));
    }

    @Test
    public void testSort() throws Exception {
        ForkJoinTask sort = new SortTask(array);
        ForkJoinPool fjpool = new ForkJoinPool();
        fjpool.submit(sort);
        fjpool.shutdown();

        fjpool.awaitTermination(30, TimeUnit.SECONDS);

        assertTrue(checkSorted(array));
    }

    boolean checkSorted(long[] a) {
        for (int i = 0; i < a.length - 1; i++) {
            if (a[i] > (a[i + 1])) {
                return false;
            }
        }
        return true;
    }
}

執行以上程式碼,我們可以得到以下結果:

Initial Array: [46, -12, 74, -67, 76, -13, -91, -96]

pivot = 0, low = 0, high = 7
array[-96, -12, 74, -67, 76, -13, -91, 46]

pivot = 5, low = 1, high = 7
array[-96, -12, -67, -13, -91, 46, 76, 74]

pivot = 1, low = 1, high = 4
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 4, low = 2, high = 4
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 3, low = 2, high = 3
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 2, low = 2, high = 2
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 6, low = 6, high = 7
array[-96, -91, -67, -13, -12, 46, 74, 76]

pivot = 7, low = 7, high = 7
array[-96, -91, -67, -13, -12, 46, 74, 76]

Fork/Join 模式高階特性

使用 RecursiveTask

除了 RecursiveAction,Fork/Join 框架還提供了其他 ForkJoinTask 子類:帶有返回值的 RecursiveTask,使用finish() 方法顯式中止的 AsyncAction 和 LinkedAsyncAction,以及可使用 TaskBarrier 為每個任務設定不同中止條件的 CyclicAction。

從 RecursiveTask 繼承的子類同樣需要過載 protected void compute() 方法。與 RecursiveAction 稍有不同的是,它可使用泛型指定一個返回值的型別。下面,我們來看看如何使用 RecursiveTask 的子類。

清單 4. RecursiveTask 的子類
class Fibonacci extends RecursiveTask<Integer> {
    final int n;

    Fibonacci(int n) {
        this.n = n;
    }

    private int compute(int small) {
        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
        return results[small];
    }

    public Integer compute() {
        if (n <= 10) {
            return compute(n);
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        Fibonacci f2 = new Fibonacci(n - 2);
        f1.fork();
        f2.fork();
        return f1.join() + f2.join();
    }
}

在 清單 4 中, Fibonacci 的返回值為 Integer 型別。其 compute() 函式首先建立兩個子任務,啟動子任務執行,阻塞以等待子任務的結果返回,相加後得到最終結果。同樣,當子任務足夠小時,通過查表得到其結果,以減小因過多地分割任務引起的效能降低。其中,我們用到了 RecursiveTask 提供的方法 fork() 和 join()。它們分別表示:子任務的非同步執行和阻塞等待結果完成。

現在剩下的工作就是將 Fibonacci 提交到 ForkJoinPool 了,我們在一個 JUnit 的 test 方法中作了如下處理:

清單 5. 將 Fibonacci 提交到 ForkJoinPool
@Test
public void testFibonacci() throws InterruptedException, ExecutionException {
    ForkJoinTask<Integer> fjt = new Fibonacci(45);
    ForkJoinPool fjpool = new ForkJoinPool();
    Future<Integer> result = fjpool.submit(fjt);

    // do something
    System.out.println(result.get());
}

使用 CyclicAction 來處理迴圈任務

CyclicAction 的用法稍微複雜一些。如果一個複雜任務需要幾個執行緒協作完成,並且執行緒之間需要在某個點等待所有其他執行緒到達,那麼我們就能方便的用 CyclicAction 和 TaskBarrier 來完成。圖 2 描述了使用 CyclicAction 和 TaskBarrier 的一個典型場景。

圖 2. 使用 CyclicAction 和 TaskBarrier 執行多執行緒任務
圖 2. 使用 CyclicAction 和 TaskBarrier 執行多執行緒任務

繼承自 CyclicAction 的子類需要 TaskBarrier 為每個任務設定不同的中止條件。從 CyclicAction 繼承的子類需要過載 protected void compute() 方法,定義在 barrier 的每個步驟需要執行的動作。compute() 方法將被反覆執行直到 barrier 的 isTerminated() 方法返回 True。TaskBarrier 的行為類似於 CyclicBarrier。下面,我們來看看如何使用 CyclicAction 的子類。

清單 6. 使用 CyclicAction 的子類
class ConcurrentPrint extends RecursiveAction {
    protected void compute() {
        TaskBarrier b = new TaskBarrier() {
            protected boolean terminate(int cycle, int registeredParties) {
                System.out.println("Cycle is " + cycle + ";"
                        + registeredParties + " parties");
                return cycle >= 10;
            }
        };
        int n = 3;
        CyclicAction[] actions = new CyclicAction[n];
        for (int i = 0; i < n; ++i) {
            final int index = i;
            actions[i] = new CyclicAction(b) {
                protected void compute() {
                    System.out.println("I'm working " + getCycle() + " "
                            + index);
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
        }
        for (int i = 0; i < n; ++i)
            actions[i].fork();
        for (int i = 0; i < n; ++i)
            actions[i].join();
    }
}

在 清單 6 中,CyclicAction[] 陣列建立了三個任務,列印各自的工作次數和序號。而在 b.terminate() 方法中,我們設定的中止條件表示重複 10 次計算後中止。現在剩下的工作就是將 ConcurrentPrint 提交到 ForkJoinPool 了。我們可以在 ForkJoinPool 的建構函式中指定需要的執行緒數目,例如 ForkJoinPool(4) 就表明執行緒池包含 4 個執行緒。我們在一個 JUnit 的 test 方法中執行 ConcurrentPrint 的這個迴圈任務:

清單 7. 執行 ConcurrentPrint 迴圈任務
@Test
public void testBarrier () throws InterruptedException, ExecutionException {
    ForkJoinTask fjt = new ConcurrentPrint();
    ForkJoinPool fjpool = new ForkJoinPool(4);
    fjpool.submit(fjt);
    fjpool.shutdown();
}

RecursiveTask 和 CyclicAction 兩個例子的完整程式碼如下所示:

清單 8. RecursiveTask 和 CyclicAction 兩個例子的完整程式碼
package tests;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import jsr166y.forkjoin.CyclicAction;
import jsr166y.forkjoin.ForkJoinPool;
import jsr166y.forkjoin.ForkJoinTask;
import jsr166y.forkjoin.RecursiveAction;
import jsr166y.forkjoin.RecursiveTask;
import jsr166y.forkjoin.TaskBarrier;

import org.junit.Test;

class Fibonacci extends RecursiveTask<Integer> {
    final int n;

    Fibonacci(int n) {
        this.n = n;
    }

    private int compute(int small) {
        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };
        return results[small];
    }

    public Integer compute() {
        if (n <= 10) {
            return compute(n);
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        Fibonacci f2 = new Fibonacci(n - 2);
        System.out.println("fork new thread for " + (n - 1));
        f1.fork();
        System.out.println("fork new thread for " + (n - 2));
        f2.fork();
        return f1.join() + f2.join();
    }
}

class ConcurrentPrint extends RecursiveAction {
    protected void compute() {
        TaskBarrier b = new TaskBarrier() {
            protected boolean terminate(int cycle, int registeredParties) {
                System.out.println("Cycle is " + cycle + ";"
                        + registeredParties + " parties");
                return cycle >= 10;
            }
        };
        int n = 3;
        CyclicAction[] actions = new CyclicAction[n];
        for (int i = 0; i < n; ++i) {
            final int index = i;
            actions[i] = new CyclicAction(b) {
                protected void compute() {
                    System.out.println("I'm working " + getCycle() + " "
                            + index);
                    try {
                        Thread.sleep(500);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            };
        }
        for (int i = 0; i < n; ++i)
            actions[i].fork();
        for (int i = 0; i < n; ++i)
            actions[i].join();
    }
}

public class TestForkJoin {
    @Test
    public void testBarrier () throws InterruptedException, ExecutionException {
		System.out.println("\ntesting Task Barrier ...");
        ForkJoinTask fjt = new ConcurrentPrint();
        ForkJoinPool fjpool = new ForkJoinPool(4);
        fjpool.submit(fjt);
        fjpool.shutdown();
    }

    @Test
    public void testFibonacci () throws InterruptedException, ExecutionException {
    	System.out.println("\ntesting Fibonacci ...");
		final int num = 14; //For demo only
        ForkJoinTask<Integer> fjt = new Fibonacci(num);
        ForkJoinPool fjpool = new ForkJoinPool();
        Future<Integer> result = fjpool.submit(fjt);

        // do something
        System.out.println("Fibonacci(" + num + ") = " + result.get());
    }
}

執行以上程式碼,我們可以得到以下結果:

testing Task Barrier ...
I'm working 0 2
I'm working 0 0
I'm working 0 1
Cycle is 0; 3 parties
I'm working 1 2
I'm working 1 0
I'm working 1 1
Cycle is 1; 3 parties
I'm working 2 0
I'm working 2 1
I'm working 2 2
Cycle is 2; 3 parties
I'm working 3 0
I'm working 3 2
I'm working 3 1
Cycle is 3; 3 parties
I'm working 4 2
I'm working 4 0
I'm working 4 1
Cycle is 4; 3 parties
I'm working 5 1
I'm working 5 0
I'm working 5 2
Cycle is 5; 3 parties
I'm working 6 0
I'm working 6 2
I'm working 6 1
Cycle is 6; 3 parties
I'm working 7 2
I'm working 7 0
I'm working 7 1
Cycle is 7; 3 parties
I'm working 8 1
I'm working 8 0
I'm working 8 2
Cycle is 8; 3 parties
I'm working 9 0
I'm working 9 2

testing Fibonacci ...
fork new thread for 13
fork new thread for 12
fork new thread for 11
fork new thread for 10
fork new thread for 12
fork new thread for 11
fork new thread for 10
fork new thread for 9
fork new thread for 10
fork new thread for 9
fork new thread for 11
fork new thread for 10
fork new thread for 10
fork new thread for 9
Fibonacci(14) = 610

結論

從以上的例子中可以看到,通過使用 Fork/Join 模式,軟體開發人員能夠方便地利用多核平臺的計算能力。儘管還沒有做到對軟體開發人員完全透明,Fork/Join 模式已經極大地簡化了編寫併發程式的瑣碎工作。對於符合 Fork/Join 模式的應用,軟體開發人員不再需要處理各種並行相關事務,例如同步、通訊等,以難以除錯而聞名的死鎖和 data race 等錯誤也就不會出現,提升了思考問題的層次。你可以把 Fork/Join 模式看作並行版本的 Divide and Conquer 策略,僅僅關注如何劃分任務和組合中間結果,將剩下的事情丟給 Fork/Join 框架。

在實際工作中利用 Fork/Join 模式,可以充分享受多核平臺為應用帶來的免費午餐。

相關文章