Details
-
Bug
-
Status: Resolved
-
Major
-
Resolution: Incomplete
-
1.6.2, 1.6.3
-
None
Description
I am trying to use a custom HashMap implementation as UserDefinedType instead of MapType in spark. The code is working fine in spark 1.5.2 but giving java.lang.ClassCastException: scala.collection.immutable.HashMap$HashMap1 cannot be cast to org.apache.spark.sql.catalyst.util.MapData exception in spark 1.6.2
The code:-
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import scala.collection.immutable.HashMap class Test extends UserDefinedAggregateFunction { def inputSchema: StructType = StructType(Array(StructField("input", StringType))) def bufferSchema = StructType(Array(StructField("top_n", CustomHashMapType))) def dataType: DataType = CustomHashMapType def deterministic = true def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = HashMap.empty[String, Long] } def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val buff0 = buffer.getAs[HashMap[String, Long]](0) buffer(0) = buff0.updated("test", buff0.getOrElse("test", 0L) + 1L) } def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1. getAs[HashMap[String, Long]](0) .merged(buffer2.getAs[HashMap[String, Long]](0))({ case ((k, v1), (_, v2)) => (k, v1 + v2) }) } def evaluate(buffer: Row): Any = { buffer(0) } } private case object CustomHashMapType extends UserDefinedType[HashMap[String, Long]] { override def sqlType: DataType = MapType(StringType, LongType) override def serialize(obj: Any): Map[String, Long] = obj.asInstanceOf[Map[String, Long]] override def deserialize(datum: Any): HashMap[String, Long] = { datum.asInstanceOf[Map[String, Long]] ++: HashMap.empty[String, Long] } override def userClass: Class[HashMap[String, Long]] = classOf[HashMap[String, Long]] }
The wrapper Class to run the UDAF:-
import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} object TestJob { def main(args: Array[String]): Unit = { val conf = new SparkConf().setMaster("local[4]").setAppName("DataStatsExecution") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val df = sc.parallelize(Seq(1,2,3,4)).toDF("col") val udaf = new Test() val outdf = df.agg(udaf(df("col"))) outdf.show } }
Stacktrace:-
Caused by: java.lang.ClassCastException: scala.collection.immutable.HashMap$HashMap1 cannot be cast to org.apache.spark.sql.catalyst.util.MapData at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getMap(rows.scala:50) at org.apache.spark.sql.catalyst.expressions.GenericMutableRow.getMap(rows.scala:248) at org.apache.spark.sql.catalyst.expressions.JoinedRow.getMap(JoinedRow.scala:115) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:345) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$31.apply(AggregationIterator.scala:344) at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:154) at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) at org.apache.spark.scheduler.Task.run(Task.scala:89) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:227) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748)