Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-11885

UDAF may nondeterministically generate wrong results

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Critical
    • Resolution: Fixed
    • 1.5.2
    • 1.5.3
    • SQL
    • None

    Description

      I could not reproduce it in 1.6 branch (it can be easily reproduced in 1.5). I think it is an issue in 1.5 branch.

      Try the following in spark 1.5 (with a cluster) and you can see the problem.

      import java.math.BigDecimal
      
      import org.apache.spark.sql.expressions.MutableAggregationBuffer
      import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
      import org.apache.spark.sql.Row
      import org.apache.spark.sql.types.{StructType, StructField, DataType, DoubleType, LongType}
      
      class GeometricMean extends UserDefinedAggregateFunction {
        def inputSchema: StructType =
          StructType(StructField("value", DoubleType) :: Nil)
      
        def bufferSchema: StructType = StructType(
          StructField("count", LongType) ::
            StructField("product", DoubleType) :: Nil
        )
      
        def dataType: DataType = DoubleType
      
        def deterministic: Boolean = true
      
        def initialize(buffer: MutableAggregationBuffer): Unit = {
          buffer(0) = 0L
          buffer(1) = 1.0
        }
      
        def update(buffer: MutableAggregationBuffer,input: Row): Unit = {
          buffer(0) = buffer.getAs[Long](0) + 1
          buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
        }
      
        def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
          buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
          buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
        }
      
        def evaluate(buffer: Row): Any = {
          math.pow(buffer.getDouble(1), 1.0d / buffer.getLong(0))
        }
      }
      
      sqlContext.udf.register("gm", new GeometricMean)
      
      val df = Seq(
        (1, "italy", "emilia", 42, BigDecimal.valueOf(100, 0), "john"),
        (2, "italy", "toscana", 42, BigDecimal.valueOf(505, 1), "jim"),
        (3, "italy", "puglia", 42, BigDecimal.valueOf(70, 0), "jenn"),
        (4, "italy", "emilia", 42, BigDecimal.valueOf(75 ,0), "jack"),
        (5, "uk", "london", 42, BigDecimal.valueOf(200 ,0), "carl"),
        (6, "italy", "emilia", 42, BigDecimal.valueOf(42, 0), "john")).
        toDF("receipt_id", "store_country", "store_region", "store_id", "amount", "seller_name")
      df.registerTempTable("receipts")
        
      val q = sql("""
      select   store_country,
               store_region,
               avg(amount),
               sum(amount),
               gm(amount)
      from     receipts
      where    amount > 50
               and store_country = 'italy'
      group by store_country, store_region
      """)
      
      q.show
      

      Attachments

        Issue Links

          Activity

            People

              davies Davies Liu
              yhuai Yin Huai
              Votes:
              1 Vote for this issue
              Watchers:
              5 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: