Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-48463

MLLib function unable to handle nested data

    XMLWordPrintableJSON

Details

    Description

      I am trying to use feature transformer on nested data after flattening, but it fails.

       

      val structureData = Seq(
        Row(Row(10, 12), 1000),
        Row(Row(12, 14), 4300),
        Row( Row(37, 891), 1400),
        Row(Row(8902, 12), 4000),
        Row(Row(12, 89), 1000)
      )
      
      val structureSchema = new StructType()
        .add("location", new StructType()
          .add("longitude", IntegerType)
          .add("latitude", IntegerType))
        .add("salary", IntegerType) 
      val df = spark.createDataFrame(spark.sparkContext.parallelize(structureData), structureSchema) 
      
      def flattenSchema(schema: StructType, prefix: String = null, prefixSelect: String = null):
      Array[Column] = {
        schema.fields.flatMap(f => {
          val colName = if (prefix == null) f.name else (prefix + "." + f.name)
          val colnameSelect = if (prefix == null) f.name else (prefixSelect + "." + f.name)
      
          f.dataType match {
            case st: StructType => flattenSchema(st, colName, colnameSelect)
            case _ =>
              Array(col(colName).as(colnameSelect))
          }
        })
      }
      
      val flattenColumns = flattenSchema(df.schema)
      val flattenedDf = df.select(flattenColumns: _*)

      Now using the string indexer on the DOT notation.

       

      val si = new StringIndexer().setInputCol("location.longitude").setOutputCol("longitutdee")
      val pipeline = new Pipeline().setStages(Array(si))
      pipeline.fit(flattenedDf).transform(flattenedDf).show() 

      The above code fails 

      xception in thread "main" org.apache.spark.sql.AnalysisException: Cannot resolve column name "location.longitude" among (location.longitude, location.latitude, salary); did you mean to quote the `location.longitude` column?
          at org.apache.spark.sql.errors.QueryCompilationErrors$.cannotResolveColumnNameAmongFieldsError(QueryCompilationErrors.scala:2261)
          at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$resolveException(Dataset.scala:258)
          at org.apache.spark.sql.Dataset.$anonfun$resolve$1(Dataset.scala:250)
      ..... 

      This points to the same failure as when we try to select dot notation columns in a spark dataframe, which is solved using BACKTICKS `column.name`. 

      https://stackoverflow.com/a/51430335/11688337

       

      so next

      I use the back ticks while defining stringIndexer

      val si = new StringIndexer().setInputCol("`location.longitude`").setOutputCol("longitutdee") 

      In this case it again fails (with a diff reason) in the stringIndexer code itself

      Exception in thread "main" org.apache.spark.SparkException: Input column `location.longitude` does not exist.
          at org.apache.spark.ml.feature.StringIndexerBase.$anonfun$validateAndTransformSchema$2(StringIndexer.scala:128)
          at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:244)
          at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
          at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33) 

       

      This blocks me to use feature transformation functions on nested columns. 
      Any help in solving this problem will be highly appreciated.

      Attachments

        Activity

          People

            weichenxu123 Weichen Xu
            chhavibansal Chhavi Bansal
            Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

              Created:
              Updated: