Description
Doing a pivot (using the pivot(pivotColumn: Column) overload) on a column containing arrays results in a runtime error:
scala> val df = Seq((1, Seq("a", "x"), 2), (1, Seq("b"), 3), (2, Seq("a", "x"), 10), (3, Seq(), 100)).toDF("x", "s", "y") df: org.apache.spark.sql.DataFrame = [x: int, s: array<string> ... 1 more field] scala> df.show +---+------+---+ | x| s| y| +---+------+---+ | 1|[a, x]| 2| | 1| [b]| 3| | 2|[a, x]| 10| | 3| []|100| +---+------+---+ scala> df.groupBy("x").pivot("s").agg(collect_list($"y")).show java.lang.RuntimeException: Unsupported literal type class scala.collection.mutable.WrappedArray$ofRef WrappedArray() at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:78) at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$pivot$1.apply(RelationalGroupedDataset.scala:419) at org.apache.spark.sql.RelationalGroupedDataset$$anonfun$pivot$1.apply(RelationalGroupedDataset.scala:419) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:35) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.AbstractTraversable.map(Traversable.scala:104) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:419) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:397) at org.apache.spark.sql.RelationalGroupedDataset.pivot(RelationalGroupedDataset.scala:317) ... 49 elided
However, this doesn't seem to be a fundamental limitation with pivot, as it works fine using the pivot(pivotColumn: Column, values: Seq[Any]) overload, as long as the arrays are mapped to the Array type:
scala> val rawValues = df.select("s").distinct.sort("s").collect rawValues: Array[org.apache.spark.sql.Row] = Array([WrappedArray()], [WrappedArray(a, x)], [WrappedArray(b)]) scala> val values = rawValues.map(_.getSeq[String](0).to[Array]) values: Array[Array[String]] = Array(Array(), Array(a, x), Array(b)) scala> df.groupBy("x").pivot("s", values).agg(collect_list($"y")).show +---+-----+------+---+ | x| []|[a, x]|[b]| +---+-----+------+---+ | 1| []| [2]|[3]| | 3|[100]| []| []| | 2| []| [10]| []| +---+-----+------+---+
It would be nice if pivot was more resilient to Spark's own representation of array columns, and so the first version worked.
Attachments
Issue Links
- links to