Grouping Data by Multiple Columns in PySpark

Grouping data is a common operation in data analysis, allowing us to organize and summarize data in a meaningful way. In PySpark, which is Apache Spark’s API for Python, grouping data by multiple columns is a powerful functionality that lets you perform complex aggregations. This article will walk you through the process of grouping data by multiple columns using PySpark, with detailed examples and explanations for better understanding.

Understanding GroupBy in PySpark

Before diving into grouping data by multiple columns, it’s essential to understand the basic concept of GroupBy in PySpark. The GroupBy method is used to group the DataFrame based on one or more columns, and then perform some aggregation functions like counting, summing, or averaging to get insightful results.

Creating a PySpark Environment and DataFrame

To start using PySpark for grouping, you first need to set up a Spark session. Here’s how you can create a Spark session:


from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("Grouping Data Example") \
    .getOrCreate()

With the Spark session setup, let’s create a simple DataFrame that we’ll be using to demonstrate grouping by multiple columns:


from pyspark.sql import Row

data = [
    Row(department="Finance", employee="John", salary=1000),
    Row(department="Marketing", employee="Jill", salary=1200),
    Row(department="Finance", employee="Bill", salary=1200),
    Row(department="Sales", employee="Steve", salary=1500),
    Row(department="Marketing", employee="Eva", salary=1300),
    Row(department="Sales", employee="Sophie", salary=1600),
    Row(department="Sales", employee="Brian", salary=1300),
    Row(department="Finance", employee="Dylan", salary=1100),
    Row(department="Marketing", employee="Alice", salary=1600)
]

df = spark.createDataFrame(data)
df.show()

The output of the code snippet will be:


+----------+--------+------+
|department|employee|salary|
+----------+--------+------+
|   Finance|    John|  1000|
| Marketing|    Jill|  1200|
|   Finance|    Bill|  1200|
|     Sales|   Steve|  1500|
| Marketing|     Eva|  1300|
|     Sales|  Sophie|  1600|
|     Sales|   Brian|  1300|
|   Finance|   Dylan|  1100|
| Marketing|   Alice|  1600|
+----------+--------+------+

Grouping by Multiple Columns

When you want to group your data by more than one column, you can simply pass the column names as a list to the groupBy function. After the groupBy, you can then apply various aggregate functions. Let’s see how this works in practice.

Example: Grouping by Two Columns

Let’s suppose you are interested in grouping the above DataFrame by department and employee to find out the total salary for each employee within each department. Here’s how you can do it:


from pyspark.sql.functions import sum

grouped_df = df.groupBy(["department", "employee"]).agg(sum("salary").alias("total_salary"))
grouped_df.show()

After running the above code following output will be displayed, showing the total salary for each employee in each department:


+----------+--------+------------+
|department|employee|total_salary|
+----------+--------+------------+
|   Finance|    John|        1000|
| Marketing|    Jill|        1200|
|   Finance|    Bill|        1200|
|     Sales|   Steve|        1500|
| Marketing|     Eva|        1300|
|     Sales|  Sophie|        1600|
|     Sales|   Brian|        1300|
|   Finance|   Dylan|        1100|
| Marketing|   Alice|        1600|
+----------+--------+------------+

More Complex Aggregations

You can also calculate other kinds of aggregations like average, min, max, etc., on different columns after grouping. Let’s find the average salary along with the maximum and minimum salary in each department:


from pyspark.sql.functions import avg, max, min

complex_agg_df = df.groupBy("department") \
    .agg(
        avg("salary").alias("average_salary"),
        max("salary").alias("max_salary"),
        min("salary").alias("min_salary")
    )
complex_agg_df.show()

The resulting DataFrame will have the average, maximum, and minimum salary for each department:


+----------+--------------+----------+----------+
|department|average_salary|max_salary|min_salary|
+----------+--------------+----------+----------+
|   Finance|        1100.0|      1200|      1000|
| Marketing|        1366.67|      1600|      1200|
|     Sales|        1466.67|      1600|      1300|
+----------+--------------+----------+----------+

Sorting Grouped Data

After grouping and aggregating the data, you might need to sort the results. You can use the orderBy method to sort the data based on the aggregate function result.

Sorting Example with Aggregated Data

Let’s sort the previous complex aggregated DataFrame based on the average salary in descending order:


sorted_df = complex_agg_df.orderBy("average_salary", ascending=False)
sorted_df.show()

The output will show the departments sorted by their average salaries:


+----------+--------------+----------+----------+
|department|average_salary|max_salary|min_salary|
+----------+--------------+----------+----------+
| Marketing|        1366.67|      1600|      1200|
|     Sales|        1466.67|      1600|      1300|
|   Finance|        1100.0|      1200|      1000|
+----------+--------------+----------+----------+

Grouping with Filter Conditions

Sometimes, before grouping the data, you may want to filter the DataFrame based on certain conditions. PySpark allows you to filter rows using the filter or where methods.

Filtering Before Grouping

Let’s say we only want to analyze data for departments where the salary is greater than 1200. Here is how to apply a filter condition before grouping:


filtered_df = df.filter(df.salary > 1200)
filtered_grouped_df = filtered_df.groupBy("department").agg(sum("salary").alias("total_salary"))
filtered_grouped_df.show()

The output will display total salaries by department, considering only employees with salaries greater than 1200:


+----------+------------+
|department|total_salary|
+----------+------------+
| Marketing|        4200|
|     Sales|        4400|
+----------+------------+

In summary, grouping by multiple columns in PySpark allows data analysts to explore complex relationships within their data. The ability to perform different aggregation functions and apply filters provides insights that are instrumental in making data-driven decisions.

About Editorial Team

Our Editorial Team is made up of tech enthusiasts who are highly skilled in Apache Spark, PySpark, and Machine Learning. They are also proficient in Python, Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They aren't just experts; they are passionate teachers. They are dedicated to making complex data concepts easy to understand through engaging and simple tutorials with examples.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top