diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java index 92600be..2ea8da3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java @@ -26,12 +26,15 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.OperatorFactory; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.RowSchema; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; @@ -40,6 +43,8 @@ import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.HashTableDummyDesc; +import org.apache.hadoop.hive.ql.plan.HashTableSinkDesc; +import org.apache.hadoop.hive.ql.plan.MapJoinDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.PlanUtils; import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; @@ -55,7 +60,7 @@ /* (non-Javadoc) * This processor addresses the RS-MJ case that occurs in spark on the small/hash - * table side of things. The work that RS will be a part of must be connected + * table side of things. The work that RS will be a part of must be connected * to the MJ work via be a broadcast edge. * We should not walk down the tree when we encounter this pattern because: * the type of work (map work or reduce work) needs to be determined @@ -139,7 +144,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, SparkWork sparkWork = context.currentTask.getWork(); LOG.debug("connecting "+parentWork.getName()+" with "+myWork.getName()); sparkWork.connect(parentWork, myWork, edgeProp); - + ReduceSinkOperator r = null; if (parentRS.getConf().getOutputName() != null) { LOG.debug("Cloning reduce sink for multi-child broadcast edge"); @@ -165,8 +170,8 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, } linkWorkMap.put(parentWork, edgeProp); context.linkOpWithWorkMap.put(mapJoinOp, linkWorkMap); - - List reduceSinks + + List reduceSinks = context.linkWorkWithReduceSinkMap.get(parentWork); if (reduceSinks == null) { reduceSinks = new ArrayList(); @@ -225,6 +230,20 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, } context.linkChildOpWithDummyOp.put(mapJoinOp, dummyOperators); + + //replace ReduceSinkOp with HashTableSinkOp for the RSops which are parents of MJop + MapJoinDesc mjDesc = mapJoinOp.getConf(); + + HashTableSinkDesc hashTableSinkDesc = new HashTableSinkDesc(mjDesc); + HashTableSinkOperator hashTableSinkOp = (HashTableSinkOperator) OperatorFactory + .get(hashTableSinkDesc); + + //get all parents of reduce sink + List> RSparentOps = parentRS.getParentOperators(); + for (Operator parent : RSparentOps) { + parent.replaceChild(parentRS, hashTableSinkOp); + } + hashTableSinkOp.setParentOperators(RSparentOps); return true; } }