Codeforces #123D: 字尾陣列+單調棧

thges發表於2018-05-13
D. String
 
 

You are given a string s. Each pair of numbers l and r that fulfill the condition 1 ≤ l ≤ r ≤ |s|, correspond to a substring of the string s, starting in the position l and ending in the position r (inclusive).

Let`s define the function of two strings F(x, y) like this. We`ll find a list of such pairs of numbers for which the corresponding substrings of string x are equal to string y. Let`s sort this list of pairs according to the pair`s first number`s increasing. The value of function F(x, y)equals the number of non-empty continuous sequences in the list.

For example: F(babbabbababbab, babb) = 6. The list of pairs is as follows:

(1, 4), (4, 7), (9, 12)

Its continuous sequences are:

  • (1, 4)
  • (4, 7)
  • (9, 12)
  • (1, 4), (4, 7)
  • (4, 7), (9, 12)
  • (1, 4), (4, 7), (9, 12)

Your task is to calculate for the given string s the sum F(s, x) for all x, that x belongs to the set of all substrings of a string s.

Input

The only line contains the given string s, consisting only of small Latin letters (1 ≤ |s| ≤ 105).

Output

Print the single number — the sought sum.

Please do not use the %lld specificator to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specificator.

Examples
input
Copy
aaaa
output
Copy
20
input
Copy
abcdef
output
Copy
21
input
Copy
abacabadabacaba
output
Copy
188
Note

In the first sample the function values at x equal to “a”, “aa”, “aaa” and “aaaa” equal 10, 6, 3 and 1 correspondingly.

In the second sample for any satisfying x the function value is 1.

題意:如果某一種子串s在原串中出現了k次,根據題目定義的函式,它產生的貢獻是(k+1)*k/2

這個條件很奇怪,我們嘗試轉化模型,就會發現這個函式相當於我們將這k個s串排成一排,每個

串和它自己以及後面的串匹配一次,總次數就是題目要求的函式

於是我們可以上字尾陣列+高度陣列,對於每一個字尾,和後面的每一個字尾的算一個最長公共字首,然後根據長度統計答案

這個東西可以用單調棧搞一搞,最後每個字尾和自己可以匹配一次,也就是說如果讀入的串長度為n,ans+=(n+1)*n/2

程式碼:

  1 //#include"bits/stdc++.h"
  2 #include"cstdio"
  3 #include"map"
  4 #include"set"
  5 #include"cmath"
  6 #include"queue"
  7 #include"vector"
  8 #include"string"
  9 #include"ctime"
 10 #include"stack"
 11 #include"deque"
 12 #include"cstdlib"
 13 #include"cstring"
 14 #include"iostream"
 15 #include"algorithm"
 16 
 17 #define db double
 18 #define ll long long
 19 #define vec vector<ll>
 20 #define Mt  vector<vec>
 21 #define ci(x) scanf("%d",&x)
 22 #define cd(x) scanf("%lf",&x)
 23 #define cl(x) scanf("%lld",&x)
 24 #define pi(x) printf("%d
",x)
 25 #define pd(x) printf("%f
",x)
 26 #define pl(x) printf("%lld
",x)
 27 //#define rep(i, x, y) for(int i=x;i<y;i++)
 28 #define rep(i, n) for(int i=0;i<n;i++)
 29 using namespace std;
 30 const int N   = 1e6 + 5;
 31 const int mod = 1e9 + 7;
 32 const int MOD = mod - 1;
 33 const int inf = 0x3f3f3f3f;
 34 const db  PI  = acos(-1.0);
 35 const db  eps = 1e-10;
 36 int sa[N];
 37 int rk[N];
 38 int tmp[N];
 39 int lcp[N];
 40 int n,k;
 41 bool cmp(int i,int j){
 42     if(rk[i] != rk[j]) return rk[i]<rk[j];
 43     else
 44     {
 45         int ri=i+k<=n?rk[i+k]:-1;
 46         int rj=j+k<=n?rk[j+k]:-1;
 47         return ri<rj;
 48     }
 49 }
 50 void bulid(string s,int *sa)
 51 {
 52     n=(int)s.size();
 53     for(int i=0;i<=n;i++){
 54         sa[i]=i;
 55         rk[i]=i<n?s[i]:-1;
 56     }
 57     for(k=1;k<=n;k*=2){
 58         sort(sa,sa+n+1,cmp);
 59         tmp[sa[0]]=0;
 60         for(int i=1;i<=n;i++){
 61             tmp[sa[i]]=tmp[sa[i-1]]+(cmp(sa[i-1],sa[i])?1:0);
 62         }
 63         for(int i=0;i<=n;i++){
 64             rk[i]=tmp[i];
 65         }
 66     }
 67 }
 68 void LCP(string s,int *sa,int *lcp){
 69     n=(int)s.size();
 70     for(int i=0;i<=n;i++) rk[sa[i]]=i;
 71     int h=0;
 72     lcp[0]=0;
 73     for(int i=0;i<n;i++){
 74         int j=sa[rk[i]-1];
 75         for (h ? h-- : 0; j + h < n&&i + h < n&&s[j + h] == s[i + h]; h++);
 76         lcp[rk[i]-1] = h;
 77     }
 78 }
 79 #define x first
 80 #define y second
 81 #define Pair pair<int,int>
 82 #define mp make_pair
 83 
 84 stack<Pair> sta;
 85 int main ()
 86 {
 87     string s;
 88     cin>>s;
 89     n=s.length();
 90     bulid(s,sa);
 91     LCP(s,sa,lcp);
 92     ll ans=(ll)n*(ll)(n+1)/2;
 93     ll cnt=0;
 94     for (int i=0;i<=n;i++)
 95     {
 96         Pair ins=mp(lcp[i],1);//貢獻為lcp[i]*num
 97         while (!sta.empty() && sta.top().x>ins.x)
 98         {
 99             cnt-=(ll)sta.top().x*sta.top().y;
100             ins.y+=sta.top().y;
101             sta.pop();
102         }
103         cnt+=(ll)ins.x*ins.y;
104         sta.push(ins);
105         ans+=cnt;
106     }
107     cout<<ans<<endl;
108     return 0;
109 }

 

相關文章