How to Select the First Row of Each Group in Apache Spark?

To select the first row of each group in Apache Spark, you can use the `Window` functions along with `row_number()` to partition your data based on the grouping column and then filter the rows to get the first occurrence in each group. Below is an example using PySpark:

Using PySpark

Here’s a PySpark example to select the first row of each group:


from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

# Create Spark session
spark = SparkSession.builder.appName("SelectFirstRow").getOrCreate()

# Sample data
data = [
    ("A", 10),
    ("A", 20),
    ("B", 30),
    ("B", 40),
    ("C", 50)
]

# Create DataFrame
columns = ["group", "value"]
df = spark.createDataFrame(data, schema=columns)

# Define window spec
windowSpec = Window.partitionBy("group").orderBy("value")

# Add row numbers
df_with_row_num = df.withColumn("row_number", row_number().over(windowSpec))

# Filter to get the first row in each group
df_first_row = df_with_row_num.filter(df_with_row_num["row_number"] == 1).drop("row_number")

# Show the result
df_first_row.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|   10|
|    B|   30|
|    C|   50|
+-----+-----+

Using Scala

Here’s a Scala example to select the first row of each group:


import org.apache.spark.sql.{SparkSession, functions => F}
import org.apache.spark.sql.expressions.Window

// Create Spark session
val spark = SparkSession.builder.appName("SelectFirstRow").getOrCreate()
import spark.implicits._

// Sample data
val data = Seq(
  ("A", 10),
  ("A", 20),
  ("B", 30),
  ("B", 40),
  ("C", 50)
)

// Create DataFrame
val df = data.toDF("group", "value")

// Define window spec
val windowSpec = Window.partitionBy("group").orderBy("value")

// Add row numbers
val df_with_row_num = df.withColumn("row_number", F.row_number().over(windowSpec))

// Filter to get the first row in each group
val df_first_row = df_with_row_num.filter($"row_number" === 1).drop("row_number")

// Show the result
df_first_row.show()

The output will be:


+-----+-----+
|group|value|
+-----+-----+
|    A|   10|
|    B|   30|
|    C|   50|
+-----+-----+

Summary

In both PySpark and Scala examples, the method involves:

  • Creating a window specification to partition the data by the group column and order by the value column
  • Adding a row number to each row within its group
  • Filtering the DataFrame to keep only rows where the row number is 1

About Editorial Team

Our Editorial Team is made up of tech enthusiasts who are highly skilled in Apache Spark, PySpark, and Machine Learning. They are also proficient in Python, Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They aren't just experts; they are passionate teachers. They are dedicated to making complex data concepts easy to understand through engaging and simple tutorials with examples.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top