import math import numpy as np def solution(w, h, s): return str(sum([s**cycle for cycle in generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h))) def generate_cycles(w, h): identity = np.arange(w*h).reshape((w, h)) for i in generate_permutation(list(range(w))): for j in generate_permutation(list(range(h))): yield count_cycle(identity[i][:, j]) def generate_permutation(numbers): if len(numbers) == 1: yield numbers else: for i, n in enumerate(numbers): for j in generate_permutation(numbers[:i] + numbers[i+1:]): yield [n] + j def count_cycle(arr): cycle = np.zeros(arr.shape, dtype=int) w, h = arr.shape cycle_id = 0 for i in range(w): for j in range(h): x, y = i, j if cycle[x, y] == 0: cycle_id += 1 while cycle[x, y] == 0: cycle[x, y] = cycle_id x, y = arr[x, y]/h, arr[x, y]%h return cycle_id tests = [ ([2, 3, 4], '430'), ([2, 2, 2], '7'), ([1, 1, 2], '2'), ([1, 1, 3], '3'), ([2, 1, 2], '3'), ([1, 2, 2], '3'), ([3, 1, 2], '4'), ([1, 3, 2], '4'), ([2, 3, 2], '13'), ] for i, o in tests: result = solution(*i) print (i, result == o, result, o)