Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-34045

OneVsRestModel.transform should not call setter of submodels

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Resolved
    • Minor
    • Resolution: Fixed
    • 3.2.0
    • 3.2.0
    • ML
    • None

    Description

      featuresCol of submodels maybe changed in transform:

       scala> val df = spark.read.format("libsvm").load("/d0/Dev/Opensource/spark/data/mllib/sample_multiclass_classification_data.txt")
      21/01/08 09:52:01 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.
      df: org.apache.spark.sql.DataFrame = [label: double, features: vector]
      
      scala> val lr = new LogisticRegression().setMaxIter(1).setTol(1E-6).setFitIntercept(true)
      lr: org.apache.spark.ml.classification.LogisticRegression = logreg_3003cb3321a1
      
      scala> val ovr = new OneVsRest().setClassifier(lr)
      ovr: org.apache.spark.ml.classification.OneVsRest = oneVsRest_b2ec3ec45dbf
      
      scala> val ovrm = ovr.fit(df)
      21/01/08 09:52:05 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
      21/01/08 09:52:05 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
      ovrm: org.apache.spark.ml.classification.OneVsRestModel = OneVsRestModel: uid=oneVsRest_b2ec3ec45dbf, classifier=logreg_3003cb3321a1, numClasses=3, numFeatures=4
      
      scala> val df2 = df.withColumnRenamed("features", "features2")
      df2: org.apache.spark.sql.DataFrame = [label: double, features2: vector]
      
      scala> ovrm.setFeaturesCol("features2")
      res0: ovrm.type = OneVsRestModel: uid=oneVsRest_b2ec3ec45dbf, classifier=logreg_3003cb3321a1, numClasses=3, numFeatures=4
      
      
      scala> ovrm.models.map(_.getFeaturesCol)
      res1: Array[String] = Array(features, features, features)
      
      scala> ovrm.transform(df2)
      res2: org.apache.spark.sql.DataFrame = [label: double, features2: vector ... 2 more fields]
      
      scala> ovrm.models.map(_.getFeaturesCol)
      res3: Array[String] = Array(features2, features2, features2)
      

      Attachments

        Activity

          People

            podongfeng Ruifeng Zheng
            podongfeng Ruifeng Zheng
            Votes:
            0 Vote for this issue
            Watchers:
            2 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: