Details
Description
Attempting to cogroup two datasets derived from the same source dataset yields an IllegalStateException when the query is executed.
Minimal reproducer:
StructType inputType = DataTypes.createStructType( new StructField[]{ DataTypes.createStructField("id", DataTypes.LongType, false), DataTypes.createStructField("type", DataTypes.StringType, false) } ); StructType keyType = DataTypes.createStructType( new StructField[]{ DataTypes.createStructField("id", DataTypes.LongType, false) } ); List<Row> inputRows = new ArrayList<>(); inputRows.add(RowFactory.create(1L, "foo")); inputRows.add(RowFactory.create(1L, "bar")); inputRows.add(RowFactory.create(2L, "foo")); Dataset<Row> input = sparkSession.createDataFrame(inputRows, inputType); KeyValueGroupedDataset<Row, Row> fooGroups = input .filter("type = 'foo'") .groupBy("id") .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); KeyValueGroupedDataset<Row, Row> barGroups = input .filter("type = 'bar'") .groupBy("id") .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); Dataset<Row> result = fooGroups.cogroup( barGroups, (CoGroupFunction<Row, Row, Row, Row>) (row, iterator, iterator1) -> new ArrayList<Row>().iterator(), RowEncoder.apply(inputType)); result.explain(); result.show();
Explain output (note mismatch in column IDs between Sort/Exchagne and LocalTableScan on the first input to the CoGroup):
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#37L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#38] +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1478/1869116781@77856cc5, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#39L], [id#39L], [id#39L, type#40], [id#39L, type#40], obj#36: org.apache.spark.sql.Row :- !Sort [id#39L ASC NULLS FIRST], false, 0 : +- !Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=19] : +- LocalTableScan [id#16L, type#17] +- Sort [id#39L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=20] +- LocalTableScan [id#39L, type#40]
Exception:
java.lang.IllegalStateException: Couldn't find id#39L in [id#16L,type#17] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75) at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:698) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1254) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1253) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:608) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:528) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94) at scala.collection.immutable.List.map(List.scala:246) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:323) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:391) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:390) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:136) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748)
Other observations:
- The same code works if I call createDataFrame() twice and use two separate datasets as input to the cogroup.
- The real code uses two different filters on the same cached dataset as the two inputs to the cogroup. However, this results in the same exception, and the same apparent error in the physical plan, which looks as follows:
== Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#47L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#48] +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1526/693211959@7b2e931, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#49L], [id#49L], [id#49L, type#50], [id#49L, type#50], obj#46: org.apache.spark.sql.Row :- !Sort [id#49L ASC NULLS FIRST], false, 0 : +- !Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=26] : +- Filter (type#17 = foo) : +- InMemoryTableScan [id#16L, type#17], [(type#17 = foo)] : +- InMemoryRelation [id#16L, type#17], StorageLevel(disk, memory, deserialized, 1 replicas) : +- LocalTableScan [id#16L, type#17] +- Sort [id#49L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=27] +- Filter (type#50 = bar) +- InMemoryTableScan [id#49L, type#50], [(type#50 = bar)] +- InMemoryRelation [id#49L, type#50], StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [id#16L, type#17]
- The issue doesn't arise if I write the same code in PySpark, using FlatMapCoGroupsInPandas.
Attachments
Issue Links
- duplicates
-
SPARK-42132 DeduplicateRelations rule breaks plan when co-grouping the same DataFrame
- Resolved
- fixes
-
SPARK-42132 DeduplicateRelations rule breaks plan when co-grouping the same DataFrame
- Resolved
- links to