程式碼隨想錄演算法訓練營第63天 | SPFA演算法最佳化+變式

哆啦**發表於2024-08-08

94.城市間貨物運輸 I
https://kamacoder.com/problempage.php?pid=1152
Bellman_ford 佇列最佳化演算法(又名SPFA)
https://www.programmercarl.com/kamacoder/0094.城市間貨物運輸I-SPFA.html
95.城市間貨物運輸 II
https://kamacoder.com/problempage.php?pid=1153
bellman_ford之判斷負權迴路
https://www.programmercarl.com/kamacoder/0095.城市間貨物運輸II.html
96.城市間貨物運輸 III
https://kamacoder.com/problempage.php?pid=1154
bellman_ford之單源有限最短路
https://www.programmercarl.com/kamacoder/0096.城市間貨物運輸III.html

94. 城市間貨物運輸 I — SPFA

  • 最佳化方式

    • 採用佇列
    • 只按照邊和節點進行遍歷(不需要n-1遍)
    點選檢視程式碼
    from collections import deque
    class edge:
    	def __init__(self,t,v):
    		self.to = t
    		self.val = v
    
    def spfa(n,m,edges,start,end):
    	isqueue = [False]*(n+1)
    	mindist = [float("inf")]*(n+1)
    	mindist[start] = 0
    	que = deque([start])
    
    	while que:
    		curr = que.popleft()
    		isqueue[curr] = False
    		if len(edges[curr])>0:
    			for ed in edges[curr]:
    				to = ed.to
    				val = ed.val
    				if mindist[to]>mindist[curr]+val:
    					mindist[to] = mindist[curr]+val
    					if not isqueue[to]:
    						que.append(to)
    						isqueue[to] = True
    
    	if mindist[end]==float("inf"):
    		return "unconnected"
    	else:
    		return mindist[end]
    
    
    def main():
    	n,m = map(int,input().split())
    	edges = [[] for _ in range(n+1)]
    	# print(edges)
    	for _ in range(m):
    		s,t,v = map(int,input().split())
    		edges[s].append(edge(t,v))
    	res = spfa(n,m,edges,1,n)
    	print(res)
    
    if __name__ == '__main__':
    	main()
    

95. 城市間貨物運輸 II - 判斷負權迴路

  • 問題:判斷是否存在負權迴路

  • 在SPFA中

  • 鬆弛n-1次

    • 如果有沒有負權迴路,n-1次鬆弛所有的邊能得到起點到節點的最短路徑;n次以上minDist陣列中的結果不會有改變;
    • 但是如果第n次也改變:說明存在負權迴路;
    點選檢視程式碼
    def bellman(n,m,edges,start,end):
    	mindist = [float("inf")]*(n+1)
    	mindist[start] = 0
    	count = [0]*(n+1)
    
    	for i in range(n):
    		update = False
    		for s,t,v in edges:
    			if mindist[s]!=float("inf") and mindist[t]>mindist[s]+v:
    				mindist[t] = mindist[s]+v
    				update = True
    		if not update:
    			break
    
    	for s,t,v in edges:
    		if mindist[s]!=float("inf") and mindist[t]>mindist[s]+v:
    			return "circle"
    
    	if mindist[end]==float("inf"):
    		return "unconnected"
    	else:
    		return mindist[end]
    
    def main():
    	n,m = map(int,input().split())
    	grid = [[float("inf")]*(n+1) for _ in range(n+1)]
    	# print(edges)
    	edges = []
    	for _ in range(m):
    		s,t,v = map(int,input().split())
    		edges.append((s,t,v))
    	res = bellman(n,m,edges,1,n)
    	print(res)
    if __name__ == '__main__':
    	main()
    
  • 佇列最佳化

    • 如果某個點更新超過n-1次,說明存在負權迴路;
    點選檢視程式碼
    from collections import deque
    class edge:
    	def __init__(self,t,v):
    		self.to = t
    		self.val = v
    
    def spfa(n,m,edges,start,end):
    	isqueue = [False]*(n+1)
    	mindist = [float("inf")]*(n+1)
    	mindist[start] = 0
    	que = deque([start])
    	count = [0]*(n+1)
    
    	while que:
    		curr = que.popleft()
    		isqueue[curr] = False
    		if len(edges[curr])>0:
    			for ed in edges[curr]:
    				to = ed.to
    				val = ed.val
    				if mindist[to]>mindist[curr]+val:
    					mindist[to] = mindist[curr]+val
    					count[to]+=1
    					if count[to]==n:
    						return "circle"
    					if not isqueue[to]:
    						que.append(to)
    						isqueue[to] = True
    
    	if mindist[end]==float("inf"):
    		return "unconnected"
    	else:
    		return mindist[end]
    
    def main():
    	n,m = map(int,input().split())
    	edges = [[] for _ in range(n+1)]
    	# print(edges)
    	for _ in range(m):
    		s,t,v = map(int,input().split())
    		edges[s].append(edge(t,v))
    	res = spfa(n,m,edges,1,n)
    	print(res)
    if __name__ == '__main__':
    	main()
    

96. 城市間貨物運輸 III

  • 限制條件:在最多經過 k 個城市的條件下,從城市 src 到城市 dst 的最低運輸成本。

  • 分析:

    • 等價於:起點最多經過k + 1 條邊到達終點的最短距離
  • 鬆弛n-1次

    • 限制迴圈k-1次
    • 複製mindist保證每次更新一步
    點選檢視程式碼
    from collections import deque
    
    class edge:
    	def __init__(self,t,v):
    		self.to = t
    		self.val = v
    def spfa(n,m,edges,start,end,k):
    	mindist = [float("inf")]*(n+1)
    	mindist[start]=0
    	isqueue = [False]*(n+1)
    	queue = deque([start])
    	k = k+1
    	while k and queue:
    		k -= 1
    		mindist_copy = mindist[:]
    		visited = [False]*(n+1)
    		que_size = len(queue)
    		for _ in range(que_size):
    			curr = queue.popleft()
    			if len(edges[curr])>0:
    				for ed in edges[curr]:
    					from_ = curr
    					to = ed.to
    					val = ed.val
    					if mindist[to]>mindist_copy[curr]+val:
    						mindist[to] = mindist_copy[curr]+val
    						if not visited[to]:
    							queue.append(to)
    							visited[to] = True
    		# print(mindist)
    	if mindist[end]!=float("inf"):
    		return mindist[end]
    	else:
    		return "unconnected"
    def main():
    	n,m = map(int,input().split())
    	edges = [[] for _ in range(n+1)]
    	# print(edges)
    	for _ in range(m):
    		s,t,v = map(int,input().split())
    		edges[s].append(edge(t,v))
    	src,dst,k = map(int,input().split())
    	res = spfa(n,m,edges,src,dst,k)
    	print(res)
    if __name__ == '__main__':
    	main()
    
  • 佇列最佳化

    • 使用visited陣列記錄遍歷過的點
    • 限制遍歷次數k
    點選檢視程式碼
    def bellman(n, m, edges, start, end, k):
    	mindist = [float("inf")] * (n + 1)
    	mindist[start] = 0
    	for i in range(1, k + 2):
    		mindist_copy = mindist[:]
    		update = False
    		for s, t, v in edges:
    			if mindist_copy[s] != float("inf") and mindist[t] > mindist_copy[s] + v:
    				mindist[t] = mindist_copy[s] + v
    				update = True
    		if not update:
    			break
    	if mindist[end] == float("inf"):
    		return "unreachable"
    	else:
    		return mindist[end]
    def main():
    	n, m = map(int, input().split())
    	edges = []
    	for _ in range(m):
    		s, t, v = map(int, input().split())
    		edges.append((s, t, v))
    	src, dst, k = map(int, input().split())
    	res = bellman(n, m, edges, src, dst, k)
    	print(res)
    if __name__ == '__main__':
    	main()
    

相關文章