Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-49261

Correlation between lit and round during grouping

Attach filesAttach ScreenshotVotersWatch issueWatchersCreate sub-taskLinkCloneUpdate Comment AuthorReplace String in CommentUpdate Comment VisibilityDelete Comments
    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 3.2.4, 3.5.0, 3.3.4, 3.4.3
    • 4.0.0, 3.4.4, 3.5.4
    • PySpark
    • Databricks DBR 14.3
      Spark 3.5.0
      Scala 2.12

    Description

      Running following code:

       

      import pyspark.sql.functions as F
      from decimal import Decimal
      
      data = [
        (1, 100, Decimal("1.1"),  "L", True),
        (2, 200, Decimal("1.2"),  "H", False),
        (2, 300, Decimal("2.345"), "E", False),
      ]
      
      columns = ["group_a", "id", "amount", "selector_a", "selector_b"]
      
      df = spark.createDataFrame(data, schema=columns)
      
      df_final = (
        df.select(
          F.lit(6).alias("run_number"),
          F.lit("AA").alias("run_type"),
          F.col("group_a"),
          F.col("id"),
          F.col("amount"),
          F.col("selector_a"),
          F.col("selector_b"),
        )
        .withColumn(
          "amount_c",
          F.when(
            (F.col("selector_b") == False)
            & (F.col("selector_a").isin(["L", "H", "E"])),
            F.col("amount"),
          ).otherwise(F.lit(None))
        )
        .withColumn(
          "count_of_amount_c",
          F.when(
            (F.col("selector_b") == False)
            & (F.col("selector_a").isin(["L", "H", "E"])),
            F.col("id")
          ).otherwise(F.lit(None))
        )
      )
      
      group_by_cols = [
        "run_number",
        "group_a",
        "run_type"
      ]
      
      df_final = df_final.groupBy(group_by_cols).agg(
        F.countDistinct("id").alias("count_of_amount"),
        F.round(F.sum("amount")/ 1000, 1).alias("total_amount"),
        F.sum("amount_c").alias("amount_c"),
        F.countDistinct("count_of_amount_c").alias(
          "count_of_amount_c"
        ),
      )
      
      df_final = (
        df_final
        .withColumn(
          "total_amount",
          F.round(F.col("total_amount") / 1000, 6),
        )
        .withColumn(
          "count_of_amount", F.col("count_of_amount").cast("int")
        )
        .withColumn(
          "count_of_amount_c",
          F.when(
            F.col("amount_c").isNull(), F.lit(None).cast("int")
          ).otherwise(F.col("count_of_amount_c").cast("int")),
        )
      )
      
      df_final = df_final.select(
        F.col("total_amount"),
        "run_number",
        "group_a",
        "run_type",
        "count_of_amount",
        "amount_c",
        "count_of_amount_c",
      )
      
      df_final.show() 

      Produces error:

      [[INTERNAL_ERROR](https://docs.microsoft.com/azure/databricks/error-messages/error-classes#internal_error)] Couldn't find total_amount#1046 in [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L] SQLSTATE: XX000 

      With stack trace:

      org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find total_amount#1046 in [group_a#984L,count_of_amount#1054,amount_c#1033,count_of_amount_c#1034L] SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:97) at org.apache.spark.SparkException$.internalError(SparkException.scala:101) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:81) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:74) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:505) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:83) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:505) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:481) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:449) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:74) at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:97) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:97) at org.apache.spark.sql.execution.ProjectExec.doConsume(basicPhysicalOperators.scala:74) at org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:202) at org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.consume(HashAggregateExec.scala:51) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.generateResultFunction(HashAggregateExec.scala:411) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsumeWithKeys(HashAggregateExec.scala:995) at org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume(AggregateCodegenSupport.scala:81) at org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doConsume$(AggregateCodegenSupport.scala:77) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doConsume(HashAggregateExec.scala:51) at org.apache.spark.sql.execution.CodegenSupport.constructDoConsumeFunction(WholeStageCodegenExec.scala:229) at org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:200) at org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:155) at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:506) at org.apache.spark.sql.execution.InputRDDCodegen.doProduce(WholeStageCodegenExec.scala:493) at org.apache.spark.sql.execution.InputRDDCodegen.doProduce$(WholeStageCodegenExec.scala:466) at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:506) at org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) at org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) at org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:506) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduceWithKeys(HashAggregateExec.scala:629) at org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce(AggregateCodegenSupport.scala:73) at org.apache.spark.sql.execution.aggregate.AggregateCodegenSupport.doProduce$(AggregateCodegenSupport.scala:69) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.doProduce(HashAggregateExec.scala:51) at org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) at org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) at org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) at org.apache.spark.sql.execution.aggregate.HashAggregateExec.produce(HashAggregateExec.scala:51) at org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:59) at org.apache.spark.sql.execution.CodegenSupport.$anonfun$produce$1(WholeStageCodegenExec.scala:100) at org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at org.apache.spark.sql.execution.CodegenSupport.produce(WholeStageCodegenExec.scala:95) at org.apache.spark.sql.execution.CodegenSupport.produce$(WholeStageCodegenExec.scala:94) at org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:46) at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:666) at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:729) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$2(SparkPlan.scala:327) at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:327) at org.apache.spark.sql.execution.SparkPlan$.org$apache$spark$sql$execution$SparkPlan$$withExecuteQueryLogging(SparkPlan.scala:130) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:385) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:381) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:322) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:117) at org.apache.spark.sql.execution.collect.Collector$.collect(Collector.scala:131) at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:94) at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:90) at org.apache.spark.sql.execution.qrc.InternalRowFormat$.collect(cachedSparkResults.scala:78) at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$1(ResultCacheManager.scala:549) at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:94) at org.apache.spark.sql.execution.qrc.ResultCacheManager.collectResult$1(ResultCacheManager.scala:540) at org.apache.spark.sql.execution.qrc.ResultCacheManager.$anonfun$computeResult$2(ResultCacheManager.scala:555) at org.apache.spark.sql.execution.adaptive.ResultQueryStageExec.$anonfun$doMaterialize$1(QueryStageExec.scala:663) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:1175) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$6(SQLExecution.scala:778) at com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$5(SQLExecution.scala:778) at com.databricks.util.LexicalThreadLocal$Handle.runWith(LexicalThreadLocal.scala:63) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$4(SQLExecution.scala:778) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$3(SQLExecution.scala:777) at scala.util.DynamicVariable.withValue(DynamicVariable.scala:62) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$2(SQLExecution.scala:776) at org.apache.spark.sql.execution.SQLExecution$.withOptimisticTransaction(SQLExecution.scala:798) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:775) at java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1604) at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.$anonfun$run$1(SparkThreadLocalForwardingThreadPoolExecutor.scala:134) at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) at com.databricks.spark.util.IdentityClaim$.withClaim(IdentityClaim.scala:48) at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.$anonfun$runWithCaptured$4(SparkThreadLocalForwardingThreadPoolExecutor.scala:91) at com.databricks.unity.UCSEphemeralState$Handle.runWith(UCSEphemeralState.scala:45) at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:90) at org.apache.spark.util.threads.SparkThreadLocalCapturingHelper.runWithCaptured$(SparkThreadLocalForwardingThreadPoolExecutor.scala:67) at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.runWithCaptured(SparkThreadLocalForwardingThreadPoolExecutor.scala:131) at org.apache.spark.util.threads.SparkThreadLocalCapturingRunnable.run(SparkThreadLocalForwardingThreadPoolExecutor.scala:134) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:750)
       

       

      It seems to be a correlation between F.lit(6).alias("run_number") and F.round(F.col("total_amount") / 1000, 6). If both lit and scale in round are set to the same number i.e. 6 code fails.

      If numbers are different all works.

      Moving F.lit(6).alias("run_number") to the final select also solves the problem when both numbers in lit and scale in round are the same.

      Example of the working code:

      import pyspark.sql.functions as F
      from decimal import Decimal
      
      data = [  (1, 100, Decimal("1.1"),  "L", True),
        (2, 200, Decimal("1.2"),  "H", False),
        (2, 300, Decimal("2.345"), "E", False),
      ]
      
      columns = ["group_a", "id", "amount", "selector_a", "selector_b"]
      
      df = spark.createDataFrame(data, schema=columns)
      
      df_final = (
        df.select(
          F.lit(7).alias("run_number"),
          F.lit("AA").alias("run_type"),
          F.col("group_a"),
          F.col("id"),
          F.col("amount"),
          F.col("selector_a"),
          F.col("selector_b"),
        )
        .withColumn(
          "amount_c",
          F.when(
            (F.col("selector_b") == False)
            & (F.col("selector_a").isin(["L", "H", "E"])),
            F.col("amount"),
          ).otherwise(F.lit(None))
        )
        .withColumn(
          "count_of_amount_c",
          F.when(
            (F.col("selector_b") == False)
            & (F.col("selector_a").isin(["L", "H", "E"])),
            F.col("id")
          ).otherwise(F.lit(None))
        )
      )
      
      group_by_cols = [
        "run_number",
        "group_a",
        "run_type"
      ]
      
      df_final = df_final.groupBy(group_by_cols).agg(
        F.countDistinct("id").alias("count_of_amount"),
        F.round(F.sum("amount")/ 1000, 1).alias("total_amount"),
        F.sum("amount_c").alias("amount_c"),
        F.countDistinct("count_of_amount_c").alias(
          "count_of_amount_c"
        ),
      )
      
      df_final = (
        df_final
        .withColumn(
          "total_amount",
          F.round(F.col("total_amount") / 1000, 6),
        )
        .withColumn(
          "count_of_amount", F.col("count_of_amount").cast("int")
        )
        .withColumn(
          "count_of_amount_c",
          F.when(
            F.col("amount_c").isNull(), F.lit(None).cast("int")
          ).otherwise(F.col("count_of_amount_c").cast("int")),
        )
      )
      
      df_final = df_final.select(
        F.col("total_amount"),
        "run_number",
        "group_a",
        "run_type",
        "count_of_amount",
        "amount_c",
        "count_of_amount_c",
      )
      
      df_final.show() 

      Output:

      +------------+----------+-------+--------+---------------+--------------------+-----------------+
      |total_amount|run_number|group_a|run_type|count_of_amount|            amount_c|count_of_amount_c|
      +------------+----------+-------+--------+---------------+--------------------+-----------------+
      |    0.000000|         7|      2|      AA|              2|3.545000000000000000|                2|
      |    0.000000|         7|      1|      AA|              1|                NULL|             NULL|
      +------------+----------+-------+--------+---------------+--------------------+-----------------+

      Expected behavior:

      Values used in the lit function shouldn't interfere with the scale parameter in the round function

       

       

       

      Attachments

        Activity

          This comment will be Viewable by All Users Viewable by All Users
          Cancel

          People

            bersprockets Bruce Robbins
            krislig Krystian Kulig
            Votes:
            0 Vote for this issue
            Watchers:
            4 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved:

              Slack

                Issue deployment