Description
Currently, pandas udfs do not support columns of vectors as input. Only columns of arrays. This means that feature columns that contain Dense- or Sparse vectors generated by CountVectorizer for example are not supported by pandas udfs out of the box. One needs to convert vectors into arrays first. It was not documented anywhere and I had to find out by trial and error. Below is an example.
from pyspark.sql.functions import udf, pandas_udf import pyspark.sql.functions as F from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT from pyspark.sql.types import * import numpy as np columns = ['features','id'] vals = [ (DenseVector([1, 2, 1, 3]),1), (DenseVector([2, 2, 1, 3]),2) ] sdf = spark.createDataFrame(vals,columns) sdf.show() +-----------------+---+ | features| id| +-----------------+---+ |[1.0,2.0,1.0,3.0]| 1| |[2.0,2.0,1.0,3.0]| 2| +-----------------+---+
@udf(returnType=ArrayType(FloatType())) def vector_to_array(v): # convert column of vectors into column of arrays a = v.values.tolist() return a sdf = sdf.withColumn('features_array',vector_to_array('features')) sdf.show() sdf.dtypes +-----------------+---+--------------------+ | features| id| features_array| +-----------------+---+--------------------+ |[1.0,2.0,1.0,3.0]| 1|[1.0, 2.0, 1.0, 3.0]| |[2.0,2.0,1.0,3.0]| 2|[2.0, 2.0, 1.0, 3.0]| +-----------------+---+--------------------+ [('features', 'vector'), ('id', 'bigint'), ('features_array', 'array<float>')]
import pandas as pd @pandas_udf(LongType()) def _pandas_udf(v): res = [] for i in v: res.append(i.mean()) return pd.Series(res) sdf.select(_pandas_udf('features_array')).show() +---------------------------+ |_pandas_udf(features_array)| +---------------------------+ | 1| | 2| +---------------------------+
But If I use the vector column I get the following error.
sdf.select(_pandas_udf('features')).show() --------------------------------------------------------------------------- Py4JJavaError Traceback (most recent call last) <ipython-input-74-d93e4117f661> in <module> 13 14 ---> 15 sdf.select(_pandas_udf('features')).show() ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/dataframe.py in show(self, n, truncate, vertical) 376 """ 377 if isinstance(truncate, bool) and truncate: --> 378 print(self._jdf.showString(n, 20, vertical)) 379 else: 380 print(self._jdf.showString(n, int(truncate), vertical)) ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args) 1255 answer = self.gateway_client.send_command(command) 1256 return_value = get_return_value( -> 1257 answer, self.gateway_client, self.target_id, self.name) 1258 1259 for temp_arg in temp_args: ~/.pyenv/versions/anaconda3-5.3.1/lib/python3.7/site-packages/pyspark/sql/utils.py in deco(*a, **kw) 61 def deco(*a, **kw): 62 try: ---> 63 return f(*a, **kw) 64 except py4j.protocol.Py4JJavaError as e: 65 s = e.java_exception.toString() ~/.pyenv/versions/3.4.4/lib/python3.4/site-packages/pyspark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". --> 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError( Py4JJavaError: An error occurred while calling o2635.showString. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 156.0 failed 1 times, most recent failure: Lost task 0.0 in stage 156.0 (TID 606, localhost, executor driver): java.lang.UnsupportedOperationException: Unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>> at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56) at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92) at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116) at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.Iterator$class.foreach(Iterator.scala:891) at scala.collection.AbstractIterator.foreach(Iterator.scala:1334) at scala.collection.IterableLike$class.foreach(IterableLike.scala:72) at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at org.apache.spark.sql.types.StructType.map(StructType.scala:99) at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115) at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71) at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345) at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945) at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194) Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926) at scala.Option.foreach(Option.scala:257) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101) at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365) at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38) at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544) at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364) at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363) at org.apache.spark.sql.Dataset.head(Dataset.scala:2544) at org.apache.spark.sql.Dataset.take(Dataset.scala:2758) at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254) at org.apache.spark.sql.Dataset.showString(Dataset.scala:291) at sun.reflect.GeneratedMethodAccessor81.invoke(Unknown Source) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) Caused by: java.lang.UnsupportedOperationException: Unsupported data type: struct<type:tinyint,size:int,indices:array<int>,values:array<double>> at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowType(ArrowUtils.scala:56) at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowField(ArrowUtils.scala:92) at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:116) at org.apache.spark.sql.execution.arrow.ArrowUtils$$anonfun$toArrowSchema$1.apply(ArrowUtils.scala:115) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.Iterator$class.foreach(Iterator.scala:891) at scala.collection.AbstractIterator.foreach(Iterator.scala:1334) at scala.collection.IterableLike$class.foreach(IterableLike.scala:72) at org.apache.spark.sql.types.StructType.foreach(StructType.scala:99) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at org.apache.spark.sql.types.StructType.map(StructType.scala:99) at org.apache.spark.sql.execution.arrow.ArrowUtils$.toArrowSchema(ArrowUtils.scala:115) at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$2.writeIteratorToStream(ArrowPythonRunner.scala:71) at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:345) at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1945) at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:194)