Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-48965

toJSON produces wrong values if DecimalType information is lost in as[Product]

    XMLWordPrintableJSON

Details

    Description

      Consider this example:

      package com.jetbrains.jetstat.etl
      
      import org.apache.spark.sql.SparkSession
      import org.apache.spark.sql.types.DecimalType
      
      object A {
        case class Example(x: BigDecimal)
      
        def main(args: Array[String]): Unit = {
          val spark = SparkSession.builder()
            .master("local[1]")
            .getOrCreate()
      
          import spark.implicits._
      
          val originalRaw = BigDecimal("123.456")
          val original = Example(originalRaw)
      
          val ds1 = spark.createDataset(Seq(original))
          val ds2 = ds1
            .withColumn("x", $"x" cast DecimalType(12, 6))
      
          val ds3 = ds2
            .as[Example]
      
          println(s"DS1: schema=${ds1.schema}, encoder.schema=${ds1.encoder.schema}")
          println(s"DS2: schema=${ds1.schema}, encoder.schema=${ds2.encoder.schema}")
          println(s"DS3: schema=${ds1.schema}, encoder.schema=${ds3.encoder.schema}")
      
          val json1 = ds1.toJSON.collect().head
          val json2 = ds2.toJSON.collect().head
          val json3 = ds3.toJSON.collect().head
      
          val collect1 = ds1.collect().head
          val collect2_ = ds2.collect().head
          val collect2 = collect2_.getDecimal(collect2_.fieldIndex("x"))
          val collect3 = ds3.collect().head
      
          println(s"Original: $original (scale = ${original.x.scale}, precision = ${original.x.precision})")
          println(s"Collect1: $collect1 (scale = ${collect1.x.scale}, precision = ${collect1.x.precision})")
          println(s"Collect2: $collect2 (scale = ${collect2.scale}, precision = ${collect2.precision})")
          println(s"Collect3: $collect3 (scale = ${collect3.x.scale}, precision = ${collect3.x.precision})")
          println(s"json1: $json1")
          println(s"json2: $json2")
          println(s"json3: $json3")
        }
      }
      

      Running it you'd see that json3 contains very much wrong data. After a bit of debugging, and sorry since I'm bad with Spark internals, I've found that:

      • In-memory representation of the data in this example used UnsafeRow, whose .getDecimal uses compression to store small Decimal values as longs, but doesn't remember decimal sizing parameters,
      • However, there are at least two sources for precision & scale to pass to that method: Dataset.schema (which is based on query execution, always contains 38,18 for me) and Dataset.encoder.schema (that gets updated in `ds2` to 12,6 but then is reset in `ds3`). Also, there is a Dataset.deserializer that seems to be combining those two non-trivially.
      • This doesn't seem to affect Dataset.collect() methods since they use deserializer, but Dataset.toJSON only uses the first schema.

      Seems to me that either .toJSON should be more aware of what's going on or .as[] should be doing something else.

      Attachments

        Issue Links

          Activity

            People

              bersprockets Bruce Robbins
              LDVSoft Dmitry Lapshin
              Votes:
              0 Vote for this issue
              Watchers:
              3 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: