Most likely you have already perused the source code:
class OrderedRDDFunctions { // <snip> def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { val part = new RangePartitioner(numPartitions, self, ascending) val shuffled = new ShuffledRDD[K, V, P](self, part) shuffled.mapPartitions(iter => { val buf = iter.toArray if (ascending) { buf.sortWith((x, y) => x._1 < y._1).iterator } else { buf.sortWith((x, y) => x._1 > y._1).iterator } }, preservesPartitioning = true) }
And, as you say, the entire data must go through the shuffle stage - as seen in the snippet.
However, your concern about subsequently invoking take(K) may not be so accurate. This operation does NOT cycle through all N items:
/** * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. */ def take(num: Int): Array[T] = {
So then, it would seem:
O(myRdd.take(K)) << O(myRdd.sortByKey()) ~= O(myRdd.sortByKey.take(k)) (at least for small K) << O(myRdd.sortByKey().collect()