Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-28125

dataframes created by randomSplit have overlapping rows



    • Important


      It appears that the function randomSplit on a DataFrame creates a separate execution plan for each of the result DataFrames, or at least that's the impression I get from reading a few StackOverflow pages on it: 




      Because of the separate executions, it is easy to create a situation where the Dataframes returned by randomSplit have overlapping rows. Thus if people are relying on it to split a dataset into training and test, then they could easily end up with the same rows in both sets, thus causing a serious problem when running model evaluation. 


      I know that if you call .cache() on the RDD before calling .randomSplit then you can be assured that the returned frames have unique rows, but this work-around is definitely not obvious. I did not know about this issue and ended up creating improper data sets when doing model training and evaluation. Something should be adjusted in .randomSplit so that under all circumstances, the returned Dataframes will have unique rows. 


      Here is a Pyspark script I wrote that re-creates the issue and includes the work-around line that fixes it as a temporary workaround: 


      import numpy as np
      from pyspark.sql import Row
      from pyspark.sql.functions import *
      from pyspark.sql.types import *
      N = 100000
      ratio1 = 0.85
      ratio2 = 0.15
      gen_rand = udf(lambda x: int(np.random.random()*50000 + 2), IntegerType())
      orig_list = list(np.zeros(N))
      rdd = sc.parallelize(orig_list).map(int).map(lambda x: {'ID': x})
      df = sqlContext.createDataFrame(rdd.map(lambda x: Row(**x)))
      dfA = df.withColumn("ID2", gen_rand(df['ID']))
      orig_list = list(np.zeros(N))
      rdd = sc.parallelize(orig_list).map(int).map(lambda x: {'ID': x})
      df = sqlContext.createDataFrame(rdd.map(lambda x: Row(**x)))
      dfA = df.withColumn("ID2", gen_rand(df['ID']))
      dfA = dfA.select("ID2").distinct()
      dfA_els = dfA.rdd.map(lambda x: x['ID2']).collect()
      print("This confirms that if you look at the parent Dataframe, the ID2 col has unqiue values")
      print("Num rows parent DF: {}".format(len(dfA_els)))
      print("num unique ID2 vals: {}".format(len(set(dfA_els))))
      #dfA = dfA.cache() #Uncommenting this line does fix the issue
      df1, df2 = dfA.randomSplit([ratio2, ratio1])
      df1_ids = set(df1.rdd.map(lambda x: x['ID2']).distinct().collect())
      df2_ids = set(df2.rdd.map(lambda x: x['ID2']).distinct().collect())
      num_inter = len(df1_ids.intersection(df2_ids))
      print("Number common IDs between the two splits: {}".format(num_inter))
      print("(should be zero if randomSplit is working as expected)")




            Unassigned Unassigned
            zacharydestefano Zachary
            0 Vote for this issue
            5 Start watching this issue

