leetcode-834. 樹中距離之和

KpLn_HJL發表於2020-10-07

題目

給定一個無向、連通的樹。樹中有 N 個標記為 0…N-1 的節點以及 N-1 條邊 。

第 i 條邊連線節點 edges[i][0] 和 edges[i][1] 。

返回一個表示節點 i 與其他所有節點距離之和的列表 ans。

示例 1:

輸入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
輸出: [8,12,6,10,10,10]
解釋: 
如下為給定的樹的示意圖:
  0
 / \
1   2
   /|\
  3 4 5

我們可以計算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此類推。

說明: 1 <= N <= 10000

解題思路

看了題解做的,樹形dp

首先求解簡單一些的問題,即給定根節點,求根節點到其他節點的距離和。以示例的樹為例,先求0到其他節點的距離和,應該為:1的距離和+2的距離和+0的子樹節點數量。12的距離和比較好理解,加上子樹的節點數量是因為,從0走到子樹的任意一個節點上,其距離都比從1或者2出發多了1,所以加上節點數量即可。

對每個節點儲存一個tuple,分別是以該節點為根的樹的距離和,以及以該節點為根的樹的節點數量(包括節點本身),則上述的樹可以寫為:

   0
 /    \
1       2
	  / | \
	3   4   5
(0,1) (0,1) (0,1)

下一步寫為:

	   0
	 /    \
	1       2
(0,1)     (3,4) = ((0+1)*3, 1*3+1)
		  / | \
		3   4   5
	(0,1) (0,1) (0,1)

更新到根節點有:

	   0
	  (8,6) = (0+3+1+4, 1+4+1)
	 /    \
	1       2
(0,1)     (3,4)
		  / | \
		3   4   5
	(0,1) (0,1) (0,1)

以0為根節點,最終的答案就是8

這樣計算出來的是以其中1個節點作為根節點的結果,以所有節點作為根節點,則可以計算出最終答案,這種方法的時間複雜度是 o ( n 2 ) o(n^2) o(n2)。求的時候用後序遍歷即可。

觀察0的子節點2可以發現,以2為根節點時,其實就是到2的子樹的所有節點距離都縮短了1,到除了2的子樹以外的節點都增加了1,所以ans[2] = ans[0] - num_child[2] + (N - num_child[2])。所以求其他節點時,不用重新計算,直接在已經計算出來的基礎上變換即可。由於每次變換的基礎都是父節點和子節點,所以用先序遍歷,用父節點的值調整子節點的值即可。

最終的時間複雜度是 o ( n ) o(n) o(n)

注意點
由於上述的思路都是基於樹的,而題目給出的是邊,所以要從邊中構造一棵樹。因為給定的邊的順序不一致,所以先按照圖的方式儲存相連的節點,選定一個節點用BFS/DFS調整邊的順序,使得按照這種順序從根節點開始畫一棵樹。

程式碼

class Solution:
    def sumOfDistancesInTree(self, N: int, edges: List[List[int]]) -> List[int]:
        if not edges:
            return [0]
        graph = {}
        for from_node, to_node in edges:
            if from_node not in graph:
                graph[from_node] = []
            graph[from_node].append(to_node)
            if to_node not in graph:
                graph[to_node] = []
            graph[to_node].append(from_node)
        # modify nodes order to form a tree
        node_dict = {}
        stack = [edges[0][0]]
        visited_nodes = set()
        while stack:
            node = stack.pop()
            if node in visited_nodes:
                continue
            visited_nodes.add(node)
            if list(set(graph[node]) - visited_nodes):
                node_dict[node] = list(set(graph[node]) - visited_nodes)
            stack += node_dict.get(node, [])
        # postorder
        ans = [0] * N
        child_num = [1] * N
        stack = [(edges[0][0], 0)]
        while stack:
            node, stat = stack.pop()
            if stat == 0:
                stack.append((node, 1))
                if node in node_dict:
                    for child_node in node_dict[node]:
                        stack.append((child_node, 0))
            else:
                if node not in node_dict:
                    continue
                ans[node] = sum(ans[child_node] + child_num[child_node] for child_node in node_dict[node])
                child_num[node] = 1 + sum(child_num[child_node] for child_node in node_dict[node])
        # preorder
        stack = [edges[0][0]]
        while stack:
            node = stack.pop()
            if node in node_dict:
                for child_node in node_dict[node]:
                    stack.append(child_node)
                    ans[child_node] = ans[node] - child_num[child_node] + (N - child_num[child_node])
        return ans

相關文章