Let's say that our input dataframe is:
+---+-------+----+----------+ |id |col1 |col2|updated_at| +---+-------+----+----------+ |123|null |null|1634228709| |123|null |80 |1634228724| |123|update2|90 |1634229000| |12 |update1|null|1634221233| |12 |null |80 |1634228333| |12 |update2|null|1634221220| +---+-------+----+----------+
What we want is to covert updated_at to TimestampType then order by id and updated_at in desc order:
df = df.withColumn("updated_at", F.col("updated_at").cast(TimestampType())).orderBy( F.col("id"), F.col("updated_at").desc() )
that gives us:
+---+-------+----+-------------------+ |id |col1 |col2|updated_at | +---+-------+----+-------------------+ |12 |null |80 |2021-10-14 18:18:53| |12 |update1|null|2021-10-14 16:20:33| |12 |update2|null|2021-10-14 16:20:20| |123|update2|90 |2021-10-14 18:30:00| |123|null |80 |2021-10-14 18:25:24| |123|null |null|2021-10-14 18:25:09| +---+-------+----+-------------------+
Now get first non None value in each column or return None and group by id:
exp = [F.first(x, ignorenulls=True).alias(x) for x in df.columns[1:]] df = df.groupBy(F.col("id")).agg(*exp)
And the result is:
+---+-------+----+-------------------+ |id |col1 |col2|updated_at | +---+-------+----+-------------------+ |123|update2|90 |2021-10-14 18:30:00| |12 |update1|80 |2021-10-14 18:18:53| +---+-------+----+-------------------+
Here's the full example code:
from pyspark.sql import SparkSession import pyspark.sql.functions as F from pyspark.sql.types import TimestampType if __name__ == "__main__": spark = SparkSession.builder.master("local").appName("Test").getOrCreate() data = [ (123, None, None, 1634228709), (123, None, 80, 1634228724), (123, "update2", 90, 1634229000), (12, "update1", None, 1634221233), (12, None, 80, 1634228333), (12, "update2", None, 1634221220), ] columns = ["id", "col1", "col2", "updated_at"] df = spark.createDataFrame(data, columns) df = df.withColumn("updated_at", F.col("updated_at").cast(TimestampType())).orderBy( F.col("id"), F.col("updated_at").desc() ) exp = [F.first(x, ignorenulls=True).alias(x) for x in df.columns[1:]] df = df.groupBy(F.col("id")).agg(*exp)