Pivoting is a process in data transformation that reshapes data by converting unique values from one column into multiple columns in a new DataFrame, applying aggregation functions if needed. In Apache Spark, pivoting can be efficiently conducted using the DataFrame API. Below, we explore pivoting through a detailed guide including examples in PySpark and Scala.
Understanding Pivot
Pivoting a DataFrame typically involves three steps:
- Group by one or more columns.
- Pivot on a column with unique values.
- Apply an aggregation function.
Example Dataset
Let’s use an example dataset to demonstrate the pivot functionality. Consider the following DataFrame showing sales data:
+------+-------+------+-----+
| Year | Month | Item | Qty |
+------+-------+------+-----+
| 2021 | Jan | Pen | 10 |
| 2021 | Jan | Pencils | 20 |
| 2021 | Feb | Pen | 15 |
| 2021 | Feb | Pencils | 25 |
| 2022 | Jan | Pen | 5 |
| 2022 | Jan | Pencils | 30 |
| 2022 | Feb | Pen | 20 |
| 2022 | Feb | Pencils | 35 |
+------+-------+------+-----+
Pivoting using PySpark
Below is the code snippet to pivot this dataset using PySpark:
from pyspark.sql import SparkSession
# Initialize SparkSession
spark = SparkSession.builder.appName("example-pivot").getOrCreate()
# Sample data
data = [
(2021, "Jan", "Pen", 10),
(2021, "Jan", "Pencils", 20),
(2021, "Feb", "Pen", 15),
(2021, "Feb", "Pencils", 25),
(2022, "Jan", "Pen", 5),
(2022, "Jan", "Pencils", 30),
(2022, "Feb", "Pen", 20),
(2022, "Feb", "Pencils", 35)
]
# Create DataFrame
columns = ["Year", "Month", "Item", "Qty"]
df = spark.createDataFrame(data, columns)
# Pivot DataFrame
pivot_df = df.groupBy("Year", "Month").pivot("Item").sum("Qty")
# Show results
pivot_df.show()
+----+-----+---+-------+
|Year|Month| Pen|Pencils|
+----+-----+----+-------+
|2021| Jan| 10| 20|
|2021| Feb| 15| 25|
|2022| Jan| 5| 30|
|2022| Feb| 20| 35|
+----+-----+----+-------+
Pivoting using Scala
Below is the code snippet to pivot this dataset using Scala in Spark:
import org.apache.spark.sql.SparkSession
// Initialize SparkSession
val spark = SparkSession.builder.appName("example-pivot").getOrCreate()
// Sample data
val data = Seq(
(2021, "Jan", "Pen", 10),
(2021, "Jan", "Pencils", 20),
(2021, "Feb", "Pen", 15),
(2021, "Feb", "Pencils", 25),
(2022, "Jan", "Pen", 5),
(2022, "Jan", "Pencils", 30),
(2022, "Feb", "Pen", 20),
(2022, "Feb", "Pencils", 35)
)
// Create DataFrame
import spark.implicits._
val df = data.toDF("Year", "Month", "Item", "Qty")
// Pivot DataFrame
val pivotDF = df.groupBy("Year", "Month").pivot("Item").sum("Qty")
// Show results
pivotDF.show()
+----+-----+---+-------+
|Year|Month| Pen|Pencils|
+----+-----+----+-------+
|2021| Jan| 10| 20|
|2021| Feb| 15| 25|
|2022| Jan| 5| 30|
|2022| Feb| 20| 35|
+----+-----+----+-------+
Conclusion
Pivoting is a powerful tool in data transformation and analysis. Using the DataFrame API, Spark makes it easy and efficient to pivot large datasets. Whether using PySpark or Scala, the process involves grouping, pivoting, and aggregating, leading to a reorganized view of the original dataset. This guide should help get you started with pivoting in Spark.