Use DP to enhance the performance

Verifying solution...
Test 1 passed!
Test 2 passed!
Test 3 passed! [Hidden]
Test 4 passed! [Hidden]
Test 5 failed  [Hidden]
Test 6 failed  [Hidden]
Test 7 passed! [Hidden]
Test 8 failed  [Hidden]
Test 9 passed! [Hidden]
Test 10 failed  [Hidden]
This commit is contained in:
Seongbeom Park 2022-03-21 15:56:31 +09:00
parent e8f69a8743
commit 7b386ff851

View File

@ -2,23 +2,106 @@ import math
import numpy as np
def solution(w, h, s):
return str(sum([s**cycle for cycle in generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h)))
memo_f = {w:math.factorial(w), h:math.factorial(h)}
memo_sp = {}
cases = 0
for c in generate_cycle_count(w, h, memo_f):
if c not in memo_sp:
memo_sp[c] = s ** c
cases += memo_sp[c]
return str(cases / (memo_f[w] * memo_f[h]))
def generate_cycles(w, h):
def generate_cycle_count(w, h, memo_f = {}):
memo_c = {}
memo_p = {}
memo_gcd = {}
for pw in generate_cycles(w, memo = memo_c):
for ph in generate_cycles(h, memo = memo_c):
cycle_count = 0
for i in pw.split():
for j in ph.split():
cycle_count += gcd(i, j, memo_gcd)
cases = count_permutation(pw, memo_f, memo_p) * count_permutation(ph, memo_f, memo_p)
#print pw, ph, cases, cycle_count
for i in range(cases):
yield cycle_count
# generate cycle list in a string format
# e.g., a string with length 5 can be composed with '2 2 1' which means there are 3 cycles, and foreach size is 2, 2 and 1.
def generate_cycles(remain, start = 1, memo = {}):
if remain < start:
return
if (remain, start) in memo:
for j in memo[(remain, start)]:
yield j
else:
memo[(remain, start)] = []
for i in range(start, remain+1):
if i == remain:
memo[(remain, start)] += [str(i)]
for j in generate_cycles(remain - i, i, memo):
memo[(remain, start)] += [str(i) + ' ' + j]
for j in memo[(remain, start)]:
yield j
def gcd(x, y, memo = {}):
if (x, y) in memo:
return memo[(x, y)]
if (y, x) in memo:
return memo[(y, x)]
m, n = int(x), int(y)
while n != 0:
t = m % n
m, n = n, t
memo[(x, y)] = m
return m
# count number of possible permutation
# e.g., '1 1' => 12
# e.g., '2 2 1' => 12233, 21233, 22133, 22313, 22331, 12323, 21323, 23123, 23213, 23231, 12332, 21332, 23132, 23312, 23321
# e.g., '3' => 111(*2)
def count_permutation(s, memo_f, memo_p):
if s not in memo_p:
group = {}
length = 0
for c in s.split():
size = int(c)
if size in group:
group[size] += 1
else:
group[size] = 1
length += size
result = memo_f[length]
for size, freq in group.items():
if freq not in memo_f:
memo_f[freq] = math.factorial(freq)
result /= memo_f[freq] * (size ** freq)
memo_p[s] = result
return memo_p[s]
def _solution(w, h, s):
return str(sum([s**cycle for cycle in _generate_cycles(w, h)])/(math.factorial(w)*math.factorial(h)))
def _generate_cycles(w, h):
identity = np.arange(w*h).reshape((w, h))
for i in generate_permutation(list(range(w))):
for j in generate_permutation(list(range(h))):
yield count_cycle(identity[i][:, j])
for i in _generate_permutation(list(range(w))):
for j in _generate_permutation(list(range(h))):
#print i, j, _count_cycle(identity[i][:, j])
yield _count_cycle(identity[i][:, j])
def generate_permutation(numbers):
def _generate_permutation(numbers):
if len(numbers) == 1:
yield numbers
else:
for i, n in enumerate(numbers):
for j in generate_permutation(numbers[:i] + numbers[i+1:]):
for j in _generate_permutation(numbers[:i] + numbers[i+1:]):
yield [n] + j
def count_cycle(arr):
def _count_cycle(arr):
cycle = np.zeros(arr.shape, dtype=int)
w, h = arr.shape
cycle_id = 0
@ -33,19 +116,32 @@ def count_cycle(arr):
return cycle_id
tests = [
([2, 3, 4], '430'),
([2, 2, 2], '7'),
([1, 1, 2], '2'),
([1, 1, 3], '3'),
([2, 1, 2], '3'),
([1, 2, 2], '3'),
([3, 1, 2], '4'),
([1, 3, 2], '4'),
([2, 3, 2], '13'),
#([2, 3, 4], '430'),
#([2, 2, 2], '7'),
#([1, 1, 2], '2'),
#([1, 1, 3], '3'),
#([2, 1, 2], '3'),
#([1, 2, 2], '3'),
#([3, 1, 2], '4'),
#([1, 3, 2], '4'),
#([4, 1, 2], '5'),
#([1, 4, 2], '5'),
#([5, 1, 2], '6'),
#([1, 5, 2], '6'),
#([6, 1, 2], '7'),
#([1, 6, 2], '7'),
([12, 1, 2], '13'),
#([1, 12, 2], '13'),
#([2, 3, 2], '13'),
#([3, 3, 2], '36'),
#([4, 4, 2], '317'),
#([5, 4, 2], '1053'),
#([5, 5, 2], '5624'),
#([6, 5, 2], '28576'),
#([6, 6, 2], '251610'),
#([12, 12, 20], '251610'),
]
for i, o in tests:
result = solution(*i)
print (i, result == o, result, o)