diff --git a/level5/disorderly-escape/solution.py b/level5/disorderly-escape/solution.py index 2921b7b..3c8b92e 100644 --- a/level5/disorderly-escape/solution.py +++ b/level5/disorderly-escape/solution.py @@ -2,23 +2,106 @@ import math import numpy as np def solution(w, h, s): - return str(sum([s**cycle for cycle in generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h))) + memo_f = {w:math.factorial(w), h:math.factorial(h)} + memo_sp = {} + cases = 0 + for c in generate_cycle_count(w, h, memo_f): + if c not in memo_sp: + memo_sp[c] = s ** c + cases += memo_sp[c] + return str(cases / (memo_f[w] * memo_f[h])) -def generate_cycles(w, h): +def generate_cycle_count(w, h, memo_f = {}): + memo_c = {} + memo_p = {} + memo_gcd = {} + for pw in generate_cycles(w, memo = memo_c): + for ph in generate_cycles(h, memo = memo_c): + cycle_count = 0 + for i in pw.split(): + for j in ph.split(): + cycle_count += gcd(i, j, memo_gcd) + cases = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p) + #print pw, ph, cases, cycle_count + for i in range(cases): + yield cycle_count + +# generate cycle list in a string format +# e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1. +def generate_cycles(remain, start = 1, memo = {}): + if remain < start: + return + if (remain, start) in memo: + for j in memo[(remain, start)]: + yield j + else: + memo[(remain, start)] = [] + for i in range(start, remain+1): + if i == remain: + memo[(remain, start)] += [str(i)] + for j in generate_cycles(remain - i, i, memo): + memo[(remain, start)] += [str(i) + ' ' + j] + for j in memo[(remain, start)]: + yield j + +def gcd(x, y, memo = {}): + if (x, y) in memo: + return memo[(x, y)] + if (y, x) in memo: + return memo[(y, x)] + m, n = int(x), int(y) + while n != 0: + t = m % n + m, n = n, t + memo[(x, y)] = m + return m + +# count number of possible permutation +# e.g., '1 1' => 12 +# e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321 +# e.g., '3' => 111(*2) +def count_permutation(s, memo_f, memo_p): + if s not in memo_p: + group = {} + length = 0 + for c in s.split(): + size = int(c) + if size in group: + group[size] += 1 + else: + group[size] = 1 + length += size + result = memo_f[length] + for size, freq in group.items(): + if freq not in memo_f: + memo_f[freq] = math.factorial(freq) + result /= memo_f[freq] * (size ** freq) + memo_p[s] = result + return memo_p[s] + + + + + +def _solution(w, h, s): + return str(sum([s**cycle for cycle in _generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h))) + +def _generate_cycles(w, h): identity = np.arange(w*h).reshape((w, h)) - for i in generate_permutation(list(range(w))): - for j in generate_permutation(list(range(h))): - yield count_cycle(identity[i][:, j]) + for i in _generate_permutation(list(range(w))): + for j in _generate_permutation(list(range(h))): + #print i, j, _count_cycle(identity[i][:, j]) + yield _count_cycle(identity[i][:, j]) -def generate_permutation(numbers): +def _generate_permutation(numbers): if len(numbers) == 1: yield numbers else: for i, n in enumerate(numbers): - for j in generate_permutation(numbers[:i] + numbers[i+1:]): + for j in _generate_permutation(numbers[:i] + numbers[i+1:]): yield [n] + j -def count_cycle(arr): +def _count_cycle(arr): cycle = np.zeros(arr.shape, dtype=int) w, h = arr.shape cycle_id = 0 @@ -32,20 +115,33 @@ def count_cycle(arr): x, y = arr[x, y]/h, arr[x, y]%h return cycle_id -tests = [ - ([2, 3, 4], '430'), - ([2, 2, 2], '7'), - ([1, 1, 2], '2'), - ([1, 1, 3], '3'), - ([2, 1, 2], '3'), - ([1, 2, 2], '3'), - ([3, 1, 2], '4'), - ([1, 3, 2], '4'), - ([2, 3, 2], '13'), +tests = [ + #([2, 3, 4], '430'), + #([2, 2, 2], '7'), + #([1, 1, 2], '2'), + #([1, 1, 3], '3'), + #([2, 1, 2], '3'), + #([1, 2, 2], '3'), + #([3, 1, 2], '4'), + #([1, 3, 2], '4'), + #([4, 1, 2], '5'), + #([1, 4, 2], '5'), + #([5, 1, 2], '6'), + #([1, 5, 2], '6'), + #([6, 1, 2], '7'), + #([1, 6, 2], '7'), + ([12, 1, 2], '13'), + #([1, 12, 2], '13'), + #([2, 3, 2], '13'), + #([3, 3, 2], '36'), + #([4, 4, 2], '317'), + #([5, 4, 2], '1053'), + #([5, 5, 2], '5624'), + #([6, 5, 2], '28576'), + #([6, 6, 2], '251610'), + #([12, 12, 20], '251610'), ] for i, o in tests: result = solution(*i) print (i, result == o, result, o) - -