Use DP to enhance the performance
Verifying solution... Test 1 passed! Test 2 passed! Test 3 passed! [Hidden] Test 4 passed! [Hidden] Test 5 failed [Hidden] Test 6 failed [Hidden] Test 7 passed! [Hidden] Test 8 failed [Hidden] Test 9 passed! [Hidden] Test 10 failed [Hidden]
This commit is contained in:
parent
e8f69a8743
commit
7b386ff851
@ -2,23 +2,106 @@ import math
|
||||
import numpy as np
|
||||
|
||||
def solution(w, h, s):
|
||||
return str(sum([s**cycle for cycle in generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h)))
|
||||
memo_f = {w:math.factorial(w), h:math.factorial(h)}
|
||||
memo_sp = {}
|
||||
cases = 0
|
||||
for c in generate_cycle_count(w, h, memo_f):
|
||||
if c not in memo_sp:
|
||||
memo_sp[c] = s ** c
|
||||
cases += memo_sp[c]
|
||||
return str(cases / (memo_f[w] * memo_f[h]))
|
||||
|
||||
def generate_cycles(w, h):
|
||||
def generate_cycle_count(w, h, memo_f = {}):
|
||||
memo_c = {}
|
||||
memo_p = {}
|
||||
memo_gcd = {}
|
||||
for pw in generate_cycles(w, memo = memo_c):
|
||||
for ph in generate_cycles(h, memo = memo_c):
|
||||
cycle_count = 0
|
||||
for i in pw.split():
|
||||
for j in ph.split():
|
||||
cycle_count += gcd(i, j, memo_gcd)
|
||||
cases = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p)
|
||||
#print pw, ph, cases, cycle_count
|
||||
for i in range(cases):
|
||||
yield cycle_count
|
||||
|
||||
# generate cycle list in a string format
|
||||
# e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1.
|
||||
def generate_cycles(remain, start = 1, memo = {}):
|
||||
if remain < start:
|
||||
return
|
||||
if (remain, start) in memo:
|
||||
for j in memo[(remain, start)]:
|
||||
yield j
|
||||
else:
|
||||
memo[(remain, start)] = []
|
||||
for i in range(start, remain+1):
|
||||
if i == remain:
|
||||
memo[(remain, start)] += [str(i)]
|
||||
for j in generate_cycles(remain - i, i, memo):
|
||||
memo[(remain, start)] += [str(i) + ' ' + j]
|
||||
for j in memo[(remain, start)]:
|
||||
yield j
|
||||
|
||||
def gcd(x, y, memo = {}):
|
||||
if (x, y) in memo:
|
||||
return memo[(x, y)]
|
||||
if (y, x) in memo:
|
||||
return memo[(y, x)]
|
||||
m, n = int(x), int(y)
|
||||
while n != 0:
|
||||
t = m % n
|
||||
m, n = n, t
|
||||
memo[(x, y)] = m
|
||||
return m
|
||||
|
||||
# count number of possible permutation
|
||||
# e.g., '1 1' => 12
|
||||
# e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321
|
||||
# e.g., '3' => 111(*2)
|
||||
def count_permutation(s, memo_f, memo_p):
|
||||
if s not in memo_p:
|
||||
group = {}
|
||||
length = 0
|
||||
for c in s.split():
|
||||
size = int(c)
|
||||
if size in group:
|
||||
group[size] += 1
|
||||
else:
|
||||
group[size] = 1
|
||||
length += size
|
||||
result = memo_f[length]
|
||||
for size, freq in group.items():
|
||||
if freq not in memo_f:
|
||||
memo_f[freq] = math.factorial(freq)
|
||||
result /= memo_f[freq] * (size ** freq)
|
||||
memo_p[s] = result
|
||||
return memo_p[s]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _solution(w, h, s):
|
||||
return str(sum([s**cycle for cycle in _generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h)))
|
||||
|
||||
def _generate_cycles(w, h):
|
||||
identity = np.arange(w*h).reshape((w, h))
|
||||
for i in generate_permutation(list(range(w))):
|
||||
for j in generate_permutation(list(range(h))):
|
||||
yield count_cycle(identity[i][:, j])
|
||||
for i in _generate_permutation(list(range(w))):
|
||||
for j in _generate_permutation(list(range(h))):
|
||||
#print i, j, _count_cycle(identity[i][:, j])
|
||||
yield _count_cycle(identity[i][:, j])
|
||||
|
||||
def generate_permutation(numbers):
|
||||
def _generate_permutation(numbers):
|
||||
if len(numbers) == 1:
|
||||
yield numbers
|
||||
else:
|
||||
for i, n in enumerate(numbers):
|
||||
for j in generate_permutation(numbers[:i] + numbers[i+1:]):
|
||||
for j in _generate_permutation(numbers[:i] + numbers[i+1:]):
|
||||
yield [n] + j
|
||||
|
||||
def count_cycle(arr):
|
||||
def _count_cycle(arr):
|
||||
cycle = np.zeros(arr.shape, dtype=int)
|
||||
w, h = arr.shape
|
||||
cycle_id = 0
|
||||
@ -33,19 +116,32 @@ def count_cycle(arr):
|
||||
return cycle_id
|
||||
|
||||
tests = [
|
||||
([2, 3, 4], '430'),
|
||||
([2, 2, 2], '7'),
|
||||
([1, 1, 2], '2'),
|
||||
([1, 1, 3], '3'),
|
||||
([2, 1, 2], '3'),
|
||||
([1, 2, 2], '3'),
|
||||
([3, 1, 2], '4'),
|
||||
([1, 3, 2], '4'),
|
||||
([2, 3, 2], '13'),
|
||||
#([2, 3, 4], '430'),
|
||||
#([2, 2, 2], '7'),
|
||||
#([1, 1, 2], '2'),
|
||||
#([1, 1, 3], '3'),
|
||||
#([2, 1, 2], '3'),
|
||||
#([1, 2, 2], '3'),
|
||||
#([3, 1, 2], '4'),
|
||||
#([1, 3, 2], '4'),
|
||||
#([4, 1, 2], '5'),
|
||||
#([1, 4, 2], '5'),
|
||||
#([5, 1, 2], '6'),
|
||||
#([1, 5, 2], '6'),
|
||||
#([6, 1, 2], '7'),
|
||||
#([1, 6, 2], '7'),
|
||||
([12, 1, 2], '13'),
|
||||
#([1, 12, 2], '13'),
|
||||
#([2, 3, 2], '13'),
|
||||
#([3, 3, 2], '36'),
|
||||
#([4, 4, 2], '317'),
|
||||
#([5, 4, 2], '1053'),
|
||||
#([5, 5, 2], '5624'),
|
||||
#([6, 5, 2], '28576'),
|
||||
#([6, 6, 2], '251610'),
|
||||
#([12, 12, 20], '251610'),
|
||||
]
|
||||
|
||||
for i, o in tests:
|
||||
result = solution(*i)
|
||||
print (i, result == o, result, o)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user