diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index cf8e843..47b229f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -571,53 +571,52 @@ protected void optimizeTaskPlan(List> rootTasks, Pa return; } - private static class SemijoinRemovalContext implements NodeProcessorCtx { - List> parents = new ArrayList>(); + private static class SMBJoinOpProcContext implements NodeProcessorCtx { + HashMap JoinOpToTsOpMap = new HashMap(); } - private static class SemijoinRemovalProc implements NodeProcessor { + private static class SMBJoinOpProc implements NodeProcessor { @Override public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) throws SemanticException { - SemijoinRemovalContext ctx = (SemijoinRemovalContext) procCtx; - Operator parent = (Operator) stack.get(stack.size() - 2); - ctx.parents.add(parent); + SMBJoinOpProcContext ctx = (SMBJoinOpProcContext) procCtx; + ctx.JoinOpToTsOpMap.put((CommonMergeJoinOperator) nd, + (TableScanOperator) stack.get(0)); return null; } } - private static void collectSemijoinOps(Operator ts, NodeProcessorCtx ctx) throws SemanticException { - // create a walker which walks the tree in a DFS manner while maintaining - // the operator stack. The dispatcher - // generates the plan from the operator tree + private static void removeSemijoinOptimizationFromSMBJoins( + OptimizeTezProcContext procCtx) throws SemanticException { + if (!procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION) || + procCtx.parseContext.getRsOpToTsOpMap().size() == 0) { + return; + } + Map opRules = new LinkedHashMap(); - opRules.put(new RuleRegExp("R1", SelectOperator.getOperatorName() + "%" + - TezDummyStoreOperator.getOperatorName() + "%"), - new SemijoinRemovalProc()); - opRules.put(new RuleRegExp("R2", SelectOperator.getOperatorName() + "%" + + opRules.put( + new RuleRegExp("R1", TableScanOperator.getOperatorName() + "%" + + ".*" + TezDummyStoreOperator.getOperatorName() + "%" + CommonMergeJoinOperator.getOperatorName() + "%"), - new SemijoinRemovalProc()); + new SMBJoinOpProc()); + + SMBJoinOpProcContext ctx = new SMBJoinOpProcContext(); + // The dispatcher finds SMB and if there is semijoin optimization before it, removes it. Dispatcher disp = new DefaultRuleDispatcher(null, opRules, ctx); + List topNodes = new ArrayList(); + topNodes.addAll(procCtx.parseContext.getTopOps().values()); GraphWalker ogw = new PreOrderOnceWalker(disp); - List startNodes = new ArrayList(); - startNodes.add(ts); - - HashMap outputMap = new HashMap(); - ogw.startWalking(startNodes, null); - } - - private static class SMBJoinOpProc implements NodeProcessor { + ogw.startWalking(topNodes, null); - @Override - public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, - Object... nodeOutputs) throws SemanticException { + // Iterate over the map and remove semijoin optimizations if needed. + for (CommonMergeJoinOperator joinOp : ctx.JoinOpToTsOpMap.keySet()) { List tsOps = new ArrayList(); // Get one top level TS Op directly from the stack - tsOps.add((TableScanOperator)stack.get(0)); + tsOps.add(ctx.JoinOpToTsOpMap.get(joinOp)); // Get the other one by examining Join Op - List> parents = ((CommonMergeJoinOperator) nd).getParentOperators(); + List> parents = joinOp.getParentOperators(); for (Operator parent : parents) { if (parent instanceof TezDummyStoreOperator) { // already accounted for @@ -636,7 +635,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, // Now the relevant TableScanOperators are known, find if there exists // a semijoin filter on any of them, if so, remove it. - ParseContext pctx = ((OptimizeTezProcContext) procCtx).parseContext; + ParseContext pctx = procCtx.parseContext; for (TableScanOperator ts : tsOps) { for (ReduceSinkOperator rs : pctx.getRsOpToTsOpMap().keySet()) { if (ts == pctx.getRsOpToTsOpMap().get(rs)) { @@ -646,11 +645,27 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } } + } + } + + private static class SemiJoinCycleRemovalDueTOMapsideJoinContext implements NodeProcessorCtx { + HashMap,Operator> childParentMap = new HashMap,Operator>(); + } + + private static class SemiJoinCycleRemovalDueToMapsideJoins implements NodeProcessor { + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + + SemiJoinCycleRemovalDueTOMapsideJoinContext ctx = + (SemiJoinCycleRemovalDueTOMapsideJoinContext) procCtx; + ctx.childParentMap.put((Operator)stack.get(stack.size() - 2), (Operator) nd); return null; } } - private static void removeSemijoinOptimizationFromSMBJoins( + private static void removeSemiJoinCyclesDueToMapsideJoins( OptimizeTezProcContext procCtx) throws SemanticException { if (!procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION) || procCtx.parseContext.getRsOpToTsOpMap().size() == 0) { @@ -659,31 +674,37 @@ private static void removeSemijoinOptimizationFromSMBJoins( Map opRules = new LinkedHashMap(); opRules.put( - new RuleRegExp("R1", TableScanOperator.getOperatorName() + "%" + - ".*" + TezDummyStoreOperator.getOperatorName() + "%" + + new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%" + + MapJoinOperator.getOperatorName() + "%"), + new SemiJoinCycleRemovalDueToMapsideJoins()); + opRules.put( + new RuleRegExp("R2", MapJoinOperator.getOperatorName() + "%" + CommonMergeJoinOperator.getOperatorName() + "%"), - new SMBJoinOpProc()); + new SemiJoinCycleRemovalDueToMapsideJoins()); + opRules.put( + new RuleRegExp("R3", CommonMergeJoinOperator.getOperatorName() + "%" + + MapJoinOperator.getOperatorName() + "%"), + new SemiJoinCycleRemovalDueToMapsideJoins()); + opRules.put( + new RuleRegExp("R4", CommonMergeJoinOperator.getOperatorName() + "%" + + CommonMergeJoinOperator.getOperatorName() + "%"), + new SemiJoinCycleRemovalDueToMapsideJoins()); - // The dispatcher finds SMB and if there is semijoin optimization before it, removes it. - Dispatcher disp = new DefaultRuleDispatcher(null, opRules, procCtx); + SemiJoinCycleRemovalDueTOMapsideJoinContext ctx = + new SemiJoinCycleRemovalDueTOMapsideJoinContext(); + Dispatcher disp = new DefaultRuleDispatcher(null, opRules, ctx); List topNodes = new ArrayList(); topNodes.addAll(procCtx.parseContext.getTopOps().values()); GraphWalker ogw = new PreOrderOnceWalker(disp); ogw.startWalking(topNodes, null); - } - private static class SemiJoinCycleRemovalDueToMapsideJoins implements NodeProcessor { - - @Override - public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, - Object... nodeOutputs) throws SemanticException { - ParseContext pCtx = ((OptimizeTezProcContext) procCtx).parseContext; - Operator childJoin = ((Operator) nd); - Operator parentJoin = ((Operator) stack.get(stack.size() - 2)); + // process the list + ParseContext pCtx = procCtx.parseContext; + for (Operator parentJoin : ctx.childParentMap.keySet()) { + Operator childJoin = ctx.childParentMap.get(parentJoin); if (parentJoin.getChildOperators().size() == 1) { - // Nothing to do here - return null; + continue; } for (Operator child : parentJoin.getChildOperators()) { @@ -723,40 +744,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } } - return null; - } - } - - private static void removeSemiJoinCyclesDueToMapsideJoins( - OptimizeTezProcContext procCtx) throws SemanticException { - if (!procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION) || - procCtx.parseContext.getRsOpToTsOpMap().size() == 0) { - return; } - - Map opRules = new LinkedHashMap(); - opRules.put( - new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%" + - MapJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R2", MapJoinOperator.getOperatorName() + "%" + - CommonMergeJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R3", CommonMergeJoinOperator.getOperatorName() + "%" + - MapJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R4", CommonMergeJoinOperator.getOperatorName() + "%" + - CommonMergeJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - - Dispatcher disp = new DefaultRuleDispatcher(null, opRules, procCtx); - List topNodes = new ArrayList(); - topNodes.addAll(procCtx.parseContext.getTopOps().values()); - GraphWalker ogw = new PreOrderOnceWalker(disp); - ogw.startWalking(topNodes, null); } private static class SemiJoinRemovalIfNoStatsProc implements NodeProcessor {