diff --git a/ql/src/java/org/apache/hadoop/hive/ql/QueryProperties.java b/ql/src/java/org/apache/hadoop/hive/ql/QueryProperties.java index 3bc9432..8ad555c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/QueryProperties.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/QueryProperties.java @@ -64,6 +64,7 @@ private boolean multiDestQuery; private boolean filterWithSubQuery; + private boolean hasUnion; public boolean isQuery() { @@ -140,6 +141,14 @@ public int getOuterJoinCount() { return noOfOuterJoins; } + public void setHasUnion(boolean hasUnion) { + this.hasUnion = hasUnion; + } + + public boolean hasUnion() { + return hasUnion; + } + public void setHasLateralViews(boolean hasLateralViews) { this.hasLateralViews = hasLateralViews; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java index 4828d70..f8365dd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java @@ -32,6 +32,7 @@ import org.apache.hadoop.mapred.OutputCollector; import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; public class OperatorUtils { @@ -185,6 +186,41 @@ public static void setChildrenCollector(List> c return lastOp; } + public static enum SEARCH_STATUS { + ABSENT, PRESENT, TERMINATED_EARLY, COULDNOT_NAVIGATE + }; + + public static SEARCH_STATUS findOperatorDownStream(Operator start, + String endClass, ImmutableSet classToLookFor) { + SEARCH_STATUS srchStatus = SEARCH_STATUS.ABSENT; + String curOpclassName; + Operator currentOp = null; + + if (start.getChildOperators().size() > 1) { + srchStatus = SEARCH_STATUS.COULDNOT_NAVIGATE; + } else if ((start.getChildOperators().size() == 1)) { + currentOp = start.getChildOperators().get(0); + } + + while (currentOp != null) { + curOpclassName = currentOp.getClass().getName(); + if (endClass.equalsIgnoreCase(curOpclassName)) { + srchStatus = SEARCH_STATUS.PRESENT; + currentOp = null; + } else if (classToLookFor.contains(curOpclassName)) { + srchStatus = SEARCH_STATUS.TERMINATED_EARLY; + currentOp = null; + } else if (currentOp.getChildOperators().size() > 1) { + srchStatus = SEARCH_STATUS.COULDNOT_NAVIGATE; + currentOp = null; + } else { + currentOp = currentOp.getChildOperators().size() == 1 ? currentOp.getChildOperators().get(0) + : null; + } + } + return srchStatus; + } + public static void iterateParents(Operator operator, Function> function) { iterateParents(operator, function, new HashSet>()); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcCtx.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcCtx.java index b18a034..3a9b43f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcCtx.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcCtx.java @@ -108,7 +108,7 @@ public ParseContext getParseContext() { prunList = joinPrunedColLists.get(child).get((byte) tag); } else if (child instanceof UnionOperator) { List positions = unionPrunedColLists.get(child); - if (positions != null && positions.size() > 0) { + if (positions != null) { prunList = new ArrayList<>(); RowSchema oldRS = curOp.getSchema(); for (Integer pos : positions) { @@ -278,7 +278,7 @@ public void handleFilterUnionChildren(Operator curOp) for (Operator child : curOp.getChildOperators()) { if (child instanceof UnionOperator) { prunList = genColLists(curOp, child); - if (prunList == null || prunList.size() == 0 || parentPrunList.size() == prunList.size()) { + if (prunList == null || parentPrunList.size() == prunList.size()) { continue; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java index 78bce23..dc01b69 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ColumnPrunerProcFactory.java @@ -993,13 +993,16 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx ctx, Object.. if (inputSchema != null) { List positions = new ArrayList<>(); RowSchema oldRS = op.getSchema(); + ArrayList newSignature = new ArrayList(); for (int index = 0; index < oldRS.getSignature().size(); index++) { ColumnInfo colInfo = oldRS.getSignature().get(index); if (childColLists.contains(colInfo.getInternalName())) { positions.add(index); + newSignature.add(colInfo); } } cppCtx.getUnionPrunedColLists().put(op, positions); + op.getSchema().setSignature(newSignature); } return null; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/IdentityProjectRemover.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/IdentityProjectRemover.java index 114c683..fa9765c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/IdentityProjectRemover.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/IdentityProjectRemover.java @@ -18,24 +18,43 @@ package org.apache.hadoop.hive.ql.optimizer; +import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Stack; import com.google.common.base.Predicates; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterators; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.CommonJoinOperator; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.FilterOperator; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.JoinOperator; import org.apache.hadoop.hive.ql.exec.LateralViewForwardOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.OperatorUtils; +import org.apache.hadoop.hive.ql.exec.OperatorUtils.SEARCH_STATUS; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.SMBMapJoinOperator; import org.apache.hadoop.hive.ql.exec.SelectOperator; +import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; +import org.apache.hadoop.hive.ql.exec.TemporaryHashSinkOperator; +import org.apache.hadoop.hive.ql.exec.TerminalOperator; +import org.apache.hadoop.hive.ql.exec.UDTFOperator; +import org.apache.hadoop.hive.ql.exec.UnionOperator; +import org.apache.hadoop.hive.ql.exec.vector.VectorFileSinkOperator; import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; import org.apache.hadoop.hive.ql.lib.GraphWalker; @@ -85,15 +104,35 @@ public ParseContext transform(ParseContext pctx) throws SemanticException { // 1. We apply the transformation Map opRules = new LinkedHashMap(); opRules.put(new RuleRegExp("R1", - "(" + SelectOperator.getOperatorName() + "%)"), new ProjectRemover()); + "(" + SelectOperator.getOperatorName() + "%)"), new ProjectRemover(pctx.getQueryProperties().hasUnion())); GraphWalker ogw = new DefaultGraphWalker(new DefaultRuleDispatcher(null, opRules, null)); ArrayList topNodes = new ArrayList(); topNodes.addAll(pctx.getTopOps().values()); ogw.startWalking(topNodes, null); + pctx.getUCtx(); return pctx; } private static class ProjectRemover implements NodeProcessor { + //TODO: Remove this once HIVE-12355 is fixed + static final ImmutableSet opsThatMutateObjInsp = ImmutableSet.of(MapJoinOperator.class + .getName(), SMBMapJoinOperator.class + .getName(), JoinOperator.class + .getName(), FileSinkOperator.class + .getName(), + HashTableSinkOperator.class + .getName(), + ReduceSinkOperator.class.getName(), + SparkHashTableSinkOperator.class + .getName(), SelectOperator.class + .getName(), UDTFOperator.class + .getName(), GroupByOperator.class.getName()); + + private final boolean unionOpPresent; + + ProjectRemover(boolean unionOpPresent) { + this.unionOpPresent = unionOpPresent; + } @Override public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, @@ -121,6 +160,18 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, if ((curParent instanceof PTFOperator)) { return null; } + + // If Union is present then we need to check if obj insp would be set + // correctly in the op pipeline before reaching union. + // TODO: Remove this once we match OutputObjInspectors to RowSchema of operators(HIVE-12355) + if (unionOpPresent && !opsThatMutateObjInsp.contains(curParent.getClass().getName())) { + SEARCH_STATUS status = OperatorUtils.findOperatorDownStream(sel, + UnionOperator.class.getName(), opsThatMutateObjInsp); + if (status == SEARCH_STATUS.PRESENT || status == SEARCH_STATUS.COULDNOT_NAVIGATE) { + return null; + } + } + if ((curParent instanceof FilterOperator) && curParent.getParentOperators() != null) { ancestorList.addAll(curParent.getParentOperators()); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index d2c3a7c..3500ad0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -448,6 +448,8 @@ public void doPhase1QBExpr(ASTNode ast, QBExpr qbexpr, String id, String alias) doPhase1QBExpr((ASTNode) ast.getChild(1), qbexpr2, id + SUBQUERY_TAG_2, alias + SUBQUERY_TAG_2); qbexpr.setQBExpr2(qbexpr2); + + queryProperties.setHasUnion(true); } break; } @@ -1344,6 +1346,7 @@ public boolean doPhase1(ASTNode ast, QB qb, Phase1Ctx ctx_1, PlannerContext plan ErrorMsg.UNION_NOTIN_SUBQ.getMsg())); } skipRecursion = false; + queryProperties.setHasUnion(true); break; case HiveParser.TOK_INSERT: diff --git a/ql/src/test/queries/clientpositive/colpruner1.q b/ql/src/test/queries/clientpositive/colpruner1.q new file mode 100644 index 0000000..164834d --- /dev/null +++ b/ql/src/test/queries/clientpositive/colpruner1.q @@ -0,0 +1,6 @@ +set hive.cbo.enable=false; +set hive.optimize.constant.propagation=false; +set hive.optimize.remove.identity.project=false; + +select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0; +select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0; diff --git a/ql/src/test/queries/clientpositive/colpruner2.q b/ql/src/test/queries/clientpositive/colpruner2.q new file mode 100644 index 0000000..0273d60 --- /dev/null +++ b/ql/src/test/queries/clientpositive/colpruner2.q @@ -0,0 +1,2 @@ +select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0; +select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0; diff --git a/ql/src/test/results/clientpositive/colpruner.q.out b/ql/src/test/results/clientpositive/colpruner.q.out new file mode 100644 index 0000000..bfbd767 --- /dev/null +++ b/ql/src/test/results/clientpositive/colpruner.q.out @@ -0,0 +1,9 @@ +PREHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 diff --git a/ql/src/test/results/clientpositive/colpruner1.q.out b/ql/src/test/results/clientpositive/colpruner1.q.out new file mode 100644 index 0000000..49074e4 --- /dev/null +++ b/ql/src/test/results/clientpositive/colpruner1.q.out @@ -0,0 +1,18 @@ +PREHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 +PREHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 diff --git a/ql/src/test/results/clientpositive/colpruner2.q.out b/ql/src/test/results/clientpositive/colpruner2.q.out new file mode 100644 index 0000000..49074e4 --- /dev/null +++ b/ql/src/test/results/clientpositive/colpruner2.q.out @@ -0,0 +1,18 @@ +PREHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src UNION ALL SELECT key as c1, value as c2 FROM src) x)v1 WHERE v1.c2 = 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 +PREHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select count(*) from (SELECT c1, c2 FROM (SELECT key as c1, value as c2 FROM src UNION ALL SELECT key as c1, CAST(NULL AS INT) AS c2 FROM src) x)v1 WHERE v1.c2 = 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0