from pyspark.sql import Row from pyspark.sql.types import StructField, StructType, StringType, IntegerType from pyspark.sql.window import Window from pyspark.sql.functions import create_map, explode, struct, split, row_number, to_json from functools import reduce
/* DataFrame Schema */
dfSchema = StructType([ StructField('RowID', IntegerType()), StructField('Name', StringType()), StructField('Place', StringType()) ])
/* Raw Data */
rowID_11 = Row(1, 'Gaga', 'India,US,UK') rowID_12 = Row(1, 'Katy', 'UK,India,Europe') rowID_13 = Row(1, 'Bey', 'Europe') rowID_21 = Row(2, 'Gaga', None) rowID_22 = Row(2, 'Katy', 'India,Europe') rowID_23 = Row(2, 'Bey', 'US') rowID_31 = Row(3, 'Gaga', 'Europe') rowID_32 = Row(3, 'Katy', 'US') rowID_33 = Row(3, 'Bey', None) rowList = [rowID_11, rowID_12, rowID_13, rowID_21, rowID_22, rowID_23, rowID_31, rowID_32, rowID_33]
/* Create initial DataFrame */
df = spark.createDataFrame(rowList, dfSchema) df.show()
+-----+----+---------------+ |RowID|Name| Place| +-----+----+---------------+ | 1|Gaga| India,US,UK| | 1|Katy|UK,India,Europe| | 1| Bey| Europe| | 2|Gaga| null| | 2|Katy| India,Europe| | 2| Bey| US| | 3|Gaga| Europe| | 3|Katy| US| | 3| Bey| null| +-----+----+---------------+
/* Use create_map, struct and to_json to create intermediate output */
jsonDFCol = df.select( to_json( create_map('Name', struct('RowID', 'Place')))\ .alias('name_place')) jsonList = [js[0] for js in jsonDFCol.rdd.collect()] jsonDF = spark.read.json(sc.parallelize(jsonList)) intermediateList = [jsonDF .selectExpr(f'{name}.RowID', f'{name}.Place AS {name}')\ .where('RowID is not Null') for name in jsonDF .columns] intermediateDF = reduce(lambda curr, nxt: curr.join(nxt, on='RowID'), intermediateList).sort('RowID')\ .select('RowID', 'Gaga', 'Katy', 'Bey') intermediateDF.show()
+-----+-----------+---------------+------+ |RowID| Gaga| Katy| Bey| +-----+-----------+---------------+------+ | 1|India,US,UK|UK,India,Europe|Europe| | 2| null| India,Europe| US| | 3| Europe| US| null| +-----+-----------+---------------+------+
/* Use window to create Id column */
rowWindow = Window.partitionBy('RowID').orderBy('RowID')
/* Use split and explode functions to obtain final output */
finalDFList = \ [intermediateDF\ .select('RowID', explode(split(intermediateDF[col_], ',')).alias(col_))\ .withColumn('id', row_number().over(rowWindow)) for col_ in intermediateDF.columns[1:]] finalDFID = reduce(lambda curr, nxt: curr.select('RowID', 'Id')\ .unionAll(nxt.select('RowId', 'Id')), finalDFList) finalDF = reduce(lambda curr, nxt: curr.join(nxt, on=['RowID', 'Id'], how='left'), finalDFList, finalDFID).distinct()\ .sort('RowId', 'Id')\ .select('RowID', 'Id', 'Gaga', 'Katy', 'Bey') finalDF.show()
+-----+---+------+------+------+ |RowID| Id| Gaga| Katy| Bey| +-----+---+------+------+------+ | 1| 1| India| UK|Europe| | 1| 2| US| India| null| | 1| 3| UK|Europe| null| | 2| 1| null| India| US| | 2| 2| null|Europe| null| | 3| 1|Europe| US| null| +-----+---+------+------+------+