diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java index 7b2ae40107..456786c240 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java @@ -29,6 +29,7 @@ import java.util.Stack; import org.apache.hadoop.hive.ql.exec.NodeUtils.Function; +import org.apache.hadoop.hive.ql.parse.SemiJoinBranchInfo; import org.apache.hadoop.hive.ql.parse.spark.SparkPartitionPruningSinkOperator; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapJoinDesc; @@ -441,4 +442,63 @@ public static boolean isInBranch(SparkPartitionPruningSinkOperator op) { } return null; } + + public static Set> + findWorkOperatorsAndSemiJoinEdges(Operator start, + final Map rsToSemiJoinBranchInfo, + Set semiJoinOps, Set> terminalOps) { + Set> found = new HashSet<>(); + findWorkOperatorsAndSemiJoinEdges(start, + found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps); + return found; + } + + private static void + findWorkOperatorsAndSemiJoinEdges(Operator start, Set> found, + final Map rsToSemiJoinBranchInfo, + Set semiJoinOps, Set> terminalOps) { + found.add(start); + + if (start.getParentOperators() != null) { + for (Operator parent : start.getParentOperators()) { + if (parent instanceof ReduceSinkOperator) { + continue; + } + if (!found.contains(parent)) { + findWorkOperatorsAndSemiJoinEdges(parent, found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps); + } + } + } + if (start instanceof TerminalOperator) { + // This could be RS1 in semijoin edge which looks like, + // SEL->GBY1->RS1->GBY2->RS2 + boolean semiJoin = false; + if (start.getChildOperators().size() == 1) { + Operator gb2 = start.getChildOperators().get(0); + if (gb2 instanceof GroupByOperator && gb2.getChildOperators().size() == 1) { + Operator rs2 = gb2.getChildOperators().get(0); + if (rs2 instanceof ReduceSinkOperator && (rsToSemiJoinBranchInfo.get(rs2) != null)) { + // Semijoin edge found. Add all the operators to the set + found.add(start); + found.add(gb2); + found.add(rs2); + semiJoinOps.add((ReduceSinkOperator)rs2); + semiJoin = true; + } + } + } + if (!semiJoin) { + terminalOps.add((TerminalOperator)start); + } + return; + } + if (start.getChildOperators() != null) { + for (Operator child : start.getChildOperators()) { + if (!found.contains(child)) { + findWorkOperatorsAndSemiJoinEdges(child, found, rsToSemiJoinBranchInfo, semiJoinOps, terminalOps); + } + } + } + return; + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseContext.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseContext.java index 89121e3c8d..de52191ae6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseContext.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseContext.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.parse; +import com.google.common.collect.Multimap; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.QueryProperties; @@ -34,6 +35,7 @@ import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; import org.apache.hadoop.hive.ql.exec.SelectOperator; import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.TerminalOperator; import org.apache.hadoop.hive.ql.exec.Task; import org.apache.hadoop.hive.ql.hooks.LineageInfo; import org.apache.hadoop.hive.ql.hooks.ReadEntity; @@ -136,6 +138,7 @@ private Map> semiJoinHints; private boolean disableMapJoin; + private Multimap, ReduceSinkOperator> terminalOpToRSMap; public ParseContext() { } @@ -713,4 +716,12 @@ public void setDisableMapJoin(boolean disableMapJoin) { public boolean getDisableMapJoin() { return disableMapJoin; } + + public void setTerminalOpToRSMap(Multimap, ReduceSinkOperator> terminalOpToRSMap) { + this.terminalOpToRSMap = terminalOpToRSMap; + } + + public Multimap, ReduceSinkOperator> getTerminalOpToRSMap() { + return terminalOpToRSMap; + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index cc0bd07f6d..0333ddf1d6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -31,6 +31,8 @@ import java.util.Stack; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Multimap; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.ql.Context; @@ -50,6 +52,7 @@ import org.apache.hadoop.hive.ql.exec.SelectOperator; import org.apache.hadoop.hive.ql.exec.TableScanOperator; import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.TerminalOperator; import org.apache.hadoop.hive.ql.exec.TezDummyStoreOperator; import org.apache.hadoop.hive.ql.exec.UnionOperator; import org.apache.hadoop.hive.ql.exec.tez.TezTask; @@ -202,6 +205,8 @@ protected void optimizeOperatorPlan(ParseContext pCtx, Set inputs, private void runCycleAnalysisForPartitionPruning(OptimizeTezProcContext procCtx, Set inputs, Set outputs) throws SemanticException { + // Semijoins may have created task level cycles, examine those + connectTerminalOps(procCtx.parseContext); boolean cycleFree = false; while (!cycleFree) { cycleFree = true; @@ -317,7 +322,6 @@ private void removeCycleOperator(Set> component, OptimizeTezProcCont + ((DynamicPruningEventDesc) victim.getConf()).getTableScan().toString() + ". Needed to break cyclic dependency"); } - return; } // Tarjan's algo @@ -351,20 +355,25 @@ private void connect(Operator o, AtomicInteger index, Stack> node List> children; if (o instanceof AppMasterEventOperator) { - children = new ArrayList>(); - children.addAll(o.getChildOperators()); + children = new ArrayList<>((o.getChildOperators())); TableScanOperator ts = ((DynamicPruningEventDesc) o.getConf()).getTableScan(); LOG.debug("Adding special edge: " + o.getName() + " --> " + ts.toString()); children.add(ts); - } else if (o instanceof ReduceSinkOperator){ - // semijoin case - children = new ArrayList>(); - children.addAll(o.getChildOperators()); - SemiJoinBranchInfo sjInfo = parseContext.getRsToSemiJoinBranchInfo().get(o); - if (sjInfo != null ) { - TableScanOperator ts = sjInfo.getTsOp(); - LOG.debug("Adding special edge: " + o.getName() + " --> " + ts.toString()); - children.add(ts); + } else if (o instanceof TerminalOperator) { + children = new ArrayList<>((o.getChildOperators())); + for (ReduceSinkOperator rs : parseContext.getTerminalOpToRSMap().get((TerminalOperator)o)) { + // add an edge + LOG.debug("Adding special edge: From terminal op to semijoin edge " + o.getName() + " --> " + rs.toString()); + children.add(rs); + } + if (o instanceof ReduceSinkOperator) { + // semijoin case + SemiJoinBranchInfo sjInfo = parseContext.getRsToSemiJoinBranchInfo().get(o); + if (sjInfo != null) { + TableScanOperator ts = sjInfo.getTsOp(); + LOG.debug("Adding special edge: " + o.getName() + " --> " + ts.toString()); + children.add(ts); + } } } else { children = o.getChildOperators(); @@ -428,7 +437,8 @@ private void semijoinRemovalBasedTransformations(OptimizeTezProcContext procCtx, final boolean dynamicPartitionPruningEnabled = procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_PARTITION_PRUNING); final boolean semiJoinReductionEnabled = dynamicPartitionPruningEnabled && - procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION); + procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION) && + procCtx.parseContext.getRsToSemiJoinBranchInfo().size() != 0; final boolean extendedReductionEnabled = dynamicPartitionPruningEnabled && procCtx.conf.getBoolVar(ConfVars.TEZ_DYNAMIC_PARTITION_PRUNING_EXTENDED); @@ -438,46 +448,31 @@ private void semijoinRemovalBasedTransformations(OptimizeTezProcContext procCtx, } perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Run remove dynamic pruning by size"); - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); if (semiJoinReductionEnabled) { + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); markSemiJoinForDPP(procCtx); - } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Mark certain semijoin edges important based "); - - // Removing semijoin optimization when it may not be beneficial - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); - if (semiJoinReductionEnabled) { - removeSemijoinOptimizationByBenefit(procCtx); - } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove Semijoins based on cost benefits"); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Mark certain semijoin edges important based "); - // Remove any parallel edge between semijoin and mapjoin. - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); - if (semiJoinReductionEnabled) { + // Remove any parallel edge between semijoin and mapjoin. + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); removeSemijoinsParallelToMapJoin(procCtx); - } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove any parallel edge between semijoin and mapjoin"); - - // Remove semijoin optimization if it creates a cycle with mapside joins - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); - if (semiJoinReductionEnabled && procCtx.parseContext.getRsToSemiJoinBranchInfo().size() != 0) { - removeSemiJoinCyclesDueToMapsideJoins(procCtx); - } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove semijoin optimizations if it creates a cycle with mapside join"); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove any parallel edge between semijoin and mapjoin"); - // Remove semijoin optimization if SMB join is created. - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); - if (semiJoinReductionEnabled && procCtx.parseContext.getRsToSemiJoinBranchInfo().size() != 0) { + // Remove semijoin optimization if SMB join is created. + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); removeSemijoinOptimizationFromSMBJoins(procCtx); - } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove semijoin optimizations if needed"); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove semijoin optimizations if needed"); - // Remove bloomfilter if no stats generated - perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); - if (semiJoinReductionEnabled && procCtx.parseContext.getRsToSemiJoinBranchInfo().size() != 0) { + // Remove bloomfilter if no stats generated + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); removeSemiJoinIfNoStats(procCtx); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove bloom filter optimizations if needed"); + + // Removing semijoin optimization when it may not be beneficial + perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER); + removeSemijoinOptimizationByBenefit(procCtx); + perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove Semijoins based on cost benefits"); } - perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Remove bloom filter optimizations if needed"); // after the stats phase we might have some cyclic dependencies that we need // to take care of. @@ -842,107 +837,52 @@ private static void removeSemijoinOptimizationFromSMBJoins( } } - private static class SemiJoinCycleRemovalDueTOMapsideJoinContext implements NodeProcessorCtx { - HashMap,Operator> childParentMap = new HashMap,Operator>(); - } - - private static class SemiJoinCycleRemovalDueToMapsideJoins implements NodeProcessor { + private static class TerminalOpsInfo { + public Set> terminalOps; + public Set rsOps; - @Override - public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, - Object... nodeOutputs) throws SemanticException { - - SemiJoinCycleRemovalDueTOMapsideJoinContext ctx = - (SemiJoinCycleRemovalDueTOMapsideJoinContext) procCtx; - ctx.childParentMap.put((Operator)stack.get(stack.size() - 2), (Operator) nd); - return null; + TerminalOpsInfo(Set> terminalOps, Set rsOps) { + this.terminalOps = terminalOps; + this.rsOps = rsOps; } } - private static void removeSemiJoinCyclesDueToMapsideJoins( - OptimizeTezProcContext procCtx) throws SemanticException { - Map opRules = new LinkedHashMap(); - opRules.put( - new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%" + - MapJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R2", MapJoinOperator.getOperatorName() + "%" + - CommonMergeJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R3", CommonMergeJoinOperator.getOperatorName() + "%" + - MapJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - opRules.put( - new RuleRegExp("R4", CommonMergeJoinOperator.getOperatorName() + "%" + - CommonMergeJoinOperator.getOperatorName() + "%"), - new SemiJoinCycleRemovalDueToMapsideJoins()); - - SemiJoinCycleRemovalDueTOMapsideJoinContext ctx = - new SemiJoinCycleRemovalDueTOMapsideJoinContext(); - Dispatcher disp = new DefaultRuleDispatcher(null, opRules, ctx); - List topNodes = new ArrayList(); - topNodes.addAll(procCtx.parseContext.getTopOps().values()); - GraphWalker ogw = new PreOrderOnceWalker(disp); - ogw.startWalking(topNodes, null); + private void connectTerminalOps(ParseContext pCtx) { + // The map which contains the virtual edges from non-semijoin terminal ops to semjoin RSs. + Multimap, ReduceSinkOperator> terminalOpToRSMap = ArrayListMultimap.create(); - // process the list - ParseContext pCtx = procCtx.parseContext; - for (Operator parentJoin : ctx.childParentMap.keySet()) { - Operator childJoin = ctx.childParentMap.get(parentJoin); + // Map of semijoin RS to work ops to ensure no work is examined more than once. + Map rsToTerminalOpsInfo = new HashMap<>(); - if (parentJoin.getChildOperators().size() == 1) { - continue; + // Get all the terminal ops + for (ReduceSinkOperator rs : pCtx.getRsToSemiJoinBranchInfo().keySet()) { + TerminalOpsInfo terminalOpsInfo = rsToTerminalOpsInfo.get(rs); + if (terminalOpsInfo != null) { + continue; // done with this one } - for (Operator child : parentJoin.getChildOperators()) { - if (!(child instanceof SelectOperator)) { - continue; - } - - while(child.getChildOperators().size() > 0) { - child = child.getChildOperators().get(0); - } - - if (!(child instanceof ReduceSinkOperator)) { - continue; - } - - ReduceSinkOperator rs = ((ReduceSinkOperator) child); - SemiJoinBranchInfo sjInfo = pCtx.getRsToSemiJoinBranchInfo().get(rs); - if (sjInfo == null) { - continue; - } - - TableScanOperator ts = sjInfo.getTsOp(); - // This is a semijoin branch. Find if this is creating a potential - // cycle with childJoin. - for (Operator parent : childJoin.getParentOperators()) { - if (parent == parentJoin) { - continue; - } - - assert parent instanceof ReduceSinkOperator; - while (parent.getParentOperators().size() > 0) { - parent = parent.getParentOperators().get(0); - } - - if (parent == ts) { - // We have a cycle! - if (sjInfo.getIsHint()) { - throw new SemanticException("Removing hinted semijoin as it is creating cycles with mapside joins " + rs + " : " + ts); - } - if (LOG.isDebugEnabled()) { - LOG.debug("Semijoin cycle due to mapjoin. Removing semijoin " - + OperatorUtils.getOpNamePretty(rs) + " - " + OperatorUtils.getOpNamePretty(ts)); - } - GenTezUtils.removeBranch(rs); - GenTezUtils.removeSemiJoinOperator(pCtx, rs, ts); - } + Set workRSOps = new HashSet<>(); + Set> workTerminalOps = new HashSet<>(); + // Get the SEL Op in the semijoin-branch, SEL->GBY1->RS1->GBY2->RS2 + Operator selOp = rs.getParentOperators().get(0) + .getParentOperators().get(0) + .getParentOperators().get(0) + .getParentOperators().get(0); + OperatorUtils.findWorkOperatorsAndSemiJoinEdges(selOp, + pCtx.getRsToSemiJoinBranchInfo(), workRSOps, workTerminalOps); + + TerminalOpsInfo candidate = new TerminalOpsInfo(workTerminalOps, workRSOps); + + // A work may contain multiple semijoin edges, traverse rsOps and add for each + for (ReduceSinkOperator rsFound : workRSOps) { + rsToTerminalOpsInfo.put(rsFound, candidate); + for (TerminalOperator terminalOp : candidate.terminalOps) { + terminalOpToRSMap.put(terminalOp, rsFound); } } } + + pCtx.setTerminalOpToRSMap(terminalOpToRSMap); } private void removeSemiJoinIfNoStats(OptimizeTezProcContext procCtx) @@ -1558,7 +1498,7 @@ private static double computeBloomFilterNetBenefit( private void removeSemijoinOptimizationByBenefit(OptimizeTezProcContext procCtx) throws SemanticException { - List semijoinRsToRemove = new ArrayList(); + List semijoinRsToRemove = new ArrayList<>(); Map map = procCtx.parseContext.getRsToSemiJoinBranchInfo(); double semijoinReductionThreshold = procCtx.conf.getFloatVar( HiveConf.ConfVars.TEZ_DYNAMIC_SEMIJOIN_REDUCTION_THRESHOLD);