原題連結:https://www.luogu.com.cn/problem/CF25E https://codeforces.com/contest/25/problem/E
題意解讀:給定a,b,c三個字串,求包含a、b、c的最短字串長度。
解題思路:
要得到包含a、b、c的字串,可以透過a、b、c連線形成,而要使得連線後的字串最短,可以儘可能的利用重疊部分
如a、b字串連線的情況可能有三種:連續、交叉、包含
因此,對於a、b字串連線的最短長度,可以先計算a是否包含b,如果包含b,連線後的長度即a的長度;再計算a字尾與b字首公共長度ab,連線後長度為a.size()+b.size()-ab。
而a、b、c連線的方式一共有六種:
a->b->c
a->c->b
b->a->c
b->c->a
c->a->b
c->b->a
列舉所有情況下連線後得到的字串長度,取最短即可。
這裡有一個關鍵函式:int longest(string &x, string &y)
用來計算x字尾與y字首的最長公共長度,且如果x包含y函式返回-1
STL String暴力列舉:
int longest(string &x, string &y)
{
if(x.find(y) != string::npos) return -1; //如果x包含y
int len = min(x.size(), y.size());
int res = 0;
for(int i = 1; i <= len; i++) //列舉前字尾的公共長度
{
string post = x.substr(x.size() - i, i); //x的字尾
string pre = y.substr(0, i); //y的字首
if(post == pre) res = i; //判斷是否相等
}
return res;
}
以上方法是O(n^2)的,必須最佳化
KMP最佳化實現:
int longest_kmp(string &x, string &y)
{
memset(Next, 0, sizeof(Next));
//利用y計算Next陣列
for(int i = 1, j = 0; i < y.size(); i++)
{
while(j && y[i] != y[j]) j = Next[j - 1];
if(y[i] == y[j]) j++;
Next[i] = j;
}
int j = 0;
//在x中找y,如果找到返回-1,如果沒找到,返回y最後匹配的位置即最大公共前字尾長度
for(int i = 0; i < x.size(); i++)
{
while(j && x[i] != y[j]) j = Next[j - 1];
if(x[i] == y[j]) j++;
if(j == y.size()) return -1;
}
return j;
}
100分程式碼:
#include <bits/stdc++.h>
using namespace std;
string a, b, c;
int Next[100005];
//計算相同的x的字尾和y的字首長度,x包含y時返回-1
int longest(string &x, string &y)
{
if(x.find(y) != string::npos) return -1; //如果x包含y
int len = min(x.size(), y.size());
int res = 0;
for(int i = 1; i <= len; i++) //列舉前字尾的公共長度
{
string post = x.substr(x.size() - i, i); //x的字尾
string pre = y.substr(0, i); //y的字首
if(post == pre) res = i; //判斷是否相等
}
return res;
}
//用KMP最佳化上面的longest函式
int longest_kmp(string &x, string &y)
{
memset(Next, 0, sizeof(Next));
//利用y計算Next陣列
for(int i = 1, j = 0; i < y.size(); i++)
{
while(j && y[i] != y[j]) j = Next[j - 1];
if(y[i] == y[j]) j++;
Next[i] = j;
}
int j = 0;
//在x中找y,如果找到返回-1,如果沒找到,返回y最後匹配的位置即最大公共前字尾長度
for(int i = 0; i < x.size(); i++)
{
while(j && x[i] != y[j]) j = Next[j - 1];
if(x[i] == y[j]) j++;
if(j == y.size()) return -1;
}
return j;
}
//計算x->y->z連線順序下的最短字串長度
long long calcu(string &x, string &y, string &z, int xy, int yz, int xz)
{
long long res = x.size(); //加上x的長度
if(xy >= 0) //如果y不是x的子串
{
res += y.size() - xy; //加上y的長度,減去y字首和x字尾重疊部分的長度
if(yz >= 0) res += z.size() - yz; //如果z不是y的子串,加上z的長度,減去z字首和y字尾重疊部分的長度
}
else //如果y是x的子串
{
if(xz >= 0) res += z.size() - xz; //如果z不是x的子串,加上z的長度,減去z字首和x字尾重疊部分的長度
}
return res;
}
int main()
{
cin >> a >> b >> c;
int ab = longest_kmp(a, b);
int ba = longest_kmp(b, a);
int bc = longest_kmp(b, c);
int cb = longest_kmp(c, b);
int ca = longest_kmp(c, a);
int ac = longest_kmp(a, c);
long long ans = 1e18, t;
//abc
t = calcu(a, b, c, ab, bc, ac);
ans = min(ans, t);
//acb
t = calcu(a, c, b, ac, cb, ab);
ans = min(ans, t);
//bac
t = calcu(b, a, c, ba, ac, bc);
ans = min(ans, t);
//bca
t = calcu(b, c, a, bc, ca, ba);
ans = min(ans, t);
//cab
t = calcu(c, a, b ,ca, ab, cb);
ans = min(ans, t);
//cba
t = calcu(c, b, a, cb, ba, ca);
ans = min(ans, t);
cout << ans;
return 0;
}