import math import numpy as np def solution(w, h, s): memo_f = {w:math.factorial(w), h:math.factorial(h)} memo_sp = {} cases = 0 for c in generate_cycle_count(w, h, memo_f): if c not in memo_sp: memo_sp[c] = s ** c cases += memo_sp[c] return str(cases / (memo_f[w] * memo_f[h])) def generate_cycle_count(w, h, memo_f = {}): memo_c = {} memo_p = {} memo_gcd = {} for pw in generate_cycles(w, memo = memo_c): for ph in generate_cycles(h, memo = memo_c): cycle_count = 0 for i in pw.split(): for j in ph.split(): cycle_count += gcd(i, j, memo_gcd) cases = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p) #print pw, ph, cases, cycle_count for i in range(cases): yield cycle_count # generate cycle list in a string format # e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1. def generate_cycles(remain, start = 1, memo = {}): if remain < start: return if (remain, start) in memo: for j in memo[(remain, start)]: yield j else: memo[(remain, start)] = [] for i in range(start, remain+1): if i == remain: memo[(remain, start)] += [str(i)] for j in generate_cycles(remain - i, i, memo): memo[(remain, start)] += [str(i) + ' ' + j] for j in memo[(remain, start)]: yield j def gcd(x, y, memo = {}): if (x, y) in memo: return memo[(x, y)] if (y, x) in memo: return memo[(y, x)] m, n = int(x), int(y) while n != 0: t = m % n m, n = n, t memo[(x, y)] = m return m # count number of possible permutation # e.g., '1 1' => 12 # e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321 # e.g., '3' => 111(*2) def count_permutation(s, memo_f, memo_p): if s not in memo_p: group = {} length = 0 for c in s.split(): size = int(c) if size in group: group[size] += 1 else: group[size] = 1 length += size result = memo_f[length] for size, freq in group.items(): if freq not in memo_f: memo_f[freq] = math.factorial(freq) result /= memo_f[freq] * (size ** freq) memo_p[s] = result return memo_p[s] def _solution(w, h, s): return str(sum([s**cycle for cycle in _generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h))) def _generate_cycles(w, h): identity = np.arange(w*h).reshape((w, h)) for i in _generate_permutation(list(range(w))): for j in _generate_permutation(list(range(h))): #print i, j, _count_cycle(identity[i][:, j]) yield _count_cycle(identity[i][:, j]) def _generate_permutation(numbers): if len(numbers) == 1: yield numbers else: for i, n in enumerate(numbers): for j in _generate_permutation(numbers[:i] + numbers[i+1:]): yield [n] + j def _count_cycle(arr): cycle = np.zeros(arr.shape, dtype=int) w, h = arr.shape cycle_id = 0 for i in range(w): for j in range(h): x, y = i, j if cycle[x, y] == 0: cycle_id += 1 while cycle[x, y] == 0: cycle[x, y] = cycle_id x, y = arr[x, y]/h, arr[x, y]%h return cycle_id tests = [ #([2, 3, 4], '430'), #([2, 2, 2], '7'), #([1, 1, 2], '2'), #([1, 1, 3], '3'), #([2, 1, 2], '3'), #([1, 2, 2], '3'), #([3, 1, 2], '4'), #([1, 3, 2], '4'), #([4, 1, 2], '5'), #([1, 4, 2], '5'), #([5, 1, 2], '6'), #([1, 5, 2], '6'), #([6, 1, 2], '7'), #([1, 6, 2], '7'), ([12, 1, 2], '13'), #([1, 12, 2], '13'), #([2, 3, 2], '13'), #([3, 3, 2], '36'), #([4, 4, 2], '317'), #([5, 4, 2], '1053'), #([5, 5, 2], '5624'), #([6, 5, 2], '28576'), #([6, 6, 2], '251610'), #([12, 12, 20], '251610'), ] for i, o in tests: result = solution(*i) print (i, result == o, result, o)