Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-23833

Incorrect primitive type check for input arguments of udf

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Incomplete
    • 2.2.0, 2.3.0
    • None
    • Optimizer, SQL

    Description

      There is claimed behavior for scala UDFs with primitive type arguments:

      Note that if you use primitive parameters, you are not able to check if it is null or not, and the UDF will return null for you if the primitive input is null.

      This is initial issue - SPARK-11725
      Correspondent pr - PR

      The problem is that ScalaReflection.getParameterTypes doesn't work correctly due to type erasure.
      The correct check "if type is primitive" should be based on typeTag something like this:

      typeTag[T].tpe.typeSymbol.asClass.isPrimitive
      

       

      The problem appears if we have high order functions:

      val f = (x: Long) => x
      def identity[T, U](f: T => U): T => U = (t: T) => f(t)
      val udf0 = udf(f)
      val udf1 = udf(identity(f))
      val getNull = udf(() => null.asInstanceOf[java.lang.Long])
      spark.range(5).toDF().
        withColumn("udf0", udf0(getNull())).
        withColumn("udf1", udf1(getNull())).
        show()
      spark.range(5).toDF().
        withColumn("udf0", udf0(getNull())).
        withColumn("udf1", udf1(getNull())).
        explain()
      

      Test execution on Spark 2.2 spark-shell:

      scala> val f = (x: Long) => x
      f: Long => Long = <function1>
      
      scala> def identity[T, U](f: T => U): T => U = (t: T) => f(t)
      identity: [T, U](f: T => U)T => U
      
      scala> val udf0 = udf(f)
      udf0: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
      
      scala> val udf1 = udf(identity(f))
      udf1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,LongType,Some(List(LongType)))
      
      scala> val getNull = udf(() => null.asInstanceOf[java.lang.Long])
      getNull: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function0>,LongType,Some(List()))
      
      scala> spark.range(5).toDF().
           |   withColumn("udf0", udf0(getNull())).
           |   withColumn("udf1", udf1(getNull())).
           |   show()
      +---+----+----+                                                                 
      | id|udf0|udf1|
      +---+----+----+
      |  0|null|   0|
      |  1|null|   0|
      |  2|null|   0|
      |  3|null|   0|
      |  4|null|   0|
      +---+----+----+
      
      
      scala> spark.range(5).toDF().
           |   withColumn("udf0", udf0(getNull())).
           |   withColumn("udf1", udf1(getNull())).
           |   explain()
      == Physical Plan ==
      *Project [id#19L, if (isnull(UDF())) null else UDF(UDF()) AS udf0#24L, UDF(UDF()) AS udf1#28L]
      +- *Range (0, 5, step=1, splits=6)
      

       

      The typeTag information about input parameters is available in udf function but only used to get schema, it should be added to ScalaUDF too so that we can used it later:

      def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
        val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption
        UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes)
      }
      

       

      Here is current vs desired version:

      scala> import org.apache.spark.sql.catalyst.ScalaReflection
      import org.apache.spark.sql.catalyst.ScalaReflection
      
      scala> ScalaReflection.getParameterTypes(identity(f))
      res2: Seq[Class[_]] = WrappedArray(class java.lang.Object)
      
      scala> ScalaReflection.getParameterTypes(identity(f)).map(_.isPrimitive)
      res7: Seq[Boolean] = ArrayBuffer(false)
      

      versus

      scala> import scala.reflect.runtime.universe.{typeTag, TypeTag}
      import scala.reflect.runtime.universe.{typeTag, TypeTag}
      
      scala> def myGetParameterTypes[T : TypeTag, U](func: T => U) = {
           |   typeTag[T].tpe.typeSymbol.asClass
           | }
      myGetParameterTypes: [T, U](func: T => U)(implicit evidence$1: reflect.runtime.universe.TypeTag[T])reflect.runtime.universe.ClassSymbol
      
      scala> myGetParameterTypes(f)
      res3: reflect.runtime.universe.ClassSymbol = class Long
      
      scala> myGetParameterTypes(f).isPrimitive
      res4: Boolean = true
      

      Although for this case there is workaround with using @specialized(Long)

      scala> def identity2[@specialized(Long) T, U](f: T => U): T => U = (t: T) => f(t)
      identity2: [T, U](f: T => U)T => U
      
      scala> ScalaReflection.getParameterTypes(identity2(f))
      res10: Seq[Class[_]] = WrappedArray(long)
      

      Attachments

        Activity

          People

            Unassigned Unassigned
            rednikotin Valentin Nikotin
            Votes:
            0 Vote for this issue
            Watchers:
            4 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: