Apache Spark: Fix data skew issue using salting technique

Islam ElbannaIslam Elbanna
4 min read

When working with large datasets in Apache Spark, a common performance issue is data skew. This occurs when a few keys dominate the data distribution, leading to uneven partitions and slow queries. It mainly happens during operations that require shuffling, like joins or even regular aggregations.

A practical way to reduce skew is salting, which involves artificially spreading out heavy keys across multiple partitions. In this post, I’ll guide you through this with a practical example.

How Salting Resolves Data Skew Issues

By adding a randomly generated number to the join key and then joining over this combined key, we can distribute large keys more evenly. This makes the data distribution more uniform and spreads the load across more workers, instead of sending most of the data to one worker and leaving the others idle.

Benefits of Salting

  • Reduced Skew: Spreads data evenly across partitions, preventing a few workers overload and improves utilization.

  • Improved Performance: Speeds up joins and aggregations by balancing the workload.

  • Avoids Resource Contention: Reduces the risk of out-of-memory errors caused by large, uneven partitions.

When to Use Salting

During joins or aggregations with skewed keys, use salting when you notice long shuffle times or executor failures due to data skew. It's also helpful in real-time streaming applications where partitioning affects data processing efficiency, or when most workers are idle while a few are stuck in a running state.

Salting Example in Scala

Input

Let's generate some data with an unbalanced number of rows. We can assume there are two datasets we need to join: one is a large dataset, and the other is a small dataset.

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

// Simulated large dataset with skew
val largeDF = Seq(
  (1, "txn1"), (1, "txn2"), (1, "txn3"), (2, "txn4"), (3, "txn5")
).toDF("customer_id", "transaction")

// Small dataset
val smallDF = Seq(
  (1, "Ahmed"), (2, "Ali"), (3, "Hassan")
).toDF("customer_id", "name")

Let’s add the salting column to the large datasets, which we use randomization to spreed the values of the large key into smaller partitions


// Step 1: create a salting key in the large dataset
val numBuckets = 3
val saltedLargeDF = largeDF.
    withColumn("salt", (rand() * numBuckets).cast("int")).
    withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt"))

saltedLargeDF.show()
+-----------+-----------+----+------------------+
|customer_id|transaction|salt|salted_customer_id|
+-----------+-----------+----+------------------+
|          1|       txn1|   1|               1_1|
|          1|       txn2|   1|               1_1|
|          1|       txn3|   2|               1_2|
|          2|       txn4|   2|               2_2|
|          3|       txn5|   0|               3_0|
+-----------+-----------+----+------------------+

To make sure we cover all possible randomized salted keys in the large datasets, we need to explode the small dataset with all possible salted values


// Step 2: Explode rows in smallDF for possible salted keys
val saltedSmallDF = (0 until numBuckets).toDF("salt").
    crossJoin(smallDF).
    withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt")) 

saltedSmallDF.show()
+----+-----------+------+------------------+
|salt|customer_id|  name|salted_customer_id|
+----+-----------+------+------------------+
|   0|          1| Ahmed|               1_0|
|   1|          1| Ahmed|               1_1|
|   2|          1| Ahmed|               1_2|
|   0|          2|   Ali|               2_0|
|   1|          2|   Ali|               2_1|
|   2|          2|   Ali|               2_2|
|   0|          3|Hassan|               3_0|
|   1|          3|Hassan|               3_1|
|   2|          3|Hassan|               3_2|
+----+-----------+------+------------------+

Now we can easily join the two datasets

// Step 3: Perform salted join
val joinedDF = saltedLargeDF.
    join(saltedSmallDF, Seq("salted_customer_id", "customer_id"), "inner").
    select("customer_id", "transaction", "name")

joinedDF.show()
+-----------+-----------+------+
|customer_id|transaction|  name|
+-----------+-----------+------+
|          1|       txn2| Ahmed|
|          1|       txn1| Ahmed|
|          1|       txn3| Ahmed|
|          2|       txn4|   Ali|
|          3|       txn5|Hassan|
+-----------+-----------+------+

Tuning Tip: Choosing numBuckets

  • If you set numBuckets = 100, each key can be divided into 100 sub-partitions. However, be cautious because using too many buckets can decrease performance, especially for keys with little data. Always test different values based on the skew profile of your dataset.

  • If you know how to identify the skewed keys, then you can apply the salting for those keys only, and set the salting for other keys as literal 0, e.x.

    •     // Step 1: create a salting key in the large dataset
          val numBuckets = 3
          val saltedLargeDF = largeDF.
              withColumn("salt", when($"customer_id" === 1, (rand() * numBuckets).cast("int")).otherwise(lit(0))).
              withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt"))
      
          // Step 2: Explode rows in smallDF for possible salted keys
          val saltedSmallDF = (0 until numBuckets).toDF("salt").
              crossJoin(smallDF.filter($"customer_id" === 1)).
              select("customer_id", "salt", "name").
              union(smallDF.filter($"customer_id" =!= 1).withColumn("salt", lit(0)).select("customer_id", "salt", "name")).
              withColumn("salted_customer_id", concat($"customer_id", lit("_"), $"salt"))
      

Rule of Thumb:
Start small (e.g., 10-20) and increase gradually based on observed shuffle sizes and task runtime.


Final Thoughts

Salting is an effective and simple method to manage skew in Apache Spark when traditional partitioning or hints (SKEWED JOIN) are insufficient. With the right tuning and monitoring, this technique can significantly decrease job execution times on highly skewed datasets.

10
Subscribe to my newsletter

Read articles from Islam Elbanna directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Islam Elbanna
Islam Elbanna

I am a software engineer with over 12 years of experience in the IT industry, including 4+ years specializing in big data technologies such as Hadoop, Sqoop, Spark, and more, along with a foundation in machine learning. With 7+ years in software engineering, I have extensive experience in web development, utilizing Java, HTML, Bootstrap, Angular, and various frameworks to build and deploy high-scale distributed systems. Additionally, I possess DevOps skills, with hands-on experience managing AWS cloud infrastructure and Linux systems.