Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-35142

`OneVsRest` classifier uses incorrect data type for `rawPrediction` column

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 3.0.0, 3.0.2, 3.1.0, 3.1.1
    • 3.0.3, 3.1.2, 3.2.0
    • ML
    • None

    Description

      `OneVsRest` classifier uses an incorrect data type for the `rawPrediction` column.

       Code to reproduce the issue:

      from pyspark.ml.classification import LogisticRegression, OneVsRest
      from pyspark.ml.linalg import Vectors
      from pyspark.sql import SparkSession
      from sklearn.datasets import load_iris
      
      spark = SparkSession.builder.getOrCreate()
      
      X, y = load_iris(return_X_y=True)
      df = spark.createDataFrame(
       [(Vectors.dense(features), int(label)) for features, label in zip(X, y)], ["features", "label"]
      )
      train, test = df.randomSplit([0.8, 0.2])
      lor = LogisticRegression(maxIter=5)
      ovr = OneVsRest(classifier=lor)
      ovrModel = ovr.fit(train)
      pred = ovrModel.transform(test)
      
      pred.printSchema()
      # This prints out:
      # root
      #  |-- features: vector (nullable = true)
      #  |-- label: long (nullable = true)
      #  |-- rawPrediction: string (nullable = true)  # <- should not be string
      #  |-- prediction: double (nullable = true)
      
      # pred.show()  # this fails because of the incorrect datatype

      I ran the code above using GitHub Actiosn:

      https://github.com/harupy/SPARK-35142/pull/1

       

      It looks like the UDF to compute the `rawPrediction` column is generated without specyfing the return type:
      https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/python/pyspark/ml/classification.py#L3154

      rawPredictionUDF = udf(func)
      

       

      Attachments

        Activity

          People

            harupy Harutaka Kawamura
            harupy Harutaka Kawamura
            Votes:
            0 Vote for this issue
            Watchers:
            2 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: