Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-49793

Enable PredictBatchUDFTests.test_caching for NumPy 2

    XMLWordPrintableJSON

Details

    • Story
    • Status: Open
    • Major
    • Resolution: Unresolved
    • 4.0.0
    • None
    • ML, Tests
    • None

    Description

       

      import numpy as np
      import pandas as pd
      from pyspark.ml.functions import predict_batch_udf
      from pyspark.sql.types import DoubleType
      from pyspark.sql.functions import struct
      data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
      pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
      df = spark.createDataFrame(pdf)
      def make_predict_fn():
          fake_output = np.random.random()
          def predict(inputs):
              return np.array([fake_output for i in inputs])
          return predict
       
      identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(), batch_size=5)
      df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
      df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
      

      NumPy 2.1.0

      >>> df1
            a     b     c     d     preds
      0   0.0   1.0   2.0   3.0  0.431752
      1   4.0   5.0   6.0   7.0  0.912097
      2   8.0   9.0  10.0  11.0  0.679628
      3  12.0  13.0  14.0  15.0  0.853850
      4  16.0  17.0  18.0  19.0  0.389971
      5  20.0  21.0  22.0  23.0  0.654521
      6  24.0  25.0  26.0  27.0  0.430569
      7  28.0  29.0  30.0  31.0  0.331055
      8  32.0  33.0  34.0  35.0  0.306073
      >>> df2
            a     b     c     d     preds
      0   0.0   1.0   2.0   3.0  0.679628
      1   4.0   5.0   6.0   7.0  0.430569
      2   8.0   9.0  10.0  11.0  0.853850
      3  12.0  13.0  14.0  15.0  0.306073
      4  16.0  17.0  18.0  19.0  0.654521
      5  20.0  21.0  22.0  23.0  0.389971
      6  24.0  25.0  26.0  27.0  0.507598
      7  28.0  29.0  30.0  31.0  0.912097
      8  32.0  33.0  34.0  35.0  0.431752 

      which should be

      >>> df1
            a     b     c     d     preds
      0   0.0   1.0   2.0   3.0  0.685941
      1   4.0   5.0   6.0   7.0  0.685941
      2   8.0   9.0  10.0  11.0  0.685941
      3  12.0  13.0  14.0  15.0  0.685941
      4  16.0  17.0  18.0  19.0  0.685941
      5  20.0  21.0  22.0  23.0  0.685941
      6  24.0  25.0  26.0  27.0  0.685941
      7  28.0  29.0  30.0  31.0  0.685941
      8  32.0  33.0  34.0  35.0  0.685941
      >>> df2
            a     b     c     d     preds
      0   0.0   1.0   2.0   3.0  0.685941
      1   4.0   5.0   6.0   7.0  0.685941
      2   8.0   9.0  10.0  11.0  0.685941
      3  12.0  13.0  14.0  15.0  0.685941
      4  16.0  17.0  18.0  19.0  0.685941
      5  20.0  21.0  22.0  23.0  0.685941
      6  24.0  25.0  26.0  27.0  0.685941
      7  28.0  29.0  30.0  31.0  0.685941
      8  32.0  33.0  34.0  35.0  0.685941 

       

      Attachments

        Issue Links

          Activity

            People

              weichenxu123 Weichen Xu
              XinrongM Xinrong Meng
              Votes:
              0 Vote for this issue
              Watchers:
              1 Start watching this issue

              Dates

                Created:
                Updated: