Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-43298

predict_batch_udf with scalar input fails when batch size consists of a single value

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 3.4.0
    • 3.5.0
    • ML, PySpark
    • None

    Description

      This is related to SPARK-42250.  For scalar inputs, the predict_batch_udf will fail if the batch size is 1:

      import numpy as np
      from pyspark.ml.functions import predict_batch_udf
      from pyspark.sql.types import DoubleType
      
      df = spark.createDataFrame([[1.0],[2.0]], schema=["a"])
      
      def make_predict_fn():
          def predict(inputs):
              return inputs
          return predict
      
      identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), batch_size=1)
      preds = df.withColumn("preds", identity("a")).collect()
      

      fails with:

        File "/.../spark/python/pyspark/worker.py", line 869, in main
          process()
        File "/.../spark/python/pyspark/worker.py", line 861, in process
          serializer.dump_stream(out_iter, outfile)
        File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 354, in dump_stream
          return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
        File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 86, in dump_stream
          for batch in iterator:
        File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 347, in init_stream_yield_batches
          for series in iterator:
        File "/.../spark/python/pyspark/worker.py", line 555, in func
          for result_batch, result_type in result_iter:
        File "/.../spark/python/pyspark/ml/functions.py", line 818, in predict
          yield _validate_and_transform_prediction_result(
        File "/.../spark/python/pyspark/ml/functions.py", line 339, in _validate_and_transform_prediction_result
          if len(preds_array) != num_input_rows:
      TypeError: len() of unsized object
      

      Attachments

        Activity

          People

            leewyang Lee Yang
            leewyang Lee Yang
            Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: