Details
-
Bug
-
Status: Resolved
-
Major
-
Resolution: Fixed
-
3.1.1, 3.5.1
Description
Consider this example:
package com.jetbrains.jetstat.etl import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.DecimalType object A { case class Example(x: BigDecimal) def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .master("local[1]") .getOrCreate() import spark.implicits._ val originalRaw = BigDecimal("123.456") val original = Example(originalRaw) val ds1 = spark.createDataset(Seq(original)) val ds2 = ds1 .withColumn("x", $"x" cast DecimalType(12, 6)) val ds3 = ds2 .as[Example] println(s"DS1: schema=${ds1.schema}, encoder.schema=${ds1.encoder.schema}") println(s"DS2: schema=${ds1.schema}, encoder.schema=${ds2.encoder.schema}") println(s"DS3: schema=${ds1.schema}, encoder.schema=${ds3.encoder.schema}") val json1 = ds1.toJSON.collect().head val json2 = ds2.toJSON.collect().head val json3 = ds3.toJSON.collect().head val collect1 = ds1.collect().head val collect2_ = ds2.collect().head val collect2 = collect2_.getDecimal(collect2_.fieldIndex("x")) val collect3 = ds3.collect().head println(s"Original: $original (scale = ${original.x.scale}, precision = ${original.x.precision})") println(s"Collect1: $collect1 (scale = ${collect1.x.scale}, precision = ${collect1.x.precision})") println(s"Collect2: $collect2 (scale = ${collect2.scale}, precision = ${collect2.precision})") println(s"Collect3: $collect3 (scale = ${collect3.x.scale}, precision = ${collect3.x.precision})") println(s"json1: $json1") println(s"json2: $json2") println(s"json3: $json3") } }
Running it you'd see that json3 contains very much wrong data. After a bit of debugging, and sorry since I'm bad with Spark internals, I've found that:
- In-memory representation of the data in this example used UnsafeRow, whose .getDecimal uses compression to store small Decimal values as longs, but doesn't remember decimal sizing parameters,
- However, there are at least two sources for precision & scale to pass to that method: Dataset.schema (which is based on query execution, always contains 38,18 for me) and Dataset.encoder.schema (that gets updated in `ds2` to 12,6 but then is reset in `ds3`). Also, there is a Dataset.deserializer that seems to be combining those two non-trivially.
- This doesn't seem to affect Dataset.collect() methods since they use deserializer, but Dataset.toJSON only uses the first schema.
Seems to me that either .toJSON should be more aware of what's going on or .as[] should be doing something else.
Attachments
Issue Links
- links to