「題解」小 R 打怪獸 monster

Lu_Anlai發表於2021-06-14

本文將同步釋出於:

題目

題目描述

小 R 最近在玩一款遊戲。在遊戲中,小 R 要依次打 \(n\) 個怪獸,他需要打敗至少 \(k\) 個怪獸才能通關。小 R 有兩個屬性值,分別是攻擊力 \(A\) 和耐力 \(R\),每個怪獸也有兩個屬性值,分別是防禦力 \(D\) 和生命值 \(H\)(不同的怪獸屬性值可能不同)。小 R 每攻擊一次怪獸,可以讓怪獸的生命值減少 \(\max(A-D,1)\)點,同時小 R 的耐力會減少 \(1\)。怪獸不會攻擊。若在一次攻擊之後,怪獸的生命值 \(\leq 0\),則小 R 勝利。若在一次攻擊之後,在怪獸的生命值大於 \(0\) 的條件下,小 \(R\) 的耐力值降低到了 \(0\),則怪獸勝利。在和一個怪獸戰鬥結束後,無論輸贏,小 R 都會恢復全部的耐力值。

現在,小 R 的攻擊 \(A\) 和每個怪獸的防禦力 \(D\) 是確定的,小 R 的耐力值 \(R\) 是一個在 \([1,m]\) 區間內的整數,第 \(i\) 個怪獸的生命值是一個在 \([X_i,Y_i]\) 區間內的整數。求有多少種情況使得小 R 能通關,你只需要輸出答案模 \(10^9+7\) 的值就可以了。

兩種情況不同當且僅當這兩種情況下小 R 的耐力值不同或者其中一個怪獸的生命值不同。

\(1\leq k\leq n\leq 50\)\(1\leq m,A,D_i\leq 10^9\)

題解

簡單暴力

考慮列舉耐力值 \(R\in [1,m]\),那麼我們可以輕鬆得到一個關於方案數的動態規劃:

\(d_i=\max(A-D_i,1)\)

\(f_{i,j}\) 表示考慮前 \(i\) 個怪獸,打敗恰好 \(j\) 個的方案數,我們不難得到轉移。

\[\begin{cases} f_{i-1,j-1}(Y_i-X_i+1)&\to f_{i,j},Rd_i\geq Y_i\\ f_{i-1,j}(Y_i-X_i+1)&\to f_{i,j},Rd_i<X_i\\ f_{i-1,j-1}(Rd_i-X_i+1)&\to f_{i,j},X_i\leq Rd_i<Y_i\\ f_{i-1,j}(Y_i-Rd_i)&\to f_{i,j},X_i\leq Rd_i<Y_i\\ \end{cases} \]

這個演算法的時間複雜度為 \(\Theta(n^2m)\)

多項式

我們把 \(f_i\) 看作一個多項式,\(f_{i,j}\)\(x^j\) 的係數。

那麼上面的轉移方程可以看作 \(f_i=f_{i-1}\times (ax+b)\)

最後的結果就是對 \(x^k,x^{k+1},\cdots,x^n\) 的係數求和。

拉格朗日插值

考慮到上述式子可以表示為一個關於 \(R\)\(n\) 次多項式,那麼我們不妨用拉格朗日插值求出 \(n+2\) 個點值,求出字首和的多項式表示,然後做差求解。

求出這個函式的分段點,插值即可。

參考程式

#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
static char buf[1<<21],*p1=buf,*p2=buf;
inline int read(void){
    reg char ch=getchar();
    reg int res=0;
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) res=10*res+(ch^'0'),ch=getchar();
    return res;
}
 
inline int max(reg int a,reg int b){
    return a>b?a:b;
}
 
inline int min(reg int a,reg int b){
    return a<b?a:b;
}
 
const int MAXN=50+5;
const int mod=1e9+7;
 
struct modInt{
    int x;
    inline modInt(reg int x=0):x(x){
        x=(x%mod+mod)%mod;
        assert(0<=x&&x<mod);
        return;
    }
    inline modInt operator+(const modInt& a)const{
        reg int sum=x+a.x;
        return sum>=mod?sum-mod:sum;
    }
    inline modInt operator-(const modInt& a)const{
        reg int sum=x-a.x;
        return sum<0?sum+mod:sum;
    }
    inline modInt operator*(const modInt& a)const{
        return 1ll*x*a.x%mod;
    }
    inline void operator+=(const modInt& a){
        x+=a.x;
        if(x>=mod) x-=mod;
        return;
    }
    inline void operator-=(const modInt& a){
        x-=a.x;
        if(x<0) x+=mod;
        return;
    }
    inline void operator*=(const modInt& a){
        x=1ll*x*a.x%mod;
        return;
    }
};
 
inline modInt fpow(modInt x,reg int exp){
    modInt res=1;
    while(exp){
        if(exp&1)
            res*=x;
        x*=x,exp>>=1;
    }
    return res;
}
 
inline modInt operator/(const modInt& a,const modInt& b){
    return a*fpow(b,mod-2);
}
 
inline void operator/=(modInt& a,const modInt& b){
    a*=fpow(b,mod-2);
    return;
}
 
struct Node{
    int delta,l,r;
};
 
int n,m,k,A;
Node a[MAXN];
modInt f[MAXN][MAXN];
 
inline modInt getVal(reg int R){
    for(reg int i=0;i<=n;++i)
        for(reg int j=0;j<=i;++j)
            f[i][j]=0;
    f[0][0]=1;
    for(reg int i=0;i<n;++i){
        for(reg int j=0;j<=i;++j)
            if(f[i][j].x){
                reg ll val=1ll*a[i+1].delta*R;
                if(val>=a[i+1].r)
                    f[i+1][j+1]+=f[i][j]*(a[i+1].r-a[i+1].l+1);
                    //(len) x
                else if(val<a[i+1].l)
                    f[i+1][j]+=f[i][j]*(a[i+1].r-a[i+1].l+1);
                    //(len)
                else{
                    f[i+1][j]+=f[i][j]*(a[i+1].r-val);
                    f[i+1][j+1]+=f[i][j]*(val-a[i+1].l+1);
                    //(rig)+(lef)*x
                }
            }
    }
    modInt res=0;
    for(reg int i=k;i<=n;++i)
        res+=f[n][i];
    return res;
}
 
int B;
 
inline modInt Lagrange(reg int lef,reg int rig,modInt x[],modInt y[],reg int X){
    modInt res=0;
    for(reg int i=lef;i<=rig;++i){
        modInt pod=1;
        for(reg int j=lef;j<=rig;++j)
            if(i!=j)
                pod*=(modInt(X)-x[j])/(x[i]-x[j]);
        res+=y[i]*pod;
    }
    return res;
}
 
inline modInt getAns(reg int lef,reg int rig){
    if(rig-lef+1<=B){
        modInt res=0;
        for(reg int i=lef;i<=rig;++i)
            res+=getVal(i);
        return res;
    }
    else{
        modInt x[B],y[B];
        for(reg int i=0;i<B;++i)
            x[i]=lef+i,y[i]=getVal(lef+i);
        for(reg int i=1;i<B;++i)
            y[i]+=y[i-1];
        return Lagrange(0,B-1,x,y,rig)-Lagrange(0,B-1,x,y,lef-1);
    }
}
 
int main(void){
    n=read(),m=read(),k=read(),A=read();
    B=n+2;
    for(reg int i=1;i<=n;++i){
        static int d,x,y;
        d=read(),x=read(),y=read();
        a[i].delta=max(A-d,1),a[i].l=x,a[i].r=y;
    }
    vector<int> V;
    V.push_back(1),V.push_back(m+1);
    for(reg int i=1;i<=n;++i){
        V.push_back((a[i].l+a[i].delta-1)/a[i].delta);
        V.push_back(a[i].r/a[i].delta+1);
    }
    sort(V.begin(),V.end()),V.erase(unique(V.begin(),V.end()),V.end());
    while(V.back()>m+1) V.pop_back();
    modInt ans=0;
    for(reg int i=1,siz=V.size();i<siz;++i)
        ans+=getAns(V[i-1],V[i]-1);
    printf("%d\n",ans.x);
    return 0;
}

相關文章