I am trying to perform the following operation on pyspark.sql.dataframe
from pyspark.sql.functions import sum as spark_sum df = spark.createDataFrame([ ('a', 1.0, 1.0), ('a',1.0, 0.2), ('b', 1.0, 1.0), ('c' ,1.0, 0.5), ('d', 0.55, 1.0),('e', 1.0, 1.0) ]) >>> df.show() +---+----+---+ | _1| _2| _3| +---+----+---+ | a| 1.0|1.0| | a| 1.0|0.2| | b| 1.0|1.0| | c| 1.0|0.5| | d|0.55|1.0| | e| 1.0|1.0| +---+----+---+ Then, I am trying to do the following operation.
1) Select the rows when column df[_2] > df[_3]
2) For each row of selected from above, multiply df[_2] * df[_3], then take their sum
3) divide the result from above by the sum of column of df[_3]
Here is what I did:
>>> filter_df = df.where(df['_2'] > df['_3']) >>> filter_df.show() +---+---+---+ | _1| _2| _3| +---+---+---+ | a|1.0|0.2| | c|1.0|0.5| +---+---+---+ >>> result = spark_sum(filter_df['_2'] * filter_df['_3']) / spark_sum(filter_df['_3']) >>> df.select(result).show() +--------------------------+ |(sum((_2 * _3)) / sum(_3))| +--------------------------+ | 0.9042553191489361| +--------------------------+ But the answer should be (1.0 * 0.2 + 1.0 * 0.5) / (0.2+0.5) = 1.0 This is not correct. What??
It seems to me that such operation only taken on the original df, but not the filter_df. WTF?
result.show()?, it gave me the following error: Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: 'Column' object is not callabledftag is for the Unix command by that name and has nothing to do with dataframes.