Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-45282

Join loses records for cached datasets



    • Bug
    • Status: Resolved
    • Blocker
    • Resolution: Fixed
    • 3.4.1, 3.5.0
    • 3.4.2
    • SQL
    • spark 3.4.1 on apache hadoop 3.3.6 or kubernetes 1.26 or databricks 13.3


      we observed this issue on spark 3.4.1 but it is also present on 3.5.0. it is not present on spark 3.3.1.

      it only shows up in distributed environment. i cannot replicate in unit test. however i did get it to show up on hadoop cluster, kubernetes, and on databricks 13.3

      the issue is that records are dropped when two cached dataframes are joined. it seems in spark 3.4.1 in queryplan some Exchanges are dropped as an optimization while in spark 3.3.1 these Exhanges are still present. it seems to be an issue with AQE with canChangeCachedPlanOutputPartitioning=true.

      to reproduce on distributed cluster these settings needed are:

      spark.sql.adaptive.advisoryPartitionSizeInBytes 33554432
      spark.sql.adaptive.coalescePartitions.parallelismFirst false
      spark.sql.adaptive.enabled true
      spark.sql.optimizer.canChangeCachedPlanOutputPartitioning true 

      code using scala to reproduce is:

      import java.util.UUID
      import org.apache.spark.sql.functions.col
      import spark.implicits._
      val data = (1 to 1000000).toDS().map(i => UUID.randomUUID().toString).persist()
      val left = data.map(k => (k, 1))
      val right = data.map(k => (k, k)) // if i change this to k => (k, 1) it works!
      println("number of left " + left.count())
      println("number of right " + right.count())
      println("number of (left join right) " +
        left.toDF("key", "value1").join(right.toDF("key", "value2"), "key").count()
      val left1 = left
        .toDF("key", "value1")
        .repartition(col("key")) // comment out this line to make it work
      println("number of left1 " + left1.count())
      val right1 = right
        .toDF("key", "value2")
        .repartition(col("key")) // comment out this line to make it work
      println("number of right1 " + right1.count())
      println("number of (left1 join right1) " +  left1.join(right1, "key").count()) // this gives incorrect result

      this produces the following output:

      number of left 1000000
      number of right 1000000
      number of (left join right) 1000000
      number of left1 1000000
      number of right1 1000000
      number of (left1 join right1) 859531 

      note that the last number (the incorrect one) actually varies depending on settings and cluster size etc.



        Issue Links



              eejbyfeldt Emil Ejbyfeldt
              koert koert kuipers
              0 Vote for this issue
              7 Start watching this issue