How to Find the Maximum Row per Group in a Spark DataFrame?

Finding the maximum row per group in a Spark DataFrame is a common task in data analysis. Here’s how you can do it using both PySpark and Scala.

PySpark

Let’s start with an example in PySpark. Suppose you have the following DataFrame:


from pyspark.sql import SparkSession
from pyspark.sql.functions import col, max as max_

# Initialize Spark session
spark = SparkSession.builder.appName("example-max-row-group").getOrCreate()

# Sample data
data = [
    ("A", 1),
    ("A", 2),
    ("A", 3),
    ("B", 10),
    ("B", 20),
    ("B", 30)
]

# Create DataFrame
columns = ["group", "value"]
df = spark.createDataFrame(data, columns)

# Show DataFrame
df.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|    1|
|    A|    2|
|    A|    3|
|    B|   10|
|    B|   20|
|    B|   30|
+-----+-----+

Next, we will find the maximum value for each group:


from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

# Define the window specification
windowSpec = Window.partitionBy("group").orderBy(col("value").desc())

# Add row number to each row within the window partition
df_with_row_num = df.withColumn("row_num", row_number().over(windowSpec))

# Filter rows where row number is 1
result_df = df_with_row_num.filter(col("row_num") == 1).drop("row_num")

# Show the result
result_df.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|    3|
|    B|   30|
+-----+-----+

Scala

Here is a similar example in Scala:


import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val spark = SparkSession.builder().appName("example-max-row-group").getOrCreate()

// Sample data
val data = Seq(
    ("A", 1),
    ("A", 2),
    ("A", 3),
    ("B", 10),
    ("B", 20),
    ("B", 30)
)

val columns = Seq("group", "value")

// Create DataFrame
import spark.implicits._
val df = data.toDF(columns: _*)

// Show DataFrame
df.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|    1|
|    A|    2|
|    A|    3|
|    B|   10|
|    B|   20|
|    B|   30|
+-----+-----+

Next, we will find the maximum value for each group:


// Define the window specification
val windowSpec = Window.partitionBy("group").orderBy(col("value").desc)

// Add row number to each row within the window partition
val df_with_row_num = df.withColumn("row_num", row_number.over(windowSpec))

// Filter rows where row number is 1
val result_df = df_with_row_num.filter(col("row_num") === 1).drop("row_num")

// Show the result
result_df.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|    3|
|    B|   30|
+-----+-----+

Both the PySpark and Scala examples follow similar steps:

  1. Define a window specification that partitions the data by group and orders it by the value column in descending order.
  2. Add a row number to each row within the window partition.
  3. Filter the rows to keep only those where the row number is 1, which corresponds to the maximum value per group.

These methods allow you to efficiently find the maximum row per group in a Spark DataFrame.

About Editorial Team

Our Editorial Team is made up of tech enthusiasts deeply skilled in Apache Spark, PySpark, and Machine Learning, alongside proficiency in Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They're not just experts; they're passionate educators, dedicated to demystifying complex data concepts through engaging and easy-to-understand tutorials.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top