Description
val path = "./spark-2.0.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt" val data = spark.read.format("libsvm").load(path) val rf = new RandomForestClassifier() val model = rf.fit(data) model.numClasses res48: Int = 3 model.setThresholds(Array(0.5,0.1)) res49: org.apache.spark.ml.classification.RandomForestClassificationModel = RandomForestClassificationModel (uid=rfc_b39da354ac8b) with 20 trees model.transform(data) java.lang.IllegalArgumentException: requirement failed: RandomForestClassificationModel.transform() called with non-matching numClasses and thresholds.length. numClasses=3, but thresholds has length 2 at scala.Predef$.require(Predef.scala:224) at org.apache.spark.ml.classification.ProbabilisticClassificationModel.transform(ProbabilisticClassifier.scala:101) ... 58 elided
Although model set with wrong threshoulds will fail in prediction, it maybe nice to evoke exception earlier in setThreshoulds
Attachments
Issue Links
- is duplicated by
-
SPARK-16863 ProbabilisticClassifier.fit check threshoulds' length
- Resolved
- links to