小美的樹上染色(美團2024屆秋招筆試第一場程式設計真題)

athenanevergiveup發表於2024-04-02

題面

核心思想

樹形DP
dp[1]表示以當前節點為根節點所包含的子樹 且 當前節點能染色的最大染色數量
dp[0]表示以當前節點為根節點所包含的子樹 且 當前節點不染色的最大染色數量
詳情看註釋~

程式碼

import java.util.*;

public class Main {
    public static void main(String[] args) {
        final long MOD = (long) (1e9 + 7);
        Scanner scanner = new Scanner(System.in);

        int n = scanner.nextInt();
        long[] value = new long[n + 1];
        List<Integer>[] next = new List[n + 1];
        //存放value
        for(int i = 1; i <= n; i++){
            value[i] = scanner.nextInt();
            next[i] = new ArrayList<>();
        }
        //建樹
        for(int i = 1; i < n; i++){
            int x = scanner.nextInt();
            int y = scanner.nextInt();
            next[x].add(y);
            next[y].add(x);
        }
        int[] res = dpOnTheTree(1, -1, value, next);
        System.out.println(Math.max(res[0], res[1]));
    }

    //dp[0] 表示當前節點為根  當前節點不染色的最大染色數量 dp[1]則表示當前節點染色的最大染色數量
    static int[] dpOnTheTree(int cur, int pre, long[] value, List<Integer>[] next){
        int[] dp = new int[2];
        //存放孩子節點的dp結果
        HashMap<Integer, int[]> res = new HashMap<>();

        //當前節點的dp結果分步做
        //dp[0]
        for(int nxt: next[cur]){
            if(nxt == pre)
                continue;
            int[] child = dpOnTheTree(nxt, cur, value, next);
            res.put(nxt, child);

            // 當前節點不染色 那就是所有孩子節點的最大值和
            dp[0] += Math.max(child[0], child[1]);
        }

        //dp[1]
        for(int nxt: next[cur]){
            if(nxt == pre)
                continue;
            long mul = value[cur] * value[nxt];
            long sqrt = (long) Math.sqrt(mul);
            // 可以和孩子節點染色
            if(sqrt * sqrt == mul){
                // dp[0] 存放的是所有孩子節點染色或不然染色的最大值和
                // Math.max(res.get(nxt)[0], res.get(nxt)[1]) 取需要染色的孩子節點nxt的dp[1], dp[0]的最大值
                // 用當前節點的dp[0]減去就剩下了其他孩子的最大值和
                // 那麼當前節點和孩子節點nxt染色 dp[1]就等於 1. nxt這個孩子節點的dp[0] + 2. 其他孩子節點的dp最大值的和
                dp[1] = Math.max(dp[1], dp[0] - Math.max(res.get(nxt)[0], res.get(nxt)[1]) + res.get(nxt)[0] + 2);
            }
        }
        return dp;
    }
}

相關文章