diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/AbstractFilterStringColLikeStringScalar.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/AbstractFilterStringColLikeStringScalar.java index 272ff9c..b70beef 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/AbstractFilterStringColLikeStringScalar.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/AbstractFilterStringColLikeStringScalar.java @@ -24,10 +24,13 @@ import java.nio.charset.Charset; import java.nio.charset.CharsetDecoder; import java.nio.charset.CodingErrorAction; +import java.util.ArrayList; import java.util.List; +import java.util.StringTokenizer; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.commons.lang.ArrayUtils; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; @@ -218,8 +221,8 @@ public String getOutputType() { /** * Matches the whole string to its pattern. */ - protected static class NoneChecker implements Checker { - byte [] byteSub; + protected static final class NoneChecker implements Checker { + final byte [] byteSub; NoneChecker(String pattern) { try { @@ -246,8 +249,8 @@ public boolean check(byte[] byteS, int start, int len) { /** * Matches the beginning of each string to a pattern. */ - protected static class BeginChecker implements Checker { - byte[] byteSub; + protected static final class BeginChecker implements Checker { + final byte[] byteSub; BeginChecker(String pattern) { try { @@ -258,23 +261,20 @@ public boolean check(byte[] byteS, int start, int len) { } public boolean check(byte[] byteS, int start, int len) { + int lenSub = byteSub.length; if (len < byteSub.length) { return false; } - for (int i = start, j = 0; j < byteSub.length; i++, j++) { - if (byteS[i] != byteSub[j]) { - return false; - } - } - return true; + return StringExpr.equal(byteSub, 0, lenSub, byteS, start, lenSub); } } /** * Matches the ending of each string to its pattern. */ - protected static class EndChecker implements Checker { - byte[] byteSub; + protected static final class EndChecker implements Checker { + final byte[] byteSub; + EndChecker(String pattern) { try { byteSub = pattern.getBytes("UTF-8"); @@ -288,21 +288,16 @@ public boolean check(byte[] byteS, int start, int len) { if (len < lenSub) { return false; } - for (int i = start + len - lenSub, j = 0; j < lenSub; i++, j++) { - if (byteS[i] != byteSub[j]) { - return false; - } - } - return true; + return StringExpr.equal(byteSub, 0, lenSub, byteS, start + len - lenSub, lenSub); } } /** * Matches the middle of each string to its pattern. */ - protected static class MiddleChecker implements Checker { - byte[] byteSub; - int lenSub; + protected static final class MiddleChecker implements Checker { + final byte[] byteSub; + final int lenSub; MiddleChecker(String pattern) { try { @@ -314,25 +309,134 @@ public boolean check(byte[] byteS, int start, int len) { } public boolean check(byte[] byteS, int start, int len) { + return index(byteS, start, len) != -1; + } + + /* + * Returns absolute offset of the match + */ + public int index(byte[] byteS, int start, int len) { if (len < lenSub) { - return false; + return -1; } int end = start + len - lenSub + 1; - boolean match = false; for (int i = start; i < end; i++) { - match = true; - for (int j = 0; j < lenSub; j++) { - if (byteS[i + j] != byteSub[j]) { - match = false; - break; - } + if (StringExpr.equal(byteSub, 0, lenSub, byteS, i, lenSub)) { + return i; } - if (match) { - return true; + } + return -1; + } + } + + /** + * Matches a chained sequence of checkers. + * + * This has 4 chain scenarios cases in it (has no escaping or single char wildcards) + * + * 1) anchored left "abc%def%" + * 2) anchored right "%abc%def" + * 3) unanchored "%abc%def%" + * 4) anchored on both sides "abc%def" + */ + protected static final class ChainedChecker implements Checker { + + final int minLen; + final BeginChecker begin; + final EndChecker end; + final MiddleChecker[] middle; + final int[] midLens; + final int beginLen; + final int endLen; + + ChainedChecker(String pattern) { + final StringTokenizer tokens = new StringTokenizer(pattern, "%"); + final boolean leftAnchor = pattern.startsWith("%") == false; + final boolean rightAnchor = pattern.endsWith("%") == false; + int len = 0; + // at least 2 checkers always + BeginChecker left = null; + EndChecker right = null; + int leftLen = 0; // not -1 + int rightLen = 0; // not -1 + final List checkers = new ArrayList(2); + final List lengths = new ArrayList(2); + + for (int i = 0; tokens.hasMoreTokens(); i++) { + String chunk = tokens.nextToken(); + if (chunk.length() == 0) { + // %% is folded in the .*?.*? regex usually into .*? + continue; + } + len += utf8Length(chunk); + if (leftAnchor && i == 0) { + // first item + left = new BeginChecker(chunk); + leftLen = utf8Length(chunk); + } else if (rightAnchor && tokens.hasMoreTokens() == false) { + // last item + right = new EndChecker(chunk); + rightLen = utf8Length(chunk); + } else { + // middle items in order + checkers.add(new MiddleChecker(chunk)); + lengths.add(utf8Length(chunk)); + } + } + midLens = ArrayUtils.toPrimitive(lengths.toArray(ArrayUtils.EMPTY_INTEGER_OBJECT_ARRAY)); + middle = checkers.toArray(new MiddleChecker[0]); + minLen = len; + begin = left; + end = right; + beginLen = leftLen; + endLen = rightLen; + } + + public boolean check(byte[] byteS, final int start, final int len) { + int pos = start; + int mark = len; + if (len < minLen) { + return false; + } + // prefix, extend start + if (begin != null && false == begin.check(byteS, pos, mark)) { + // no match + return false; + } else { + pos += beginLen; + mark -= beginLen; + } + // suffix, reduce len + if (end != null && false == end.check(byteS, pos, mark)) { + // no match + return false; + } else { + // no pos change - no need since we've shrunk the string with same pos + mark -= endLen; + } + // loop for middles + for (int i = 0; i < middle.length; i++) { + int index = middle[i].index(byteS, pos, mark); + if (index == -1) { + // no match + return false; + } else { + mark -= ((index-pos) + midLens[i]); + pos = index + midLens[i]; } } - return match; + // if all is good + return true; } + + private int utf8Length(String chunk) { + try { + return chunk.getBytes("UTF-8").length; + } catch (UnsupportedEncodingException ue) { + throw new RuntimeException(ue); + } + } + } /** diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/FilterStringColLikeStringScalar.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/FilterStringColLikeStringScalar.java index c03c34e..0b279c7 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/FilterStringColLikeStringScalar.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/FilterStringColLikeStringScalar.java @@ -38,6 +38,7 @@ new EndCheckerFactory(), new MiddleCheckerFactory(), new NoneCheckerFactory(), + new ChainedCheckerFactory(), new ComplexCheckerFactory()); public FilterStringColLikeStringScalar() { @@ -119,6 +120,23 @@ public Checker tryCreate(String pattern) { } /** + * Accepts chained LIKE patterns without escaping like "abc%def%ghi%" and creates corresponding + * checkers. + * + */ + private static class ChainedCheckerFactory implements CheckerFactory { + private static final Pattern CHAIN_PATTERN = Pattern.compile("(%?[^%_\\\\]+%?)+"); + + public Checker tryCreate(String pattern) { + Matcher matcher = CHAIN_PATTERN.matcher(pattern); + if (matcher.matches()) { + return new ChainedChecker(pattern); + } + return null; + } + } + + /** * Accepts any LIKE patterns and creates corresponding checkers. */ private static class ComplexCheckerFactory implements CheckerFactory { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringExpressions.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringExpressions.java index a51837e..5c323ba 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringExpressions.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringExpressions.java @@ -18,8 +18,13 @@ package org.apache.hadoop.hive.ql.exec.vector.expressions; +import static org.junit.Assert.assertEquals; + import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.Random; +import java.util.StringTokenizer; import junit.framework.Assert; @@ -55,15 +60,23 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.StringGroupColLessStringGroupColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.StringScalarEqualStringGroupColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.VarCharScalarEqualStringGroupColumn; +import org.apache.hadoop.hive.ql.exec.vector.util.VectorizedRowGroupGenUtil; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFLike; +import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.Text; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Test vectorized expression and filter evaluation for strings. */ public class TestVectorStringExpressions { + private static final Logger LOG = LoggerFactory + .getLogger(TestVectorStringExpressions.class); + private static byte[] red; private static byte[] redred; private static byte[] red2; // second copy of red, different object @@ -99,7 +112,7 @@ mixedUpLower = "mixedup".getBytes("UTF-8"); mixedUpUpper = "MIXEDUP".getBytes("UTF-8"); mixPercentPattern = "mix%".getBytes("UTF-8"); // for use as wildcard pattern to test LIKE - multiByte = new byte[100]; + multiByte = new byte[10]; addMultiByteChars(multiByte); blanksLeft = " foo".getBytes("UTF-8"); blanksRight = "foo ".getBytes("UTF-8"); @@ -4237,50 +4250,179 @@ public void testStringLike() throws HiveException { Assert.assertEquals(initialBatchSize, batch.size); } + @Test public void testStringLikePatternType() throws UnsupportedEncodingException, HiveException { FilterStringColLikeStringScalar expr; + VectorizedRowBatch vrb = VectorizedRowGroupGenUtil.getVectorizedRowBatch(1, 1, 1); + vrb.cols[0] = new BytesColumnVector(1); + BytesColumnVector bcv = (BytesColumnVector) vrb.cols[0]; + vrb.size = 0; // BEGIN pattern expr = new FilterStringColLikeStringScalar(0, "abc%".getBytes()); + expr.evaluate(vrb); Assert.assertEquals(FilterStringColLikeStringScalar.BeginChecker.class, expr.checker.getClass()); // END pattern expr = new FilterStringColLikeStringScalar(0, "%abc".getBytes("UTF-8")); + expr.evaluate(vrb); Assert.assertEquals(FilterStringColLikeStringScalar.EndChecker.class, expr.checker.getClass()); // MIDDLE pattern expr = new FilterStringColLikeStringScalar(0, "%abc%".getBytes()); + expr.evaluate(vrb); Assert.assertEquals(FilterStringColLikeStringScalar.MiddleChecker.class, expr.checker.getClass()); - // COMPLEX pattern + // CHAIN pattern expr = new FilterStringColLikeStringScalar(0, "%abc%de".getBytes()); + expr.evaluate(vrb); + Assert.assertEquals(FilterStringColLikeStringScalar.ChainedChecker.class, + expr.checker.getClass()); + + // COMPLEX pattern + expr = new FilterStringColLikeStringScalar(0, "%abc_%de".getBytes()); + expr.evaluate(vrb); Assert.assertEquals(FilterStringColLikeStringScalar.ComplexChecker.class, expr.checker.getClass()); // NONE pattern expr = new FilterStringColLikeStringScalar(0, "abc".getBytes()); + expr.evaluate(vrb); Assert.assertEquals(FilterStringColLikeStringScalar.NoneChecker.class, expr.checker.getClass()); } - public void testStringLikeMultiByte() throws HiveException { + @Test + public void testStringLikeMultiByte() throws HiveException, UnsupportedEncodingException { FilterStringColLikeStringScalar expr; VectorizedRowBatch batch; // verify that a multi byte LIKE expression matches a matching string batch = makeStringBatchMixedCharSize(); - expr = new FilterStringColLikeStringScalar(0, ("%" + multiByte + "%").getBytes()); + expr = new FilterStringColLikeStringScalar(0, ('%' + new String(multiByte) + '%').getBytes(StandardCharsets.UTF_8)); expr.evaluate(batch); - Assert.assertEquals(batch.size, 1); + Assert.assertEquals(1, batch.size); // verify that a multi byte LIKE expression doesn't match a non-matching string batch = makeStringBatchMixedCharSize(); - expr = new FilterStringColLikeStringScalar(0, ("%" + multiByte + "x").getBytes()); + expr = new FilterStringColLikeStringScalar(0, ('%' + new String(multiByte) + 'x').getBytes(StandardCharsets.UTF_8)); expr.evaluate(batch); - Assert.assertEquals(batch.size, 0); + Assert.assertEquals(0, batch.size); + } + + private String randomizePattern(Random control, String value) { + switch (control.nextInt(10)) { + default: + case 0: { + return value; + } + case 1: { + return control.nextInt(1000) + value; + } + case 2: { + return value + control.nextInt(1000); + } + case 3: { + return control.nextInt(1000) + value.substring(1); + } + case 4: { + return value.substring(1) + control.nextInt(1000); + } + case 5: { + return control.nextInt(1000) + value.substring(0, value.length() - 1); + } + case 6: { + return ""; + } + case 7: { + return value.toLowerCase(); + } + case 8: { + StringBuffer sb = new StringBuffer(8); + for (int i = 0; i < control.nextInt(12); i++) { + sb.append((char) ('a' + control.nextInt(26))); + } + return sb.toString(); + } + case 9: { + StringBuffer sb = new StringBuffer(8); + for (int i = 0; i < control.nextInt(12); i++) { + sb.append((char) ('A' + control.nextInt(26))); + } + return sb.toString(); + } + } + } + + private String generateCandidate(Random control, String pattern) { + StringBuffer sb = new StringBuffer(); + final StringTokenizer tokens = new StringTokenizer(pattern, "%"); + final boolean leftAnchor = pattern.startsWith("%"); + final boolean rightAnchor = pattern.endsWith("%"); + for (int i = 0; tokens.hasMoreTokens(); i++) { + String chunk = tokens.nextToken(); + if (leftAnchor && i == 0) { + // first item + sb.append(randomizePattern(control, chunk)); + } else if (rightAnchor && tokens.hasMoreTokens() == false) { + // last item + sb.append(randomizePattern(control, chunk)); + } else { + // middle item + sb.append(randomizePattern(control, chunk)); + } + } + return sb.toString(); + } + + @Test + public void testStringLikeRandomized() throws HiveException, UnsupportedEncodingException { + final String [] patterns = new String[] { + "ABC%", + "%ABC", + "%ABC%", + "ABC%DEF", + "ABC%DEF%", + "%ABC%DEF", + "%ABC%DEF%", + "ABC%DEF%EFG", + "%ABC%DEF%EFG", + "%ABC%DEF%EFG%H", + }; + long positive = 0; + long negative = 0; + Random control = new Random(1234); + UDFLike udf = new UDFLike(); + for (String pattern : patterns) { + VectorExpression expr = new FilterStringColLikeStringScalar(0, pattern.getBytes("utf-8")); + VectorizedRowBatch batch = VectorizedRowGroupGenUtil.getVectorizedRowBatch(1, 1, 1); + batch.cols[0] = new BytesColumnVector(1); + BytesColumnVector bcv = (BytesColumnVector) batch.cols[0]; + + Text pText = new Text(pattern); + for (int i=0; i < 1024; i++) { + String input = generateCandidate(control,pattern); + BooleanWritable like = udf.evaluate(new Text(input), pText); + batch.reset(); + bcv.initBuffer(); + byte[] utf8 = input.getBytes("utf-8"); + bcv.setVal(0, utf8, 0, utf8.length); + bcv.noNulls = true; + batch.size = 1; + expr.evaluate(batch); + if (like.get()) { + positive++; + } else { + negative++; + } + assertEquals(String.format("Checking '%s' against '%s'", input, pattern), like.get(), (batch.size != 0)); + } + } + LOG.info(String.format("Randomized testing: ran %d positive tests and %d negative tests", + positive, negative)); } @Test