Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-22277

Chi Square selector garbling Vector content.

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Incomplete
    • 2.1.1
    • None
    • MLlib

    Description

      There is a difference in behavior when Chisquare selector is used v direct feature use in decision tree classifier.
      In the below code, I have used chisquare selector as a thru' pass but the decision tree classifier is unable to process it. But, it is able to process when the features are used directly.

      The example is pulled out directly from Apache spark python documentation.

      Kindly help.

      from pyspark.ml.feature import ChiSqSelector
      from pyspark.ml.linalg import Vectors
      import sys
      
      df = spark.createDataFrame([
          (7, Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0,),
          (8, Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0,),
          (9, Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0,)], ["id", "features", 
      "clicked"])
      
      # ChiSq selector will just be a pass-through. All four featuresin the i/p will be in output also.
      selector = ChiSqSelector(numTopFeatures=4, featuresCol="features",
                           outputCol="selectedFeatures", labelCol="clicked")
      result = selector.fit(df).transform(df)
      print("ChiSqSelector output with top %d features selected" % 
      selector.getNumTopFeatures())
      
      from pyspark.ml.classification import DecisionTreeClassifier
      
      try:
      # Fails
          dt = DecisionTreeClassifier(labelCol="clicked",featuresCol="selectedFeatures")
          model = dt.fit(result)
      except:
          print(sys.exc_info())
      #Works    
          dt = DecisionTreeClassifier(labelCol="clicked",featuresCol="features")
          model = dt.fit(df)
          
      # Make predictions. Using same dataset, not splitting!!
      predictions = model.transform(result)
      
      # Select example rows to display.
      predictions.select("prediction", "clicked", "features").show(5)
      
      # Select (prediction, true label) and compute test error
      evaluator = MulticlassClassificationEvaluator(
          labelCol="clicked", predictionCol="prediction", metricName="accuracy")
      accuracy = evaluator.evaluate(predictions)
      print("Test Error = %g " % (1.0 - accuracy))
      

      Output:

      ChiSqSelector output with top 4 features selected
      (<class 'pyspark.sql.utils.IllegalArgumentException'>, IllegalArgumentException('Feature 0 is marked as Nominal (categorical), but it does not have the number of values specified.', 'org.apache.spark.ml.util.MetadataUtils$$anonfun$getCategoricalFeatures$1.apply(MetadataUtils.scala:69)\n\t at org.apache.spark.ml.util.MetadataUtils$$anonfun$getCategoricalFeatures$1.apply(MetadataUtils.scala:59)\n\t at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)\n\t at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241)\n\t at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)\n\t at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)\n\t at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241)\n\t at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:186)\n\t at org.apache.spark.ml.util.MetadataUtils$.getCategoricalFeatures(MetadataUtils.scala:59)\n\t at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:101)\n\t at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:45)\n\t at org.apache.spark.ml.Predictor.fit(Predictor.scala:96)\n\t at org.apache.spark.ml.Predictor.fit(Predictor.scala:72)\n\t at sun.reflect.GeneratedMethodAccessor280.invoke(Unknown Source)\n\t at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\n\t at java.lang.reflect.Method.invoke(Method.java:498)\n\t at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)\n\t at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)\n\t at py4j.Gateway.invoke(Gateway.java:280)\n\t at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)\n\t at py4j.commands.CallCommand.execute(CallCommand.java:79)\n\t at py4j.GatewayConnection.run(GatewayConnection.java:214)\n\t at java.lang.Thread.run(Thread.java:745)'), <traceback object at 0x0A87D878>)
      -------------------------------

      prediction clicked features

      -------------------------------

      1.0 1.0 [0.0,0.0,18.0,1.0]
      0.0 0.0 [0.0,1.0,12.0,0.0]
      0.0 0.0 [1.0,0.0,15.0,0.1]

      -------------------------------

      Test Error = 0

      Attachments

        Issue Links

          Activity

            People

              Unassigned Unassigned
              cheburakshu Cheburakshu
              Votes:
              0 Vote for this issue
              Watchers:
              5 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: