I am trying to precompute partitions for some SparkSql queries. If I compute and persist the the partitions, Spark uses them. If I save the partitioned data to Parquet and reload it later, the partition information is gone and Spark will recompute it.
The actual data is large enough that significant time is spent partitioning. The code below demonstrates the problems sufficiently though. Test2() is currently the only thing that I can get to work, but I would like to jumpstart the actual processing, which is what test3() is attempting to do.
Anyone know what I'm doing wrong? ..or if this is something that Spark can do?
from pyspark import SparkContext from pyspark.sql import SQLContext from pyspark.sql.types import * # NOTE: Need to have python in PATH, SPARK_HOME set to location of spark, HADOOP_HOME set to location of winutils if __name__ == "__main__": sc = SparkContext(appName="PythonPartitionBug") sql_text = "select foo, bar from customer c, orders o where foo < 300 and c.custkey=o.custkey" def setup(): sqlContext = SQLContext(sc) fields1 = [StructField(name, IntegerType()) for name in ['custkey', 'foo']] data1 = [(1, 110), (2, 210), (3, 310), (4, 410), (5, 510)] df1 = sqlContext.createDataFrame(data1, StructType(fields1)) df1.persist() fields2 = [StructField(name, IntegerType()) for name in ['orderkey', 'custkey', 'bar']] data2 = [(1, 1, 10), (2, 1, 20), (3, 2, 30), (4, 3, 40), (5, 4, 50)] df2 = sqlContext.createDataFrame(data2, StructType(fields2)) df2.persist() return sqlContext, df1, df2 def test1(): # Without repartition the final plan includes hashpartitioning # == Physical Plan == # Project [foo#1,bar#14] # +- SortMergeJoin [custkey#0], [custkey#13] # :- Sort [custkey#0 ASC], false, 0 # : +- TungstenExchange hashpartitioning(custkey#0,200), None # : +- Filter (foo#1 < 300) # : +- InMemoryColumnarTableScan [custkey#0,foo#1], [(foo#1 < 300)], InMemoryRelation [custkey#0,foo#1], true, 10000, StorageLevel(false, true, false, false, 1), ConvertToUnsafe, None # +- Sort [custkey#13 ASC], false, 0 # +- TungstenExchange hashpartitioning(custkey#13,200), None # +- InMemoryColumnarTableScan [bar#14,custkey#13], InMemoryRelation [orderkey#12,custkey#13,bar#14], true, 10000, StorageLevel(false, true, false, false, 1), ConvertToUnsafe, None sqlContext, df1, df2 = setup() df1.registerTempTable("customer") df2.registerTempTable("orders") df3 = sqlContext.sql(sql_text) df3.collect() df3.explain(True) def test2(): # With repartition the final plan does not include hashpartitioning # == Physical Plan == # Project [foo#56,bar#69] # +- SortMergeJoin [custkey#55], [custkey#68] # :- Sort [custkey#55 ASC], false, 0 # : +- Filter (foo#56 < 300) # : +- InMemoryColumnarTableScan [custkey#55,foo#56], [(foo#56 < 300)], InMemoryRelation [custkey#55,foo#56], true, 10000, StorageLevel(false, true, false, false, 1), TungstenExchange hashpartitioning(custkey#55,4), None, None # +- Sort [custkey#68 ASC], false, 0 # +- InMemoryColumnarTableScan [bar#69,custkey#68], InMemoryRelation [orderkey#67,custkey#68,bar#69], true, 10000, StorageLevel(false, true, false, false, 1), TungstenExchange hashpartitioning(custkey#68,4), None, None sqlContext, df1, df2 = setup() df1a = df1.repartition(4, 'custkey').persist() df1a.registerTempTable("customer") df2a = df2.repartition(4, 'custkey').persist() df2a.registerTempTable("orders") df3 = sqlContext.sql(sql_text) df3.collect() df3.explain(True) def test3(): # After round tripping the partitioned data, the partitioning is lost and spark repartitions # == Physical Plan == # Project [foo#223,bar#284] # +- SortMergeJoin [custkey#222], [custkey#283] # :- Sort [custkey#222 ASC], false, 0 # : +- TungstenExchange hashpartitioning(custkey#222,200), None # : +- Filter (foo#223 < 300) # : +- InMemoryColumnarTableScan [custkey#222,foo#223], [(foo#223 < 300)], InMemoryRelation [custkey#222,foo#223], true, 10000, StorageLevel(false, true, false, false, 1), Scan ParquetRelation[custkey#222,foo#223] InputPaths: file:/E:/.../df1.parquet, None # +- Sort [custkey#283 ASC], false, 0 # +- TungstenExchange hashpartitioning(custkey#283,200), None # +- InMemoryColumnarTableScan [bar#284,custkey#283], InMemoryRelation [orderkey#282,custkey#283,bar#284], true, 10000, StorageLevel(false, true, false, false, 1), Scan ParquetRelation[orderkey#282,custkey#283,bar#284] InputPaths: file:/E:/.../df2.parquet, None sqlContext, df1, df2 = setup() df1a = df1.repartition(4, 'custkey').persist() df1a.write.parquet("df1.parquet", mode='overwrite') df1a = sqlContext.read.parquet("df1.parquet") df1a.persist() df1a.registerTempTable("customer") df2a = df2.repartition(4, 'custkey').persist() df2a.write.parquet("df2.parquet", mode='overwrite') df2a = sqlContext.read.parquet("df2.parquet") df2a.persist() df2a.registerTempTable("orders") df3 = sqlContext.sql(sql_text) df3.collect() df3.explain(True) test1() test2() test3() sc.stop()