Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-43781

IllegalStateException when cogrouping two datasets derived from the same source

Rank to TopRank to BottomAttach filesAttach ScreenshotBulk Copy AttachmentsBulk Move AttachmentsVotersWatch issueWatchersCreate sub-taskConvert to sub-taskLinkCloneLabelsUpdate Comment AuthorReplace String in CommentUpdate Comment VisibilityDelete Comments
    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 3.3.1, 3.4.0
    • 4.0.0
    • SQL
    • None
    • Reproduces in a unit test, using Spark 3.3.1, the Java API, and a local[2] SparkSession.

    Description

      Attempting to cogroup two datasets derived from the same source dataset yields an IllegalStateException when the query is executed.

      Minimal reproducer:

      StructType inputType = DataTypes.createStructType(
          new StructField[]{
              DataTypes.createStructField("id", DataTypes.LongType, false),
              DataTypes.createStructField("type", DataTypes.StringType, false)
          }
      );
      
      StructType keyType = DataTypes.createStructType(
          new StructField[]{
              DataTypes.createStructField("id", DataTypes.LongType, false)
          }
      );
      
      List<Row> inputRows = new ArrayList<>();
      inputRows.add(RowFactory.create(1L, "foo"));
      inputRows.add(RowFactory.create(1L, "bar"));
      inputRows.add(RowFactory.create(2L, "foo"));
      Dataset<Row> input = sparkSession.createDataFrame(inputRows, inputType);
      
      KeyValueGroupedDataset<Row, Row> fooGroups = input
          .filter("type = 'foo'")
          .groupBy("id")
          .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType));
      
      KeyValueGroupedDataset<Row, Row> barGroups = input
          .filter("type = 'bar'")
          .groupBy("id")
          .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType));
      
      Dataset<Row> result = fooGroups.cogroup(
          barGroups,
          (CoGroupFunction<Row, Row, Row, Row>) (row, iterator, iterator1) -> new ArrayList<Row>().iterator(),
          RowEncoder.apply(inputType));
      
      result.explain();
      result.show();

      Explain output (note mismatch in column IDs between Sort/Exchagne and LocalTableScan on the first input to the CoGroup):

      == Physical Plan ==
      AdaptiveSparkPlan isFinalPlan=false
      +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#37L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#38]
         +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1478/1869116781@77856cc5, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#39L], [id#39L], [id#39L, type#40], [id#39L, type#40], obj#36: org.apache.spark.sql.Row
            :- !Sort [id#39L ASC NULLS FIRST], false, 0
            :  +- !Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=19]
            :     +- LocalTableScan [id#16L, type#17]
            +- Sort [id#39L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=20]
                  +- LocalTableScan [id#39L, type#40]

      Exception:

      java.lang.IllegalStateException: Couldn't find id#39L in [id#16L,type#17]
              at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
              at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
              at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584)
              at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176)
              at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584)
              at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589)
              at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
              at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
              at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:698)
              at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589)
              at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589)
              at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1254)
              at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1253)
              at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:608)
              at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589)
              at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560)
              at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:528)
              at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73)
              at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94)
              at scala.collection.immutable.List.map(List.scala:246)
              at scala.collection.immutable.List.map(List.scala:79)
              at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94)
              at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160)
              at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:323)
              at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:391)
              at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:390)
              at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877)
              at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877)
              at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
              at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
              at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
              at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
              at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
              at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
              at org.apache.spark.scheduler.Task.run(Task.scala:136)
              at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
              at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
              at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
              at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
              at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
              at java.lang.Thread.run(Thread.java:748) 

      Other observations:

      • The same code works if I call createDataFrame() twice and use two separate datasets as input to the cogroup.
      • The real code uses two different filters on the same cached dataset as the two inputs to the cogroup. However, this results in the same exception, and the same apparent error in the physical plan, which looks as follows:
        == Physical Plan ==
        AdaptiveSparkPlan isFinalPlan=false
        +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#47L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#48]
           +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1526/693211959@7b2e931, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#49L], [id#49L], [id#49L, type#50], [id#49L, type#50], obj#46: org.apache.spark.sql.Row
              :- !Sort [id#49L ASC NULLS FIRST], false, 0
              :  +- !Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=26]
              :     +- Filter (type#17 = foo)
              :        +- InMemoryTableScan [id#16L, type#17], [(type#17 = foo)]
              :              +- InMemoryRelation [id#16L, type#17], StorageLevel(disk, memory, deserialized, 1 replicas)
              :                    +- LocalTableScan [id#16L, type#17]
              +- Sort [id#49L ASC NULLS FIRST], false, 0
                 +- Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=27]
                    +- Filter (type#50 = bar)
                       +- InMemoryTableScan [id#49L, type#50], [(type#50 = bar)]
                             +- InMemoryRelation [id#49L, type#50], StorageLevel(disk, memory, deserialized, 1 replicas)
                                   +- LocalTableScan [id#16L, type#17] 
      • The issue doesn't arise if I write the same code in PySpark, using FlatMapCoGroupsInPandas.

      Attachments

        Issue Links

        Activity

          This comment will be Viewable by All Users Viewable by All Users
          Cancel

          People

            fanjia Jia Fan
            mrry Derek Murray
            Votes:
            1 Vote for this issue
            Watchers:
            5 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved:

              Slack

                Issue deployment