Description
We encountered a very weird bug heavily using withField in production with the Spark 3.1.1. Jobs were dying with a lot of very weird JVM crashing errors (like jshort_disjoint_arraycopy during a copyMemory call), and occasional NegativeArraySize exceptions. We finally found a work around by ordering our withField calls in a certain way, and I was finally able to create some minimal examples to reproduce similar weird/broken behavior.
It seems to stem from the optimizations added in https://github.com/apache/spark/pull/29812. Because the same new optimization was added as an analyzer, there seems to be two different ways this issue can crop up, once at analysis time and once at runtime.
While these examples might seem odd, they represent how we've created a helper function that can create columns in arbitrary nested fields even if the intermediate fields don't exist yet.
Example of what I assume is an issue during analysis:
import pyspark.sql.functions as F df = spark.range(1).withColumn('data', F.struct() .withField('a', F.struct()) .withField('b', F.struct()) .withField('a.aa', F.lit('aa')) .withField('b.ba', F.lit('ba')) .withField('a.ab', F.lit('ab')) .withField('b.bb', F.lit('bb')) .withField('a.ac', F.lit('ac')) ) df.printSchema()
Output schema:
root |-- id: long (nullable = false) |-- data: struct (nullable = false) | |-- b: struct (nullable = false) | | |-- aa: string (nullable = false) | | |-- ab: string (nullable = false) | | |-- bb: string (nullable = false) | |-- a: struct (nullable = false) | | |-- aa: string (nullable = false) | | |-- ab: string (nullable = false) | | |-- ac: string (nullable = false)
And an example of runtime data issue:
df = (spark.range(1) .withColumn('data', F.struct() .withField('a', F.struct().withField('aa', F.lit('aa'))) .withField('b', F.struct().withField('ba', F.lit('ba'))) ) .withColumn('data', F.col('data').withField('b.bb', F.lit('bb'))) .withColumn('data', F.col('data').withField('a.ab', F.lit('ab'))) ) df.printSchema() df.groupBy('data.a.aa', 'data.a.ab', 'data.b.ba', 'data.b.bb').count().show()
Output:
root |-- id: long (nullable = false) |-- data: struct (nullable = false) | |-- a: struct (nullable = false) | | |-- aa: string (nullable = false) | | |-- ab: string (nullable = false) | |-- b: struct (nullable = false) | | |-- ba: string (nullable = false) | | |-- bb: string (nullable = false) +---+---+---+---+-----+ | aa| ab| ba| bb|count| +---+---+---+---+-----+ | ba| bb| aa| ab| 1| +---+---+---+---+-----+
The columns have the wrong data in them, even though the schema is correct. Additionally, if you add another column you get an exception:
df = (spark.range(1) .withColumn('data', F.struct() .withField('a', F.struct().withField('aa', F.lit('aa'))) .withField('b', F.struct().withField('ba', F.lit('ba'))) ) .withColumn('data', F.col('data').withField('a.ab', F.lit('ab'))) .withColumn('data', F.col('data').withField('b.bb', F.lit('bb'))) .withColumn('data', F.col('data').withField('a.ac', F.lit('ac'))) ) df.groupBy('data.a.aa', 'data.a.ab', 'data.a.ac', 'data.b.ba', 'data.b.bb').count().show() java.lang.ArrayIndexOutOfBoundsException: 2 at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getUTF8String(rows.scala:46) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getUTF8String$(rows.scala:46) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.getUTF8String(rows.scala:195) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.agg_doConsume_0$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.agg_doAggregateWithKeys_0$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
But if you reorder the withField expressions, you get correct behavior:
df = (spark.range(1) .withColumn('data', F.struct() .withField('a', F.struct().withField('aa', F.lit('aa'))) .withField('b', F.struct().withField('ba', F.lit('ba'))) ) .withColumn('data', F.col('data').withField('a.ab', F.lit('ab'))) .withColumn('data', F.col('data').withField('a.ac', F.lit('ac'))) .withColumn('data', F.col('data').withField('b.bb', F.lit('bb'))) ) df.groupBy('data.a.aa', 'data.a.ab', 'data.a.ac', 'data.b.ba', 'data.b.bb').count().show() +---+---+---+---+---+-----+ | aa| ab| ac| ba| bb|count| +---+---+---+---+---+-----+ | aa| ab| ac| ba| bb| 1| +---+---+---+---+---+-----+
I think this has to do with the double ".reverse" method to dedupe expressions in OptimizeUpdateFields. I'm working on a PR to try to fix these issues.