Description
SPARK-41468 tried to fix a bug but introduced a new regression. Its change to EquivalentExpressions added a supportedExpression() guard to the addExprTree() and getExprState() methods, but didn't add the same guard to the other "add" entry point – addExpr().
As such, uses that add single expressions to CSE via addExpr() may succeed, but upon retrieval via getExprState() it'd inconsistently get a None due to failing the guard.
We need to make sure the "add" and "get" methods are consistent. It could be done by one of:
1. Adding the same supportedExpression() guard to addExpr(), or
2. Removing the guard from getExprState(), relying solely on the guard on the "add" path to make sure only intended state is added.
(or other alternative refactorings to fuse the guard into various methods to make it more efficient)
There are pros and cons to the two directions above, because addExpr() used to allow (potentially incorrect) more expressions to get CSE'd, making it more restrictive may cause performance regressions (for the cases that happened to work).
Example:
select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)
Running this query on Spark 3.2 branch returns the correct value:
scala> spark.sql("select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)").collect
res0: Array[org.apache.spark.sql.Row] = Array([WrappedArray(1),WrappedArray(1)])
Here, transform(array(id), x -> x) is an AggregateExpression that was (potentially unsafely) recognized by addExpr() as a common subexpression, and getExprState() doesn't do extra guarding, so during physical planning, in PhysicalAggregation this expression gets CSE'd in both the aggregation expression list and the result expressions list.
AdaptiveSparkPlan isFinalPlan=false +- SortAggregate(key=[], functions=[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))]) +- Range (0, 2, step=1, splits=16)
Running the same query on current master triggers an error when binding the result expression to the aggregate expression in the Aggregate operators (for a WSCG-enabled operator like HashAggregateExec, the same error would show up during codegen):
ERROR TaskSetManager: Task 0 in stage 2.0 failed 1 times; aborting job org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 (TID 16) (ip-10-110-16-93.us-west-2.compute.internal executor driver): java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in [max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false)))#3] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:512) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:104) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:512) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:517) at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1249) at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1248) at org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:532) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:517) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:488) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:456) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94) at scala.collection.immutable.List.map(List.scala:297) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:161) at org.apache.spark.sql.execution.aggregate.AggregationIterator.generateResultProjection(AggregationIterator.scala:246) at org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:296) at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.<init>(SortBasedAggregationIterator.scala:49) at org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1(SortAggregateExec.scala:79) at org.apache.spark.sql.execution.aggregate.SortAggregateExec.$anonfun$doExecute$1$adapted(SortAggregateExec.scala:59) ...
Note that the aggregate expressions are deduplicated in PhysicalAggregation, but the result expressions were unable to deduplicate consistently due to the bug mentioned in this ticket.
AdaptiveSparkPlan isFinalPlan=false +- SortAggregate(key=[], functions=[max(transform(array(id#15L), lambdafunction(lambda x#16L, lambda x#16L, false)))]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=38] +- SortAggregate(key=[], functions=[partial_max(transform(array(id#15L), lambdafunction(lambda x#16L, lambda x#16L, false)))]) +- Range (0, 2, step=1, splits=16)
Fixing it via method 1 is more correct than method 2 in terms of avoiding incorrect CSE:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 330d66a21b..12def60042 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -40,7 +40,11 @@ class EquivalentExpressions { * Returns true if there was already a matching expression. */ def addExpr(expr: Expression): Boolean = { - updateExprInMap(expr, equivalenceMap) + if (supportedExpression(expr)) { + updateExprInMap(expr, equivalenceMap) + } else { + false + } } /**
the query runs correctly again, but this time the aggregate expression is NOT CSE'd anymore, done consistently for both aggregate expressions and result expressions:
AdaptiveSparkPlan isFinalPlan=false +- SortAggregate(key=[], functions=[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false))), max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] +- SortAggregate(key=[], functions=[partial_max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, false))), partial_max(transform(array(id#0L), lambdafunction(lambda x#2L, lambda x#2L, false)))]) +- Range (0, 2, step=1, splits=16)
and for this particular case, the CSE that used to take place was actually okay, so losing CSE here means performance regression.