I have a large parquet dataset that I am reading with Spark. Once read, I filter for a subset of rows which are used in a number of functions that apply different transformations:
The following is similar but not exact logic to what I'm trying to accomplish:
df = spark.read.parquet(file) special_rows = df.filter(col('special') > 0) # Thinking about adding the following line special_rows.cache() def f1(df): new_df_1 = df.withColumn('foo', lit(0)) return new_df_1 def f2(df): new_df_2 = df.withColumn('foo', lit(1)) return new_df_2 new_df_1 = f1(special_rows) new_df_2 = f2(special_rows) output_df = new_df_1.union(new_df_2) output_df.write.parquet(location) Because a number of functions might be using this filtered subset of rows, I'd like to cache or persist it in order to potentially speed up execution speed / memory consumption. I understand that in the above example, there is no action called until my final write to parquet.
My questions is, do I need to insert some sort of call to count(), for example, in order to trigger the caching, or if Spark during that final write to parquet call will be able to see that this dataframe is being used in f1 and f2 and will cache the dataframe itself.
If yes, is this an idiomatic approach? Does this mean in production and large scale Spark jobs that rely on caching, random operations that force an action on the dataframe pre-emptively are frequently used, such as a call to count?