diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java index 7a3280c..f231b06 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -228,7 +228,7 @@ private void convertJoinSMBJoin(JoinOperator joinOp, OptimizeTezProcContext cont @SuppressWarnings("unchecked") CommonMergeJoinOperator mergeJoinOp = (CommonMergeJoinOperator) OperatorFactory.get(new CommonMergeJoinDesc(numBuckets, - isSubQuery, mapJoinConversionPos, mapJoinDesc)); + isSubQuery, mapJoinConversionPos, mapJoinDesc), joinOp.getSchema()); OpTraits opTraits = new OpTraits(joinOp.getOpTraits().getBucketColNames(), numBuckets, joinOp.getOpTraits() .getSortCols()); diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java index 8516643..e3c8727 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java @@ -22,21 +22,6 @@ import org.apache.hadoop.hive.ql.plan.TezWork.VertexType; public class MergeJoinProc implements NodeProcessor { - - public Operator getLeafOperator(Operator op) { - for (Operator childOp : op.getChildOperators()) { - // FileSink or ReduceSink operators are used to create vertices. See - // TezCompiler. - if ((childOp instanceof ReduceSinkOperator) || (childOp instanceof FileSinkOperator)) { - return childOp; - } else { - return getLeafOperator(childOp); - } - } - - return null; - } - @Override public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) @@ -60,13 +45,13 @@ // merge work already exists for this merge join operator, add the dummy store work to the // merge work. Else create a merge work, add above work to the merge work MergeJoinWork mergeWork = null; - if (context.opMergeJoinWorkMap.containsKey(getLeafOperator(mergeJoinOp))) { + if (context.opMergeJoinWorkMap.containsKey(mergeJoinOp)) { // we already have the merge work corresponding to this merge join operator - mergeWork = context.opMergeJoinWorkMap.get(getLeafOperator(mergeJoinOp)); + mergeWork = context.opMergeJoinWorkMap.get(mergeJoinOp); } else { mergeWork = new MergeJoinWork(); tezWork.add(mergeWork); - context.opMergeJoinWorkMap.put(getLeafOperator(mergeJoinOp), mergeWork); + context.opMergeJoinWorkMap.put(mergeJoinOp, mergeWork); } mergeWork.setMergeJoinOperator(mergeJoinOp); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java index 516e576..59a6327 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java @@ -137,15 +137,15 @@ public Object process(Node nd, Stack stack, // we are currently walking the big table side of the merge join. we need to create or hook up // merge join work. MergeJoinWork mergeJoinWork = null; - if (context.opMergeJoinWorkMap.containsKey(operator)) { + if (context.opMergeJoinWorkMap.containsKey(context.currentMergeJoinOperator)) { // we have found a merge work corresponding to this closing operator. Hook up this work. - mergeJoinWork = context.opMergeJoinWorkMap.get(operator); + mergeJoinWork = context.opMergeJoinWorkMap.get(context.currentMergeJoinOperator); } else { // we need to create the merge join work mergeJoinWork = new MergeJoinWork(); mergeJoinWork.setMergeJoinOperator(context.currentMergeJoinOperator); tezWork.add(mergeJoinWork); - context.opMergeJoinWorkMap.put(operator, mergeJoinWork); + context.opMergeJoinWorkMap.put(context.currentMergeJoinOperator, mergeJoinWork); } // connect the work correctly. mergeJoinWork.addMergedWork(work, null); @@ -334,10 +334,15 @@ public Object process(Node nd, Stack stack, UnionWork unionWork = (UnionWork) followingWork; int index = getMergeIndex(tezWork, unionWork, rs); // guaranteed to be instance of MergeJoinWork if index is valid - MergeJoinWork mergeJoinWork = (MergeJoinWork) tezWork.getChildren(unionWork).get(index); - // disconnect the connection to union work and connect to merge work - followingWork = mergeJoinWork; - rWork = (ReduceWork) mergeJoinWork.getMainWork(); + BaseWork baseWork = tezWork.getChildren(unionWork).get(index); + if (baseWork instanceof MergeJoinWork) { + MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork; + // disconnect the connection to union work and connect to merge work + followingWork = mergeJoinWork; + rWork = (ReduceWork) mergeJoinWork.getMainWork(); + } else { + rWork = (ReduceWork) baseWork; + } } else { rWork = (ReduceWork) followingWork; }