Removing People 題解

XuYueming發表於2024-10-27

前言

題目連結:Atcoder洛谷

題意簡述

\(n\) 人站成一個圓圈,按順時針方向依次為 \(1, 2, \cdots, n\)

每個人面對的方向由長度為 \(n\) 的字串 \(S\) 給出。對於第 \(i\) 個人,如果 \(S_i = \texttt{L}\),則 \(i\) 面向逆時針方向。如果 \(S_i = \texttt{R}\),則面向順時針方向。

重複 \(n-1\) 次以下操作:以相等的機率從剩餘的人中選擇一個,並從圓中移除離被選中的人最近的人。這樣做的代價等於被選中的人到被移除的人的距離。

定義從 \(i\)\(j\)\(i \neq j\))的距離 \(\operatorname{dis}(i, j)\) 為,\(i\) 按照其方向行走多少步能夠到達 \(j\)

\(n-1\) 次操作後代價之和的期望值,對 \(M = 998244353\) 取模。

\(2 \leq n \leq 300\)

題目分析

期望類題目,我們學過的演算法好像只有 DP 吧?考慮 DP。狀態如何設計呢?移除一個人,難道我們要把 \(n\) 個人還在不在壓到狀態裡嗎?顯然不行。正難則反,我們考慮從只有 \(1\) 個人開始,逐漸往裡面加入一個人。

加入 \(i\),在反轉操作前是刪除 \(i\),說明我們選擇了一個 \(j\),且 \(i\)\(j\) 朝著其方向前進遇到的第一個人,將 \(j\)\(i\) 的距離累加到答案中去。

我們注意到,對於 \(j\)\(i\) 中間的 \(k\),如果加入 \(k\),只可能是選擇了 \(i\)\(j\) 其中的一個。加入 \(k\) 之後,分成了兩個區間,就又形成了規模更小的子問題,且問題僅和區間兩端有關!

考慮區間 DP。設 \(f_{l, r}\) 表示 \(l\)\(r\) 中,最初僅有 \(l\)\(r\) 加入了,經過了若干次操作,把中間的所有人都加入的期望價值。可是,如果要算期望,我們需要知道目前已經放下了多少個人,不然算不了機率,而這是我們狀態之外的東西。

那就別記期望了吧,把期望變成價值和比上方案數。方案數顯然是 \(n!\),那麼我們 DP 價值和,即記 \(f_{l, r}\) 表示把 \(l\)\(r\) 中間的放下的價值和,為了轉移需要再記一個方案數 \(g_{l, r}\)

先來考慮 \(g\) 的轉移。先列舉 \(k\) 表示第一個放下的,這裡需要注意,我們是選擇 \(l\)\(r\) 的哪一個,導致 \(k\) 被放下的呢?如果合法,都有可能,所以這裡的方案數為 \([S_l = \texttt{R}] + [S_r = \texttt{L}]\)。放下後,兩個子問題的方案數直接相乘 \(g_{l, k} g_{k, r}\) 就行了嗎?並不是,因為我們可以交叉著放置,即先放置左邊區間的某一個,再放置右邊的某一個,以此類推。這一部分的方案數是合併兩個有序序列的方案數,設 \(x = \operatorname{dis}(l, k) - 1, y = \operatorname{dis}(k, r) - 1\),即合併兩個長度分別為 \(x, y\) 的有序序列,考慮最終長度為 \(x + y\) 的序列中選出 \(x\) 個位置作為其中一個有序序列,方案數是 \(\dbinom{x + y}{x}\)

說了這麼多,其實就是一個轉移方程:

\[g_{l, r} = \sum \Big([S_l = \texttt{R}] + [S_r = \texttt{L}]\Big) \cdot g_{l, k} \cdot g_{k, r} \cdot \binom{\operatorname{dis}(l, r) - 2}{\operatorname{dis}(l, k) - 1} \]

\(f\) 的轉移很類似,如果是選擇 \(l\) 導致 \(k\) 被加入:

\[f_{l, r} \gets [S_l = \texttt{R}] \sum (\operatorname{dis}(l, k) \cdot g_{l, k} \cdot g_{k, r} + g_{l, k} \cdot f_{k, r} + g_{k, r} \cdot f_{l, k}) \cdot \binom{\operatorname{dis}(l, r) - 2}{\operatorname{dis}(l, k) - 1} \]

選擇 \(r\) 同理有:

\[f_{l, r} \gets [S_r = \texttt{L}] \sum (\operatorname{dis}(k, r) \cdot g_{l, k} \cdot g_{k, r} + g_{l, k} \cdot f_{k, r} + g_{k, r} \cdot f_{l, k}) \cdot \binom{\operatorname{dis}(l, r) - 2}{\operatorname{dis}(l, k) - 1} \]

注意到,之所以我一直避免 \(j - i\) 之類的出現,是因為這是一個環,讀者要處理好環的問題。

DP 初值考慮相鄰的兩項的 \(g = 1\)。答案即為 \(\dfrac{\sum f_{i, i + n}}{n!}\)

時間複雜度:\(\Theta(n ^ 3)\)

程式碼

#include <cstdio>
#include <iostream>
#include <limits>
using namespace std;

namespace Mod_Int_Class {
    template <typename T, typename _Tp>
    constexpr bool in_range(_Tp val) {
        return std::numeric_limits<T>::min() <= val && val <= std::numeric_limits<T>::max();
    }
    
    template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
    static constexpr inline bool is_prime(_Tp val) {
        if (val < 2) return false;
        for (_Tp i = 2; i * i <= val; ++i)
            if (val % i == 0)
                return false;
        return true;
    }
    
    template <auto _mod = 998244353, typename T = int, typename S = long long>
    class Mod_Int {
        static_assert(in_range<T>(_mod), "mod must in the range of type T.");
        static_assert(std::is_integral<T>::value, "type T must be an integer.");
        static_assert(std::is_integral<S>::value, "type S must be an integer.");
        public:
            constexpr Mod_Int() noexcept = default;
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr Mod_Int(_Tp v) noexcept: val(0) {
                if (0 <= T(v) && T(v) < mod) val = v;
                else val = (T(v) % mod + mod) % mod;
            }
            
            constexpr T const& raw() const {
                return this -> val;
            }
            static constexpr T mod = _mod;
            
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend Mod_Int pow(Mod_Int a, _Tp p) {
                return a ^ p;
            }
            constexpr friend Mod_Int sub(Mod_Int a, Mod_Int b) {
                return a - b;
            }
            constexpr friend Mod_Int& tosub(Mod_Int& a, Mod_Int b) {
                return a -= b;
            }
            
            constexpr friend Mod_Int add(Mod_Int a) { return a; }
            template <typename... args_t>
            constexpr friend Mod_Int add(Mod_Int a, args_t... args) {
                return a + add(args...);
            }
            constexpr friend Mod_Int mul(Mod_Int a) { return a; }
            template <typename... args_t>
            constexpr friend Mod_Int mul(Mod_Int a, args_t... args) {
                return a * mul(args...);
            }
            template <typename... args_t>
            constexpr friend Mod_Int& toadd(Mod_Int& a, args_t... b) {
                return a = add(a, b...);
            }
            template <typename... args_t>
            constexpr friend Mod_Int& tomul(Mod_Int& a, args_t... b) {
                return a = mul(a, b...);
            }
            
            template <T __mod = mod, typename = std::enable_if_t<is_prime(__mod)>>
            static constexpr inline T inv(T a) {
                assert(a != 0);
                return _pow(a, mod - 2);
            }
            
            constexpr Mod_Int& operator + () const {
                return *this;
            }
            constexpr Mod_Int operator - () const {
                return _sub(0, val);
            }
            constexpr Mod_Int inv() const {
                return inv(val);
            }
            
            constexpr friend inline Mod_Int operator + (Mod_Int a, Mod_Int b) {
                return _add(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator - (Mod_Int a, Mod_Int b) {
                return _sub(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator * (Mod_Int a, Mod_Int b) {
                return _mul(a.val, b.val);
            }
            constexpr friend inline Mod_Int operator / (Mod_Int a, Mod_Int b) {
                return _mul(a.val, inv(b.val));
            }
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend inline Mod_Int operator ^ (Mod_Int a, _Tp p) {
                return _pow(a.val, p);
            }
            
            constexpr friend inline Mod_Int& operator += (Mod_Int& a, Mod_Int b) {
                return a = _add(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator -= (Mod_Int& a, Mod_Int b) {
                return a = _sub(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator *= (Mod_Int& a, Mod_Int b) {
                return a = _mul(a.val, b.val);
            }
            constexpr friend inline Mod_Int& operator /= (Mod_Int& a, Mod_Int b) {
                return a = _mul(a.val, inv(b.val));
            }
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            constexpr friend inline Mod_Int& operator ^= (Mod_Int& a, _Tp p) {
                return a = _pow(a.val, p);
            }
            
            constexpr friend inline bool operator == (Mod_Int a, Mod_Int b) {
                return a.val == b.val;
            }
            constexpr friend inline bool operator != (Mod_Int a, Mod_Int b) {
                return a.val != b.val;
            }
			
			constexpr Mod_Int& operator ++ () {
				this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
				return *this;
			}
			constexpr Mod_Int& operator -- () {
				this -> val == 0 ? this -> val = mod - 1 : --this -> val;
				return *this;
			}
			constexpr Mod_Int operator ++ (int) {
				Mod_Int res = *this;
				this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
				return res;
			}
			constexpr Mod_Int operator -- (int) {
				Mod_Int res = *this;
				this -> val == 0 ? this -> val = mod - 1 : --this -> val;
				return res;
			}
			
			friend std::istream& operator >> (std::istream& is, Mod_Int<mod, T, S>& x) {
				T ipt;
				return is >> ipt, x = ipt, is;
			}
			friend std::ostream& operator << (std::ostream& os, Mod_Int<mod, T, S> x) {
				return os << x.val;
			}
        protected:
            T val;
            
            static constexpr inline T _add(T a, T b) {
                return a >= mod - b ? a + b - mod : a + b;
            }
            static constexpr inline T _sub(T a, T b) {
                return a < b ? a - b + mod : a - b;
            }
            static constexpr inline T _mul(T a, T b) {
                return static_cast<S>(a) * b % mod;
            }
            
            template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
            static constexpr inline T _pow(T a, _Tp p) {
                T res = 1;
                for (; p; p >>= 1, a = _mul(a, a))
                    if (p & 1) res = _mul(res, a);
                return res;
            }
    };
    using mint = Mod_Int<>;
    constexpr mint operator ""_m (unsigned long long x) {
        return mint(x);
    }
    constexpr mint operator ""_mod (unsigned long long x) {
        return mint(x);
    }
}

using namespace Mod_Int_Class;

const int N = 310;

int n;
char S[N];

mint frac[N], Inv[N], ifrac[N];
mint f[N][N], g[N][N];

inline mint C(int n, int m) {
    return frac[n] * ifrac[m] * ifrac[n - m];
}

signed main() {
    scanf("%d%s", &n, S + 1);
    for (int i = 1; i < n; ++i) g[i][i + 1] = 1;
    g[n][1] = 1, frac[0] = ifrac[0] = 1;
    for (int i = 1; i <= n; ++i) {
        frac[i] = frac[i - 1] * i;
        Inv[i] = i == 1 ? 1 : 0_mod - (mint::mod / i) * Inv[mint::mod % i];
        ifrac[i] = ifrac[i - 1] * Inv[i];
    }
    for (int len = 3; len <= n + 1; ++len)
    for (int l = 1; l <= n; ++l) {
        int r = l + len - 1;
        int rr = r > n ? r - n : r;
        for (int k = l + 1; k < r; ++k) {
            int kk = k > n ? k - n : k;
            mint o = C(r - l - 2, k - l - 1);
            g[l][rr] += g[l][kk] * g[kk][rr] * o;
            if (S[l] == 'R') {
                f[l][rr] += ((k - l) * g[l][kk] * g[kk][rr] + g[l][kk] * f[kk][rr] + g[kk][rr] * f[l][kk]) * o;
            }
            if (S[rr] == 'L') {
                f[l][rr] += ((r - k) * g[l][kk] * g[kk][rr] + g[l][kk] * f[kk][rr] + g[kk][rr] * f[l][kk]) * o;
            }
        }
        g[l][rr] *= (S[l] == 'R') + (S[rr] == 'L');
    }
    mint sum = 0;
    for (int i = 1; i <= n; ++i) sum += f[i][i];
    sum *= ifrac[n];
    printf("%d\n", sum.raw());
    return 0;
}

相關文章