diff --git ql/src/java/org/apache/hadoop/hive/ql/lib/RuleExactMatch.java ql/src/java/org/apache/hadoop/hive/ql/lib/RuleExactMatch.java index 5e5c054..6f7962e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/lib/RuleExactMatch.java +++ ql/src/java/org/apache/hadoop/hive/ql/lib/RuleExactMatch.java @@ -30,7 +30,7 @@ public class RuleExactMatch implements Rule { private final String ruleName; - private final String pattern; + private final String[] pattern; /** * The rule specified as operator names separated by % symbols, the left side represents the @@ -45,7 +45,7 @@ * @param regExp * string specification of the rule **/ - public RuleExactMatch(String ruleName, String pattern) { + public RuleExactMatch(String ruleName, String[] pattern) { this.ruleName = ruleName; this.pattern = pattern; } @@ -62,23 +62,24 @@ public RuleExactMatch(String ruleName, String pattern) { * @return cost of the function * @throws SemanticException */ + @Override public int cost(Stack stack) throws SemanticException { int numElems = (stack != null ? stack.size() : 0); - String name = new String(); - for (int pos = numElems - 1; pos >= 0; pos--) { - name = stack.get(pos).getName() + "%" + name; + if (numElems != pattern.length) { + return -1; } - - if (pattern.equals(name)) { - return 1; + for (int pos = numElems - 1; pos >= 0; pos--) { + if(!stack.get(pos).getName().equals(pattern[pos])) { + return -1; + } } - - return -1; + return numElems; } /** * @return the name of the Node **/ + @Override public String getName() { return ruleName; } diff --git ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java index ddc96c2..2470ac2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java +++ ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java @@ -18,6 +18,9 @@ package org.apache.hadoop.hive.ql.lib; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import java.util.Stack; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -31,9 +34,54 @@ public class RuleRegExp implements Rule { private final String ruleName; - private final Pattern pattern; + private final Pattern patternWithWildCardChar; + private final String patternWithoutWildCardChar; + private String[] patternORWildChar; + private static final Set wildCards = new HashSet(Arrays.asList( + '[', '^', '$', '*', ']', '+', '|', '(', '\\', '.', '?', ')', '&')); /** + * The function iterates through the list of wild card characters and sees if + * this regular expression contains a wild card character. + * + * @param pattern + * pattern expressed as a regular Expression + */ + private static boolean patternHasWildCardChar(String pattern) { + if (pattern == null) { + return false; + } + for (char pc : pattern.toCharArray()) { + if (wildCards.contains(pc)) { + return true; + } + } + return false; + } + + /** + * The function iterates through the list of wild card characters and sees if + * this regular expression contains a wild card character. + * + * @param pattern + * pattern expressed as a regular Expression + */ + private static boolean patternHasOnlyWildCardChar(String pattern, char wcc) { + if (pattern == null) { + return false; + } + boolean ret = true; + boolean hasWildCard = false; + for (char pc : pattern.toCharArray()) { + if (wildCards.contains(pc)) { + hasWildCard = true; + ret = ret && (pc == wcc); + } + } + return ret && hasWildCard; + } + + /** * The rule specified by the regular expression. Note that, the regular * expression is specified in terms of Node name. For eg: TS.*RS -> means * TableScan Node followed by anything any number of times followed by @@ -46,33 +94,156 @@ **/ public RuleRegExp(String ruleName, String regExp) { this.ruleName = ruleName; - pattern = Pattern.compile(regExp); + + if (patternHasWildCardChar(regExp)) { + if (patternHasOnlyWildCardChar(regExp, '|')) { + this.patternWithWildCardChar = null; + this.patternWithoutWildCardChar = null; + this.patternORWildChar = regExp.split("\\|"); + } else { + this.patternWithWildCardChar = Pattern.compile(regExp); + this.patternWithoutWildCardChar = null; + this.patternORWildChar = null; + } + } else { + this.patternWithWildCardChar = null; + this.patternWithoutWildCardChar = regExp; + this.patternORWildChar = null; + } } /** - * This function returns the cost of the rule for the specified stack. Lower - * the cost, the better the rule is matched - * + * This function returns the cost of the rule for the specified stack when the pattern + * matched for has no wildcard character in it. The function expects patternWithoutWildCardChar + * to be not null. * @param stack * Node stack encountered so far * @return cost of the function * @throws SemanticException */ - @Override - public int cost(Stack stack) throws SemanticException { + private int costPatternWithoutWildCardChar(Stack stack) throws SemanticException { + int numElems = (stack != null ? stack.size() : 0); + String name = new String(""); + int patLen = patternWithoutWildCardChar.length(); + + for (int pos = numElems - 1; pos >= 0; pos--) { + name = stack.get(pos).getName() + "%" + name; + if (name.length() >= patLen) { + if (patternWithoutWildCardChar.equals(name)) { + return patLen; + } else { + return -1; + } + } + } + return -1; + } + + /** + * This function returns the cost of the rule for the specified stack when the pattern + * matched for has only OR wildcard character in it. The function expects patternORWildChar + * to be not null. + * @param stack + * Node stack encountered so far + * @return cost of the function + * @throws SemanticException + */ + private int costPatternWithORWildCardChar(Stack stack) throws SemanticException { int numElems = (stack != null ? stack.size() : 0); + for (String pattern : patternORWildChar) { + String name = new String(""); + int patLen = pattern.length(); + + for (int pos = numElems - 1; pos >= 0; pos--) { + name = stack.get(pos).getName() + "%" + name; + if (name.length() >= patLen) { + if (pattern.equals(name)) { + return patLen; + } else { + break; + } + } + } + } + return -1; + } + + /** + * This function returns the cost of the rule for the specified stack when the pattern + * matched for has wildcard character in it. The function expects patternWithWildCardChar + * to be not null. + * + * @param stack + * Node stack encountered so far + * @return cost of the function + * @throws SemanticException + */ + private int costPatternWithWildCardChar(Stack stack) throws SemanticException { + int numElems = (stack != null ? stack.size() : 0); String name = ""; + Matcher m = patternWithWildCardChar.matcher(""); for (int pos = numElems - 1; pos >= 0; pos--) { name = stack.get(pos).getName() + "%" + name; - Matcher m = pattern.matcher(name); + m.reset(name); if (m.matches()) { - return m.group().length(); + return name.length(); } } return -1; } /** + * Returns true if the rule pattern is valid and has wild character in it. + */ + boolean rulePatternIsValidWithWildCardChar() { + return patternWithoutWildCardChar == null && patternWithWildCardChar != null && this.patternORWildChar == null; + } + + /** + * Returns true if the rule pattern is valid and has wild character in it. + */ + boolean rulePatternIsValidWithoutWildCardChar() { + return patternWithWildCardChar == null && patternWithoutWildCardChar != null && this.patternORWildChar == null; + } + + /** + * Returns true if the rule pattern is valid and has wild character in it. + */ + boolean rulePatternIsValidWithORWildCardChar() { + return patternWithoutWildCardChar == null && patternWithWildCardChar == null && this.patternORWildChar != null; + } + + /** + * This function returns the cost of the rule for the specified stack. Lower + * the cost, the better the rule is matched + * + * @param stack + * Node stack encountered so far + * @return cost of the function + * @throws SemanticException + */ + @Override + public int cost(Stack stack) throws SemanticException { + if (rulePatternIsValidWithoutWildCardChar()) { + return costPatternWithoutWildCardChar(stack); + } + if (rulePatternIsValidWithWildCardChar()) { + return costPatternWithWildCardChar(stack); + } + if (rulePatternIsValidWithORWildCardChar()) { + return costPatternWithORWildCardChar(stack); + } + // If we reached here, either : + // 1. patternWithWildCardChar and patternWithoutWildCardChar are both nulls. + // 2. patternWithWildCardChar and patternWithoutWildCardChar are both not nulls. + // This is an internal error and we should not let this happen, so throw an exception. + throw new SemanticException ( + "Rule pattern is invalid for " + getName() + " : patternWithWildCardChar = " + + patternWithWildCardChar + " patternWithoutWildCardChar = " + + patternWithoutWildCardChar); + } + + /** * @return the name of the Node **/ @Override diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/PrunerUtils.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/PrunerUtils.java index 108177e..5d375f6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/PrunerUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/PrunerUtils.java @@ -35,7 +35,9 @@ import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleExactMatch; import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.lib.TypeRule; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; @@ -76,9 +78,8 @@ public static void walkOperatorTree(ParseContext pctx, NodeProcessorCtx opWalker String tsOprName = TableScanOperator.getOperatorName(); String filtOprName = FilterOperator.getOperatorName(); - opRules.put(new RuleRegExp("R1", new StringBuilder().append("(").append(tsOprName).append("%") - .append(filtOprName).append("%)|(").append(tsOprName).append("%").append(filtOprName) - .append("%").append(filtOprName).append("%)").toString()), filterProc); + opRules.put(new RuleExactMatch("R1", new String[] {tsOprName, filtOprName, filtOprName}), filterProc); + opRules.put(new RuleExactMatch("R2", new String[] {tsOprName, filtOprName}), filterProc); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along @@ -111,10 +112,9 @@ public static void walkOperatorTree(ParseContext pctx, NodeProcessorCtx opWalker // the operator stack. The dispatcher // generates the plan from the operator tree Map exprRules = new LinkedHashMap(); - exprRules.put(new RuleRegExp("R1", ExprNodeColumnDesc.class.getName() + "%"), colProc); - exprRules.put(new RuleRegExp("R2", ExprNodeFieldDesc.class.getName() + "%"), fieldProc); - exprRules.put(new RuleRegExp("R5", ExprNodeGenericFuncDesc.class.getName() + "%"), - genFuncProc); + exprRules.put(new TypeRule(ExprNodeColumnDesc.class) , colProc); + exprRules.put(new TypeRule(ExprNodeFieldDesc.class), fieldProc); + exprRules.put(new TypeRule(ExprNodeGenericFuncDesc.class), genFuncProc); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/BucketingSortingInferenceOptimizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/BucketingSortingInferenceOptimizer.java index f370d4d..a6b8d54 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/BucketingSortingInferenceOptimizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/BucketingSortingInferenceOptimizer.java @@ -104,10 +104,10 @@ private void inferBucketingSorting(List mapRedTasks) throws Semantic BucketingSortingOpProcFactory.getSelProc()); // Matches only GroupByOperators which are reducers, rather than map group by operators, // or multi group by optimization specific operators - opRules.put(new RuleExactMatch("R2", GroupByOperator.getOperatorName() + "%"), + opRules.put(new RuleExactMatch("R2", new String[]{GroupByOperator.getOperatorName()}), BucketingSortingOpProcFactory.getGroupByProc()); // Matches only JoinOperators which are reducers, rather than map joins, SMB map joins, etc. - opRules.put(new RuleExactMatch("R3", JoinOperator.getOperatorName() + "%"), + opRules.put(new RuleExactMatch("R3", new String[]{JoinOperator.getOperatorName()}), BucketingSortingOpProcFactory.getJoinProc()); opRules.put(new RuleRegExp("R5", FileSinkOperator.getOperatorName() + "%"), BucketingSortingOpProcFactory.getFileSinkProc()); @@ -126,8 +126,8 @@ private void inferBucketingSorting(List mapRedTasks) throws Semantic BucketingSortingOpProcFactory.getForwardProc()); // Matches only ForwardOperators which are reducers and are followed by GroupByOperators // (specific to the multi group by optimization) - opRules.put(new RuleExactMatch("R12", ForwardOperator.getOperatorName() + "%" + - GroupByOperator.getOperatorName() + "%"), + opRules.put(new RuleExactMatch("R12",new String[]{ ForwardOperator.getOperatorName(), + GroupByOperator.getOperatorName()}), BucketingSortingOpProcFactory.getMultiGroupByProc()); // The dispatcher fires the processor corresponding to the closest matching rule and passes diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/ASTNode.java ql/src/java/org/apache/hadoop/hive/ql/parse/ASTNode.java index c8dbe97..d2f2027 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/ASTNode.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/ASTNode.java @@ -31,7 +31,7 @@ */ public class ASTNode extends CommonTree implements Node,Serializable { private static final long serialVersionUID = 1L; - + private transient String str; private transient ASTNodeOrigin origin; public ASTNode() { @@ -81,6 +81,7 @@ public Tree dupNode() { * * @see org.apache.hadoop.hive.ql.lib.Node#getName() */ + @Override public String getName() { return (Integer.valueOf(super.getToken().getType())).toString(); } @@ -126,4 +127,34 @@ private StringBuilder dump(StringBuilder sb, String ws) { } return sb; } + + @Override + public String toStringTree() { + + if (null != str) { + return str; + } + if (getChildCount() == 0) { + str = this.toString(); + return str; + } + StringBuilder buf = new StringBuilder(); + if ( !isNil() ) { + buf.append("("); + buf.append(this.toString()); + buf.append(' '); + } + for (int i = 0; i < children.size(); i++) { + Tree t = (Tree)children.get(i); + if ( i>0 ) { + buf.append(' '); + } + buf.append(t.toStringTree()); + } + if ( !isNil() ) { + buf.append(")"); + } + str = buf.toString(); + return str; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/ppd/ExprWalkerProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/ppd/ExprWalkerProcFactory.java index 3a07b17..6a1bef9 100644 --- ql/src/java/org/apache/hadoop/hive/ql/ppd/ExprWalkerProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/ppd/ExprWalkerProcFactory.java @@ -38,7 +38,9 @@ import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleExactMatch; import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.lib.TypeRule; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; @@ -267,14 +269,9 @@ public static ExprWalkerInfo extractPushdownPreds(OpWalkerInfo opContext, // the operator stack. The dispatcher // generates the plan from the operator tree Map exprRules = new LinkedHashMap(); - exprRules.put( - new RuleRegExp("R1", ExprNodeColumnDesc.class.getName() + "%"), - getColumnProcessor()); - exprRules.put( - new RuleRegExp("R2", ExprNodeFieldDesc.class.getName() + "%"), - getFieldProcessor()); - exprRules.put(new RuleRegExp("R3", ExprNodeGenericFuncDesc.class.getName() - + "%"), getGenericFuncProcessor()); + exprRules.put(new TypeRule(ExprNodeColumnDesc.class), getColumnProcessor()); + exprRules.put(new TypeRule(ExprNodeFieldDesc.class), getFieldProcessor()); + exprRules.put(new TypeRule(ExprNodeGenericFuncDesc.class), getGenericFuncProcessor()); // The dispatcher fires the processor corresponding to the closest matching // rule and passes the context along @@ -319,9 +316,9 @@ private static void extractFinalCandidates(ExprNodeDesc expr, assert ctx.getNewToOldExprMap().containsKey(expr); for (int i = 0; i < expr.getChildren().size(); i++) { ctx.getNewToOldExprMap().put( - (ExprNodeDesc) expr.getChildren().get(i), + expr.getChildren().get(i), ctx.getNewToOldExprMap().get(expr).getChildren().get(i)); - extractFinalCandidates((ExprNodeDesc) expr.getChildren().get(i), + extractFinalCandidates(expr.getChildren().get(i), ctx, conf); } return;