Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-16831

PySpark CrossValidator reports incorrect avgMetrics

    XMLWordPrintableJSON

    Details

    • Type: Bug
    • Status: Resolved
    • Priority: Major
    • Resolution: Fixed
    • Affects Version/s: 2.0.0
    • Fix Version/s: 2.0.1, 2.1.0
    • Component/s: ML, PySpark
    • Labels:
      None

      Description

      The avgMetrics are summed up across all folds instead of being averaged. This is an easy fix in CrossValidator._fit() function:

      metrics[j]+=metric

      should be

      metrics[j]+=metric/nFolds

      .

      dataset = spark.createDataFrame(
        [(Vectors.dense([0.0]), 0.0),
         (Vectors.dense([0.4]), 1.0),
         (Vectors.dense([0.5]), 0.0),
         (Vectors.dense([0.6]), 1.0),
         (Vectors.dense([1.0]), 1.0)] * 1000,
        ["features", "label"]).cache()
      
      paramGrid = pyspark.ml.tuning.ParamGridBuilder().build()
      tvs = pyspark.ml.tuning.TrainValidationSplit(estimator=pyspark.ml.regression.LinearRegression(), 
                                 estimatorParamMaps=paramGrid,
                                 evaluator=pyspark.ml.evaluation.RegressionEvaluator(),
                                 trainRatio=0.8)
      model = tvs.fit(train)
      print(model.validationMetrics)
      
      for folds in (3, 5, 10):
        cv = pyspark.ml.tuning.CrossValidator(estimator=pyspark.ml.regression.LinearRegression(), 
                                            estimatorParamMaps=paramGrid, 
                                            evaluator=pyspark.ml.evaluation.RegressionEvaluator(),
                                            numFolds=folds
                                           )
        cvModel = cv.fit(dataset)
        print(folds, cvModel.avgMetrics)
      

        Attachments

          Activity

            People

            • Assignee:
              mmoroz Max Moroz
              Reporter:
              mmoroz Max Moroz
            • Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

              Dates

              • Created:
                Updated:
                Resolved: