如何在 Java 中實現最小生成樹演算法

之一Yo發表於2022-04-06

定義

在一幅無向圖 \(G=(V,E)\) 中,\((u, v)\) 為連線頂點 \(u\) 和頂點 \(v\) 的邊,\(w(u,v)\) 為邊的權重,若存在邊的子集 \(T\subseteq E\)\((V,T)\) 為樹,使得

\[w(T)=\sum_{(u,v)\in T}w(u,v) \]

最小,這稱 \(T\) 為圖 \(G\) 的最小生成樹。

說的通俗點,最小生成樹就是帶權無向圖中權值和最小的樹。下圖中黑色邊所標識的就是一棵最小生成樹(圖片來自《演算法第四版》),對於權值各不相同的連通圖來說最小生成樹只會有一棵:

最小生成樹

帶權圖的實現

《如何在 Java 中實現無向圖》 中我們使用鄰接表陣列實現了無向圖,其中鄰接表上的每個節點的資料域只是一個整數,代表著一個頂點。為了方便最小生成樹的迭代,我們將資料域換成 Edge 例項。Edge 有三個成員:頂點 v、頂點 w 和權重 weight,為了比較每一條邊的權重,需要實現 Comparable 介面。程式碼如下所示:

package com.zhiyiyo.graph;

/**
 * 圖中的邊
 */
public class Edge implements Comparable<Edge> {
    private final int v, w;
    private final double weight;

    public Edge(int v, int w, double weight) {
        this.v = v;
        this.w = w;
        this.weight = weight;
    }

    /**
     * 返回邊中的一個頂點
     */
    int either() {
        return v;
    }

    /**
     * 返回邊中的拎一個頂點
     *
     * @param v 頂點 v
     * @return 另一個頂點
     */
    int another(int v) {
        if (this.v == v) {
            return w;
        } else if (w == v) {
            return this.v;
        } else {
            throw new RuntimeException("邊中不存在該頂點");
        }
    }

    public double getWeight() {
        return weight;
    }

    @Override
    public String toString() {
        return String.format("Edge{%d-%d %f}", v, w, weight);
    }

    @Override
    public int compareTo(Edge edge) {
        return Double.compare(weight, edge.weight);
    }
}

之後只要照貓畫虎,將 LinkGraph 的泛型從 Integer 換成 Edge 就行了:

package com.zhiyiyo.graph;

import com.zhiyiyo.collection.stack.LinkStack;
import com.zhiyiyo.collection.stack.Stack;

/**
 * 帶權無向圖
 */
public class WeightedGraph {
    private final int V;
    protected int E;
    protected LinkStack<Edge>[] adj;

    public WeightedGraph(int V) {
        this.V = V;
        adj = (LinkStack<Edge>[]) new LinkStack[V];
        for (int i = 0; i < V; i++) {
            adj[i] = new LinkStack<>();
        }
    }

    public int V() {
        return V;
    }

    public int E() {
        return E;
    }

    public void addEdge(Edge edge) {
        int v = edge.either();
        int w = edge.another(v);
        adj[v].push(edge);
        adj[w].push(edge);
        E++;
    }

    public Iterable<Edge> adj(int v) {
        return adj[v];
    }

    /**
     * 獲取所有邊
     */
    public Iterable<Edge> edges() {
        Stack<Edge> edges = new LinkStack<>();
        for (int v = 0; v < V; ++v) {
            for (Edge edge : adj(v)) {
                if (edge.another(v) > v) {
                    edges.push(edge);
                }
            }
        }

        return edges;
    }
}

同時給出最小生成樹的 API:

package com.zhiyiyo.graph;

/**
 * 最小生成樹
 */
public interface MST {
    /**
     * 獲取最小生成樹中的所有邊
     */
    Iterable<Edge> edges();

    /**
     * 獲取最小生成樹的權重
     */
    double weight();
}

Kruskal 演算法

假設 \(E\) 是圖 \(G\) 中所有邊的集合,\(T\) 是最小生成樹的邊集合,kruskal 演算法的思想是每次從 \(E\)出權值最小的邊 \(e_m\),如果 \(e_m\) 不會和 \(T\) 中的邊構成環,就將其加入 \(T\) 中,直到 \(|T|=|V|-1\) 也就是 \(T\) 中邊的個數是圖 \(G\) 的頂點個數 -1 時,就得到了最小生成樹。

對於上一幅圖,使用 kruskal 演算法得到最小生成樹的過程如下圖所示:

kruskal 演算法到最小生成樹的過程

首先將 \(E\) 中最小的邊 0-7 彈出並加到 \(T\) 中,此時的 \(E\) 中最小邊為 2-3,雖然 2-3 和 0-7 無法構成連通圖,但是沒關係,只要貪心地將其加入 \(T\) 中即可,因為後續其他邊的新增總會將二者連通起來。接著按照權值的升序依次把邊 1-7、0-2、5-7 加到 \(T\) 中,直到碰到邊 1-3,如果把 1-3 加入 \(T\) 中,就會出現環 1-3-2-0-7-1,所以直接將 1-3 捨棄,1-5、2-7 也同理被丟棄掉。由於邊 4-5 不會在 \(T\) 中構成環,所以將其加入 \(T\)。重複上述步驟,直到 \(|T|=|V|-1\)

上述過程中有兩個影響效能的地方,一個是找出 \(E\) 中權值最小的邊 \(e_m\),一個是判斷將 \(e_m\) 加到 \(T\) 中是否會出現環。

二叉堆

二叉堆是一棵完全二叉樹,且每個父節點總是大於等於(最大堆)或者小於等於(最小堆)他的子節點。《演算法第四版》中給出了使用陣列儲存的最大堆的結構,其中陣列下標為 0 的地方不儲存元素,假設下標為 \(i\) 出存放的是父節點,那麼 \(2i\)\(2i+1\) 處就是子節點:

最大堆

由於最小堆的堆頂節點總是最小的,所以只需將 \(E\) 變為一個最小堆,每次取出堆頂的元素即可,時間複雜度為 \(O(\log N)\)。下面來看下如何實現最小堆。

API

對於一個二叉堆,我們關心以下操作:

package com.zhiyiyo.collection.queue;

public interface PriorQueue<T extends Comparable<T>> {
    /**
     * 向堆中插入一個元素
     * @param item 插入的元素
     */
    void insert(T item);

    /**
     * 彈出堆頂的元素
     * @return 堆頂元素
     */
    T pop();

    /**
     * 獲取堆中的元素個數
     */
    int size();

    /**
     * 堆是否為空
     */
    boolean isEmpty();
}

插入

為了保證二叉堆是一棵完全二叉樹,每次都將新節點插到陣列的末尾,也就是二叉樹的最後一個節點。如下圖所示,假設插入的節點為 A,它的父節點為 P,兄弟節點為 S,由於 P > A,這就打破了二叉堆的有序性,所以需要對堆進行調整。具體流程就是將兄弟節點中的較小者(A)選為父節點,而先前的父節點 P 則退位變為子節點。如果此時 A 的父節點小於 A,則無需繼續調整。但是下圖中只交換了 A、P 之後還是沒將二叉樹調整為堆有序狀態,因為父節點 D > A,接著將兄弟節點中較小的 A 變為父節點,而 D 則變成 A 的子節點,至此完成最小堆的調整。

最小堆的插入

上述過程的程式碼如下所示,為了保證後續插入操作,每當陣列滿員時就對其進行擴容操作:

package com.zhiyiyo.collection.queue;

import java.util.Arrays;

public class MinPriorQueue<T extends Comparable<T>> implements PriorQueue<T>{
    private T[] array;
    private int N;

    public MinPriorQueue() {
        this(3);
    }

    public MinPriorQueue(int maxSize) {
        array = (T[]) new Comparable[maxSize + 1];
    }

    @Override
    public boolean isEmpty() {
        return N == 0;
    }

    @Override
    public int size() {
        return N;
    }

    @Override
    public void insert(T item) {
        array[++N] = item;
        swim(N);
        if (N == array.length - 1) resize(1 + 2 * N);
    }

    /**
     * 元素上浮
     *
     * @param k 元素的索引
     */
    private void swim(int k) {
        while (k > 1 && less(k, k / 2)) {
            swap(k, k / 2);
            k /= 2;
        }
    }

    private void swap(int a, int b) {
        T tmp = array[a];
        array[a] = array[b];
        array[b] = tmp;
    }

    private boolean less(int a, int b) {
        return array[a].compareTo(array[b]) < 0;
    }

    private void resize(int size) {
        array = Arrays.copyOf(array, size);
    }
}

刪除最小元素

假設我們需要刪除下圖中的 A 元素,這時候就需要將 A 和最小堆的最後一個元素 P 交換位置,並將陣列的最後一個元素置為 null,使得 A 的引用次數變為 0,能被垃圾回收機制自動回收掉。交換之後最小堆的有序性被破壞了,因為父節點 P > 子節點 D,這時候和插入元素的操作一樣,將較小的子節點和父節點交換位置,使得較大的父節點能夠下沉,而較小的子節點上位,這個過程持續到沒有子節點被 P 更小為止。

最小堆刪除最小元素

實現程式碼如下:

@Override
public T pop() {
    T item = array[1];
    swap(1, N);
    array[N--] = null;
    sink(1);
    if (N < (array.length - 1) / 4) resize((array.length - 1) / 2);
    return item;
}

/**
 * 元素下沉
 *
 * @param k 元素的索引
 */
private void sink(int k) {
    while (2 * k <= N) {
        int j = 2 * k;
        // 檢查是否有兩個子節點
        if (j < N && less(j + 1, j)) j++;
        if (less(k, j)) break;
        swap(k, j);
        k = j;
    }
}

並查集

假設 \(T\) 中的頂點的集合為 \(V'\),則有圖 \(G'=(V', T)\)。我們可以將 \(G'\) 劃分為 \(n\) 個連通分量,每個連通分量有一個標識 \(id\in [0, n-1]\)。要想判斷將邊 \(e_m\) 加入 \(T\) 後是否會構成環,只需判斷 \(e_m\) 的兩個頂點是都屬於同一個連通分量即可。

判斷是否連通

由於每個連通分量都不存在環,可以看作一棵小樹,所以可以用一個陣列 int[] ids 的索引表示樹中的節點(圖中的頂點),而索引處的元素值為父節點的索引值,陣列中 ids[i] == i 的位置就是每棵樹的根節點,i 就是這個連通分量的標識。而我們想要知道兩個節點之間是否連通,只需判斷他們所屬的樹的根節點是否相同即可。

並查集的表示方式

假設從樹底的葉節點 6 出發,一路向上直到樹頂 1,中間需要經過 5 和 0 兩個節點,如果節點 6 的根節點查詢得比較頻繁,那麼這種查詢效率是比較低的。由於我們只需知道根節點是誰即可,樹的結構無關緊要,那麼為何不想個辦法把節點 5、6 直接掛到根節點 1,這樣只要一步就能知道根節點。實現這種想法的的方式就是路徑壓縮:當從節點 6 走到父節點 5 時,就將節點 6 掛到節點 5 的父節點 0 上;而從節點 0 走到根節點 1 時,就將子節點 6 和 5 掛到根節點 1 下,樹高被壓縮為 1。

實現上述過程的程式碼如下所示:

package com.zhiyiyo.collection.tree;

public class UnionFind {
    private int[] ids;
    private int[] ranks;	// 每棵樹的高度
    private int N;			// 樹的數量

    public UnionFind(int N) {
        this.N = N;
        ids = new int[N];
        ranks = new int[N];
        for (int i = 0; i < N; i++) {
            ids[i] = i;
            ranks[i] = 1;
        }
    }

    /**
     * 獲取連通分量個數
     *
     * @return 連通分量個數
     */
    public int count() {
        return N;
    }

    /**
     * 獲得連通分量的 id
     *
     * @param p 觸點 id
     * @return 連通分量 id
     */
    public int find(int p) {
        while (p != ids[p]) {
            ids[p] = ids[ids[p]];   // 路徑壓縮
            p = ids[p];
        }
        return p;
    }

    /**
     * 判斷兩個觸點是否連通
     *
     * @param p 觸點 p 的 id
     * @param q 觸點 q 的 id
     * @return 是否連通
     */
    public boolean isConnected(int p, int q) {
        return find(p) == find(q);
    }
}

合併連通分量

我們將 \(E\) 中的 \(e_m\) 新增到 \(T\) 中時,\(e_m\) 的兩個節點肯定分屬於兩個連通分量,加入 \(T\) 之後就需要將這兩個分量合併,也就是將兩棵小樹合併為一顆大樹。假設兩棵樹的高度分別為 \(h_1\)\(h_2\),如果直接將一顆樹的根節點接到另一棵樹的葉節點上,會導致新樹高度為 \(h_1+h_2\),降低尋找根節點的效率。解決方式是按秩歸併,將矮樹的根節點接到高樹的根節點上,會出現兩種情況:

  • 如果 \(h_1 \neq h_2\),新樹高度會是 \(\max\{h_1, h_2\}\)
  • 如果 \(h_1=h_2=c\),新樹高度會是 \(c+1\)

上述過程的程式碼如下所示:

/**
 * 如果兩個觸點不處於同一個連通分量中,則連線兩個觸點
 *
 * @param p 觸點 p 的 id
 * @param q 觸點 q 的 id
 */
public void union(int p, int q) {
    int pId = find(p);
    int qId = find(q);
    if (qId == pId) return;

    // 將小樹併到大樹
    if (ranks[qId] > ranks[pId]) {
        ids[pId] = qId;
    } else if (ranks[qId] < ranks[pId]) {
        ids[qId] = pId;
    } else {
        ids[qId] = pId;
        ranks[pId]++;
    }

    N--;
}

實現演算法

實現 kruskal 演算法時,先將所有邊加入最小堆中,每次取出堆頂的元素 \(e_m\),然後使用並查集判斷邊的兩個頂點是否連通,如果不連通就將 \(e_m\) 加入 \(T\),重複這個過程直至 \(|T|=|V|-1\),時間複雜度為 \(O(|E|\log |E|)\)

package com.zhiyiyo.graph;

import com.zhiyiyo.collection.queue.LinkQueue;
import com.zhiyiyo.collection.queue.MinPriorQueue;
import com.zhiyiyo.collection.queue.Queue;
import com.zhiyiyo.collection.tree.UnionFind;

import java.util.stream.Stream;
import java.util.stream.StreamSupport;


public class KruskalMST implements MST {
    private Queue<Edge> mst;

    public KruskalMST(WeightedGraph graph) {
        mst = new LinkQueue<>();
        UnionFind uf = new UnionFind(graph.V());

        MinPriorQueue<Edge> pq = new MinPriorQueue<>();
        for (Edge e : graph.edges()) {
            pq.insert(e);
        }

        while (mst.size() < graph.V() - 1 && !pq.isEmpty()) {
            Edge edge = pq.pop();
            int v = edge.either();
            int w = edge.another(v);
            if (!uf.isConnected(v, w)) {
                mst.enqueue(edge);
                uf.union(v, w);
            }
        }
    }

    @Override
    public Iterable<Edge> edges() {
        return mst;
    }

    @Override
    public double weight() {
        Stream<Edge> stream = StreamSupport.stream(mst.spliterator(), false);
        return stream.map(Edge::getWeight).reduce(0d, Double::sum);
    }
}

Prim 演算法

Prim 演算法的思想是初始化最小生成樹為一個根節點 0,然後將根節點的所有鄰邊加入最小堆中,從最小堆中彈出最小的邊 \(e_m\),如果 \(e_m\) 不會使得樹中出現環,將將其併入樹中。每當有新的節點 \(v\) 被併入樹中時,就得將 \(v\) 的所有鄰邊加入最小堆中。重複上述過程直到 \(|T|=|V|-1\),時間複雜度為 \(O(|E|\log|E|)\)。程式碼如下所示:

package com.zhiyiyo.graph;

import com.zhiyiyo.collection.queue.LinkQueue;
import com.zhiyiyo.collection.queue.MinPriorQueue;
import com.zhiyiyo.collection.queue.Queue;

import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
 * 延時版本 Prim 演算法
 */
public class PrimMST implements MST {
    private boolean[] marked;
    private MinPriorQueue<Edge> pq;
    private Queue<Edge> mst;

    public LazyPrimMST(WeightedGraph graph) {
        marked = new boolean[graph.V()];
        pq = new MinPriorQueue<>();
        mst = new LinkQueue<>();

        mark(graph, 0);
        while (mst.size() < graph.V() - 1 && !pq.isEmpty()) {
            Edge edge = pq.pop();
            int v = edge.either();
            int w = edge.another(v);

            // 構成環則捨棄
            if (marked[v] && marked[w]) continue;
            mst.enqueue(edge);

            if (!marked[v]) mark(graph, v);
            else if (!marked[w]) mark(graph, w);
        }
    }

    private void mark(WeightedGraph graph, int v) {
        marked[v] = true;
        for (Edge edge : graph.adj(v)) {
            if (!marked[edge.another(v)]) {
                pq.insert(edge);
            }
        }
    }

    @Override
    public Iterable<Edge> edges() {
        return mst;
    }

    @Override
    public double weight() {
        Stream<Edge> stream = StreamSupport.stream(mst.spliterator(), false);
        return stream.map(Edge::getWeight).reduce(0d, Double::sum);
    }
}

由於每次都是把新節點的所有鄰邊都加到了最小堆中,會引入許多無用的邊,所以《演算法第四版》中給出了使用索引優先佇列實現的即時版 Prim 演算法,時間複雜度能達到 \(O(|E|\log |V|)\),但是這裡寫不下了,大家可以自行查閱,以上~~

相關文章