As @pault mentioned, as of Spark 2.4+, you can use Spark SQL built-in function to handle your task, here is one way with array_distinct + transform + aggregate:
from pyspark.sql.functions import expr # set up data spark_df = spark.createDataFrame([ (101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1]) , (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1]) , (103, 'Smith', ['avocado'], [10]) ], ['ID', 'name', 'fruit', 'qty'] ) >>> spark_df.show(5,0) +---+-----+-----------------------------------------+----------------+ |ID |name |fruit |qty | +---+-----+-----------------------------------------+----------------+ |101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] | |102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]| |103|Smith|[avocado] |[10] | +---+-----+-----------------------------------------+----------------+ >>> spark_df.printSchema() root |-- ID: long (nullable = true) |-- name: string (nullable = true) |-- fruit: array (nullable = true) | |-- element: string (containsNull = true) |-- qty: array (nullable = true) | |-- element: long (containsNull = true)
Set up the SQL statement:
stmt = ''' transform(array_distinct(fruit), x -> (x, aggregate( transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0)) , 0 , (y,z) -> int(y + z) ))) AS sum_fruit ''' >>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0) +---+-----+-----------------------------------------+----------------+----------------------------------------+ |ID |name |fruit |qty |sum_fruit | +---+-----+-----------------------------------------+----------------+----------------------------------------+ |101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] |[[apple, 19], [orange, 3]] | |102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]| |103|Smith|[avocado] |[10] |[[avocado, 10]] | +---+-----+-----------------------------------------+----------------+----------------------------------------+
Explanation:
- Use
array_distinct(fruit) to find all distinct entries in the array fruit - transform this new array (with element
x) from x to (x, aggregate(..x..)) the above function aggregate(..x..) takes the simple form of summing up all elements in array_T
aggregate(array_T, 0, (y,z) -> y + z)
where the array_T is from the following transformation:
transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
which iterate through the array fruit, if the value of fruit[i] = x , then return the corresponding qty[i], otherwise return 0. for example for ID=101, when x = 'orange', it returns an array [0, 0, 3, 0]
array_zip. Older versions make this a little more difficult.explodewas being used. Just something to keep in mind..