Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-16831

PySpark CrossValidator reports incorrect avgMetrics

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 2.0.0
    • 2.0.1, 2.1.0
    • ML, PySpark
    • None

    Description

      The avgMetrics are summed up across all folds instead of being averaged. This is an easy fix in CrossValidator._fit() function:

      metrics[j]+=metric

      should be

      metrics[j]+=metric/nFolds

      .

      dataset = spark.createDataFrame(
        [(Vectors.dense([0.0]), 0.0),
         (Vectors.dense([0.4]), 1.0),
         (Vectors.dense([0.5]), 0.0),
         (Vectors.dense([0.6]), 1.0),
         (Vectors.dense([1.0]), 1.0)] * 1000,
        ["features", "label"]).cache()
      
      paramGrid = pyspark.ml.tuning.ParamGridBuilder().build()
      tvs = pyspark.ml.tuning.TrainValidationSplit(estimator=pyspark.ml.regression.LinearRegression(), 
                                 estimatorParamMaps=paramGrid,
                                 evaluator=pyspark.ml.evaluation.RegressionEvaluator(),
                                 trainRatio=0.8)
      model = tvs.fit(train)
      print(model.validationMetrics)
      
      for folds in (3, 5, 10):
        cv = pyspark.ml.tuning.CrossValidator(estimator=pyspark.ml.regression.LinearRegression(), 
                                            estimatorParamMaps=paramGrid, 
                                            evaluator=pyspark.ml.evaluation.RegressionEvaluator(),
                                            numFolds=folds
                                           )
        cvModel = cv.fit(dataset)
        print(folds, cvModel.avgMetrics)
      

      Attachments

        Activity

          People

            mmoroz Max Moroz
            mmoroz Max Moroz
            Votes:
            0 Vote for this issue
            Watchers:
            2 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: