Retrieving the top N records in each group using PySpark DataFrame is a common requirement in data processing and analysis. We can achieve this using the `Window` function in PySpark, combined with `partitionBy` to create groups and `orderBy` to sort the records within each group. Here, I will provide a detailed explanation along with a code example and its output.
Steps to Retrieve Top N Records in Each Group
1. Import Required Libraries
First, we need to import the necessary libraries and create a Spark session.
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
# Create a Spark session
spark = SparkSession.builder.appName("TopNRecords").getOrCreate()
2. Create Sample DataFrame
Next, let’s create a sample DataFrame to work with:
data = [
("A", 1, 10),
("A", 2, 20),
("A", 3, 30),
("B", 1, 5),
("B", 2, 15),
("B", 3, 25)
]
columns = ["Category", "ID", "Value"]
df = spark.createDataFrame(data, columns)
df.show()
Output:
+--------+---+-----+
|Category| ID|Value|
+--------+---+-----+
| A| 1| 10|
| A| 2| 20|
| A| 3| 30|
| B| 1| 5|
| B| 2| 15|
| B| 3| 25|
+--------+---+-----+
3. Define Window Specification
Use the `Window` function and specify partitioning by the group column and ordering within each group:
window_spec = Window.partitionBy("Category").orderBy(col("Value").desc())
4. Assign Row Numbers
Add a row number to each record within its group:
df_with_row_num = df.withColumn("row_number", row_number().over(window_spec))
df_with_row_num.show()
Output:
+--------+---+-----+----------+
|Category| ID|Value|row_number|
+--------+---+-----+----------+
| A| 3| 30| 1|
| A| 2| 20| 2|
| A| 1| 10| 3|
| B| 3| 25| 1|
| B| 2| 15| 2|
| B| 1| 5| 3|
+--------+---+-----+----------+
5. Filter Top N Records in Each Group
Finally, filter out the top N records in each group:
N = 2
top_n_df = df_with_row_num.filter(col("row_number") <= N).drop("row_number")
top_n_df.show()
Output:
+--------+---+-----+
|Category| ID|Value|
+--------+---+-----+
| A| 3| 30|
| A| 2| 20|
| B| 3| 25|
| B| 2| 15|
+--------+---+-----+
In this example, we have successfully retrieved the top 2 records by the `Value` column in each `Category` group.
Conclusion
By following these steps, you can retrieve the top N records in each group using PySpark DataFrame. This approach uses the `Window` function and `row_number` to assign rankings within each group, making it easy to filter out the desired number of top records.