How to Retrieve Top N Records in Each Group Using PySpark DataFrame?

Retrieving the top N records in each group using PySpark DataFrame is a common requirement in data processing and analysis. We can achieve this using the `Window` function in PySpark, combined with `partitionBy` to create groups and `orderBy` to sort the records within each group. Here, I will provide a detailed explanation along with a code example and its output.

Steps to Retrieve Top N Records in Each Group

1. Import Required Libraries

First, we need to import the necessary libraries and create a Spark session.


from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number

# Create a Spark session
spark = SparkSession.builder.appName("TopNRecords").getOrCreate()

2. Create Sample DataFrame

Next, let’s create a sample DataFrame to work with:


data = [
    ("A", 1, 10),
    ("A", 2, 20),
    ("A", 3, 30),
    ("B", 1, 5),
    ("B", 2, 15),
    ("B", 3, 25)
]

columns = ["Category", "ID", "Value"]

df = spark.createDataFrame(data, columns)
df.show()

Output:


+--------+---+-----+
|Category| ID|Value|
+--------+---+-----+
|       A|  1|   10|
|       A|  2|   20|
|       A|  3|   30|
|       B|  1|    5|
|       B|  2|   15|
|       B|  3|   25|
+--------+---+-----+

3. Define Window Specification

Use the `Window` function and specify partitioning by the group column and ordering within each group:


window_spec = Window.partitionBy("Category").orderBy(col("Value").desc())

4. Assign Row Numbers

Add a row number to each record within its group:


df_with_row_num = df.withColumn("row_number", row_number().over(window_spec))
df_with_row_num.show()

Output:


+--------+---+-----+----------+
|Category| ID|Value|row_number|
+--------+---+-----+----------+
|       A|  3|   30|         1|
|       A|  2|   20|         2|
|       A|  1|   10|         3|
|       B|  3|   25|         1|
|       B|  2|   15|         2|
|       B|  1|    5|         3|
+--------+---+-----+----------+

5. Filter Top N Records in Each Group

Finally, filter out the top N records in each group:


N = 2
top_n_df = df_with_row_num.filter(col("row_number") <= N).drop("row_number")
top_n_df.show()

Output:


+--------+---+-----+
|Category| ID|Value|
+--------+---+-----+
|       A|  3|   30|
|       A|  2|   20|
|       B|  3|   25|
|       B|  2|   15|
+--------+---+-----+

In this example, we have successfully retrieved the top 2 records by the `Value` column in each `Category` group.

Conclusion

By following these steps, you can retrieve the top N records in each group using PySpark DataFrame. This approach uses the `Window` function and `row_number` to assign rankings within each group, making it easy to filter out the desired number of top records.

About Editorial Team

Our Editorial Team is made up of tech enthusiasts who are highly skilled in Apache Spark, PySpark, and Machine Learning. They are also proficient in Python, Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They aren't just experts; they are passionate teachers. They are dedicated to making complex data concepts easy to understand through engaging and simple tutorials with examples.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top