I am struggling to wrap my head around this compiler technique, so let's say here's my factorial function
def factorial(value: int) -> int: if value == 0: return 1 else: return factorial(value-1) * value It is recursive, but not TCO friendly yet, so, as the theory goes, the first thing to try here is translate it to CPS:
def factorial_cont(value: int, cont: typing.Callable[[int], T]) -> T: if value == 0: return cont(1) else: return factorial_cont(value-1, lambda result: cont(value * result)) Now, as the function is tail call recursive, I can do the usual trick with the while loop:
def factorial_while(value: int, cont: typing.Callable[[int], T]) -> T: current_cont = cont current_value = value while True: if current_value == 0: return current_cont(1) else: current_cont = lambda result: current_cont(current_value * result) # note: in actual python that would look like # current_cont = lambda result, c=current_cont, v=current_value: c(v * result) current_value = current_value - 1 This current_cont thing effectively becomes a huge composition chain, in haskell terms for the value == 3 that would be let resulting_cont = ((initial_cont . (3*)) . (2*)) . (1*), where initial_cont is safe to default to id, and surely enough resulting_cont value == value!.
But I also know the trick with "accumulator" value:
def factorial_acc(value: int, acc: int = 1) -> int: current_acc = acc current_value = value while True: if current_value == 1: return current_acc else: current_acc = current_acc * current_value current_value = current_value - 1 which looks pretty much identical to the CPS version after the introduction of while loop.
The question is, how exactly do I massage the continuation let resulting_cont = ((initial_cont . (3*)) . (2*)) . (1*) into the form resembling accumulator version?