10

I was explaining a friend that I expected non tail recursive function in Scala to be slower than tail recursive ones, so I decided to verify it. I wrote a good old factorial function both ways and attempted to compare the results. Here's the code:

def main(args: Array[String]): Unit = { val N = 2000 // not too much or else stackoverflows var spent1: Long = 0 var spent2: Long = 0 for ( i <- 1 to 100 ) { // repeat to average the results val t0 = System.nanoTime factorial(N) val t1 = System.nanoTime tailRecFact(N) val t2 = System.nanoTime spent1 += t1 - t0 spent2 += t2 - t1 } println(spent1/1000000f) // get milliseconds println(spent2/1000000f) } @tailrec def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n) def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1) 

The results are confusing me, I get this kind of output:

578.2985

870.22125

Meaning the non tail recursive function is 30% faster than the tail recursive one, and the number of operation is the same!

What would explain those results?

2 Answers 2

10

It's actually not where you would first look.The reason is in your tail recursion method, you are doing more work with its multiply. Try swapping around the order of the params n and s in the recursive call and it will even out.

def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s) 

Moreover, most of the time in this sample is taken up with the BigInt operations which dwarf the time of the recursive call. If we switch these over to Ints (compiled to Java primitives) then you can see the how tail recursion (goto) compares to method invocation.

object Test extends App { val N = 2000 val t0 = System.nanoTime() for ( i <- 1 to 1000 ) { factorial(N) } val t1 = System.nanoTime for ( i <- 1 to 1000 ) { tailRecFact(N, 1) } val t2 = System.nanoTime println((t1 - t0) / 1000000f) // get milliseconds println((t2 - t1) / 1000000f) def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1) @tailrec final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n) } 95.16733 3.987605 

For interest, the decompiled output

 public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt); Code: 0: aload_1 1: iconst_1 2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z 8: ifeq 13 11: aload_2 12: areturn 13: aload_1 14: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$; 17: iconst_1 18: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt; 21: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt; 24: aload_1 25: aload_2 26: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt; 29: astore_2 30: astore_1 31: goto 0 public scala.math.BigInt factorial(scala.math.BigInt); Code: 0: aload_1 1: iconst_1 2: invokestatic #16 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 5: invokestatic #20 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z 8: ifeq 21 11: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$; 14: iconst_1 15: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt; 18: goto 40 21: aload_1 22: aload_0 23: aload_1 24: getstatic #26 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$; 27: iconst_1 28: invokevirtual #30 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt; 31: invokevirtual #36 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt; 34: invokevirtual #47 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt; 37: invokevirtual #39 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt; 40: areturn 
Sign up to request clarification or add additional context in comments.

6 Comments

Please elaborate. I don't see the "twice as many" invokations?
Why would it require 2 implicit conversions? 's' and 'n' are already BigInt. In "factorial", you have nfactorial(n-1), so you have the multiplication and the subtraction of 1. That is 2 BigInt-operations per recursive step. In tailRecFact, you have n-1 and ns, that is again 2 BigInt-operations per recursive step.
FYI I deleted my comment because I did not explain what I meant properly. I will update my answer.
Updated, actually was not what I expected at all.
@monkjack, do you know why swapping n and s affects time so much? In bytecode there is not much difference, only two aload instructions are swapped.
|
9

In addition to the problem shown by @monkjack (i.e multiplying small * big is faster than big * small, which does account for a greater chunk of the difference), your algorithm is different in each case so they're not really comparable.

In the tail-recursive version you're mutiplying big-to-small:

n * n-1 * n-2 * ... * 2 * 1 

In the non-tail recursive version you're multiplying small-to-big:

n * (n-1 * (n-2 * (... * (2 * 1)))) 

If you alter the tail-recursive version so it multiplies small-to-big:

def tailRecFact2(n: BigInt) = { def loop(x: BigInt, out: BigInt): BigInt = if (x > n) out else loop(x + 1, x * out) loop(1, 1) } 

then tail-recursion is about 20% faster than normal-recursion, rather than 10% slower as it is if you just make monkjack's correction. This is because multiplying together small BigInts is faster than multiplying large ones.

2 Comments

This is a bit weird to me. Is it because you can keep the smaller BigInts in cache memory?
This answer helps explain BigInteger`s odd/poor performance: stackoverflow.com/a/17590529/770361

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.