Details
-
Bug
-
Status: Resolved
-
Minor
-
Resolution: Not A Problem
-
2.3.0, 2.3.1
-
None
-
None
-
Spark 2.3.1
Pandas 0.23.1
Description
I noticed that having `NaN` values when using the new Pandas UDF feature triggers a JVM exception. Not sure if this is an issue with PySpark or PyArrow. Here is a somewhat contrived example to showcase the problem.
In [1]: import pandas as pd ...: from pyspark.sql.functions import lit, pandas_udf, PandasUDFType In [2]: d = [{'key': 'a', 'value': 1}, {'key': 'a', 'value': 2}, {'key': 'b', 'value': 3}, {'key': 'b', 'value': -2}] df = spark.createDataFrame(d, "key: string, value: int") df.show() +---+-----+ |key|value| +---+-----+ | a| 1| | a| 2| | b| 3| | b| -2| +---+-----+ In [3]: df_tmp = df.withColumn('new', lit(1.0)) # add a DoubleType column df_tmp.printSchema() root |-- key: string (nullable = true) |-- value: integer (nullable = true) |-- new: double (nullable = false)
And the Pandas UDF is simply creating a new column where negative values would be set to a particular float, in this case INF and it works fine
In [4]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP) ...: def func(pdf): ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('inf')) ...: return pdf In [5]: df.groupby('key').apply(func).show() +---+-----+----------+ |key|value|new| +---+-----+----------+ | b| 3| 3.0| | b| -2| Infinity| | a| 1| 1.0| | a| 2| 2.0| +---+-----+----------+
However if we set this value to NaN then it triggers an exception:
In [6]: @pandas_udf(df_tmp.schema, PandasUDFType.GROUPED_MAP) ...: def func(pdf): ...: pdf['new'] = pdf['value'].where(pdf['value'] > 0, float('nan')) ...: return pdf ...: ...: df.groupby('key').apply(func).show() [Stage 23:======================================================> (73 + 2) / 75]2018-07-07 16:26:27 ERROR Executor:91 - Exception in task 36.0 in stage 23.0 (TID 414) java.lang.IllegalStateException: Value at index is null at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98) at org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344) at org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99) at org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745) 2018-07-07 16:26:27 WARN TaskSetManager:66 - Lost task 36.0 in stage 23.0 (TID 414, localhost, executor driver): java.lang.IllegalStateException: Value at index is null at org.apache.arrow.vector.Float8Vector.get(Float8Vector.java:98) at org.apache.spark.sql.vectorized.ArrowColumnVector$DoubleAccessor.getDouble(ArrowColumnVector.java:344) at org.apache.spark.sql.vectorized.ArrowColumnVector.getDouble(ArrowColumnVector.java:99) at org.apache.spark.sql.execution.vectorized.MutableColumnarRow.getDouble(MutableColumnarRow.java:126) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source) at scala.collection.Iterator$$anon$11.next(Iterator.scala:409) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253) at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324) at org.apache.spark.rdd.RDD.iterator(RDD.scala:288) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87) at org.apache.spark.scheduler.Task.run(Task.scala:109) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617) at java.lang.Thread.run(Thread.java:745)