33 lines
954 B
Python
33 lines
954 B
Python
def solution(n, s, a, b, fares):
|
|
nodes = {i+1: {} for i in range(n)}
|
|
for n1, n2, c in fares:
|
|
if n2 in nodes[n1]:
|
|
nodes[n1][n2] = min(nodes[n1][n2], c)
|
|
nodes[n2][n1] = min(nodes[n1][n2], c)
|
|
else:
|
|
nodes[n1][n2] = c
|
|
nodes[n2][n1] = c
|
|
|
|
memo_s = shortest_path(s, nodes)
|
|
memo_a = shortest_path(a, nodes)
|
|
memo_b = shortest_path(b, nodes)
|
|
|
|
answer = float('inf')
|
|
for t in range(1, n+1):
|
|
if t in memo_s:
|
|
answer = min(answer, memo_s[t] + memo_a[t] + memo_b[t])
|
|
return answer
|
|
|
|
def shortest_path(x, nodes):
|
|
memo = {i+1: float('inf') for i in range(len(nodes))}
|
|
memo[x] = 0
|
|
visit = [x]
|
|
while len(visit) > 0:
|
|
curr = visit.pop(0)
|
|
for n, c in nodes[curr].items():
|
|
total_cost = memo[curr] + c
|
|
if total_cost < memo[n]:
|
|
memo[n] = total_cost
|
|
visit += [n]
|
|
return memo
|