Description
Currently `ShuffledHashJoin.outputPartitioning` inherits from `HashJoin.outputPartitioning`, which only preserves stream side partitioning:
`HashJoin.scala`
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
This loses build side partitioning information, and causes extra shuffle if there's another join / group-by after this join.
Example:
// code placeholder withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", SQLConf.SHUFFLE_PARTITIONS.key -> "2", SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { val df1 = spark.range(10).select($"id".as("k1")) val df2 = spark.range(30).select($"id".as("k2")) Seq("inner", "cross").foreach(joinType => { val plan = df1.join(df2, $"k1" === $"k2", joinType).groupBy($"k1").count() .queryExecution.executedPlan assert(plan.collect { case _: ShuffledHashJoinExec => true }.size === 1) // No extra shuffle before aggregate assert(plan.collect { case _: ShuffleExchangeExec => true }.size === 2) }) }
Current physical plan (having an extra shuffle on `k1` before aggregate)
*(4) HashAggregate(keys=[k1#220L], functions=[count(1)], output=[k1#220L, count#235L]) +- Exchange hashpartitioning(k1#220L, 2), true, [id=#117] +- *(3) HashAggregate(keys=[k1#220L], functions=[partial_count(1)], output=[k1#220L, count#239L]) +- *(3) Project [k1#220L] +- ShuffledHashJoin [k1#220L], [k2#224L], Inner, BuildLeft :- Exchange hashpartitioning(k1#220L, 2), true, [id=#109] : +- *(1) Project [id#218L AS k1#220L] : +- *(1) Range (0, 10, step=1, splits=2) +- Exchange hashpartitioning(k2#224L, 2), true, [id=#111] +- *(2) Project [id#222L AS k2#224L] +- *(2) Range (0, 30, step=1, splits=2)
Ideal physical plan (no shuffle on `k1` before aggregate)
*(3) HashAggregate(keys=[k1#220L], functions=[count(1)], output=[k1#220L, count#235L]) +- *(3) HashAggregate(keys=[k1#220L], functions=[partial_count(1)], output=[k1#220L, count#239L]) +- *(3) Project [k1#220L] +- ShuffledHashJoin [k1#220L], [k2#224L], Inner, BuildLeft :- Exchange hashpartitioning(k1#220L, 2), true, [id=#107] : +- *(1) Project [id#218L AS k1#220L] : +- *(1) Range (0, 10, step=1, splits=2) +- Exchange hashpartitioning(k2#224L, 2), true, [id=#109] +- *(2) Project [id#222L AS k2#224L] +- *(2) Range (0, 30, step=1, splits=2)
This can be fixed by overriding `outputPartitioning` method in `ShuffledHashJoinExec`, similar to `SortMergeJoinExec`.