Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-30397

[pyspark] Writer applied to custom model changes type of keys' dict from int to str

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Not A Problem
    • 2.4.4
    • None
    • ML, PySpark
    • None

    Description

      Hello,

       

      I have a custom model that I'm trying to persist. Within this custom model there is a python dict mapping from int to int. When the model is saved (with write().save('path')), the keys of the dict are modified from int to str.

       

      You can find bellow a code to reproduce the issue:

      #!/usr/bin/env python3
      # -*- coding: utf-8 -*-
      """
      @author: Jean-Marc Montanier
      @date: 2019/12/31
      """
      
      from pyspark.sql import SparkSession
      
      from pyspark import keyword_only
      from pyspark.ml import Pipeline, PipelineModel
      from pyspark.ml import Estimator, Model
      from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
      from pyspark.ml.param import Param, Params
      from pyspark.ml.param.shared import HasInputCol, HasOutputCol
      from pyspark.sql.types import IntegerType
      from pyspark.sql.functions import udf
      
      
      spark = SparkSession \
          .builder \
          .appName("ImputeNormal") \
          .getOrCreate()
      
      
      class CustomFit(Estimator,
                      HasInputCol,
                      HasOutputCol,
                      DefaultParamsReadable,
                      DefaultParamsWritable,
                      ):
          @keyword_only
          def __init__(self, inputCol="inputCol", outputCol="outputCol"):
              super(CustomFit, self).__init__()
      
              self._setDefault(inputCol="inputCol", outputCol="outputCol")
              kwargs = self._input_kwargs
              self.setParams(**kwargs)
      
          @keyword_only
          def setParams(self, inputCol="inputCol", outputCol="outputCol"):
              """
              setParams(self, inputCol="inputCol", outputCol="outputCol")
              """
              kwargs = self._input_kwargs
              self._set(**kwargs)
              return self
      
          def _fit(self, data):
              inputCol = self.getInputCol()
              outputCol = self.getOutputCol()
      
              categories = data.where(data[inputCol].isNotNull()) \
                  .groupby(inputCol) \
                  .count() \
                  .orderBy("count", ascending=False) \
                  .limit(2)
              categories = dict(categories.toPandas().set_index(inputCol)["count"])
              for cat in categories:
                  categories[cat] = int(categories[cat])
      
              return CustomModel(categories=categories,
                                 input_col=inputCol,
                                 output_col=outputCol)
      
      
      class CustomModel(Model,
                        DefaultParamsReadable,
                        DefaultParamsWritable):
      
          input_col = Param(Params._dummy(), "input_col", "Name of the input column")
          output_col = Param(Params._dummy(), "output_col", "Name of the output column")
          categories = Param(Params._dummy(), "categories", "Top categories")
      
          def __init__(self, categories: dict = None, input_col="input_col", output_col="output_col"):
              super(CustomModel, self).__init__()
      
              self._set(categories=categories, input_col=input_col, output_col=output_col)
      
          def get_output_col(self) -> str:
              """
              output_col getter
              :return:
              """
              return self.getOrDefault(self.output_col)
      
          def get_input_col(self) -> str:
              """
              input_col getter
              :return:
              """
              return self.getOrDefault(self.input_col)
      
          def get_categories(self):
              """
              categories getter
              :return:
              """
              return self.getOrDefault(self.categories)
      
          def _transform(self, data):
              input_col = self.get_input_col()
              output_col = self.get_output_col()
              categories = self.get_categories()
      
              def get_cat(val):
                  if val is None:
                      return -1
                  if val not in categories:
                      return -1
                  return int(categories[val])
      
              get_cat_udf = udf(get_cat, IntegerType())
      
              df = data.withColumn(output_col,
                                   get_cat_udf(input_col))
      
              return df
      
      
      def test_without_write():
          fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input'])
          custom_fit = CustomFit(inputCol='input', outputCol='output')
          pipeline = Pipeline(stages=[custom_fit])
          pipeline_model = pipeline.fit(fit_df)
      
          print("Categories: {}".format(pipeline_model.stages[0].get_categories()))
      
          transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] * 2, ['input'])
          test = pipeline_model.transform(transform_df)
          test.show()  # This output is the expected output
      
      
      def test_with_write():
          fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input'])
          custom_fit = CustomFit(inputCol='input', outputCol='output')
          pipeline = Pipeline(stages=[custom_fit])
          pipeline_model = pipeline.fit(fit_df)
      
          print("Categories: {}".format(pipeline_model.stages[0].get_categories()))
      
          pipeline_model.write().save('tmp')
          loaded_model = PipelineModel.load('tmp')
          # We can see that the type of the keys is know str instead of int
          print("Categories: {}".format(loaded_model.stages[0].get_categories()))
      
          transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] * 2, ['input'])
          test = loaded_model.transform(transform_df)
          test.show()  # We can see that the output does not match the expected output
      
      
      if __name__ == "__main__":
          test_without_write()
          test_with_write()
      
      

       

      Attachments

        Activity

          People

            Unassigned Unassigned
            montanier Jean-Marc Montanier
            Votes:
            0 Vote for this issue
            Watchers:
            2 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: