112 lines
3.5 KiB
Python

import math
import numpy as np
def solution(w, h, s):
memo_f = {w:math.factorial(w), h:math.factorial(h)}
memo_sp = {}
cases = 0
for weight, n_cycles in generate_cycle_count(w, h, memo_f):
if n_cycles not in memo_sp:
memo_sp[n_cycles] = s ** n_cycles
cases += weight * memo_sp[n_cycles]
return str(cases / (memo_f[w] * memo_f[h]))
def generate_cycle_count(w, h, memo_f = {}):
memo_c = {}
memo_p = {}
memo_gcd = {}
for pw in generate_cycles(w, memo = memo_c):
for ph in generate_cycles(h, memo = memo_c):
cycle_count = 0
for i in pw.split():
for j in ph.split():
cycle_count += gcd(i, j, memo_gcd)
variations = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p)
yield variations, cycle_count
# generate cycle list in a string format
# e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1.
def generate_cycles(remain, start = 1, memo = {}):
if remain < start:
return
if (remain, start) in memo:
for j in memo[(remain, start)]:
yield j
else:
memo[(remain, start)] = []
for i in range(start, remain+1):
if i == remain:
memo[(remain, start)] += [str(i)]
for j in generate_cycles(remain - i, i, memo):
memo[(remain, start)] += [str(i) + ' ' + j]
for j in memo[(remain, start)]:
yield j
def gcd(x, y, memo = {}):
if (x, y) in memo:
return memo[(x, y)]
if (y, x) in memo:
return memo[(y, x)]
m, n = int(x), int(y)
while n != 0:
t = m % n
m, n = n, t
memo[(x, y)] = m
return m
# count number of possible permutation
# e.g., '1 1' => 12
# e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321
# e.g., '3' => 111(*2)
def count_permutation(s, memo_f, memo_p):
if s not in memo_p:
group = {}
length = 0
for c in s.split():
size = int(c)
if size in group:
group[size] += 1
else:
group[size] = 1
length += size
result = memo_f[length]
for size, freq in group.items():
if freq not in memo_f:
memo_f[freq] = math.factorial(freq)
result /= memo_f[freq]
for i in range(freq):
result /= size
memo_p[s] = result
return memo_p[s]
tests = [
([2, 3, 4], '430'),
([2, 2, 2], '7'),
([1, 1, 2], '2'),
([1, 1, 3], '3'),
([2, 1, 2], '3'),
([1, 2, 2], '3'),
([3, 1, 2], '4'),
([1, 3, 2], '4'),
([4, 1, 2], '5'),
([1, 4, 2], '5'),
([5, 1, 2], '6'),
([1, 5, 2], '6'),
([6, 1, 2], '7'),
([1, 6, 2], '7'),
([12, 1, 2], '13'),
([1, 12, 2], '13'),
([2, 3, 2], '13'),
([3, 3, 2], '36'),
([4, 4, 2], '317'),
([5, 4, 2], '1053'),
([5, 5, 2], '5624'),
([6, 5, 2], '28576'),
([6, 6, 2], '251610'),
([12, 12, 20], '97195340925396730736950973830781340249131679073592360856141700148734207997877978005419735822878768821088343977969209139721682171487959967012286474628978470487193051591840'),
]
for i, o in tests:
result = solution(*i)
print (i, result == o, result, o)