For a undirected graph with tree characteristics, we can choose any node as the root. The result graph is then a rooted tree. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs). Given such a graph, write a function to find all the MHTs and return a list of their root labels.
Format
The graph contains n
nodes which are labeled from 0
to n - 1
.
You will be given the number n
and a list of undirected edges
(each edge is a pair of labels).
You can assume that no duplicate edges will appear in edges
. Since all edges are
undirected, [0, 1]
is the same as [1, 0]
and thus will not appear together in
edges
.
Example 1:
Given n = 4
, edges = [[1, 0], [1, 2], [1, 3]]
0 | 1 / \ 2 3
return [1]
Example 2:
Given n = 6
, edges = [[0, 3], [1, 3], [2, 3], [4, 3], [5, 4]]
0 1 2 \ | / 3 | 4 | 5
return [3, 4]
Note:
Solution 1:
A naive solution: we choose a node as root, change the graph into a tree; we count the height of every tree node; we then calculate the minimum height taking the height of the parent tree into consideration.
1 public class Solution { 2 public class MhtNode { 3 Map<Integer, Integer> childMap; 4 int firstDep; 5 int secondDep; 6 int treeDep; 7 8 public MhtNode() { 9 childMap = new HashMap<Integer, Integer>(); 10 firstDep = 0; 11 secondDep = 0; 12 treeDep = 0; 13 } 14 } 15 16 public class Result { 17 int minDep; 18 List<Integer> nodeList; 19 20 public Result() { 21 minDep = Integer.MAX_VALUE; 22 nodeList = new ArrayList<Integer>(); 23 } 24 } 25 26 public List<Integer> findMinHeightTrees(int n, int[][] edges) { 27 Result res = new Result(); 28 if (n == 0) 29 return res.nodeList; 30 if (n == 1) { 31 res.nodeList.add(0); 32 return res.nodeList; 33 } 34 35 // Build graph. 36 HashMap<Integer, MhtNode> treeMap = new HashMap<Integer, MhtNode>(); 37 for (int i = 0; i < edges.length; i++) { 38 MhtNode node1 = treeMap.getOrDefault(edges[i][0], new MhtNode()); 39 node1.childMap.put(edges[i][1], -1); 40 treeMap.put(edges[i][0], node1); 41 42 MhtNode node2 = treeMap.getOrDefault(edges[i][1], new MhtNode()); 43 node2.childMap.put(edges[i][0], -1); 44 treeMap.put(edges[i][1], node2); 45 } 46 47 int root = treeMap.keySet().iterator().next(); 48 buildTreeRecur(root, -1, treeMap); 49 50 // Get MHT. 51 getMHTRecur(root, 0, treeMap, res); 52 53 return res.nodeList; 54 } 55 56 public int buildTreeRecur(int cur, int parent, HashMap<Integer, MhtNode> treeMap) { 57 // Get current node, remove parent from child map. 58 MhtNode curNode = treeMap.get(cur); 59 curNode.childMap.remove(parent); 60 61 // Get height of every child tree. 62 for (int child : curNode.childMap.keySet()) { 63 int height = buildTreeRecur(child, cur, treeMap); 64 curNode.childMap.put(child, height); 65 if (height > curNode.firstDep) { 66 curNode.secondDep = curNode.firstDep; 67 curNode.firstDep = height; 68 } else if (height > curNode.secondDep) { 69 curNode.secondDep = height; 70 } 71 } 72 73 return curNode.firstDep + 1; 74 } 75 76 public void getMHTRecur(int cur, int parentHeight, HashMap<Integer, MhtNode> treeMap, Result res) { 77 // Get current node's tree height. 78 MhtNode curNode = treeMap.get(cur); 79 curNode.treeDep = Math.max(parentHeight, curNode.firstDep); 80 if (res.minDep == curNode.treeDep) { 81 res.nodeList.add(cur); 82 } else if (curNode.treeDep < res.minDep) { 83 res.minDep = curNode.treeDep; 84 res.nodeList.clear(); 85 res.nodeList.add(cur); 86 } 87 88 // Move to each of cur's child. 89 for (int child : curNode.childMap.keySet()) { 90 int childDep = curNode.childMap.get(child); 91 if (childDep == curNode.firstDep) { 92 getMHTRecur(child, Math.max(parentHeight, curNode.secondDep) + 1, treeMap, res); 93 } else { 94 getMHTRecur(child, curNode.treeDep + 1, treeMap, res); 95 } 96 } 97 } 98 }
Solution 2:
The results actually is the 1 or 2 nodes on the center of the longest path. We remove leaves nodes (nodes with 1 neighbor) layer by layer, and the left nodes are answer.
1 public class Solution { 2 public List<Integer> findMinHeightTrees(int n, int[][] edges) { 3 List<Integer> cur = new ArrayList<Integer>(); 4 if (n==0) return cur; 5 if (n==1){ 6 cur.add(0); 7 return cur; 8 } 9 10 // build graph 11 List<Set<Integer>> graph = new ArrayList<Set<Integer>>(); 12 for (int i=0;i<n;i++) graph.add(new HashSet<Integer>()); 13 for (int[] edge : edges){ 14 graph.get(edge[0]).add(edge[1]); 15 graph.get(edge[1]).add(edge[0]); 16 } 17 18 // get all leaves to start with. 19 for (int i=0;i<n;i++) 20 if (graph.get(i).size()==1){ 21 cur.add(i); 22 } 23 24 // Remove every layer of leaves. 25 while (true){ 26 List<Integer> next = new ArrayList<Integer>(); 27 for (int node : cur) 28 for (int neighbor : graph.get(node)){ 29 graph.get(neighbor).remove(node); 30 if (graph.get(neighbor).size()==1){ 31 next.add(neighbor); 32 } 33 } 34 35 if (next.isEmpty()) return cur; 36 cur = next; 37 } 38 } 39 }