diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SharedWorkOptimizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SharedWorkOptimizer.java index b60512b905..aff5520c7d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SharedWorkOptimizer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/SharedWorkOptimizer.java @@ -909,8 +909,8 @@ private static SharedResult extractSharedOptimizationInfo(ParseContext pctx, } } - discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, discardableInputOps)); - discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, discardableOps)); + discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, + Sets.union(discardableInputOps, discardableOps))); discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, retainableOps, discardableInputOps)); return new SharedResult(retainableOps, discardableOps, discardableInputOps, @@ -947,11 +947,7 @@ private static SharedResult extractSharedOptimizationInfo(ParseContext pctx, .get((TableScanOperator) op); for (Operator dppSource : c) { // Remove the branches - Operator currentOp = dppSource; - while (currentOp.getNumChild() <= 1) { - dppBranches.add(currentOp); - currentOp = currentOp.getParentOperators().get(0); - } + removeBranch(dppSource, dppBranches, ops); } } } @@ -971,11 +967,7 @@ private static SharedResult extractSharedOptimizationInfo(ParseContext pctx, findAscendantWorkOperators(pctx, optimizerCache, dppSource); if (!Collections.disjoint(ascendants, discardedOps)) { // Remove branch - Operator currentOp = dppSource; - while (currentOp.getNumChild() <= 1) { - dppBranches.add(currentOp); - currentOp = currentOp.getParentOperators().get(0); - } + removeBranch(dppSource, dppBranches, ops); } } } @@ -983,6 +975,23 @@ private static SharedResult extractSharedOptimizationInfo(ParseContext pctx, return dppBranches; } + private static void removeBranch(Operator currentOp, Set> branchesOps, + Set> discardableOps) { + if (currentOp.getNumChild() > 1) { + for (Operator childOp : currentOp.getChildOperators()) { + if (!branchesOps.contains(childOp) && !discardableOps.contains(childOp)) { + return; + } + } + } + branchesOps.add(currentOp); + if (currentOp.getParentOperators() != null) { + for (Operator parentOp : currentOp.getParentOperators()) { + removeBranch(parentOp, branchesOps, discardableOps); + } + } + } + private static List> compareAndGatherOps(ParseContext pctx, Operator op1, Operator op2) throws SemanticException { List> result = new ArrayList<>();