Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-23791

Sub-optimal generated code for sum aggregating

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Duplicate
    • 2.2.0, 2.3.0
    • None
    • Optimizer, SQL
    • Important

    Description

      It appears to be that with wholeStage codegen enabled simple spark job performing sum aggregation of 50 columns runs ~4 timer slower than without wholeStage codegen.

      Please check test case code. Please note that udf is only to prevent elimination optimizations that could be applied to literals. 

      import org.apache.spark.sql.functions._
      import org.apache.spark.sql.{Column, DataFrame, SparkSession}
      import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED
      
      object SPARK_23791 {
      
        def main(args: Array[String]): Unit = {
      
          val spark = SparkSession
            .builder()
            .master("local[4]")
            .appName("test")
            .getOrCreate()
      
          def addConstColumns(prefix: String, cnt: Int, value: Column)(inputDF: DataFrame) =
            (0 until cnt).foldLeft(inputDF)((df, idx) => df.withColumn(s"$prefix$idx", value))
      
          val dummy = udf(() => Option.empty[Int])
      
          def test(cnt: Int = 50, rows: Int = 5000000, grps: Int = 1000): Double = {
            val t0 = System.nanoTime()
            spark.range(rows).toDF()
              .withColumn("grp", col("id").mod(grps))
              .transform(addConstColumns("null_", cnt, dummy()))
              .groupBy("grp")
              .agg(sum("null_0"), (1 until cnt).map(idx => sum(s"null_$idx")): _*)
              .collect()
            val t1 = System.nanoTime()
            (t1 - t0) / 1e9
          }
      
          val timings = for (i <- 1 to 3) yield {
            spark.sessionState.conf.setConf(WHOLESTAGE_CODEGEN_ENABLED, true)
            val with_wholestage = test()
            spark.sessionState.conf.setConf(WHOLESTAGE_CODEGEN_ENABLED, false)
            val without_wholestage = test()
            (with_wholestage, without_wholestage)
          }
      
          timings.foreach(println)
      
          println("Press enter ...")
          System.in.read()
        }
      }
      

      Attachments

        Issue Links

          Activity

            People

              Unassigned Unassigned
              rednikotin Valentin Nikotin
              Votes:
              1 Vote for this issue
              Watchers:
              5 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved:

                Time Tracking

                  Estimated:
                  Original Estimate - 24h
                  24h
                  Remaining:
                  Remaining Estimate - 24h
                  24h
                  Logged:
                  Time Spent - Not Specified
                  Not Specified