import math from decimal import * def solution(s): getcontext().prec = len(s)+1 return str(beatty_sum_sqrt2(int(s))) def beatty_sum_sqrt2(n): if n == 0: return 0 if n == 1: return 1 floor_r_n = int(Decimal(2*(n**2)).sqrt()) return natural_number_sum(floor_r_n) - 2 * natural_number_sum(floor_r_n - n) - beatty_sum_sqrt2(floor_r_n - n) def natural_number_sum(n): return n*(n+1)/2 def ten(n): return str(10**n) tests = [ ['77', '4208'], ['1', '1'], ['2', '3'], ['3', '7'], ['4', '12'], ['5', '19'], ['6', '27'], ['7', '36'], ['8', '47'], ['9', '59'], ['10', '73'], [ten(100), '70710678118654752440084436210484903928483593768847403658833986899536623923105351942519376716382078638821760123411090095254685423841027253480565451739737157454059823250037671948325191776995310741236436'] ] for test in tests: result = solution(test[0]) print(test[0], result == test[1], result, test[1])