diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java index 7614463525262f01375c1336e89a18670862bb7d..e229de9fdf850b08b8ac8204a880e78be33153b9 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java @@ -57,6 +57,8 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.ExprNodeConverter; import org.apache.hadoop.hive.ql.parse.ASTNode; +import org.apache.hadoop.hive.ql.parse.HiveParser; +import org.apache.hadoop.hive.ql.parse.ParseUtils; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import com.google.common.base.Function; @@ -93,15 +95,7 @@ } public static boolean validateASTForUnsupportedTokens(ASTNode ast) { - String astTree = ast.toStringTree(); - // if any of following tokens are present in AST, bail out - String[] tokens = { "TOK_CHARSETLITERAL", "TOK_TABLESPLITSAMPLE" }; - for (String token : tokens) { - if (astTree.contains(token)) { - return false; - } - } - return true; + return ParseUtils.containsTokenOfType(ast, HiveParser.TOK_CHARSETLITERAL, HiveParser.TOK_TABLESPLITSAMPLE); } public static List getProjsFromBelowAsInputRef(final RelNode rel) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseUtils.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseUtils.java index 373429cbf666f1b19828c532aea3c07f08f95e1a..c1029c52a62419f6e90a25148153260d8d710c5f 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseUtils.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/ParseUtils.java @@ -18,15 +18,13 @@ package org.apache.hadoop.hive.ql.parse; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import org.apache.hadoop.hive.common.JavaUtils; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.ErrorMsg; +import org.apache.hadoop.hive.ql.exec.PTFUtils; import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; @@ -35,6 +33,16 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.HashSet; +import java.util.Queue; +import java.util.ArrayDeque; + + /** * Library of utility functions used in the parse code. @@ -216,6 +224,48 @@ public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) return TypeInfoFactory.getDecimalTypeInfo(precision, scale); } + public static boolean containsTokenOfType(ASTNode root, Integer ... tokens) { + final Set tokensToMatch = new HashSet(); + for (Integer tokenTypeToMatch : tokens) { + tokensToMatch.add(tokenTypeToMatch); + } + + return ParseUtils.containsTokenOfType(root, new PTFUtils.Predicate() { + @Override + public boolean apply(ASTNode node) { + return tokensToMatch.contains(node.getType()); + } + }); + } + + public static boolean containsTokenOfType(ASTNode root, PTFUtils.Predicate predicate) { + Set seenNodes = new HashSet(); + Queue queue = new ArrayDeque(); + + // BFS + queue.add(root); + while (!queue.isEmpty()) { + ASTNode current = queue.remove(); + // If the predicate matches, then return true. + // Otherwise visit the next set of nodes that haven't been seen. + if (predicate.apply(current)) { + return true; + } else { + // Guard because ASTNode.getChildren.iterator returns null if no children available (bug). + if (current.getChildCount() > 0) { + for (Node child : current.getChildren()) { + // Set.add returns true if the node has not already been added, false otherwise. + if (seenNodes.add(child)) { + queue.add((ASTNode)child); + } + } + } + } + } + + return false; + } + public static String ensureClassExists(String className) throws SemanticException { if (className == null) {