Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-29414

HasOutputCol param isSet() property is not preserved after persistence

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 2.3.2
    • 2.4.4
    • ML, PySpark
    • None

    Description

      HasOutputCol param isSet() property is not preserved after saving and loading using DefaultParamsReadable and DefaultParamsWritable.

      import pytest
      from pyspark import keyword_only
      from pyspark.ml import Model
      from pyspark.sql import DataFrame
      from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
      from pyspark.ml.param.shared import HasInputCol, HasOutputCol
      from pyspark.sql.functions import *
      
      
      class HasOutputColTester(Model,
                               HasInputCol,
                               HasOutputCol,
                               DefaultParamsReadable,
                               DefaultParamsWritable
                               ):
          @keyword_only
          def __init__(self, inputCol: str = None, outputCol: str = None):
              super(HasOutputColTester, self).__init__()
              kwargs = self._input_kwargs
              self.setParams(**kwargs)
      
          @keyword_only
          def setParams(self, inputCol: str = None, outputCol: str = None):
              kwargs = self._input_kwargs
              self._set(**kwargs)
              return self
      
          def _transform(self, data: DataFrame) -> DataFrame:
              return data
      
      
      class TestHasInputColParam(object):
          def test_persist_input_col_set(self, spark, temp_dir):
              path = temp_dir + '/test_model'
              model = HasOutputColTester()
              assert not model.isDefined(model.inputCol)
              assert not model.isSet(model.inputCol)
      
              assert model.isDefined(model.outputCol)
              assert not model.isSet(model.outputCol)
              model.write().overwrite().save(path)
      
              loaded_model: HasOutputColTester = HasOutputColTester.load(path)
              assert not loaded_model.isDefined(model.inputCol)
              assert not loaded_model.isSet(model.inputCol)
      
              assert loaded_model.isDefined(model.outputCol)
              assert not loaded_model.isSet(model.outputCol)  # AssertionError: assert not True
      

      Attachments

        Activity

          People

            Unassigned Unassigned
            borys.biletskyy Borys Biletskyy
            Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: