Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-24721

Failed to use PythonUDF with literal inputs in filter with data sources

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Fixed
    • 2.3.1
    • 2.4.0
    • PySpark, SQL
    • None

    Description

      import random
      from pyspark.sql.functions import *
      from pyspark.sql.types import *
      
      def random_probability(label):
          if label == 1.0:
            return random.uniform(0.5, 1.0)
          else:
            return random.uniform(0.0, 0.4999)
      
      def randomize_label(ratio):
          
          if random.random() >= ratio:
            return 1.0
          else:
            return 0.0
      
      random_probability = udf(random_probability, DoubleType())
      randomize_label = udf(randomize_label, DoubleType())
      
      spark.range(10).write.mode("overwrite").format('csv').save("/tmp/tab3")
      babydf = spark.read.csv("/tmp/tab3")
      data_modified_label = babydf.withColumn(
        'random_label', randomize_label(lit(1 - 0.1))
      )
      
      
      data_modified_random = data_modified_label.withColumn(
        'random_probability', 
        random_probability(col('random_label'))
      )
      
      data_modified_label.filter(col('random_label') == 0).show()
      

      The above code will generate the following exception:

      Py4JJavaError: An error occurred while calling o446.showString.
      : java.lang.RuntimeException: Invalid PythonUDF randomize_label(0.9), requires attributes from more than one child.
      	at scala.sys.package$.error(package.scala:27)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:166)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:165)
      	at scala.collection.immutable.Stream.foreach(Stream.scala:594)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:165)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:116)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:112)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310)
      	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:309)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
      	at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
      	at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:112)
      	at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:92)
      	at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119)
      	at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119)
      	at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
      	at scala.collection.immutable.List.foldLeft(List.scala:84)
      	at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:119)
      	at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:109)
      	at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:109)
      	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3016)
      	at org.apache.spark.sql.Dataset.head(Dataset.scala:2216)
      	at org.apache.spark.sql.Dataset.take(Dataset.scala:2429)
      	at org.apache.spark.sql.Dataset.showString(Dataset.scala:248)
      	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
      	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
      	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
      	at java.lang.reflect.Method.invoke(Method.java:498)
      	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
      	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
      	at py4j.Gateway.invoke(Gateway.java:293)
      	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
      	at py4j.commands.CallCommand.execute(CallCommand.java:79)
      	at py4j.GatewayConnection.run(GatewayConnection.java:226)
      	at java.lang.Thread.run(Thread.java:748)
      

      Attachments

        Activity

          People

            icexelloss Li Jin
            smilegator Xiao Li
            Votes:
            0 Vote for this issue
            Watchers:
            5 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: