Description
import random from pyspark.sql.functions import * from pyspark.sql.types import * def random_probability(label): if label == 1.0: return random.uniform(0.5, 1.0) else: return random.uniform(0.0, 0.4999) def randomize_label(ratio): if random.random() >= ratio: return 1.0 else: return 0.0 random_probability = udf(random_probability, DoubleType()) randomize_label = udf(randomize_label, DoubleType()) spark.range(10).write.mode("overwrite").format('csv').save("/tmp/tab3") babydf = spark.read.csv("/tmp/tab3") data_modified_label = babydf.withColumn( 'random_label', randomize_label(lit(1 - 0.1)) ) data_modified_random = data_modified_label.withColumn( 'random_probability', random_probability(col('random_label')) ) data_modified_label.filter(col('random_label') == 0).show()
The above code will generate the following exception:
Py4JJavaError: An error occurred while calling o446.showString. : java.lang.RuntimeException: Invalid PythonUDF randomize_label(0.9), requires attributes from more than one child. at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:166) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:165) at scala.collection.immutable.Stream.foreach(Stream.scala:594) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:165) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:116) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:112) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:309) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:112) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:92) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124) at scala.collection.immutable.List.foldLeft(List.scala:84) at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:119) at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:109) at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:109) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3016) at org.apache.spark.sql.Dataset.head(Dataset.scala:2216) at org.apache.spark.sql.Dataset.take(Dataset.scala:2429) at org.apache.spark.sql.Dataset.showString(Dataset.scala:248) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380) at py4j.Gateway.invoke(Gateway.java:293) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:226) at java.lang.Thread.run(Thread.java:748)