diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index 02cebdc5ac..7cb5dcb8bf 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -481,6 +481,11 @@ private void semijoinRemovalBasedTransformations(OptimizeTezProcContext procCtx, markSemiJoinForDPP(procCtx); perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Mark certain semijoin edges important based "); + // Remove any semi join edges from Union Op + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); + removeSemiJoinEdgesForUnion(procCtx); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove any semi join edge between Union and RS"); + // Remove any parallel edge between semijoin and mapjoin. perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); removeSemijoinsParallelToMapJoin(procCtx); @@ -1313,6 +1318,46 @@ private boolean findParallelSemiJoinBranch(Operator mapjoin, TableScanOperato return parallelEdges; } + /* + * Given an operator this method removes all semi join edges downstream (children) until it hits RS + */ + private void removeSemiJoinEdges(Operator op, OptimizeTezProcContext procCtx) throws SemanticException { + if(op instanceof ReduceSinkOperator && op.getNumChild() == 0) { + Map sjMap = procCtx.parseContext.getRsToSemiJoinBranchInfo(); + if(sjMap.get(op) != null) { + //remove semi join + GenTezUtils.removeBranch(op); + GenTezUtils.removeSemiJoinOperator(procCtx.parseContext, (ReduceSinkOperator)op, sjMap.get(op).getTsOp()); + } else { + return; + } + } + for(Operator child:op.getChildOperators()) { + removeSemiJoinEdges(child, procCtx); + } + } + + private void removeSemiJoinEdgesForUnion(OptimizeTezProcContext procCtx) throws SemanticException{ + // Get all the TS ops. + List> topOps = new ArrayList<>(); + topOps.addAll(procCtx.parseContext.getTopOps().values()); + Set> unionOps = new HashSet<>(); + + Map semijoins = new HashMap<>(); + for (Operator parent : topOps) { + Deque> deque = new LinkedList<>(); + deque.add(parent); + while (!deque.isEmpty()) { + Operator op = deque.pollLast(); + if (op instanceof UnionOperator && !unionOps.contains(op)) { + removeSemiJoinEdges(op, procCtx); + unionOps.add(op); + } + deque.addAll(op.getChildOperators()); + } + } + } + /* * The algorithm looks at all the mapjoins in the operator pipeline until * it hits RS Op and for each mapjoin examines if it has paralllel semijoin