Description
The following query produces incorrect results:
sqlContext.sql(""" SELECT a.i, a.x, ROW_NUMBER() OVER ( PARTITION BY a.i ORDER BY a.x) AS row_num FROM a JOIN b ON b.i = a.i """).show() +---+--------------------+-------+ | i| x|row_num| +---+--------------------+-------+ | 1| 0.8717439935587555| 1| | 1| 0.6684483939068196| 2| | 1| 0.3378351523586306| 3| | 1| 0.2483285619632939| 4| | 1| 0.4796752841655936| 5| | 2| 0.2971739640384895| 1| | 2| 0.2199359901600595| 2| | 2| 0.4646004597998037| 3| | 2| 0.24823688829578183| 4| | 2| 0.5914212915574378| 5| | 3|0.010912835935112164| 1| | 3| 0.6520139509583123| 2| | 3| 0.8571994559240592| 3| | 3| 0.1122635843020473| 4| | 3| 0.45913022936460457| 5| +---+--------------------+-------+
The row number doesn't follow the correct order. The join seems to break the order, ROW_NUMBER() works correctly if the join results are saved to a temporary table, and a second query is made.
Here's a small PySpark test case to reproduce the error:
from pyspark.sql import Row import random a = sc.parallelize([Row(i=i, x=random.random()) for i in range(5) for j in range(5)]) b = sc.parallelize([Row(i=i) for i in [1, 2, 3]]) af = sqlContext.createDataFrame(a) bf = sqlContext.createDataFrame(b) af.registerTempTable('a') bf.registerTempTable('b') af.show() # +---+--------------------+ # | i| x| # +---+--------------------+ # | 0| 0.12978974167478896| # | 0| 0.7105927498584452| # | 0| 0.21225679077448045| # | 0| 0.03849717391728036| # | 0| 0.4976622146442401| # | 1| 0.4796752841655936| # | 1| 0.8717439935587555| # | 1| 0.6684483939068196| # | 1| 0.3378351523586306| # | 1| 0.2483285619632939| # | 2| 0.2971739640384895| # | 2| 0.2199359901600595| # | 2| 0.5914212915574378| # | 2| 0.24823688829578183| # | 2| 0.4646004597998037| # | 3| 0.1122635843020473| # | 3| 0.6520139509583123| # | 3| 0.45913022936460457| # | 3|0.010912835935112164| # | 3| 0.8571994559240592| # +---+--------------------+ # only showing top 20 rows bf.show() # +---+ # | i| # +---+ # | 1| # | 2| # | 3| # +---+ ### WRONG sqlContext.sql(""" SELECT a.i, a.x, ROW_NUMBER() OVER ( PARTITION BY a.i ORDER BY a.x) AS row_num FROM a JOIN b ON b.i = a.i """).show() # +---+--------------------+-------+ # | i| x|row_num| # +---+--------------------+-------+ # | 1| 0.8717439935587555| 1| # | 1| 0.6684483939068196| 2| # | 1| 0.3378351523586306| 3| # | 1| 0.2483285619632939| 4| # | 1| 0.4796752841655936| 5| # | 2| 0.2971739640384895| 1| # | 2| 0.2199359901600595| 2| # | 2| 0.4646004597998037| 3| # | 2| 0.24823688829578183| 4| # | 2| 0.5914212915574378| 5| # | 3|0.010912835935112164| 1| # | 3| 0.6520139509583123| 2| # | 3| 0.8571994559240592| 3| # | 3| 0.1122635843020473| 4| # | 3| 0.45913022936460457| 5| # +---+--------------------+-------+ ### WORKAROUND BY USING TEMP TABLE t = sqlContext.sql(""" SELECT a.i, a.x FROM a JOIN b ON b.i = a.i """).cache() # trigger computation t.head() t.registerTempTable('t') sqlContext.sql(""" SELECT i, x, ROW_NUMBER() OVER ( PARTITION BY i ORDER BY x) AS row_num FROM t """).show() # +---+--------------------+-------+ # | i| x|row_num| # +---+--------------------+-------+ # | 1| 0.2483285619632939| 1| # | 1| 0.3378351523586306| 2| # | 1| 0.4796752841655936| 3| # | 1| 0.6684483939068196| 4| # | 1| 0.8717439935587555| 5| # | 2| 0.2199359901600595| 1| # | 2| 0.24823688829578183| 2| # | 2| 0.2971739640384895| 3| # | 2| 0.4646004597998037| 4| # | 2| 0.5914212915574378| 5| # | 3|0.010912835935112164| 1| # | 3| 0.1122635843020473| 2| # | 3| 0.45913022936460457| 3| # | 3| 0.6520139509583123| 4| # | 3| 0.8571994559240592| 5| # +---+--------------------+-------+