9

I have a pyspark dataframe. I have to do a group by and then aggregate certain columns into a list so that I can apply a UDF on the data frame.

As an example, I have created a dataframe and then grouped by person.

df = spark.createDataFrame(a, ["Person", "Amount","Budget", "Date"]) df = df.groupby("Person").agg(F.collect_list(F.struct("Amount", "Budget", "Date")).alias("data")) df.show(truncate=False) +------+----------------------------------------------------------------------------+ |Person|data | +------+----------------------------------------------------------------------------+ |Bob |[[85.8,Food,2017-09-13], [7.8,Household,2017-09-13], [6.52,Food,2017-06-13]]| +------+----------------------------------------------------------------------------+ 

I have left out the UDF but the resulting data frame from the UDF is below.

+------+--------------------------------------------------------------+ |Person|res | +------+--------------------------------------------------------------+ |Bob |[[562,Food,June,1], [380,Household,Sept,4], [880,Food,Sept,2]]| +------+--------------------------------------------------------------+ 

I need to convert the resulting dataframe into rows where each element in list is a new row with a new column. This can be seen below.

+------+------------------------------+ |Person|Amount|Budget |Month|Cluster| +------+------------------------------+ |Bob |562 |Food |June |1 | |Bob |380 |Household|Sept |4 | |Bob |880 |Food |Sept |2 | +------+------------------------------+ 
1

1 Answer 1

12

You can use explode and getItem as follows:

# starting from this form: +------+-------------------------------------------------------------- |Person|res | +------+--------------------------------------------------------------+ |Bob |[[562,Food,June,1], [380,Household,Sept,4], [880,Food,Sept,2]]| +------+--------------------------------------------------------------+ import pyspark.sql.functions as F # explode res to have one row for each item in res exploded_df = df.select("*", F.explode("res").alias("exploded_data")) exploded_df.show(truncate=False) # then use getItem to create separate columns exploded_df = exploded_df.withColumn( "Amount", F.col("exploded_data").getItem("Amount") # either get by name or by index e.g. getItem(0) etc ) exploded_df = exploded_df.withColumn( "Budget", F.col("exploded_data").getItem("Budget") ) exploded_df = exploded_df.withColumn( "Month", F.col("exploded_data").getItem("Month") ) exploded_df = exploded_df.withColumn( "Cluster", F.col("exploded_data").getItem("Cluster") ) exploded_df.select("Person", "Amount", "Budget", "Month", "Cluster").show(10, False) +------+------------------------------+ |Person|Amount|Budget |Month|Cluster| +------+------------------------------+ |Bob |562 |Food |June |1 | |Bob |380 |Household|Sept |4 | |Bob |880 |Food |Sept |2 | +------+------------------------------+ 

You can then drop unnecessary columns. Hope this helps, good luck!

Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.