Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-11405

ROW_NUMBER function does not adhere to window ORDER BY, when joining

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Closed
    • Critical
    • Resolution: Fixed
    • 1.5.0
    • 1.5.2
    • SQL
    • None
    • YARN

    Description

      The following query produces incorrect results:

      sqlContext.sql("""
        SELECT a.i, a.x,
          ROW_NUMBER() OVER (
            PARTITION BY a.i ORDER BY a.x) AS row_num
        FROM a
        JOIN b ON b.i = a.i
      """).show()
      +---+--------------------+-------+
      |  i|                   x|row_num|
      +---+--------------------+-------+
      |  1|  0.8717439935587555|      1|
      |  1|  0.6684483939068196|      2|
      |  1|  0.3378351523586306|      3|
      |  1|  0.2483285619632939|      4|
      |  1|  0.4796752841655936|      5|
      |  2|  0.2971739640384895|      1|
      |  2|  0.2199359901600595|      2|
      |  2|  0.4646004597998037|      3|
      |  2| 0.24823688829578183|      4|
      |  2|  0.5914212915574378|      5|
      |  3|0.010912835935112164|      1|
      |  3|  0.6520139509583123|      2|
      |  3|  0.8571994559240592|      3|
      |  3|  0.1122635843020473|      4|
      |  3| 0.45913022936460457|      5|
      +---+--------------------+-------+
      

      The row number doesn't follow the correct order. The join seems to break the order, ROW_NUMBER() works correctly if the join results are saved to a temporary table, and a second query is made.

      Here's a small PySpark test case to reproduce the error:

      from pyspark.sql import Row
      import random
      
      a = sc.parallelize([Row(i=i, x=random.random())
                          for i in range(5)
                          for j in range(5)])
      
      b = sc.parallelize([Row(i=i) for i in [1, 2, 3]])
      
      af = sqlContext.createDataFrame(a)
      bf = sqlContext.createDataFrame(b)
      af.registerTempTable('a')
      bf.registerTempTable('b')
      
      af.show()
      # +---+--------------------+
      # |  i|                   x|
      # +---+--------------------+
      # |  0| 0.12978974167478896|
      # |  0|  0.7105927498584452|
      # |  0| 0.21225679077448045|
      # |  0| 0.03849717391728036|
      # |  0|  0.4976622146442401|
      # |  1|  0.4796752841655936|
      # |  1|  0.8717439935587555|
      # |  1|  0.6684483939068196|
      # |  1|  0.3378351523586306|
      # |  1|  0.2483285619632939|
      # |  2|  0.2971739640384895|
      # |  2|  0.2199359901600595|
      # |  2|  0.5914212915574378|
      # |  2| 0.24823688829578183|
      # |  2|  0.4646004597998037|
      # |  3|  0.1122635843020473|
      # |  3|  0.6520139509583123|
      # |  3| 0.45913022936460457|
      # |  3|0.010912835935112164|
      # |  3|  0.8571994559240592|
      # +---+--------------------+
      # only showing top 20 rows
      
      bf.show()
      # +---+
      # |  i|
      # +---+
      # |  1|
      # |  2|
      # |  3|
      # +---+
      
      ### WRONG
      
      sqlContext.sql("""
        SELECT a.i, a.x,
          ROW_NUMBER() OVER (
            PARTITION BY a.i ORDER BY a.x) AS row_num
        FROM a
        JOIN b ON b.i = a.i
      """).show()
      # +---+--------------------+-------+
      # |  i|                   x|row_num|
      # +---+--------------------+-------+
      # |  1|  0.8717439935587555|      1|
      # |  1|  0.6684483939068196|      2|
      # |  1|  0.3378351523586306|      3|
      # |  1|  0.2483285619632939|      4|
      # |  1|  0.4796752841655936|      5|
      # |  2|  0.2971739640384895|      1|
      # |  2|  0.2199359901600595|      2|
      # |  2|  0.4646004597998037|      3|
      # |  2| 0.24823688829578183|      4|
      # |  2|  0.5914212915574378|      5|
      # |  3|0.010912835935112164|      1|
      # |  3|  0.6520139509583123|      2|
      # |  3|  0.8571994559240592|      3|
      # |  3|  0.1122635843020473|      4|
      # |  3| 0.45913022936460457|      5|
      # +---+--------------------+-------+
      
      ### WORKAROUND BY USING TEMP TABLE
      
      t = sqlContext.sql("""
        SELECT a.i, a.x
        FROM a
        JOIN b ON b.i = a.i
      """).cache()
      # trigger computation
      t.head()
      t.registerTempTable('t')
      
      sqlContext.sql("""
        SELECT i, x,
          ROW_NUMBER() OVER (
            PARTITION BY i ORDER BY x) AS row_num
        FROM t
      """).show()
      # +---+--------------------+-------+
      # |  i|                   x|row_num|
      # +---+--------------------+-------+
      # |  1|  0.2483285619632939|      1|
      # |  1|  0.3378351523586306|      2|
      # |  1|  0.4796752841655936|      3|
      # |  1|  0.6684483939068196|      4|
      # |  1|  0.8717439935587555|      5|
      # |  2|  0.2199359901600595|      1|
      # |  2| 0.24823688829578183|      2|
      # |  2|  0.2971739640384895|      3|
      # |  2|  0.4646004597998037|      4|
      # |  2|  0.5914212915574378|      5|
      # |  3|0.010912835935112164|      1|
      # |  3|  0.1122635843020473|      2|
      # |  3| 0.45913022936460457|      3|
      # |  3|  0.6520139509583123|      4|
      # |  3|  0.8571994559240592|      5|
      # +---+--------------------+-------+
      

      Attachments

        Activity

          People

            joshrosen Josh Rosen
            jseppanen Jarno Seppanen
            Votes:
            0 Vote for this issue
            Watchers:
            5 Start watching this issue

            Dates

              Created:
              Updated:
              Resolved: