Details
Description
When submitting an application to Spark built with Scala 2.13, there are issues with Decimal overflow that show up when using unary minus (and also abs() which uses unary minus under the hood.
Here is an example PySpark reproduce use case:
from decimal import Decimal from pyspark.sql import SparkSession from pyspark.sql.types import StructType,StructField, DecimalType spark = SparkSession.builder \ .master("local[*]") \ .appName("decimal_precision") \ .config("spark.rapids.sql.explain", "ALL") \ .config("spark.sql.ansi.enabled", "true") \ .config("spark.sql.legacy.allowNegativeScaleOfDecimal", 'true') \ .getOrCreate() precision = 38 scale = 0 DECIMAL_MIN = Decimal('-' + ('9' * precision) + 'e' + str(-scale)) data = [[DECIMAL_MIN]] schema = StructType([ StructField("a", DecimalType(precision, scale), True)]) df = spark.createDataFrame(data=data, schema=schema) df.selectExpr("a", "-a").show()
This particular example will run successfully on Spark built with Scala 2.12, but throw a java.math.ArithmeticException on Spark built with Scala 2.13.
If you change the value of DECIMAL_MIN in the previous code to something just ahead of the original DECIMAL_MIN, you will not get an exception thrown, but instead you will get an incorrect answer (possibly due to overflow):
... DECIMAL_MIN = Decimal('-8' + ('9' * (precision-1)) + 'e' + str(-scale)) ...
Output:
+--------------------+--------------------+ | a| (- a)| +--------------------+--------------------+ |-8999999999999999...|90000000000000000...| +--------------------+--------------------+
It looks like the code in Decimal.scala uses scala.math.BigDecimal. See https://github.com/scala/bug/issues/11590 with updates on how Scala 2.13 handles BigDecimal. It looks like there is java.math.MathContext missing when performing these operations.