112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
import math
|
|
import numpy as np
|
|
|
|
def solution(w, h, s):
|
|
memo_f = {w:math.factorial(w), h:math.factorial(h)}
|
|
memo_sp = {}
|
|
cases = 0
|
|
for weight, n_cycles in generate_cycle_count(w, h, memo_f):
|
|
if n_cycles not in memo_sp:
|
|
memo_sp[n_cycles] = s ** n_cycles
|
|
cases += weight * memo_sp[n_cycles]
|
|
return str(cases / (memo_f[w] * memo_f[h]))
|
|
|
|
def generate_cycle_count(w, h, memo_f = {}):
|
|
memo_c = {}
|
|
memo_p = {}
|
|
memo_gcd = {}
|
|
for pw in generate_cycles(w, memo = memo_c):
|
|
for ph in generate_cycles(h, memo = memo_c):
|
|
cycle_count = 0
|
|
for i in pw.split():
|
|
for j in ph.split():
|
|
cycle_count += gcd(i, j, memo_gcd)
|
|
variations = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p)
|
|
yield variations, cycle_count
|
|
|
|
# generate cycle list in a string format
|
|
# e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1.
|
|
def generate_cycles(remain, start = 1, memo = {}):
|
|
if remain < start:
|
|
return
|
|
if (remain, start) in memo:
|
|
for j in memo[(remain, start)]:
|
|
yield j
|
|
else:
|
|
memo[(remain, start)] = []
|
|
for i in range(start, remain+1):
|
|
if i == remain:
|
|
memo[(remain, start)] += [str(i)]
|
|
for j in generate_cycles(remain - i, i, memo):
|
|
memo[(remain, start)] += [str(i) + ' ' + j]
|
|
for j in memo[(remain, start)]:
|
|
yield j
|
|
|
|
def gcd(x, y, memo = {}):
|
|
if (x, y) in memo:
|
|
return memo[(x, y)]
|
|
if (y, x) in memo:
|
|
return memo[(y, x)]
|
|
m, n = int(x), int(y)
|
|
while n != 0:
|
|
t = m % n
|
|
m, n = n, t
|
|
memo[(x, y)] = m
|
|
return m
|
|
|
|
# count number of possible permutation
|
|
# e.g., '1 1' => 12
|
|
# e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321
|
|
# e.g., '3' => 111(*2)
|
|
def count_permutation(s, memo_f, memo_p):
|
|
if s not in memo_p:
|
|
group = {}
|
|
length = 0
|
|
for c in s.split():
|
|
size = int(c)
|
|
if size in group:
|
|
group[size] += 1
|
|
else:
|
|
group[size] = 1
|
|
length += size
|
|
result = memo_f[length]
|
|
for size, freq in group.items():
|
|
if freq not in memo_f:
|
|
memo_f[freq] = math.factorial(freq)
|
|
result /= memo_f[freq]
|
|
for i in range(freq):
|
|
result /= size
|
|
memo_p[s] = result
|
|
return memo_p[s]
|
|
|
|
tests = [
|
|
([2, 3, 4], '430'),
|
|
([2, 2, 2], '7'),
|
|
([1, 1, 2], '2'),
|
|
([1, 1, 3], '3'),
|
|
([2, 1, 2], '3'),
|
|
([1, 2, 2], '3'),
|
|
([3, 1, 2], '4'),
|
|
([1, 3, 2], '4'),
|
|
([4, 1, 2], '5'),
|
|
([1, 4, 2], '5'),
|
|
([5, 1, 2], '6'),
|
|
([1, 5, 2], '6'),
|
|
([6, 1, 2], '7'),
|
|
([1, 6, 2], '7'),
|
|
([12, 1, 2], '13'),
|
|
([1, 12, 2], '13'),
|
|
([2, 3, 2], '13'),
|
|
([3, 3, 2], '36'),
|
|
([4, 4, 2], '317'),
|
|
([5, 4, 2], '1053'),
|
|
([5, 5, 2], '5624'),
|
|
([6, 5, 2], '28576'),
|
|
([6, 6, 2], '251610'),
|
|
([12, 12, 20], '97195340925396730736950973830781340249131679073592360856141700148734207997877978005419735822878768821088343977969209139721682171487959967012286474628978470487193051591840'),
|
|
]
|
|
|
|
for i, o in tests:
|
|
result = solution(*i)
|
|
print (i, result == o, result, o)
|