diff --git a/ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java b/ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java index ddc96c2..4d37acc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/lib/RuleRegExp.java @@ -31,7 +31,30 @@ public class RuleRegExp implements Rule { private final String ruleName; - private final Pattern pattern; + private final Pattern patternWithWildCardChar; + private final String patternWithoutWildCardChar; + + /** + * The function iterates through the list of wild card characters and sees if + * this regular expression contains a wild card character. + * + * @param pattern + * pattern expressed as a regular Expression + */ + private static boolean patternHasWildCardChar(String pattern) { + if (pattern == null) { + return false; + } + + final char[] wildCards = {'[', '^', '*', ']', '+', '|', '(', '\\', '*', ')'}; + + for (char wc : wildCards) { + if (pattern.indexOf(wc) != -1) { + return true; + } + } + return false; + } /** * The rule specified by the regular expression. Note that, the regular @@ -46,25 +69,59 @@ **/ public RuleRegExp(String ruleName, String regExp) { this.ruleName = ruleName; - pattern = Pattern.compile(regExp); + + if (patternHasWildCardChar(regExp)) { + this.patternWithWildCardChar = Pattern.compile(regExp); + this.patternWithoutWildCardChar = null; + } else { + this.patternWithWildCardChar = null; + this.patternWithoutWildCardChar = regExp; + } } /** - * This function returns the cost of the rule for the specified stack. Lower - * the cost, the better the rule is matched - * + * This function returns the cost of the rule for the specified stack when the pattern + * matched for has no wildcard character in it. The function expects patternWithoutWildCardChar + * to be not null. * @param stack * Node stack encountered so far * @return cost of the function * @throws SemanticException */ - @Override - public int cost(Stack stack) throws SemanticException { + private int costPatternWithoutWildCardChar(Stack stack) throws SemanticException { int numElems = (stack != null ? stack.size() : 0); String name = ""; + int patLen = patternWithoutWildCardChar.length(); + + for (int pos = numElems - 1; pos >= 0; pos--) { + name = stack.get(pos).getName() + "%" + name; + if (name.length() >= patLen) { + if (patternWithoutWildCardChar.equals(name)) { + return patLen; + } else { + return -1; + } + } + } + return -1; + } + + /** + * This function returns the cost of the rule for the specified stack when the pattern + * matched for has wildcard character in it. The function expects patternWithWildCardChar + * to be not null. + * + * @param stack + * Node stack encountered so far + * @return cost of the function + * @throws SemanticException + */ + private int costPatternWithWildCardChar(Stack stack) throws SemanticException { + int numElems = (stack != null ? stack.size() : 0); + String name = ""; for (int pos = numElems - 1; pos >= 0; pos--) { name = stack.get(pos).getName() + "%" + name; - Matcher m = pattern.matcher(name); + Matcher m = patternWithWildCardChar.matcher(name); if (m.matches()) { return m.group().length(); } @@ -73,6 +130,47 @@ public int cost(Stack stack) throws SemanticException { } /** + * Returns true if the rule pattern is valid and has wild character in it. + */ + boolean rulePatternIsValidWithWildCardChar() { + return patternWithoutWildCardChar == null && patternWithWildCardChar != null; + } + + /** + * Returns true if the rule pattern is valid and has wild character in it. + */ + boolean rulePatternIsValidWithoutWildCardChar() { + return patternWithWildCardChar == null && patternWithoutWildCardChar != null; + } + + /** + * This function returns the cost of the rule for the specified stack. Lower + * the cost, the better the rule is matched + * + * @param stack + * Node stack encountered so far + * @return cost of the function + * @throws SemanticException + */ + @Override + public int cost(Stack stack) throws SemanticException { + if (rulePatternIsValidWithoutWildCardChar()) { + return costPatternWithoutWildCardChar(stack); + } + if (rulePatternIsValidWithWildCardChar()) { + return costPatternWithWildCardChar(stack); + } + // If we reached here, either : + // 1. patternWithWildCardChar and patternWithoutWildCardChar are both nulls. + // 2. patternWithWildCardChar and patternWithoutWildCardChar are both not nulls. + // This is an internal error and we should not let this happen, so throw an exception. + throw new SemanticException ( + "Rule pattern is invalid for " + getName() + " : patternWithWildCardChar = " + + patternWithWildCardChar + " patternWithoutWildCardChar = " + + patternWithoutWildCardChar); + } + + /** * @return the name of the Node **/ @Override diff --git a/ql/src/test/org/apache/hadoop/hive/ql/lib/TestRuleRegExp.java b/ql/src/test/org/apache/hadoop/hive/ql/lib/TestRuleRegExp.java new file mode 100644 index 0000000..f06d0df --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/lib/TestRuleRegExp.java @@ -0,0 +1,118 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.lib; + +import static org.junit.Assert.*; + +import java.util.List; +import java.util.Stack; + +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.FilterOperator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.SelectOperator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.junit.Test; + +public class TestRuleRegExp { + + public class TestNode implements Node { + private String name; + + TestNode (String name) { + this.name = name; + } + + @Override + public List getChildren() { + return null; + } + + @Override + public String getName() { + return name; + } + } + + @Test + public void testPatternWithoutWildCardChar() { + String patternStr = + ReduceSinkOperator.getOperatorName() + "%" + + SelectOperator.getOperatorName() + "%" + + FileSinkOperator.getOperatorName() + "%"; + RuleRegExp rule1 = new RuleRegExp("R1", patternStr); + assertEquals(rule1.rulePatternIsValidWithoutWildCardChar(), true); + assertEquals(rule1.rulePatternIsValidWithWildCardChar(), false); + // positive test + Stack ns1 = new Stack(); + ns1.push(new TestNode(ReduceSinkOperator.getOperatorName())); + ns1.push(new TestNode(SelectOperator.getOperatorName())); + ns1.push(new TestNode(FileSinkOperator.getOperatorName())); + try { + assertEquals(rule1.cost(ns1), patternStr.length()); + } catch (SemanticException e) { + fail(e.getMessage()); + } + // negative test + Stack ns2 = new Stack(); + ns2.push(new TestNode(ReduceSinkOperator.getOperatorName())); + ns1.push(new TestNode(TableScanOperator.getOperatorName())); + ns1.push(new TestNode(FileSinkOperator.getOperatorName())); + try { + assertEquals(rule1.cost(ns2), -1); + } catch (SemanticException e) { + fail(e.getMessage()); + } + } + + @Test + public void testPatternWithWildCardChar() { + RuleRegExp rule1 = new RuleRegExp("R1", + "(" + TableScanOperator.getOperatorName() + "%" + + FilterOperator.getOperatorName() + "%)|(" + + TableScanOperator.getOperatorName() + "%" + + FileSinkOperator.getOperatorName() + "%)"); + assertEquals(rule1.rulePatternIsValidWithoutWildCardChar(), false); + assertEquals(rule1.rulePatternIsValidWithWildCardChar(), true); + // positive test + Stack ns1 = new Stack(); + ns1.push(new TestNode(TableScanOperator.getOperatorName())); + ns1.push(new TestNode(FilterOperator.getOperatorName())); + Stack ns2 = new Stack(); + ns2.push(new TestNode(TableScanOperator.getOperatorName())); + ns2.push(new TestNode(FileSinkOperator.getOperatorName())); + try { + assertNotEquals(rule1.cost(ns1), -1); + assertNotEquals(rule1.cost(ns2), -1); + } catch (SemanticException e) { + fail(e.getMessage()); + } + // negative test + Stack ns3 = new Stack(); + ns3.push(new TestNode(ReduceSinkOperator.getOperatorName())); + ns3.push(new TestNode(ReduceSinkOperator.getOperatorName())); + ns3.push(new TestNode(FileSinkOperator.getOperatorName())); + try { + assertEquals(rule1.cost(ns3), -1); + } catch (SemanticException e) { + fail(e.getMessage()); + } + } + +}