Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-29818 Missing persist on RDD
  3. SPARK-29824

Missing persist on trainDataset in ml.classification.GBTClassifier.train()

    XMLWordPrintableJSON

Details

    • Sub-task
    • Status: Resolved
    • Major
    • Resolution: Duplicate
    • 2.4.3
    • None
    • ML
    • None

    Description

      The rdd trainDataset in ml.classification.GBTClassifier.train() is used by an action first and other actions in GradientBoostedTrees.run/runWithValidation, but it is not persisted, which will cause recomputation on trainDataset.

        override protected def train(
            dataset: Dataset[_]): GBTClassificationModel = instrumented { instr =>
          val categoricalFeatures: Map[Int, Int] =
            MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
          ...
          val numFeatures = trainDataset.first().features.size // first use trainDataset
          ...
          // trainDataset will be used by other actions in run methods.    
          val (baseLearners, learnerWeights) = if (withValidation) {
            GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
              $(seed), $(featureSubsetStrategy))
          } else {
            GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
          }
      

      This issue is reported by our tool CacheCheck, which is used to dynamically detecting persist()/unpersist() api misuses.

      Attachments

        Issue Links

          Activity

            People

              Unassigned Unassigned
              spark_cachecheck IcySanwitch
              Votes:
              0 Vote for this issue
              Watchers:
              1 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: