高維字首和SOSDP

HarlemBlog發表於2024-11-16
更新日誌

概念

高維字首和的名字已經很顯然了,不做過多講解。

思路

基本形式

我們較為熟知的二維字首和,通常情況下使用了容斥的思想。事實上,更通常的二維字首和形式往往長下面這樣:

    for(int i=0;i<n;i++){
        for(int j=0;j<n;j++){
            for(int k=1;k<n;k++){
                s[i][j][k]+=s[i][j][k-1];
            }
        }
    }
    for(int i=0;i<n;i++){
        for(int j=1;j<n;j++){
            for(int k=0;k<n;k++){
                s[i][j][k]+=s[i][j-1][k];
            }
        }
    }
    for(int i=1;i<n;i++){
        for(int j=0;j<n;j++){
            for(int k=0;k<n;k++){
                s[i][j][k]+=s[i-1][j][k];
            }
        }
    }

這麼寫是正確的,但是很繁瑣。通常情況下,SOSDP 應該是不會只有這麼粗暴的解決方式的。

通常形式

一般情況下,我們都可以使用狀態壓縮來解決多維字首和問題。

更具體地,這一類問題往往都與位運算有關,比如要求你找出集合內與為 \(0\) 的數對。或者用一些形象化的描述,比如每個集合內都有一些物品,讓你找出選一些集合使得擁有所有型別的物品的方案數,等等。

我們以第一個例子為例,詳見例題1

我們可以對於每個狀態找出他可選的一種與其與為 \(0\) 的數,也就是找出為 \((1<<m)-1\) 的子集的數(用語不太標準,意思是 \(a\&b=a\)\(a\)\(b\) 子集),那麼就可以考慮字首和解決這種為什麼什麼的子集的問題。

具體的,\(a\) 可以具象化為每一位都小於等於\(b\) 的數,這就很符合多維字首和的形式了。

這時候,如果還使用基本形式,那碼量望而生畏,事實上,有一種更簡潔的寫法:

for(int i=0;i<M;i++){
    for(int j=0;j<1<<M;j++){
        if((j>>i&1)&&(f[j^(1<<i)])){
            f[j]=f[j^(1<<i)];
        }
    }
}

這段程式碼僅適用於這道題目,但是其思路是通用的,我們可以把這麼多位都壓縮到同一個數中作為狀態。

例題

Compatible Numbers

CF165E

用作了上面的例題,這裡就不多講解了,本身也很簡單。

程式碼
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define chmax(a,b) a=max(a,b)
#define chmin(a,b) a=min(a,b)
#define dprint(x) cout<<#x<<"="<<x<<"\n"

const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7/*998244353*/;

const int N=1e6+5,M=22;

int n;
int a[N],f[1<<M];

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>a[i];
        f[a[i]]=a[i];
    }
    for(int i=0;i<M;i++){
        for(int j=0;j<1<<M;j++){
            if((j>>i&1)&&(f[j^(1<<i)])){
                f[j]=f[j^(1<<i)];
            }
        }
    }
    for(int i=1;i<=n;i++){
        if(f[(1<<M)-1^a[i]])cout<<f[(1<<M)-1^a[i]];
        else cout<<-1;
        cout<<" ";
    }
    return 0;
}

KOŠARE

LG6442

個人語言表述不夠清晰,附上一份個人認為講的不錯的部落格:

連結

程式碼
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef __int128 i128;
typedef double db;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
template <typename Type>
using vec=vector<Type>;
template <typename Type>
using grheap=priority_queue<Type>;
template <typename Type>
using lrheap=priority_queue<Type,vector<Type>,greater<Type> >;
#define fir first
#define sec second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
#define chmax(a,b) a=max(a,b)
#define chmin(a,b) a=min(a,b)
#define dprint(x) cout<<#x<<"="<<x<<"\n"

const int inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
const int mod=1e9+7/*998244353*/;

const int N=1e6+5,M=20;

int n,m;
ll s[1<<M];
int st[N];
ll ans;

ll qpow(ll a,ll b){
    ll res=1;
    while(b){
        if(b&1)res=res*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return res;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n>>m;
    int j,k;
    for(int i=1;i<=n;i++){
        cin>>k;
        st[i]=0;
        while(k--){
            cin>>j;
            st[i]|=1<<j-1;
        }
        s[st[i]]++;
    }
    for(int i=0;i<m;i++){
        for(int j=0;j<1<<m;j++){
            if(j&1<<i)s[j]=(s[j]+s[j^1<<i])%mod;
        }
    }
    for(int t=0;t<1<<m;t++){
        s[t]=(qpow(2,s[t])-1+mod)%mod;
        int pd=__builtin_popcount(t);
        if((pd&1)==(m&1))ans=(ans+s[t])%mod;
        else ans=(ans-s[t]+mod)%mod;
    }
    cout<<ans;
    return 0;
}

相關文章