[atcoder 351] [F Double Sum] [線段樹]

fishcanfly發表於2024-05-01

解法,使用線段樹。

請看程式碼:


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;
import java.util.StringTokenizer;

public class Main {
    static class SegmentNode {
        int left;
        int right;
        int cnt; //[left,right] have 多少個數
        long sum; //[left,right] 中所有數的和

        SegmentNode leftNode = null;
        SegmentNode rightNode = null;


        public SegmentNode(int l, int r) {
            this.left = l;
            this.right = r;
            cnt = 0;
            sum = 0;
            leftNode = null;
            rightNode = null;
        }

        /**
         * 查詢區間[x,y]上的cnt,sum
         *
         * @param x
         * @param y
         * @return
         */
        public long[] query(int x, int y) {
            if (y < x) {
                return new long[]{0,0};
            }
            if (this.left >= x && this.right <= y) {
                return new long[]{cnt, sum};
            }

            int mid = (this.left + this.right) >> 1;
            if (x > mid) {
                if (this.rightNode != null) {
                    return this.rightNode.query(x, y);
                }
                else {
                    return new long[]{0, 0};
                }
            }
            else if (y > mid) {
                long[] a1 = new long[2];
                long[] b1 = new long[2];

                if (this.leftNode != null) {
                    a1 = this.leftNode.query(x, mid);
                }
                if (this.rightNode != null) {
                    b1 = this.rightNode.query(mid+1, y);
                }

                return new long[]{a1[0] + b1[0], a1[1] + b1[1]};
            }
            else {
                if (this.leftNode != null) {
                    return this.leftNode.query(x, y);
                }
                return new long[]{0, 0};
            }
        }

        public void add(int n) {
            if (this.left <= n && this.right >= n) {
                this.cnt++;
                this.sum += n;
            }

            if (this.left == n && this.right == n) {
                return;
            }

            int mid  = (this.left + this.right) >> 1;

            if (n > mid) {
                if (this.rightNode == null) {
                    this.rightNode = new SegmentNode(mid+1, this.right);
                }
                this.rightNode.add(n);
            }
            else {
                if (this.leftNode == null) {
                    this.leftNode = new SegmentNode(this.left, mid);
                }
                this.leftNode.add(n);
            }
        }

    }

    static SegmentNode root = new SegmentNode(0, 1000_000_00);

    public static void main(String[] args) throws IOException {
        int n = rd.nextInt();

        long ans = 0;
        for (int i = 0; i < n; i++){
            int num = rd.nextInt();
            root.add(num);
            long[] range = root.query(0, num-1);
            long cnt = range[0];
            ans += cnt * num - range[1];
        }
        System.out.println(ans);
    }
}
class rd {
    static BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
    static StringTokenizer tokenizer = new StringTokenizer("");

    // nextLine()讀取字串
    static String nextLine() throws IOException {
        return reader.readLine();
    }

    // next()讀取字串
    static String next() throws IOException {
        while (!tokenizer.hasMoreTokens()) tokenizer = new StringTokenizer(reader.readLine());
        return tokenizer.nextToken();
    }

    // 讀取一個int型數值
    static int nextInt() throws IOException {
        return Integer.parseInt(next());
    }

    // 讀取一個double型數值
    static double nextDouble() throws IOException {
        return Double.parseDouble(next());
    }

    // 讀取一個long型數值
    static long nextLong() throws IOException {
        return Long.parseLong(next());
    }

    // 讀取一個BigInteger
    static BigInteger nextBigInteger() throws IOException {
        BigInteger d = new BigInteger(rd.nextLine());
        return d;
    }
}