How to GroupBy and Filter on Count in a Scala DataFrame?

When working with Apache Spark and Scala, it’s quite common to perform group-by operations followed by filtering based on the aggregated counts. Below is a detailed answer with a code example.

GroupBy and Filter on Count in a Scala DataFrame

Let’s consider an example where we have a DataFrame of employee data with columns such as employee_id, department, and salary. We want to group the data by the department and then filter out the departments with fewer than a specified number of employees.

Step-by-Step Explanation

1. Import Required Libraries

First, you need to import the required libraries.


import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

2. Create SparkSession

Create a SparkSession, which is the entry point of our Spark application.


val spark = SparkSession.builder()
  .appName("GroupBy and Filter")
  .getOrCreate()
import spark.implicits._

3. Sample DataFrame

Create a sample DataFrame for demonstration.


val data = Seq(
  (1, "HR", 10000),
  (2, "IT", 20000),
  (3, "HR", 15000),
  (4, "IT", 22000),
  (5, "Finance", 30000),
  (6, "Finance", 25000),
  (7, "HR", 13000),
  (8, "IT", 21500),
  (9, "Finance", 28000),
  (10, "HR", 12000)
)

val df = spark.createDataFrame(data).toDF("employee_id", "department", "salary")
df.show()

+-----------+----------+------+
|employee_id|department|salary|
+-----------+----------+------+
|          1|        HR| 10000|
|          2|        IT| 20000|
|          3|        HR| 15000|
|          4|        IT| 22000|
|          5|   Finance| 30000|
|          6|   Finance| 25000|
|          7|        HR| 13000|
|          8|        IT| 21500|
|          9|   Finance| 28000|
|         10|        HR| 12000|
+-----------+----------+------+

4. GroupBy and Count

Group the DataFrame by the department column and then count the number of employees in each department.


val groupedDF = df.groupBy("department").count()
groupedDF.show()

+----------+-----+
|department|count|
+----------+-----+
|        HR|    4|
|        IT|    3|
|   Finance|    3|
+----------+-----+

5. Filter Based on Count

Now, filter out the departments with fewer than, for example, 4 employees.


val filteredDF = groupedDF.filter($"count" >= 4)
filteredDF.show()

+----------+-----+
|department|count|
+----------+-----+
|        HR|    4|
+----------+-----+

6. Use with Original DataFrame

If you want to join this result back with the original DataFrame to get detailed information about the employees in the remaining departments, you can perform a join.


val resultDF = df.join(filteredDF, "department")
resultDF.show()

+----------+-----------+------+
|department|employee_id|salary|
+----------+-----------+------+
|        HR|          1| 10000|
|        HR|          3| 15000|
|        HR|          7| 13000|
|        HR|         10| 12000|
+----------+-----------+------+

In this example, only the “HR” department remains after filtering out departments with fewer than 4 employees. The final join operation gives you the detailed employee records for the remaining departments.

Summary

By following these steps, you can efficiently group by a column and filter based on aggregated counts in a Scala DataFrame using Apache Spark. This approach is helpful in a wide range of data processing tasks where such operations are frequently needed.

About Editorial Team

Our Editorial Team is made up of tech enthusiasts deeply skilled in Apache Spark, PySpark, and Machine Learning, alongside proficiency in Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They're not just experts; they're passionate educators, dedicated to demystifying complex data concepts through engaging and easy-to-understand tutorials.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top