Details
Description
When having multiple Python UDFs - the last Python UDF's accumulator is the only accumulator that gets updated.
import pyspark from pyspark.sql import SparkSession, Row from pyspark.sql import functions as F from pyspark.sql import types as T from pyspark import AccumulatorParam spark = SparkSession.builder.getOrCreate() spark.sparkContext.setLogLevel("ERROR") test_accum = spark.sparkContext.accumulator(0.0) SHUFFLE = False def main(data): print(">>> Check0", test_accum.value) def test(x): global test_accum test_accum += 1.0 return x print(">>> Check1", test_accum.value) def test2(x): global test_accum test_accum += 100.0 return x print(">>> Check2", test_accum.value) func_udf = F.udf(test, T.DoubleType()) print(">>> Check3", test_accum.value) func_udf2 = F.udf(test2, T.DoubleType()) print(">>> Check4", test_accum.value) data = data.withColumn("out1", func_udf(data["a"])) if SHUFFLE: data = data.repartition(2) print(">>> Check5", test_accum.value) data = data.withColumn("out2", func_udf2(data["b"])) if SHUFFLE: data = data.repartition(2) print(">>> Check6", test_accum.value) data.show() # ACTION print(">>> Check7", test_accum.value) return data df = spark.createDataFrame([ [1.0, 2.0] ], schema=T.StructType([T.StructField(field_name, T.DoubleType(), True) for field_name in ["a", "b"]])) df2 = main(df)
######## Output 1 - with SHUFFLE=False ... # >>> Check7 100.0 ######## Output 2 - with SHUFFLE=True ... # >>> Check7 101.0
Basically looks like:
- Accumulator works only for last UDF before a shuffle-like operation