Description
Trying to understand and debug the performance of some of our jobs, I started digging into what the whole stage codegen code was doing. We use a lot of case when statements, and I found that there were a lot of unused sub expressions that were left over after the subexpression elimination, and it gets worse the more expressions you have chained. The simple example:
import org.apache.spark.sql.functions._ import spark.implicits._ val myUdf = udf((s: String) => { println("In UDF") s.toUpperCase }) spark.range(5).select(when(length(myUdf($"id")) > 0, length(myUdf($"id")))).show()
Running the code, you'll see "In UDF" printed out 10 times. And if you change both to log(length(myUdf($"id")), "In UDF" will print out 20 times (one more for a cast from int to double and one more for the actual log calculation I think).
In the codegen for this (without the log), there are these initial subexpressions:
/* 076 */ UTF8String project_subExprValue_0 = project_subExpr_0(project_expr_0_0); /* 077 */ int project_subExprValue_1 = project_subExpr_1(project_expr_0_0); /* 078 */ UTF8String project_subExprValue_2 = project_subExpr_2(project_expr_0_0);
project_subExprValue_0 and project_subExprValue_2 are never actually used, so it's properly resolving the two expressions and sharing the result of project_subExprValue_1, but it's not removing the other sub expression calls it seems like.