Details
-
Improvement
-
Status: Closed
-
Minor
-
Resolution: Later
-
2.1.1
-
None
-
None
Description
In database scenarios, users sometimes use parameterized queries for repeated execution (e.g., by using prepared statements).
So, I think this functionality is also useful for Spark users.
What I suggest here seems to be like:
My prototype here: https://github.com/apache/spark/compare/master...maropu:PreparedStmt2
scala> Seq((1, 2), (2, 3)).toDF("col1", "col2").createOrReplaceTempView("t") // Define a query with a parameter placeholder named `val` scala> val df = sql("SELECT * FROM t WHERE col1 = $val") scala> df.explain == Physical Plan == *Project [_1#13 AS col1#16, _2#14 AS col2#17] +- *Filter (_1#13 = cast(parameterholder(val) as int)) +- LocalTableScan [_1#13, _2#14] // Apply optimizer rules and get an optimized logical plan with the parameter placeholder scala> val preparedDf = df.prepared // Bind an actual value and do execution scala> preparedDf.bindParam("val", 1).show() +----+----+ |col1|col2| +----+----+ | 1| 2| +----+----+
To implement this, my prototype adds a new expression leaf node named `ParameterHolder`.
In a binding phase, this node is replaced with `Literal` including an actual value by using `bindParam`.
Currently, Spark sometimes consumes much time to rewrite logical plans in `Optimizer` (e.g. constant propagation desribed in SPARK-19846).
So, I feel this approach is also helpful in that case:
def timer[R](f: => {}): Unit = { val count = 9 val iters = (0 until count).map { i => val t0 = System.nanoTime() f val t1 = System.nanoTime() val elapsed = t1 - t0 + 0.0 println(s"#$i: ${elapsed / 1000000000.0}") elapsed } println("Avg. Elapsed Time: " + ((iters.sum / count) / 1000000000.0) + "s") } import org.apache.spark.sql.Row import org.apache.spark.sql.types._ val numCols = 50 val df = spark.range(100).selectExpr((0 until numCols).map(i => s"id AS _c$i"): _*) // Add conditions to take much time in Optimizer val filter = (0 until 128).foldLeft(lit(false))((e, i) => e.or(df.col(df.columns(i % numCols)) === (rand() * 10).cast("int"))) val df2 = df.filter(filter).sort(df.columns(0)) // Regular path timer { df2.filter(df2.col(df2.columns(0)) === lit(3)).collect df2.filter(df2.col(df2.columns(0)) === lit(4)).collect df2.filter(df2.col(df2.columns(0)) === lit(5)).collect df2.filter(df2.col(df2.columns(0)) === lit(6)).collect df2.filter(df2.col(df2.columns(0)) === lit(7)).collect df2.filter(df2.col(df2.columns(0)) === lit(8)).collect } #0: 24.178487906 #1: 22.619839888 #2: 22.318617035 #3: 22.131305502 #4: 22.532095611 #5: 22.245152778 #6: 22.314114847 #7: 22.284385952 #8: 22.053593855 Avg. Elapsed Time: 22.519732597111112s // Prepared path val df3b = df2.filter(df2.col(df2.columns(0)) === param("val")).prepared timer { df3b.bindParam("val", 3).collect df3b.bindParam("val", 4).collect df3b.bindParam("val", 5).collect df3b.bindParam("val", 6).collect df3b.bindParam("val", 7).collect df3b.bindParam("val", 8).collect } #0: 0.744693912 #1: 0.743187129 #2: 0.745100003 #3: 0.721668718 #4: 0.757573342 #5: 0.763240883 #6: 0.731287275 #7: 0.728740601 #8: 0.674275592 Avg. Elapsed Time: 0.7344186061111112s
I'm not sure this approach is acceptable, so welcome any suggestion and advice.