插頭dp初探

添雅發表於2019-02-22

問題描述

插頭dp用於解決一類可基於圖連通性遞推的問題。用插頭來表示輪廓線上的連通性,然後根據連通性與下一位結合討論進行轉移。

表示連通性的方法

<最小表示法> 與字串迴圈最小表示不同,這種方法用於給輪廓線上的聯通情況確定一個唯一對應的標號序列,做法是從左至右輪廓線掃描,每掃描到一個未標號的位置就新建一個標號,並將輪廓線以後與這一位聯通的位置都標上此號,不被包含的點標號為0。舉例本質相同的連通性((3,3,2,1,3))((2,2,3,1,2))都會被標記為(1,1,2,3,1)

<括號表示法> 用於解決路徑(、迴路)相關的連通性。做法是將輪廓線上方的迴路連結到輪廓向上插頭區別為左插頭與右插頭,逐格轉移時討論格子上邊有左邊的插頭;求解任意路徑問題時保留左右插頭但不合並,並且引入“獨立插頭”表示只有一端連結到輪廓線路徑的連結端;

<其它> 引入一些插頭,然後直接轉換為進位制數表示的不清楚怎麼分類的方法。

具體操作

(f[x,s])為考慮到位置(x) 輪廓線狀態為(s)的解。轉移是個費腦子的事情,按下不表。連通性轉換為進位制數時選擇(2^t)作為進位制數可以更快速的取出、修改輪廓線上某一位的值,但時需要把所有的狀態扔進hash表裡。

練習題 (7/7)

luogu5056 【模板】插頭dp

我就不造輪子了 講解可以參考ladylex 的例題2(雖然不是一道題,但分類討論差不多的而且有圖解)

#include <bits/stdc++.h>
#define LL long long

const int mod=299987;

int n,m,endx,endy;
LL ans;
char a[20][20];

struct hash_set {
    LL val[mod];
    int siz,key[mod],hsh[mod];
    void clear() {
        memset(val,0,sizeof val);
        memset(key,-1,sizeof key);
        memset(hsh,0,sizeof hsh);
        siz=0;
    }
    void newhsh(int id,int vl) {
        hsh[id]=++siz,key[siz]=vl;
    }
    LL&operator[](const int &sta) {
        for(int i=sta%mod; ;i=(i+1==mod?0:i+1)) {
            if(!hsh[i]) newhsh(i,sta);
            if(key[hsh[i]]==sta) return val[hsh[i]];
        }
    }
} f[2];

int find(int sta,int id) {
    return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
    bit=(bit-1)<<1;
    sta|=3<<bit;
    sta^=3<<bit;
    sta|=val<<bit; 
}
int link(int sta,int pos) {
    int cnt=0,dlt=(find(sta,pos)==1?1:-1);
    for(int i=pos; i&&i<=m+1; i+=dlt) {
        int plg=find(sta,i);
        if(plg==1) cnt++;
        else if(plg==2) cnt--;
        if(!cnt) return i;
    }
    return -1;
}
void p_dp(int x,int y) {
    int now=((x-1)*m+y)&1,lst=now^1,tot=f[lst].siz;
    f[now].clear();
    for(int i=1; i<=tot; ++i) {
        int sta=f[lst].key[i];
        LL val=f[lst].val[i];
        if(link(sta,y)==-1||link(sta,y+1)==-1) 
            continue; // 狀態不可用 
        int p1=find(sta,y),p2=find(sta,y+1);
        if(a[x][y]!=`.`) {
            if(!p1&&!p2) f[now][sta]+=val;
        } else if(!p1&&!p2) {
            if(a[x+1][y]==`.`&&a[x][y+1]==`.`) {
                set(sta,y,1);
                set(sta,y+1,2);
                f[now][sta]+=val;
            }
        } else if(p1&&!p2) {
            if(a[x+1][y]==`.`) f[now][sta]+=val;
            if(a[x][y+1]==`.`) {
                set(sta,y,0);
                set(sta,y+1,p1);
                f[now][sta]+=val;
            }           
        } else if(!p1&&p2) {
            if(a[x][y+1]==`.`) f[now][sta]+=val;
            if(a[x+1][y]==`.`) {
                set(sta,y,p2);
                set(sta,y+1,0);
                f[now][sta]+=val;
            }
        } else if(p1==1&&p2==1) { // `((` ))
            set(sta,link(sta,y+1),1);
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        } else if(p1==1&&p2==2) { // `()`
            if(x==endx&&y==endy) ans+=val;
        } else if(p1==2&&p2==1) { // `)(` => merge
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        } else if(p1==2&&p2==2) { //(( `))`
            set(sta,link(sta,y),2);
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        }
    }
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; ++i) {
        scanf("%s",a[i]+1); 
        for(int j=1; j<=m; ++j) {
            if(a[i][j]==`.`) endx=i,endy=j;
        }
    }
    f[0].clear();
    f[0][0]=1;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) p_dp(i,j);
        if(i!=n) {
            int now=(i*m)&1,tot=f[now].siz;
            for(int j=1; j<=tot; ++j) 
                f[now].key[j]<<=2;
        }
    }
    printf("%lld
",ans);
    return 0;
}
luogu2289 郵遞員

容易發現從((1,1))出發再回到((1,1))且所有點都恰號經過一次的方案數正是途中的曼哈頓迴路數目*2(正著走和逆著走),高精度。

#include <bits/stdc++.h>
//using namespace std;

struct cint {
    static const int P=1e9;
    int bit[10];
    cint() { clear();}
    void clear() {
        memset(bit,0,sizeof bit);
    }
    void set(int t) {
        for(clear(); t; bit[++bit[0]]=t%P,t/=P);
    }
    int&operator[](const int &d) {
        return bit[d];
    }
    void print(char ed=`
`) {
        printf("%d",bit[bit[0]]);
        for(int i=bit[0]-1; i>0; --i) printf("%09d",bit[i]);
        putchar(ed);
    }
    cint operator+(cint b) {
        cint c;
        c.clear();
        c[0]=std::max(bit[0],b[0])+1;
        for(int i=1; i<=c[0]; ++i) {
            c[i]+=bit[i]+b[i];
            c[i+1]+=c[i]/P;
            c[i]%=P;
        }
        while(!c[c[0]]) c[0]--;
        return c;
    }
    void operator+=(cint b) {
        *this=*this+b;
    }
    void operator=(int x) {
        set(x);
    }
} ans;
struct hash_map {
    static const int P=299987;
    cint val[P];
    int siz,key[P],hsh[P];
    void clear() {
        siz=0;
        memset(val,0,sizeof val);
        memset(key,-1,sizeof key);
        memset(hsh,0,sizeof hsh);
    }
    void new_hsh(int id,int vl) {
        hsh[id]=++siz,key[siz]=vl;
    }
    cint &operator[](const int &s) {
        for(int i=s%P; ; i=(i+1==P?0:i+1)) {
            if(!hsh[i]) new_hsh(i,s);
            if(key[hsh[i]]==s) return val[hsh[i]];
        } 
    }
} f[2];

int n,m;
int find(int sta,int id) {
    return (sta>>((id-1)<<1))&3;
}
void set(int&sta,int bit,int val) {
    bit=(bit-1)<<1;
    sta|=3<<bit;
    sta^=3<<bit;
    sta|=val<<bit;
}
int link(int sta,int pos) {
    int cnt=0,dlt=(find(sta,pos)==1?1:-1);
    for(int i=pos; i&&i<=m+1; i+=dlt) {
        int plg=find(sta,i);
        if(plg==1) cnt++;
        else if(plg==2) cnt--;
        if(!cnt) return i;
    }
    return -1;
}
void p_dp(int x,int y) {
    int now=((x-1)*m+y)&1,lst=now^1;
    f[now].clear();
    for(int i=1; i<=f[lst].siz; ++i) {
        int sta=f[lst].key[i];
        cint val=f[lst].val[i];
        if(link(sta,y)==-1||link(sta,y+1)==-1) continue;
        int p1=find(sta,y),p2=find(sta,y+1);
        if(!p1&&!p2) {
            if(x!=n&&y!=m) {
                set(sta,y,1);
                set(sta,y+1,2);
                f[now][sta]+=val;
            }
        } else if(p1&&!p2) {
            if(x!=n) f[now][sta]+=val;
            if(y!=m) {
                set(sta,y,0);
                set(sta,y+1,p1);
                f[now][sta]+=val;
            } 
        } else if(!p1&&p2) {
            if(y!=m) f[now][sta]+=val;
            if(x!=n) {
                set(sta,y,p2);
                set(sta,y+1,0);
                f[now][sta]+=val;
            } 
        } else if(p1==1&&p2==1) {
            set(sta,link(sta,y+1),1);
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        } else if(p1==1&&p2==2) {
            if(x==n&&y==m) ans+=val;
        } else if(p1==2&&p2==1) {
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        } else {
            set(sta,link(sta,y),2);
            set(sta,y,0);
            set(sta,y+1,0);
            f[now][sta]+=val;
        }
    }
}

int main() {
    scanf("%d%d",&n,&m);
    if(n==1||m==1) {
        puts("1");
        return 0;
    }
    if(n<m) std::swap(n,m);
    f[0].clear();
    f[0][0]=1;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) p_dp(i,j);
        if(i!=n) {
            int now=(i*m)&1;
            for(int j=1; j<=f[now].siz; ++j) 
                f[now].key[j]<<=2;
        }
    }
    ans+=ans;
    ans.print();
    return 0;
}
bzoj2310 ParkII

求最大權任意路徑,引入了獨立插頭

#include <bits/stdc++.h>
#define upd(sta,x) f[now][sta]=max(f[now][sta],(x))
using std::max;

struct hash_map {
    static const int P=23333;
    int siz,hsh[P],val[P],key[P]; 
    void clear() {
        siz=0;
        memset(hsh,0,sizeof hsh);
        memset(key,-1,sizeof key);
        memset(val,-0x3f,sizeof val);
    }
    void new_hsh(int id,int sta) {
        hsh[id]=++siz,key[siz]=sta;
    }
    int &operator[](int sta) {
        for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
            if(!hsh[i]) new_hsh(i,sta);
            if(key[hsh[i]]==sta) return val[hsh[i]];
        }
    }
} f[2];

int n,m,ans=-0x3f3f3f3f,a[101][101];
int find(int sta,int id) {
    return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
    bit=(bit-1)<<1;
    sta|=3<<bit;
    sta^=3<<bit;
    sta|=val<<bit;
}
int link(int sta,int pos) {
    int cnt=0,dlt=(find(sta,pos)==1?1:-1);
    for(int i=pos; i&&i<=m+1; i+=dlt) {
        int plg=find(sta,i);
        if(plg==1) cnt++;
        else if(plg==2) cnt--;
        if(cnt==0) return i;
    }
    return -1;
}
bool check(int sta) {
    int cnt=0,cnt1=0;
    for(int i=1; i<=m+1; ++i) {
        int plg=find(sta,i);
        if(plg==3) cnt++;
        else if(plg==1) cnt1++;
        else if(plg==2) cnt1--;
        if(cnt>2/*||cnt1<0*/) break;
    } 
    return cnt<=2&&cnt1==0;
} 
void p_dp(int x,int y) {
    int now=((x-1)*m+y)&1,lst=now^1;
    f[now].clear();
    for(int i=1; i<=f[lst].siz; ++i) {
        int sta=f[lst].key[i];
        int val=f[lst].val[i];
        if(!check(sta)||sta>=(1<<((m+1)<<1))) continue;
        int p1=find(sta,y);
        int p2=find(sta,y+1);
        int idl=sta;
        set(idl,y,0);
        set(idl,y+1,0);
        int ept1=idl,ept2=idl;
        if(!p1&&!p2) {
            upd(idl,val); //跳過這個格子 
            if(x<n&&y<m) set(sta,y,1),set(sta,y+1,2),upd(sta,val+a[x][y]); //新建一對括號 
            if(x<n) set(ept1,y,3),upd(ept1,val+a[x][y]); //新建向下的獨立插頭 
            if(y<m) set(ept2,y+1,3),upd(ept2,val+a[x][y]); //新建向右的獨立插頭
        } else if(p1&&!p2) {
            if(x<n) upd(sta,val+a[x][y]); //向下擴充套件p1 
            if(y<m) set(ept1,y+1,p1),upd(ept1,val+a[x][y]); //向右擴充套件p1 
            if(p1==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
            else set(ept2,link(sta,y),3),upd(ept2,val+a[x][y]); //停止擴充套件p1,p1的另一頭改為獨立插頭 
        } else if(!p1&&p2) { 
            if(y<m) upd(sta,val+a[x][y]); //向右擴充套件p2 
            if(x<n) set(ept2,y,p2),upd(ept2,val+a[x][y]); //向下擴充套件p2
            if(p2==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
            else set(ept1,link(sta,y+1),3),upd(ept1,val+a[x][y]); //停止擴充套件p2,p2的另一頭改為獨立插頭 
        } 
        else if(p1==1&&p2==1) set(ept1,link(sta,y+1),1),upd(ept1,val+a[x][y]); //`((`)) 
        else if(p1==1&&p2==2) continue; //形成迴路,不合法 
        else if(p1==2&&p2==1) upd(idl,val+a[x][y]); //(`)(`) 連線 
        else if(p1==2&&p2==2) set(ept2,link(sta,y),2),upd(ept2,val+a[x][y]); //((`))` 
        else if(p1==3&&p2==3) {if(!idl) ans=max(ans,val+a[x][y]);} 
        else if(p2==3) set(ept1,link(sta,y),3),upd(ept1,val+a[x][y]); //連線
        else if(p1==3) set(ept2,link(sta,y+1),3),upd(ept2,val+a[x][y]); //連線 
    }
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) {
            scanf("%d",&a[i][j]);
            ans=max(ans,a[i][j]);
        }
    }
    f[0].clear();
    f[0][0]=0;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) p_dp(i,j);
        if(i!=n) {
            int now=(i*m)&1;
            for(int j=1; j<=f[now].siz; ++j) 
                f[now].key[j]<<=2;
        } 
    }
    printf("%d
",ans);
    return 0;
}
bzoj2331 [SCOI2011]地板

輪廓線上的狀態0表示無插頭,1表示有一個沒有拐彎的插頭,2表示拐過彎的插頭。

#include <bits/stdc++.h>
const int mod=20110520;

struct hash_map {
    static const int P=233333;
    int siz,hsh[P],val[P],key[P]; 
    void clear() {
        siz=0;
        memset(hsh,0,sizeof hsh);
        memset(key,-1,sizeof key);
        memset(val,0,sizeof val);
    }
    void new_hsh(int id,int sta) {
        hsh[id]=++siz,key[siz]=sta;
    }
    int &operator[](int sta) {
        for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
            if(!hsh[i]) new_hsh(i,sta);
            if(key[hsh[i]]==sta) return val[hsh[i]];
        }
    }
} f[2];

int n,m,edx,edy,ans;
char a[102][102];

int find(int sta,int id) {
    return (sta>>((id-1)<<1))&3;
}
void set(int &sta,int bit,int val) {
    bit=(bit-1)<<1;
    sta|=3<<bit;
    sta^=3<<bit;
    sta|=val<<bit;
}
#define upd(val) f[now][sta]=(f[now][sta]+(val))%mod;
void p_dp(int x,int y) {
    int now=((x-1)*m+y)&1,lst=now^1;
    f[now].clear();
    for(int i=1; i<=f[lst].siz; ++i) {
        int sta=f[lst].key[i];
        int val=f[lst].val[i];
        int p1=find(sta,y);
        int p2=find(sta,y+1);
        if(sta>=(1<<((m+1)<<1))) continue;
        if(a[x][y]!=`_`) {
            if(!p1&&!p2) upd(val);
        } else if(!p1&&!p2) {
            if(a[x+1][y]==`_`) set(sta,y,1),set(sta,y+1,0),upd(val);
            if(a[x][y+1]==`_`) set(sta,y,0),set(sta,y+1,1),upd(val);
            if(a[x][y+1]==`_`&&a[x+1][y]==`_`) set(sta,y,2),set(sta,y+1,2),upd(val); 
        } else if(!p1&&p2) {
            if(p2==1) {
                if(a[x+1][y]==`_`) set(sta,y,p2),set(sta,y+1,0),upd(val);
                if(a[x][y+1]==`_`) set(sta,y,0),set(sta,y+1,2),upd(val);
            } else {
                set(sta,y,0),set(sta,y+1,0),upd(val);
                if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
                if(a[x+1][y]==`_`) set(sta,y,2),upd(val); 
            }
        } else if(p1&&!p2) {
            if(p1==1) {
                if(a[x][y+1]==`_`) set(sta,y,0),set(sta,y+1,1),upd(val);
                if(a[x+1][y]==`_`) set(sta,y,2),set(sta,y+1,0),upd(val);
            } else {
                set(sta,y,0),set(sta,y+1,0),upd(val);
                if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
                if(a[x][y+1]==`_`) set(sta,y+1,2),upd(val);
            }
        } else if(p1==1&&p2==1) {
            set(sta,y,0),set(sta,y+1,0),upd(val);
            if(x==edx&&y==edy&&!sta) ans=(ans+val)%mod;
        } //其餘情況不合法 
    } 
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; ++i) {
        scanf("%s",a[i]+1);
    } 
    if(n<m) { //轉置 
        for(int i=1; i<=n; ++i) {
            for(int j=i+1; j<=m; ++j) 
                std::swap(a[i][j],a[j][i]);
        }
        std::swap(n,m);
    }
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) {
            if(a[i][j]!=`*`) a[i][j]=`_`;
            if(a[i][j]==`_`) edx=i,edy=j; 
        }
    }
    f[0].clear();
    f[0][0]=1;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=m; ++j) p_dp(i,j);
        if(i!=n) {  
            int now=(i*m)&1; 
            for(int j=1; j<=f[now].siz; ++j) 
                f[now].key[j]<<=2;
        } 
    }
    printf("%d",ans);
    return 0;
}
luogu3886 [JLOI2009]神祕的生物

很裸的最小表示法的題目。

#include <bits/stdc++.h>
//using namespace std;
const int inf=0x3f3f3f3f;

struct hash_map {
    static const int P=23333;
    int siz,hsh[P],key[P],val[P];
    void clear() {
        siz=0;
        memset(hsh,0,sizeof hsh);
        memset(key,-1,sizeof key);
        memset(val,-inf,sizeof val);
    }
    void new_hsh(int id,int sta) {
        hsh[id]=++siz,key[siz]=sta;
    }
    int&operator[](int sta) {
        for(int i=sta%P; ; i=(i+1==P?0:i+1)) {
            if(!hsh[i]) new_hsh(i,sta);
            if(key[hsh[i]]==sta) return val[hsh[i]];
        }
    }
} f[2];

int n,ans=-inf;
int now,lst=1,a[10][10];
int find(int sta,int bit) {
    if(!bit) return 0;
    return (sta>>(3*(bit-1)))&7;
} 
void set(int&sta,int bit,int val) {
    bit=3*(bit-1);
    sta|=7<<bit;
    sta^=7<<bit;
    sta|=val<<bit;
}
int count(int sta,int val) {
    int c=0;
    for(int i=1; i<=n; ++i,sta>>=3) 
        if((sta&7)==val) c++;
    return c;
}
int relabel(int sta) {
    static int hsh,cnt,id[10],w[10];
    memset(id,-1,sizeof id);
    hsh=cnt=id[0]=0;
    for(int i=1; i<=n; ++i,sta>>=3) w[i]=sta&7;
    for(int i=n; i; --i) {
        if(id[w[i]]==-1) id[w[i]]=++cnt;
        hsh=hsh<<3|id[w[i]];
    } 
    return hsh;
}
bool unicom(int sta) {
    bool exi=0;
    for(int i=1; i<=n; ++i,sta>>=3) {
        if((sta&7)>1) return 0;
        if((sta&7)==1) exi=1;
    }
    return exi;
}

#define upd(sta,val) f[now][sta]=std::max(f[now][sta],(val))

void p_dp(int x,int y) {
    now=lst,lst^=1;
    f[now].clear();
    for(int i=1; i<=f[lst].siz; ++i) {
        int sta=f[lst].key[i];
        int val=f[lst].val[i];
        int p1=find(sta,y-1);
        int p2=find(sta,y);
        if(!p1&&!p2) {
            upd(sta,val);
            set(sta,y,7);
            upd(relabel(sta),val+a[x][y]);
        } else if(!p1&&p2) {
            if(count(sta,p2)==1) upd(sta,val+a[x][y]);
            else {
                upd(sta,val+a[x][y]);
                set(sta,y,0);
                upd(relabel(sta),val);
            }
        } else if(p1&&!p2) {
            upd(sta,val);
            set(sta,y,p1);
            upd(relabel(sta),val+a[x][y]);
        } else if(p1==p2) {
            upd(sta,val+a[x][y]);
            set(sta,y,0);
            upd(relabel(sta),val);
        } else {
            if(count(sta,p2)==1) {
                for(int j=1,tmp=sta; j<=n; ++j,tmp>>=3) 
                    if((tmp&7)==p1) set(sta,j,p2);
                upd(relabel(sta),val+a[x][y]);
            } else {
                int tmp=sta;
                set(tmp,y,0);
                upd(relabel(tmp),val);
                tmp=sta;
                for(int j=1; j<=n; ++j,tmp>>=3) 
                    if((tmp&7)==p1) set(sta,j,p2);
                upd(relabel(sta),val+a[x][y]);
            }
        }
    }
    for(int i=1; i<=f[now].siz; ++i) {
        if(unicom(f[now].key[i])) ans=std::max(ans,f[now].val[i]);
    }
}

int main() {
    scanf("%d",&n);
    f[now].clear();
    f[now][0]=0;
    for(int i=1; i<=n; ++i) {
        for(int j=1; j<=n; ++j) {
            scanf("%d",&a[i][j]);
            p_dp(i,j);
        }
    }
    printf("%d
",ans);
}

bzoj2595 [WC2008]遊覽計劃

斯坦納樹的解法詳見最小斯坦納樹初探
插頭dp(最小表示法)詳見QAQ

bzoj1494 [NOI2007]生成樹計數

假設已經考慮了前i-1個點,此時輪廓線定義為i-k到i-1的連通性(狀態設為(f[i-1,s]))。k很小,搜尋可知連續k位的連通性表示(最小表示法)不會超過55個。而且在i>=k時的狀態的轉移顯然可以矩乘優化,只需要處理考慮i=k時前k位的連通性(s)的方案數。

可以參考論文

#include <bits/stdc++.h>
using namespace std;
const int N=55;
const int mod=65521;

int k,cnt,expr[N][6];
long long n;

struct mtr {
    int a[N][N];
    int*operator[](int x) {return a[x];}
    mtr operator*(mtr b) {
        static mtr c;
        memset(&c,0,sizeof c);
        for(int i=1; i<=cnt; ++i) {
            for(int k=1; k<=cnt; ++k) {
                for(int j=1; j<=cnt; ++j) {
                    c[i][j]=(c[i][j]+1u*a[i][k]*b[k][j])%mod;
                }
            }
        } 
        return c;
    }
    mtr pow(long long y) {
        static mtr x,c;
        x=*this;
        memset(&c,0,sizeof c);
        for(int i=1; i<=cnt; ++i) c[i][i]=1;
        for(; y; y>>=1,x=x*x) if(y&1) c=c*x;
        return c; 
    }
} ans,A,B;

int id[8000],t[7],tmp[7],vis[7],cpl[6];
int qpow(int x,int y) {
    int c=1;
    for(; y>0; y>>=1,x=1u*x*x%mod) 
        if(y&1) c=1u*c*x%mod;
    return c;
}
void dfs(int dep,int mx) {
    if(dep==k+1) {
        cnt++;
        int hs=0;
        memset(vis,0,sizeof vis);
        for(int i=1; i<=k; ++i) {
            expr[cnt][i]=t[i];
            hs=hs*6+t[i];
            vis[t[i]]++;
        }
        id[hs]=cnt;
        B[cnt][1]=1;
        for(int i=1; vis[i]; ++i) {
            B[cnt][1]=1u*B[cnt][1]*cpl[vis[i]]%mod;
        }
        return;
    }
    for(int i=1; i<=mx; ++i) {
        t[dep]=i,dfs(dep+1,mx);
    }
    t[dep]=mx+1;
    dfs(dep+1,mx+1);
}
void init() {
    for(int i=1; i<=k; ++i) cpl[i]=qpow(i,i-2);
    dfs(1,0);
    for(int i=1; i<=cnt; ++i) {
        copy(expr[i]+1,expr[i]+k+1,t), t[k]=6;
        copy(t,t+k+1,tmp);
        for(int j=0; j<(1<<k); ++j) {
            bool ok=1;
            for(int p=0; p<k; ++p) if((j>>p)&1) {
                int c=t[p];
                if(c==6) { ok=0; break;}
                for(int q=0; q<k; ++q) if(t[q]==c) t[q]=6;
            }
            if(ok) {
                int tot=0,hs=0;
                memset(vis,0,sizeof vis);
                for(int p=1; p<=k; ++p) {
                    if(!vis[t[p]]) vis[t[p]]=++tot;
                    hs=hs*6+(t[p]=vis[t[p]]);
                }
                if(vis[t[0]]) (A[id[hs]][i]+=1)%=mod;
            }
            copy(tmp,tmp+k+1,t);
        }
    }
}

int main() {
    scanf("%d%lld",&k,&n);
    if(n<=k) {
        printf("%d
",qpow(n,n-2));
        return 0;
    }
    init();
    ans=A.pow(n-k)*B;
    printf("%d
",ans[1][1]);
    return 0;
}

相關文章