hihocoder 1260 String Problem I (Trie樹 好題)

_TCgogogo_發表於2016-03-10
時間限制:10000ms
單點時限:1000ms
記憶體限制:256MB

描述

我們有一個字串集合S,其中有N個兩兩不同的字串。

還有M個詢問,每個詢問給出一個字串w,求有多少S中的字串可以由w新增恰好一個字母得到。

字母可以新增在包括開頭結尾在內的任意位置,比如在"abc"中新增"x",就可能得到"xabc", "axbc", "abxc", "abcx".這4種串。

輸入

第一行兩個數N和M,表示集合S中字串的數量和詢問的數量。

接下來N行,其中第i行給出S中第i個字串。

接下來M行,其中第i行給出第i個詢問串。

所有字串只由小寫字母構成。

資料範圍:

N,M<=10000。

S中字串長度和<=100000。

所有詢問中字串長度和<=100000。

輸出

對每個詢問輸出一個數表示答案。

樣例輸入
3 3
tourist
petr
rng
toosimple
rg
ptr
樣例輸出
0
1
1

題目連結:http://hihocoder.com/problemset/problem/1260

題目分析:感覺正解不是Trie但是還是用Trie搞過了,其實就是在trie樹上DFS的時候有一個點可以“失配”而直接跳到下一層,Search(char *s, int pos, int p, bool ok),s表示要查詢的字串,pos表示列舉到s的第pos位,p表示當前父親編號,ok表示之前是否已經跳過一次,true表示沒跳過,這裡有個問題就是比如

2 1
aa
ab
a

錯誤答案跑出來會是3,因為對於aa來說,他先跳和後跳的情況都被計算了,因此還需要記錄一個單詞是否已經滿足情況,由於點的個數太多但是n比較小,所以考慮將單詞結尾標號用map離散化,還有一點要注意的就是當pos==len的時候不能直接return,因為有可能在最後一個單詞那裡跳,比如

1 1

ab

b

所以當pos==len時要先判斷跳不跳,再return

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
using namespace std;
int const MAX = 1e5 + 5;
int n, m, ans, len;
char s[MAX];

struct Trie
{
	int next[MAX * 26][26], end[MAX * 26], tot, root, num;
	bool vis[10005];
	map <int, int> mp;

	inline int Newnode()
	{
		memset(next[tot], -1, sizeof(next[tot]));
		end[tot] = false;
		return tot ++;
	} 

	inline void Init()
	{
		num = 0;
		tot = 0;
		root = Newnode();
	}

	inline void Insert(char *s)
	{
		int p = root;
		for(int i = 0; i < (int)strlen(s); i++)
		{
			int idx = s[i] - 'a';
			if(next[p][idx] == -1)
				next[p][idx] = Newnode();
			p = next[p][idx];
		}
		end[p] = true;
		mp[p] = num ++;
	}

	inline void Search(char *s, int pos, int p, bool ok)
	{
		if(end[p] && !vis[mp[p]] && pos == len)
		{
			vis[mp[p]] = true;
			ans ++;
			return;
		}
		if(ok)
			for(int i = 0; i < 26; i++)
				if(next[p][i] != -1)
					Search(s, pos, next[p][i], false);
		if(pos == len)
			return;
		int idx = s[pos] - 'a';
		if(next[p][idx] != -1)
			Search(s, pos + 1, next[p][idx], ok);

	}

}t;

int main()
{
	t.Init();
	scanf("%d %d",&n, &m);
	for(int i = 0; i < n; i++)
	{
		scanf("%s", s);
		t.Insert(s);
	}
	for(int i = 0; i < m; i++)
	{
		scanf("%s", s);
		ans = 0;
		len = strlen(s);
		memset(t.vis, false, sizeof(t.vis));
		t.Search(s, 0, t.root, true);
		printf("%d\n", ans);
	}
}

相關文章