POJ 3613 Cow Relays 矩陣乘法Floyd+矩陣快速冪

~hsm~發表於2019-03-05

title

POJ 3613
LUOGU 2886
JZOJ 1355
JZOJ 1932

The meaning of problem

給出一張圖,求k邊最短路,即經過k條邊的最短路。

analysis

思考一下:如果一個矩陣,表示走k條邊後,一張圖的點與點的最短路徑,(a,b)表示從a到b的最短路徑,然後我們把它與自己,按照矩陣乘法的格式“相乘”,把其中的乘改為取minc.a[i][j]=min(c.a[i][j],x.a[i][k]+y.a[k][j]);min,c.a[i][j] = min(c.a[i][j],x.a[i][k]+y.a[k][j]);看不懂先看下面。

這樣得到的是走k+k條邊的矩陣。有點抽象,下面詳細解釋下:

c中的一個點(a,b),當我們用x矩陣和y矩陣求它時,我們列舉了x矩陣的a行所有數,與y矩陣的b列所有數,並且他們的座標只能是相對應的,比如x矩陣的(a,2)這個點,相應的y矩陣點就是(2,b),那麼放到圖上去理解,即從a點經過2點到b點的距離,類似的點不只有2,把所有點列舉完後,c.a[a][b]就是從a到b的最短距離。(意會一下)

這樣下來,會得到走k+k條邊的最短路徑,對於其他的矩陣這樣操作,得到的是他們兩個,經過的邊數相加的結果。(一個經過a條邊後的矩陣 與 一個經過b條邊後的矩陣這樣操作後,是經過a+b條邊後的矩陣,矩陣中存的是最短路徑)。解釋一下:向上面的例子一樣,(a,2)(2,b),是即從a點經過2點到b點的距離,因為x矩陣和y矩陣都是走k條邊後的最短路徑,那麼x矩陣中的(a,2)是走k步後的最短路徑,(2,b)也是,那麼他們相加不就是走k+k條邊後的最短路徑嗎?其他的矩陣一樣。

然後,就可以套用快速冪的模板了,只不過將以前的乘改成加了,也就是倍增的思想的,比如對於走10條邊,它的二進位制是1010,那麼我們就讓在走2(10)邊時的矩陣 乘以 8(1000)邊的矩陣,得到走10條邊的矩陣即開始時由1->2->4->8->16……即倍增中的2次冪。
摘自hwim

code

#include<algorithm>
#include<bitset>
#include<cctype>
#include<cerrno>
#include<clocale>
#include<cmath>
#include<complex>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<deque>
#include<exception>
#include<fstream>
#include<functional>
#include<limits>
#include<list>
#include<map>
#include<iomanip>
#include<ios>
#include<iosfwd>
#include<iostream>
#include<istream>
#include<ostream>
#include<queue>
#include<set>
#include<sstream>
#include<stack>
#include<stdexcept>
#include<streambuf>
#include<string>
#include<utility>
#include<vector>
#include<cwchar>
#include<cwctype>
using namespace std;
const int maxn=210,inf=0x3f3f3f3f;
template<typename T>inline void read(T &x)
{
	x=0;
	T f=1,ch=getchar();
	while (!isdigit(ch) && ch^'-') ch=getchar();
	if (ch=='-') f=-1, ch=getchar();
	while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48), ch=getchar();
	x*=f;
}
int tot=200;
map<int,int>m;
struct Cow
{
	int a[maxn+10][maxn+10];
	inline void pre()
	{
		for (int i=1;i<=tot;++i)
			for (int j=1;j<=tot;++j)
				a[i][j]=inf;
	}
}st,ed;
inline Cow mul(Cow a,Cow b)
{
	Cow c;
	c.pre();
	for (int i=1;i<=tot;++i)
		for (int j=1;j<=tot;++j)
			for (int k=1;k<=tot;++k)
				c.a[i][j]=min(c.a[i][j],a.a[i][k]+b.a[k][j]);
	return c;
}
int main()
{
	int n,t,s,e;
	read(n);read(t);read(s);read(e);
	st.pre();
	tot=0;
	while (t--)
	{
		int z,x,y;
		read(z);read(x);read(y);
		x=m[x]?m[x]:(m[x]=++tot);
		y=m[y]?m[y]:(m[y]=++tot);
		st.a[x][y]=st.a[y][x]=z;
	}
	memcpy(ed.a,st.a,sizeof(ed.a));
	--n;
	while (n)
	{
		if (n&1)
			ed=mul(ed,st);
		st=mul(st,st);
		n>>=1;
	}
	printf("%d\n",ed.a[m[s]][m[e]]);
	return 0;
}

相關文章