Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-16863

ProbabilisticClassifier.fit check threshoulds' length

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Resolved
    • Minor
    • Resolution: Duplicate
    • None
    • None
    • ML
    • None

    Description

      val path = "./spark-2.0.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt"
      val data = spark.read.format("libsvm").load(path)
      val rf = new RandomForestClassifier
      rf.setThresholds(Array(0.1,0.2,0.3,0.4,0.5))
      
      val rfm = rf.fit(data)
      rfm: org.apache.spark.ml.classification.RandomForestClassificationModel = RandomForestClassificationModel (uid=rfc_fec31a5b954d) with 20 trees
      
      rfm.numClasses
      res2: Int = 3
      
      rfm.getThresholds
      res3: Array[Double] = Array(0.1, 0.2, 0.3, 0.4, 0.5)
      
      rfm.transform(data)
      java.lang.IllegalArgumentException: requirement failed: RandomForestClassificationModel.transform() called with non-matching numClasses and thresholds.length. numClasses=3, but thresholds has length 5
        at scala.Predef$.require(Predef.scala:224)
        at org.apache.spark.ml.classification.ProbabilisticClassificationModel.transform(ProbabilisticClassifier.scala:101)
        ... 72 elided
      

      ProbabilisticClassifier.fit() should throw some exception if it's threshoulds is set incorrectly.

      Attachments

        Issue Links

          Activity

            People

              Unassigned Unassigned
              podongfeng Ruifeng Zheng
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: