Details
-
Sub-task
-
Status: Resolved
-
Major
-
Resolution: Duplicate
-
2.4.3
-
None
-
None
Description
The rdd oldDataset in ml.regression.RandomForestRegressor.train() needs to be persisted, because it used in two actions in RandomForest.run() and oldDataset.first().
override protected def train( dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) // Needs to persist val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) instr.logPipelineStage(this) instr.logDataset(dataset) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) // First use oldDataset val trees = RandomForest .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) // Second use oldDataset val numFeatures = oldDataset.first().features.size instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) new RandomForestRegressionModel(uid, trees, numFeatures) }
The same situation exits in ml.classification.RandomForestClassifier.train.
This issue is reported by our tool CacheCheck, which is used to dynamically detecting persist()/unpersist() api misuses.
Attachments
Issue Links
- duplicates
-
SPARK-29818 Missing persist on RDD
- Resolved