33 lines
954 B
Python

def solution(n, s, a, b, fares):
nodes = {i+1: {} for i in range(n)}
for n1, n2, c in fares:
if n2 in nodes[n1]:
nodes[n1][n2] = min(nodes[n1][n2], c)
nodes[n2][n1] = min(nodes[n1][n2], c)
else:
nodes[n1][n2] = c
nodes[n2][n1] = c
memo_s = shortest_path(s, nodes)
memo_a = shortest_path(a, nodes)
memo_b = shortest_path(b, nodes)
answer = float('inf')
for t in range(1, n+1):
if t in memo_s:
answer = min(answer, memo_s[t] + memo_a[t] + memo_b[t])
return answer
def shortest_path(x, nodes):
memo = {i+1: float('inf') for i in range(len(nodes))}
memo[x] = 0
visit = [x]
while len(visit) > 0:
curr = visit.pop(0)
for n, c in nodes[curr].items():
total_cost = memo[curr] + c
if total_cost < memo[n]:
memo[n] = total_cost
visit += [n]
return memo