7.27考試總結(NOIP模擬25)[random·string·queue]

Varuxn發表於2021-07-28

死亡的盡頭,沒有神

T1 random

解題思路

這波是找規律完勝了。。

lby dalao根據樣例找出了正確的式子:\(\dfrac{n^2-1}{9}\)

然而,我這個菜雞卻推出了這樣一個錯誤的式子:\(\dfrac{(n-1)^2\times 2^n}{n^2\times (n+1)}\)

那麼哪個像正解呢,當然是我的這個了(雖然他一點道理沒有)。。。

別的啥也不想說了,看一下官方題解吧。。。

code

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int mod=998244353,INV_9=443664157;
int n,T;
signed main()
{
	scanf("%lld",&T);
	while(T--){
		scanf("%lld",&n);
		n=n%mod;
		printf("%lld\n",(n%mod*n%mod-1)*INV_9%mod);
	}
	return 0;
}

T2 string

解題思路

首先明確一下:是把字首接在字尾後面。

這樣就可以直接拼接對於每一個拼好的串維護一個 Hash 值,然後匹配就好了(40pts到手)

對於官方題解裡的 Subtask3 其實是可以卡到 90pts 的,只要剪一下枝就好了。

只可惜我太菜只卡到了 80pts 。

思路還是和題解一樣的維護每個串的 Hash 後,計算出在大串上作為字尾結尾和字首開頭的數量,分別記到 f 和 g 陣列裡。

\[Ans=\sum\limits_{i=1}^{n}f_i\times g_{i+1} \]

正解還是在求 f 和 g 陣列,對於之前的操作有了一個優化。

先把 n 個串分別正反壓入兩個 Tire 樹,然後就可以處理出每個位置的字尾或者字首的數量。

然後在 Tire 樹上 DFS 一遍就可以求出每一個的 Hash 值,然後的操作就與前面的差不多了,不過是加了一個二分優化。

注意要用不會自動排序的 unordered_map 以及在查詢是否有值的時候用 find 函式,不要直接呼叫值。

code

40pts

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')	f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*f;
}
const int N=2e5+10,M=1e3+10;
const unsigned long long base=13331ull;
int n,len,ans;
unsigned long long has[N],ha[N],p[N];
string s,ch[N];
signed main()
{
	cin>>s;
	len=s.size();
	s=" "+s;
	n=read();
	p[0]=1;
	for(int i=1;i<=len+1;i++)
		p[i]=p[i-1]*base;
	for(int i=1;i<=n;i++)
		cin>>ch[i];
	for(int i=1;i<s.size();i++)
		has[i]=has[i-1]*base+s[i];
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
		{
			string c=" "+ch[i]+ch[j];
			int fjx=ch[i].size(),le=c.size()-1;
			for(int k=1;k<=le;k++)
				ha[k]=ha[k-1]*base+c[k];
			for(int l=1;l<=fjx;l++)
				for(int r=fjx+1;r<=le;r++)
				{
					int lent=r-l+1;
					for(int k=1;k+lent-1<=len;k++)
					{
						int pos=k+lent-1;
						unsigned long long temp=has[pos]-has[k-1]*p[pos-k+1]+base*10;
						unsigned long long tmp=ha[r]-ha[l-1]*p[r-l+1]+base*10;
						if(temp==tmp)	ans++;
					}
				}
		}
	printf("%lld",ans);
	return 0;
}

80pts(最高90pts)

#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
using namespace std;
inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')	f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*f;
}
const int N=1e5+10;
const ull base=131ull;
int n,len,ans,lent[N],f[N],g[N];
ull has[N],p[N];
vector<ull > ha[N];
char s[N],ch[N];
signed main()
{
	scanf("%s",s+1);
	len=strlen(s+1);
	scanf("%lld",&n);
	p[0]=1;
	for(int i=1;i<=len;i++)
		p[i]=p[i-1]*base;
	for(int i=1;i<=len;i++)
		has[i]=has[i-1]*base+s[i];
	for(int i=1;i<=n;i++)
	{
		scanf("%s",ch+1);
		lent[i]=strlen(ch+1);
		ha[i].push_back(0);
		for(int j=1;j<=lent[i];j++)
			ha[i].push_back(ha[i][j-1]*base+ch[j]);
	}
	for(int i=1;i<=len;i++)
		for(int j=1;j<=n;j++)
		{
			int le=lent[j];
			for(int k=1;k<=le;k++)
			{
				ull tmp1=ha[j][le]-ha[j][le-k]*p[k];
				ull tmp2=has[i]-has[i-k]*p[k];
				if(tmp1==tmp2)	f[i]++;
				else	break;
			}
			for(int k=1;k<=le;k++)
			{
				ull tmp1=ha[j][k];
				ull tmp2=has[i+k-1]-has[i-1]*p[k];
				if(tmp1==tmp2)	g[i]++;
				else	break;
			}
		}
	for(int i=1;i<=len;i++)
		ans+=f[i]*g[i+1];
	printf("%lld",ans);
	return 0;
}

正解

#include<bits/stdc++.h>
#define int long long
#define ull unsigned long long
#define f() cout<<"Pass"<<endl
using namespace std;
inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')	f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*f;
}
const int N=1e5+10,M=5e5+10;
const ull base=1331;
vector<int> ch[N];
char s[N],c[N];
int len[N],n,lent,ans,f[N],g[N];
ull p[N],preh[N],sufh[N];
struct Tire
{
	int all,tre[M][30],val[M*30];
	unordered_map<ull,int> mp;
	Tire(){all=1;}
	void insert(int pos)
	{
		int rt=1;
		for(int i=0;i<len[pos];i++)
		{
			int num=ch[pos][i];
			if(!tre[rt][num])	tre[rt][num]=++all;
			rt=tre[rt][num];
			val[rt]++;
		}
	}
	void dfs(int x,ull cnt)
	{
		mp[cnt]=val[x];
		for(int i=1;i<=26;i++)
			if(tre[x][i])
			{
				val[tre[x][i]]+=val[x];
				dfs(tre[x][i],cnt*base+i);
			}
	}
}pre,suf;
signed main()
{
	scanf("%s",s+1);
	lent=strlen(s+1);
	n=read();
	for(int i=1;i<=n;i++)
	{
		scanf("%s",c+1);
		len[i]=strlen(c+1);
		for(int j=1;j<=len[i];j++)
			ch[i].push_back(c[j]-'a'+1);
	}
	for(int i=1;i<=n;i++)
		pre.insert(i);
	for(int i=1;i<=n;i++)
		reverse(ch[i].begin(),ch[i].end());
	for(int i=1;i<=n;i++)
		suf.insert(i);
	pre.dfs(1,0);
	suf.dfs(1,0);
	p[0]=1;
	for(int i=1;i<=lent;i++)
		p[i]=p[i-1]*base;
	for(int i=1;i<=lent;i++)
		preh[i]=preh[i-1]*base+s[i]-'a'+1;
	for(int i=lent;i>=1;i--)
		sufh[i]=sufh[i+1]*base+s[i]-'a'+1;
	for(int i=1;i<=lent;i++)
	{
		int temp=0,l=1,r=lent-i+1;
		if(pre.mp.find(s[i]-'a'+1)!=pre.mp.end())
		{
			while(l<=r)
			{
				int mid=(l+r)>>1;
				if(pre.mp.find(preh[i+mid-1]-preh[i-1]*p[mid])!=pre.mp.end())	l=mid+1,temp=mid;
				else	r=mid-1;
			}
			g[i]=pre.mp[preh[i+temp-1]-preh[i-1]*p[temp]];	
		}
		temp=0;l=0,r=i;
		if(suf.mp.find(s[i]-'a'+1)!=suf.mp.end())
		{
			while(l<=r)
			{
				int mid=(l+r)>>1;
				if(suf.mp.find(sufh[i-mid+1]-sufh[i+1]*p[mid])!=suf.mp.end())	l=mid+1,temp=mid;
				else	r=mid-1;
			}
			f[i]=suf.mp[sufh[i-temp+1]-sufh[i+1]*p[temp]];	
		}
	}
	for(int i=1;i<=lent;i++)
		ans+=f[i]*g[i+1];
	printf("%lld",ans);
	return 0;
}

T3 queen

解題思路

其實就是推式子,然後敲就行了。。

比較重要的就是一個柿子:\(\sum\limits_{i=1}^{n}i^2=\dfrac{n\times (n+1)\times (2n+1)}{6}\)

考場上推了一下 k=3 的以為 4以及以上的有非常難的一些東西就直接棄掉了。

考完之後看了一下 k=1 的情況沒有取\(\bmod\),掛了20pts

然後就是公式亂用,然後就是非常噁心的邊界問題,官方題解寫的就挺好:

code

80pts

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')	f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*f;
}
const int N=1e3+10,mod=3e5+7;
int T,n,m,k,ans,c[N][N];
void get_C()
{
	c[0][0]=1;
	for(int i=1;i<N;i++)
	{
		c[i][0]=c[i][i]=1;
		for(int j=1;j<i;j++)
		{
			c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
		}
	}
}
void solve()
{
	n=read();
	m=read();
	k=read();
	ans=0;
	if(k>m&&k>n)
	{
		cout<<0<<endl;
		return ;
	}
	if(k==1)
	{
		cout<<n*m%mod<<endl;
		return ;
	}
	if(m>n)	swap(n,m);
	ans=(c[n][k]*m%mod+c[m][k]*n%mod)%mod;
	for(int i=k;i<m;i++)
		ans=(ans+c[i][k]*4%mod)%mod;
	ans=(ans+2*c[m][k]%mod*(n-m+1)%mod)%mod;
	if(k==3)
	{
		for(int len=2;len<=m;len++)
			ans=(ans+(n-len+1)*(m-len+1)%mod*4%mod)%mod;
		for(int len=2;len<=n&&2*len-1<=m;len++)
			ans=(ans+(n-len+1)*(m-2*len+2)%mod*2%mod)%mod;
		for(int len=2;len<=m&&2*len-1<=n;len++)
			ans=(ans+(m-len+1)*(n-2*len+2)%mod*2%mod)%mod;
	}
	if(k==4)
	{
		for(int len=2;len<=m;len++)
			ans=(ans+(n-len+1)*(m-len+1)%mod)%mod;
		for(int len=2;len<=n&&2*len-1<=m;len++)
			ans=(ans+(n-len+1)*(m-2*len+2)%mod*2%mod)%mod;
		for(int len=2;len<=m&&2*len-1<=n;len++)
			ans=(ans+(m-len+1)*(n-2*len+2)%mod*2%mod)%mod;
		for(int len=1;2*len+1<=m;len++)
			ans=(ans+(m-2*len)*(n-2*len)%mod*5%mod)%mod;
	}
	if(k==5)
	{
		for(int len=1;2*len+1<=m;len++)
			ans=(ans+(m-2*len)*(n-2*len)%mod*2%mod)%mod;
	}
	printf("%lld\n",ans%mod);
}
signed main()
{
	T=read();
	get_C();
	while(T--)	solve();
	return 0;
}

正解(程式碼略醜)

#include<bits/stdc++.h>
#define int long long
#define f() cout<<"Pass"<<endl
using namespace std;
inline int read()
{
	int x=0,f=1;
	char ch=getchar();
	while(ch>'9'||ch<'0')
	{
		if(ch=='-')	f=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9')
	{
		x=(x<<1)+(x<<3)+(ch^48);
		ch=getchar();
	}
	return x*f;
}
const int N=1e3+10,mod=3e5+7;
int T,n,m,pn,pm,k,ans,jc[mod+10],inv[mod+10];
int ksm(int x,int y)
{
	int temp=1;
	while(y)
	{
		if(y&1)	temp=temp*x%mod;
		x=x*x%mod;
		y>>=1;
	}
	return temp%mod;
}
int work(int x,int y)
{
	if(y>x)	return 0;
	return jc[x]*ksm(jc[y],mod-2)%mod*ksm(jc[x-y],mod-2)%mod;
}
int C(int x,int y)
{
	if(y>x)	return 0;
	if(!x||!y)	return 1;
	return C(x/mod,y/mod)*work(x%mod,y%mod)%mod;
}
void solve()
{
	n=read();
	m=read();
	k=read();
	if(m<n)	swap(n,m);
	pn=(n-1)%mod+1;
	pm=(m-1)%mod+1;
	if(k>m&&k>n)
	{
		cout<<0<<endl;
		return ;
	}
	if(k==1)
	{
		cout<<pn*pm%mod<<endl;
		return ;
	}
	ans=(C(n,k)*pm%mod+C(m,k)*pn%mod+2*(m-n+1)%mod*C(n,k)%mod+4*C(n,k+1)%mod)%mod;
	if(k==3)
	{
		int tmp1=(min(m,n/2)-1)%mod+1,tmp2=(min(n,m/2)-1)%mod+1;
		n=pn;m=pm;
		ans=(ans+4ll*(m*n%mod*(n-1)%mod-(m+n)*(n-1)%mod*n%mod*ksm(2,mod-2)%mod+(n-1)*n%mod*(2*n-1)%mod*ksm(6,mod-2)%mod+2ll*mod)%mod)%mod;
		ans=(ans+2ll*(-tmp1*(tmp1+1)%mod*ksm(2,mod-2)%mod*(n+m*2)%mod+(tmp1+1)*tmp1%mod*(2*tmp1+1)%mod*ksm(6,mod-2)%mod*2ll%mod+n*m%mod*tmp1%mod+mod*2)%mod)%mod;
		ans=(ans+2ll*(-tmp2*(tmp2+1)%mod*ksm(2,mod-2)%mod*(m+n*2)%mod+(tmp2+1)*tmp2%mod*(2*tmp2+1)%mod*ksm(6,mod-2)%mod*2ll%mod+n*m%mod*tmp2%mod+mod*2)%mod)%mod;
	}
	if(k==4)
	{
		int tmp1=(min(m,n/2)-1)%mod+1,tmp2=(min(n,m/2)-1)%mod+1,temp1=(n-1)/2%mod,temp2=n/2%mod;
		n=pn;m=pm;
		ans=(ans+(m*n%mod*(n-1)%mod-(m+n)*(n-1)%mod*n%mod*ksm(2,mod-2)%mod+(n-1)*n%mod*(2*n-1)%mod*ksm(6,mod-2)%mod+2ll*mod)%mod)%mod;
		ans=(ans+2ll*(-tmp1*(tmp1+1)%mod*ksm(2,mod-2)%mod*(n+m*2)%mod+(tmp1+1)*tmp1%mod*(2*tmp1+1)%mod*ksm(6,mod-2)%mod*2ll%mod+n*m%mod*tmp1%mod+mod*2)%mod)%mod;
		ans=(ans+2ll*(-tmp2*(tmp2+1)%mod*ksm(2,mod-2)%mod*(m+n*2)%mod+(tmp2+1)*tmp2%mod*(2*tmp2+1)%mod*ksm(6,mod-2)%mod*2ll%mod+n*m%mod*tmp2%mod+mod*2)%mod)%mod;
		ans=(ans+4*(n*m%mod*temp1%mod+4*(temp1+1)%mod*temp1%mod*(2*temp1+1)%mod*ksm(6,mod-2)%mod-2*(n+m)%mod*temp1%mod*(temp1+1)%mod*ksm(2,mod-2)%mod+2*mod)%mod)%mod;
		ans=(ans+n*m%mod*temp2%mod+4*(temp2+1)%mod*temp2%mod*(2*temp2+1)%mod*ksm(6,mod-2)%mod-2*(n+m)%mod*temp2%mod*(temp2+1)%mod*ksm(2,mod-2)%mod+2*mod)%mod;
	}
	if(k==5)
	{
		int temp1=(n-1)/2%mod,temp2=n/2%mod;
		n=pn;m=pm;
		ans=(ans+n*m%mod*temp1%mod+4*(temp1+1)%mod*temp1%mod*(2*temp1+1)%mod*ksm(6,mod-2)%mod-2*(n+m)%mod*temp1%mod*(temp1+1)%mod*ksm(2,mod-2)%mod+2*mod)%mod;
		ans=(ans+n*m%mod*temp2%mod+4*(temp2+1)%mod*temp2%mod*(2*temp2+1)%mod*ksm(6,mod-2)%mod-2*(n+m)%mod*temp2%mod*(temp2+1)%mod*ksm(2,mod-2)%mod+2*mod)%mod;
	}
	printf("%lld\n",ans%mod);
}
void init()
{
	jc[0]=1;
	for(int i=1;i<=mod;i++)
		jc[i]=jc[i-1]*i%mod;
	inv[mod-1]=ksm(jc[mod-1],mod-2);
	inv[0]=inv[1]=1;
	for(int i=mod-2;i>=1;i--)
		inv[i]=inv[i+1]*(i+1)%mod;
}
signed main()
{
	T=read();
	init();
	while(T--)	solve();
	return 0;
}

相關文章