Description
from pyspark.sql.functions import udf spark.range(10).write.mode("overwrite").parquet("/tmp/abc") @udf(returnType='string') def my_udf(arg): return arg df = spark.read.parquet("/tmp/abc") df.limit(10).withColumn("prediction", my_udf(df["id"])).explain()
As an example. since Python UDFs are executed asynchronously, so pushing limits benefit the performance.
== Physical Plan ==
CollectLimit 10
+- *(2) Project [id#3L, pythonUDF0#10 AS prediction#6]
+- BatchEvalPython [my_udf(id#3L)#5], [pythonUDF0#10]
+- *(1) ColumnarToRow
+- FileScan parquet [id#3L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/abc], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint>
This is a regression from Spark 3.3.1:
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Project [id#3L, pythonUDF0#10 AS prediction#6] +- BatchEvalPython [my_udf(id#3L)#5], [pythonUDF0#10] +- GlobalLimit 10 +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=30] +- LocalLimit 10 +- FileScan parquet [id#3L] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/abc], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:bigint>