3

I have the following sample dataframe

fruit_list = ['apple', 'apple', 'orange', 'apple'] qty_list = [16, 2, 3, 1] spark_df = spark.createDataFrame([(101, 'Mark', fruit_list, qty_list)], ['ID', 'name', 'fruit', 'qty']) 

and I would like to create another column which contains a result similar to what I would achieve with a pandas groupby('fruit').sum()

 qty fruits apple 19 orange 3 

The above result could be stored in the new column in any form (either a string, dictionary, list of tuples...).

I've tried an approach similar to the following one which does not work

sum_cols = udf(lambda x: pd.DataFrame({'fruits': x[0], 'qty': x[1]}).groupby('fruits').sum()) spark_df.withColumn('Result', sum_cols(F.struct('fruit', 'qty'))).show() 

One example of result dataframe could be

+---+----+--------------------+-------------+-------------------------+ | ID|name| fruit| qty| Result| +---+----+--------------------+-------------+-------------------------+ |101|Mark|[apple, apple, or...|[16, 2, 3, 1]|[(apple,19), (orange,3)] | +---+----+--------------------+-------------+-------------------------+ 

Do you have any suggestion on how I could achieve that?

Thanks

Edit: running on Spark 2.4.3

5
  • What is your desired output? It's unclear from the description, please show it explicitly. Commented Jul 31, 2019 at 13:30
  • thanks for your comment, done! Commented Jul 31, 2019 at 13:37
  • 1
    What version of spark? If it's spark 2.4+ you can use array_zip. Older versions make this a little more difficult. Commented Jul 31, 2019 at 13:40
  • I'm running on 2.4.3 could you kindly provide me with an example usage for that in my case? Commented Jul 31, 2019 at 13:52
  • In my (limited) experience, I've seen "native" pyspark code perform 10x faster than UDFs (and especially UDAFs), even when an explode was being used. Just something to keep in mind.. Commented Aug 1, 2019 at 21:38

3 Answers 3

4

As @pault mentioned, as of Spark 2.4+, you can use Spark SQL built-in function to handle your task, here is one way with array_distinct + transform + aggregate:

from pyspark.sql.functions import expr # set up data spark_df = spark.createDataFrame([ (101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1]) , (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1]) , (103, 'Smith', ['avocado'], [10]) ], ['ID', 'name', 'fruit', 'qty'] ) >>> spark_df.show(5,0) +---+-----+-----------------------------------------+----------------+ |ID |name |fruit |qty | +---+-----+-----------------------------------------+----------------+ |101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] | |102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]| |103|Smith|[avocado] |[10] | +---+-----+-----------------------------------------+----------------+ >>> spark_df.printSchema() root |-- ID: long (nullable = true) |-- name: string (nullable = true) |-- fruit: array (nullable = true) | |-- element: string (containsNull = true) |-- qty: array (nullable = true) | |-- element: long (containsNull = true) 

Set up the SQL statement:

stmt = ''' transform(array_distinct(fruit), x -> (x, aggregate( transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0)) , 0 , (y,z) -> int(y + z) ))) AS sum_fruit ''' >>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0) +---+-----+-----------------------------------------+----------------+----------------------------------------+ |ID |name |fruit |qty |sum_fruit | +---+-----+-----------------------------------------+----------------+----------------------------------------+ |101|Mark |[apple, apple, orange, apple] |[16, 2, 3, 1] |[[apple, 19], [orange, 3]] | |102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]| |103|Smith|[avocado] |[10] |[[avocado, 10]] | +---+-----+-----------------------------------------+----------------+----------------------------------------+ 

Explanation:

  1. Use array_distinct(fruit) to find all distinct entries in the array fruit
  2. transform this new array (with element x) from x to (x, aggregate(..x..))
  3. the above function aggregate(..x..) takes the simple form of summing up all elements in array_T

    aggregate(array_T, 0, (y,z) -> y + z) 

    where the array_T is from the following transformation:

    transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0)) 

    which iterate through the array fruit, if the value of fruit[i] = x , then return the corresponding qty[i], otherwise return 0. for example for ID=101, when x = 'orange', it returns an array [0, 0, 3, 0]

Sign up to request clarification or add additional context in comments.

4 Comments

wow that's a very nice answer thanks @jxc! Do you think this performs better than @pault solution?
@crash I would imagine this is better than using a udf +1
@pault, you're an experienced user, do you think this should be the accepted solution?
Yes, but it's your decision to make.
3

There may be a fancy way to do this using only the API functions on Spark 2.4+, perhaps with some combination of arrays_zip and aggregate, but I can't think of any that don't involve an explode step followed by a groupBy. With that in mind, using a udf may actually be better for you in this case.

I think creating a pandas DataFrame just for the purpose of calling .groupby().sum() is overkill. Furthermore, even if you did do it that way, you'd need to convert the final output to a different data structure because a udf can't return a pandas DataFrame.

Here's one way with a udf using collections.defaultdict:

from collections import defaultdict from pyspark.sql.functions import udf def sum_cols_func(frt, qty): d = defaultdict(int) for x, y in zip(frt, map(int, qty)): d[x] += y return d.items() sum_cols = udf( lambda x: sum_cols_func(*x), ArrayType( StructType([StructField("fruit", StringType()), StructField("qty", IntegerType())]) ) ) 

Then call this by passing in the fruit and qty columns:

from pyspark.sql.functions import array, col spark_df.withColumn( "Result", sum_cols(array([col("fruit"), col("qty")])) ).show(truncate=False) #+---+----+-----------------------------+-------------+--------------------------+ #|ID |name|fruit |qty |Result | #+---+----+-----------------------------+-------------+--------------------------+ #|101|Mark|[apple, apple, orange, apple]|[16, 2, 3, 1]|[[orange, 3], [apple, 19]]| #+---+----+-----------------------------+-------------+--------------------------+ 

7 Comments

I like your solution pault, thanks for your time. However I'm getting this error Py4JJavaError: An error occurred while calling o3564.showString....Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype).. does it tell you anything?
just added from pyspark.sql.types import ArrayType, StructType, StringType, IntegerType, StructField
Also try forcing Python int type as well in the output return [(k, int(sum(v))) for k, v in d.items()], so that way the result of the udf will be definitely of Python native type.
@RichardNemeth that's a great point- which makes me wonder: are you sure you're using __builtin__.sum and not numpy.sum or pyspark.sql.functions.sum? Why you shouldn't use import *. Edit If you're getting a numpy object, it suggests that you're using numpy.sum.
yep Richard's solution seemed to solve the problem!
|
1

If you have spark < 2.4, use the follwoing to explode (otherwise check this answer):

df_split = (spark_df.rdd.flatMap(lambda row: [(row.ID, row.name, f, q) for f, q in zip(row.fruit, row.qty)]).toDF(["ID", "name", "fruit", "qty"])) df_split.show() 

Output:

+---+----+------+---+ | ID|name| fruit|qty| +---+----+------+---+ |101|Mark| apple| 16| |101|Mark| apple| 2| |101|Mark|orange| 3| |101|Mark| apple| 1| +---+----+------+---+ 

Then prepare the result you want. First find the aggregated dataframe:

df_aggregated = df_split.groupby('ID', 'fruit').agg(F.sum('qty').alias('qty')) df_aggregated.show() 

Output:

+---+------+---+ | ID| fruit|qty| +---+------+---+ |101|orange| 3| |101| apple| 19| +---+------+---+ 

And finally change it to the desired format:

df_aggregated.groupby('ID').agg(F.collect_list(F.struct(F.col('fruit'), F.col('qty'))).alias('Result')).show() 

Output:

+---+--------------------------+ |ID |Result | +---+--------------------------+ |101|[[orange, 3], [apple, 19]]| +---+--------------------------+ 

5 Comments

udf likely has better performance than explode in this case
I do not know, but I changed the answer to RDD instead of explode.
I'm trying this out. I think distributing the code over multiple lines would help readability thou.
udf would almost surely (definitely?) be better than rdd as well
This works as well, but pault solution is much cleaner and easier to understand.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.