Description
HasOutputCol param isSet() property is not preserved after saving and loading using DefaultParamsReadable and DefaultParamsWritable.
import pytest from pyspark import keyword_only from pyspark.ml import Model from pyspark.sql import DataFrame from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.param.shared import HasInputCol, HasOutputCol from pyspark.sql.functions import * class HasOutputColTester(Model, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable ): @keyword_only def __init__(self, inputCol: str = None, outputCol: str = None): super(HasOutputColTester, self).__init__() kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, inputCol: str = None, outputCol: str = None): kwargs = self._input_kwargs self._set(**kwargs) return self def _transform(self, data: DataFrame) -> DataFrame: return data class TestHasInputColParam(object): def test_persist_input_col_set(self, spark, temp_dir): path = temp_dir + '/test_model' model = HasOutputColTester() assert not model.isDefined(model.inputCol) assert not model.isSet(model.inputCol) assert model.isDefined(model.outputCol) assert not model.isSet(model.outputCol) model.write().overwrite().save(path) loaded_model: HasOutputColTester = HasOutputColTester.load(path) assert not loaded_model.isDefined(model.inputCol) assert not loaded_model.isSet(model.inputCol) assert loaded_model.isDefined(model.outputCol) assert not loaded_model.isSet(model.outputCol) # AssertionError: assert not True