How to Automatically and Elegantly Flatten DataFrames in Spark SQL?

In Apache Spark, flattening nested DataFrames can be a common task, particularly when dealing with complex data structures like JSON. To achieve this elegantly, we can use the PySpark and Scala APIs to recursively flatten the DataFrame. Let’s start with an example in PySpark.

Flattening DataFrames in PySpark

Consider a nested DataFrame that we want to flatten:


from pyspark.sql import SparkSession
from pyspark.sql.functions import col, explode, concat_ws

# Initialize Spark session
spark = SparkSession.builder.appName("FlattenDataFrame").getOrCreate()

# Sample nested JSON data
data = [
    {
        "id": 1,
        "name": "John",
        "address": {"city": "New York", "street": "5th Avenue"},
        "phones": [{"type": "home", "number": "111-111-1111"}, {"type": "work", "number": "222-222-2222"}]
    },
    {
        "id": 2,
        "name": "Jane",
        "address": {"city": "San Francisco", "street": "Market Street"},
        "phones": [{"type": "home", "number": "333-333-3333"}]
    }
]

df = spark.createDataFrame(data)

# Function to flatten DataFrame
def flatten(df):
    flat_cols = []
    nested_cols = []

    for col_name, dtype in df.dtypes:
        if "struct" in dtype:
            nested_cols.append(col_name)
        else:
            flat_cols.append(col_name)

    flat_df = df.select(flat_cols + [col(n + "." + c).alias(n + "_" + c) for n in nested_cols for c in df.select(n + ".*").columns])

    return flat_df

flattened_df = flatten(df)
flattened_df.show(truncate=False)

+---+----+-----------+-------------------+------------------+
|id |name|address_city|address_street     |phones_type       |
+---+----+-----------+-------------------+------------------+
|1  |John|New York   |5th Avenue         |home              |
|1  |John|New York   |5th Avenue         |work              |
|2  |Jane|San Francisco|Market Street   |home              |
+---+----+-----------+-------------------+------------------+

Explanation

In the example above:

  • We first categorize the columns in the DataFrame as either flat (non-nested) or nested (structs/arrays).
  • We then select the flat columns directly and flatten the nested columns by expanding them using the `col` and aliasing their sub-columns.
  • Finally, the flattened DataFrame is displayed using the `.show()` method.

Flattening DataFrames in Scala

Now, let’s look at a similar example in Scala:


import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.sql.functions.col

// Initialize Spark session
val spark = SparkSession.builder.appName("FlattenDataFrame").getOrCreate()

// Sample nested JSON data
val data = Seq(
  (1, "John", (Map("city" -> "New York", "street" -> "5th Avenue"), Seq(Map("type" -> "home", "number" -> "111-111-1111"), Map("type" -> "work", "number" -> "222-222-2222")))),
  (2, "Jane", (Map("city" -> "San Francisco", "street" -> "Market Street"), Seq(Map("type" -> "home", "number" -> "333-333-3333"))))
)

val df = spark.createDataFrame(data).toDF("id", "name", "address", "phones")

// Function to flatten DataFrame
def flatten(df: DataFrame): DataFrame = {
  val flatCols = df.schema.fields.collect {
    case field if !field.dataType.typeName.equals("struct") => field.name
  }

  val nestedCols = df.schema.fields.collect {
    case field if field.dataType.typeName.equals("struct") => field.name
  }

  val flatDf = df.select(flatCols.map(col) ++ nestedCols.flatMap(n => df.select(s"$n.*").columns.map(c => col(s"$n.$c").alias(s"${n}_$c"))): _*)

  flatDf
}

val flattenedDf = flatten(df)
flattenedDf.show(false)

+---+----+-----------+-------------------+
|id |name|address_city|address_street   |
+---+----+-----------+-------------------+
|1  |John|New York   |5th Avenue         |
|2  |Jane|San Francisco|Market Street    |
+---+----+-----------+-------------------+

Explanation

Similar to the PySpark example, the Scala code achieves the same result of flattening a nested DataFrame:

  • We categorize the columns as either flat or nested.
  • We then select the flat columns directly and use the `col` function to flatten the nested struct columns by expanding their sub-columns.
  • The flattened DataFrame is displayed using the `.show(false)` method.

In summary, flattening nested DataFrames in Spark can be elegantly achieved through recursive expansion of nested columns. Whether using PySpark, Scala, or another API, similar principles apply to unwrapping nested structures to create a flat schema.

About Editorial Team

Our Editorial Team is made up of tech enthusiasts who are highly skilled in Apache Spark, PySpark, and Machine Learning. They are also proficient in Python, Pandas, R, Hive, PostgreSQL, Snowflake, and Databricks. They aren't just experts; they are passionate teachers. They are dedicated to making complex data concepts easy to understand through engaging and simple tutorials with examples.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top