def solution(n, s, a, b, fares): nodes = {i+1: {} for i in range(n)} for n1, n2, c in fares: if n2 in nodes[n1]: nodes[n1][n2] = min(nodes[n1][n2], c) nodes[n2][n1] = min(nodes[n1][n2], c) else: nodes[n1][n2] = c nodes[n2][n1] = c memo_s = shortest_path(s, nodes) memo_a = shortest_path(a, nodes) memo_b = shortest_path(b, nodes) answer = float('inf') for t in range(1, n+1): if t in memo_s: answer = min(answer, memo_s[t] + memo_a[t] + memo_b[t]) return answer def shortest_path(x, nodes): memo = {i+1: float('inf') for i in range(len(nodes))} memo[x] = 0 visit = [x] while len(visit) > 0: curr = visit.pop(0) for n, c in nodes[curr].items(): total_cost = memo[curr] + c if total_cost < memo[n]: memo[n] = total_cost visit += [n] return memo