diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 4a86b0ae01299c80accaa5ab39f07dfda6e4d5e3..be335d3b79c763a9bacf4e9b602f017521f2b506 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1634,6 +1634,9 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal + "the evaluation of certain joins, since we will not be emitting rows which are thrown away by " + "a Filter operator straight away. However, currently vectorization does not support them, thus " + "enabling it is only recommended when vectorization is disabled."), + HIVE_PTF_RANGECACHE_SIZE("hive.ptf.rangecache.size", 100000, + "Size of the cache used on reducer side, that stores boundaries of ranges within a PTF " + + "partition. Used if a query specifies a RANGE type window including an orderby clause."), // CBO related HIVE_CBO_ENABLED("hive.cbo.enable", true, "Flag to control enabling Cost Based Optimizations using Calcite framework."), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java new file mode 100644 index 0000000000000000000000000000000000000000..46861a8ffff49ab5662b6b7abbb7aa3afbc920ca --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java @@ -0,0 +1,119 @@ +package org.apache.hadoop.hive.ql.exec; + +import java.util.LinkedList; +import java.util.Map; +import java.util.TreeMap; + +/** + * Cache for storing boundaries found within a partition - used for PTF functions. + * Stores key-value pairs where key is the row index in the partition from which a range begins, + * value is the corresponding row value (based on what the user specified in the orderby column). + */ +public class BoundaryCache extends TreeMap { + + private boolean isComplete = false; + private final int maxSize; + private final LinkedList queue = new LinkedList<>(); + + public BoundaryCache(int maxSize) { + if (maxSize <= 1) { + throw new IllegalArgumentException("Cache size of 1 and below it doesn't make sense."); + } + this.maxSize = maxSize; + } + + //True if the last range(s) of the partition are loaded into the cache. + public boolean isComplete() { + return isComplete; + } + + public void setComplete(boolean complete) { + isComplete = complete; + } + + @Override + public Object put(Integer key, Object value) { + if (key == value) { + return null; + } + Object result = super.put(key, value); + //Every new element is added to FIFO too + if (result == null) { + queue.add(key); + } + //If FIFO size reaches maxSize we evict the eldest entry. + if (queue.size() > maxSize) { + evictOne(); + } + return result; + } + + /** + * Puts new key-value pair in cache. + * @param key + * @param value + * @return false if queue was full and put failed. True otherwise. + */ + public Boolean putIfNotFull(Integer key, Object value) { + if ((queue.size() + 1) > maxSize) { + return false; + } else { + put(key, value); + return true; + } + } + + @Override + public void clear() { + this.isComplete = false; + this.queue.clear(); + super.clear(); + } + + /** + * Evicts the older half of cache + */ + public void evictHalf() { + int evictCount = queue.size() / 2; + for (int i = 0; i < evictCount; ++i) { + evictOne(); + } + } + + /** + * Calculates the percentile of the row's group position in the cache, 0 means cache + * beginning, 100 means cache end. + * + * @param pos + * @return + */ + public int approxCachePositionOf(int pos) { + if (size() == 0) { + return 0; + } + Map.Entry floorEntry = floorEntry(pos); + if (floorEntry == null) { + return 100; + } else{ + //Using the fact, that queue is always filled from bottom to top in a partition + return (100 * (queue.indexOf(floorEntry.getKey()) + 1)) / maxSize; + } + } + + /** + * Returns entry corresponding to highest row index. + * @return + */ + public Map.Entry getMaxEntry() { + return floorEntry(Integer.MAX_VALUE); + } + + /** + * Removes eldest entry from the boundary cache. + */ + private void evictOne() { + Integer elementToDelete = queue.poll(); + this.remove(elementToDelete); + } + +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java index f125f9bcfa324364613d29e01a557f1b91cb9ff9..e17068e4d895ab46768550688fdf91a9407e51ae 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java @@ -46,6 +46,7 @@ StructObjectInspector inputOI; StructObjectInspector outputOI; private final PTFRowContainer> elems; + private final BoundaryCache boundaryCache; protected PTFPartition(Configuration cfg, AbstractSerDe serDe, StructObjectInspector inputOI, @@ -70,6 +71,8 @@ protected PTFPartition(Configuration cfg, } else { elems = null; } + int boundaryCacheSize = HiveConf.getIntVar(cfg, ConfVars.HIVE_PTF_RANGECACHE_SIZE); + boundaryCache = boundaryCacheSize > 1 ? new BoundaryCache(boundaryCacheSize) : null; } public void reset() throws HiveException { @@ -262,4 +265,8 @@ public static StructObjectInspector setupPartitionOutputOI(AbstractSerDe serDe, ObjectInspectorCopyOption.WRITABLE); } + public BoundaryCache getBoundaryCache() { + return boundaryCache; + } + } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java index d44604d2eced7e2a1bedb82ac8669e513fb5881c..20dc8628d6a6c0c86d964806b747474b568f25d2 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java @@ -256,6 +256,7 @@ protected static Range getRange(WindowFrameDef winFrame, int currRow, PTFPartiti end = getRowBoundaryEnd(endB, currRow, p); } else { ValueBoundaryScanner vbs = ValueBoundaryScanner.getScanner(winFrame, nullsLast); + vbs.handleCache(currRow, p); start = vbs.computeStart(currRow, p); end = vbs.computeEnd(currRow, p); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java index e633edb96e54af770086616003df6ea63233b993..ca95274ada96770a2f215fe9ea10e53d198c42ab 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java @@ -18,10 +18,15 @@ package org.apache.hadoop.hive.ql.udf.ptf; +import java.util.Map; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.Timestamp; import org.apache.hadoop.hive.common.type.TimestampTZ; +import org.apache.hadoop.hive.ql.exec.BoundaryCache; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order; @@ -44,10 +49,207 @@ public ValueBoundaryScanner(BoundaryDef start, BoundaryDef end, boolean nullsLas this.nullsLast = nullsLast; } + public abstract Object computeValue(Object row) throws HiveException; + + /** + * Checks if the distance of v2 to v1 is greater than the given amt. + * @return True if the value of v1 - v2 is greater than amt or either value is null. + */ + public abstract boolean isDistanceGreater(Object v1, Object v2, int amt); + + /** + * Checks if the values of v1 or v2 are the same. + * @return True if both values are the same or both are nulls. + */ + public abstract boolean isEqual(Object v1, Object v2); + public abstract int computeStart(int rowIdx, PTFPartition p) throws HiveException; public abstract int computeEnd(int rowIdx, PTFPartition p) throws HiveException; + /** + * Checks and maintains cache content - optimizes cache window to always be around current row + * thereby makes it follow the current progress. + * @param rowIdx current row + * @param p current partition for the PTF operator + * @throws HiveException + */ + public void handleCache(int rowIdx, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + if (cache == null) { + return; + } + + //Start of partition + if (rowIdx == 0) { + cache.clear(); + } + if (cache.isComplete()) { + return; + } + + int cachePos = cache.approxCachePositionOf(rowIdx); + + if (cache.isEmpty()) { + fillCacheUntilEndOrFull(rowIdx, p); + } else if (cachePos > 50 && cachePos <= 75) { + if (!start.isPreceding() && end.isFollowing()) { + cache.evictHalf(); + fillCacheUntilEndOrFull(rowIdx, p); + } + } else if (cachePos > 75 && cachePos <= 95) { + if (start.isPreceding() && end.isFollowing()) { + cache.evictHalf(); + fillCacheUntilEndOrFull(rowIdx, p); + } + } else if (cachePos >= 95) { + if (start.isPreceding() && !end.isFollowing()) { + cache.evictHalf(); + fillCacheUntilEndOrFull(rowIdx, p); + } + + } + } + + /** + * Inserts values into cache starting from rowIdx in the current partition p. Stops if cache + * reaches its maximum size or we get out of rows in p. + * @param rowIdx + * @param p + * @throws HiveException + */ + private void fillCacheUntilEndOrFull(int rowIdx, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + if (cache == null || p.size() <= 0) { + return; + } + + //If we continue building cache + Map.Entry ceilingEntry = cache.getMaxEntry(); + if (ceilingEntry != null) { + rowIdx = ceilingEntry.getKey(); + } + + Object rowVal = null; + Object lastRowVal = null; + + while (rowIdx < p.size()) { + rowVal = computeValue(p.getAt(rowIdx)); + if (!isEqual(rowVal, lastRowVal)){ + if (!cache.putIfNotFull(rowIdx, rowVal)){ + break; + } + } + lastRowVal = rowVal; + ++rowIdx; + + } + //Signaling end of all rows in a partition + if (cache.putIfNotFull(rowIdx, null)) { + cache.setComplete(true); + } + } + + /** + * Uses cache content to jump backwards if possible. If not, it steps one back. + * @param r + * @param rowVal + * @param p + * @return pair of (row we stepped/jumped onto ; row value at this position) + * @throws HiveException + */ + protected Pair skipOrStepBack(int r, Object rowVal, PTFPartition p) + throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (floorEntry != null && ceilingEntry != null) { + r = floorEntry.getKey() - 1; + floorEntry = cache.floorEntry(r); + if (floorEntry != null) { + rowVal = floorEntry.getValue(); + } else if (r >= 0){ + rowVal = computeValue(p.getAt(r)); + } + } else { + r--; + if ( r >= 0 ) { + rowVal = computeValue(p.getAt(r)); + } + } + return new ImmutablePair<>(r, rowVal); + } + + /** + * Uses cache content to jump forward if possible. If not, it steps one forward. + * @param r + * @param rowVal + * @param p + * @return pair of (row we stepped/jumped onto ; row value at this position) + * @throws HiveException + */ + protected Pair skipOrStepForward(int r, Object rowVal, PTFPartition p) + throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){ + ceilingEntry = cache.ceilingEntry(r + 1); + } + if (floorEntry != null && ceilingEntry != null) { + r = ceilingEntry.getKey(); + rowVal = ceilingEntry.getValue(); + } else { + r++; + if ( r < p.size() ) { + rowVal = computeValue(p.getAt(r)); + } + } + return new ImmutablePair<>(r, rowVal); + } + + /** + * Uses cache to lookup row value. Computes it on the fly on cache miss. + * @param r + * @param p + * @return + * @throws HiveException + */ + protected Object computeValueUseCache(int r, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){ + ceilingEntry = cache.ceilingEntry(r + 1); + } + if (floorEntry != null && ceilingEntry != null) { + return floorEntry.getValue(); + } else { + return computeValue(p.getAt(r)); + } + } + public static ValueBoundaryScanner getScanner(WindowFrameDef winFrameDef, boolean nullsLast) throws HiveException { OrderDef orderDef = winFrameDef.getOrderDef(); @@ -108,6 +310,7 @@ public SingleValueBoundaryScanner(BoundaryDef start, BoundaryDef end, | | | | | | such that R2.sk - R.sk > amt | |------+----------------+----------------+----------+-------+-----------------------------------| */ + @Override public int computeStart(int rowIdx, PTFPartition p) throws HiveException { switch(start.getDirection()) { @@ -127,18 +330,17 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) { return 0; } - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); if ( sortKey == null ) { // Use Case 3. if (nullsLast || expressionDef.getOrder() == Order.DESC) { while ( sortKey == null && rowIdx >= 0 ) { - --rowIdx; - if ( rowIdx >= 0 ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepBack(rowIdx, sortKey, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } - return rowIdx+1; + return rowIdx + 1; } else { // Use Case 2. if ( expressionDef.getOrder() == Order.ASC ) { @@ -153,36 +355,34 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep // Use Case 4. if ( expressionDef.getOrder() == Order.DESC ) { while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } else { // Use Case 5. while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } + return r + 1; } } protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); // Use Case 6. if ( sortKey == null ) { while ( sortKey == null && rowIdx >= 0 ) { - --rowIdx; - if ( rowIdx >= 0 ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepBack(rowIdx, sortKey, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } - return rowIdx+1; + return rowIdx + 1; } Object rowVal = sortKey; @@ -190,17 +390,16 @@ protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveExce // Use Case 7. while (r >= 0 && isEqual(rowVal, sortKey) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveException { int amt = start.getAmt(); - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); Object rowVal = sortKey; int r = rowIdx; @@ -212,10 +411,9 @@ protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveExcep } else { // Use Case 10. while (r < p.size() && rowVal == null ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -224,19 +422,17 @@ protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveExcep // Use Case 11. if ( expressionDef.getOrder() == Order.DESC) { while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } else { // Use Case 12. while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -292,7 +488,7 @@ protected int computeEndPreceding(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 1. // amt == UNBOUNDED, is caught during translation - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); if ( sortKey == null ) { // Use Case 2. @@ -310,34 +506,31 @@ protected int computeEndPreceding(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 4. if ( expressionDef.getOrder() == Order.DESC ) { while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } else { // Use Case 5. while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } } protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); // Use Case 6. if ( sortKey == null ) { while ( sortKey == null && rowIdx < p.size() ) { - ++rowIdx; - if ( rowIdx < p.size() ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepForward(rowIdx, sortKey, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } return rowIdx; } @@ -347,10 +540,9 @@ protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveExcept // Use Case 7. while (r < p.size() && isEqual(sortKey, rowVal) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -362,7 +554,7 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) { return p.size(); } - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); Object rowVal = sortKey; int r = rowIdx; @@ -374,10 +566,9 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti } else { // Use Case 10. while (r < p.size() && rowVal == null ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -386,19 +577,17 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 11. if ( expressionDef.getOrder() == Order.DESC) { while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } else { // Use Case 12. while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -717,15 +906,14 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep } protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object[] sortKey = computeValues(p.getAt(rowIdx)); - Object[] rowVal = sortKey; + Object sortKey = computeValue(p.getAt(rowIdx)); + Object rowVal = sortKey; int r = rowIdx; while (r >= 0 && isEqual(rowVal, sortKey) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValues(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } @@ -741,6 +929,7 @@ protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveExce | 2. | FOLLOWING | UNB | ANY | ANY | end = partition.size() | |------+----------------+---------------+----------+-------+-----------------------------------| */ + @Override public int computeEnd(int rowIdx, PTFPartition p) throws HiveException { switch(end.getDirection()) { @@ -756,15 +945,14 @@ public int computeEnd(int rowIdx, PTFPartition p) throws HiveException { } protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object[] sortKey = computeValues(p.getAt(rowIdx)); - Object[] rowVal = sortKey; + Object sortKey = computeValue(p.getAt(rowIdx)); + Object rowVal = sortKey; int r = rowIdx; while (r < p.size() && isEqual(sortKey, rowVal) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValues(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, rowVal, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -778,7 +966,8 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti "FOLLOWING needs UNBOUNDED for RANGE with multiple expressions in ORDER BY"); } - public Object[] computeValues(Object row) throws HiveException { + @Override + public Object computeValue(Object row) throws HiveException { Object[] objs = new Object[orderDef.getExpressions().size()]; for (int i = 0; i < objs.length; i++) { Object o = orderDef.getExpressions().get(i).getExprEvaluator().evaluate(row); @@ -787,7 +976,11 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti return objs; } - public boolean isEqual(Object[] v1, Object[] v2) { + @Override + public boolean isEqual(Object val1, Object val2) { + Object[] v1 = (Object[]) val1; + Object[] v2 = (Object[]) val2; + assert v1.length == v2.length; for (int i = 0; i < v1.length; i++) { if (v1[i] == null && v2[i] == null) { @@ -804,5 +997,10 @@ public boolean isEqual(Object[] v1, Object[] v2) { } return true; } + + @Override + public boolean isDistanceGreater(Object v1, Object v2, int amt) { + throw new UnsupportedOperationException("Only unbounded ranges supported"); + } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java new file mode 100644 index 0000000000000000000000000000000000000000..4585cbab8a2f3aca5bb10f3ff78b7f39aed36b1e --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.ptf; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hadoop.hive.ql.exec.BoundaryCache; +import org.apache.hadoop.hive.ql.exec.PTFPartition; +import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec; +import org.apache.hadoop.hive.ql.parse.WindowingSpec; +import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.OrderExpressionDef; +import org.apache.hadoop.io.IntWritable; + +import com.google.common.collect.Lists; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.ASC; +import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.DESC; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec.UNBOUNDED_AMOUNT; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.CURRENT; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.FOLLOWING; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.PRECEDING; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class TestBoundaryCache { + + private static final Logger LOG = LoggerFactory.getLogger(TestBoundaryCache.class); + private static final LinkedList> testPartition = new LinkedList<>(); + private static final List cacheSizes = Lists.newArrayList(null, 2, 5, 9, 15); + private static final List orders = Lists.newArrayList(ASC, DESC); + private static final int orderByCol = 2; + + @BeforeClass + public static void setupTests() throws Exception { + //8 ranges, max cache content is 8+1=9 entries + addRow(testPartition, 1, 1, -7); + addRow(testPartition, 2, 1, -1); + addRow(testPartition, 3, 1, -1); + addRow(testPartition, 4, 1, 1); + addRow(testPartition, 5, 1, 1); + addRow(testPartition, 6, 1, 1); + addRow(testPartition, 7, 1, 1); + addRow(testPartition, 8, 1, 2); + addRow(testPartition, 9, 1, 2); + addRow(testPartition, 10, 1, 2); + addRow(testPartition, 11, 1, 2); + addRow(testPartition, 12, 1, 3); + addRow(testPartition, 13, 1, 5); + addRow(testPartition, 14, 1, 5); + addRow(testPartition, 15, 1, 5); + addRow(testPartition, 16, 1, 5); + addRow(testPartition, 17, 1, 6); + addRow(testPartition, 18, 1, 6); + addRow(testPartition, 19, 1, 9); + + } + + @Test + public void testPrecedingUnboundedFollowingUnbounded() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, FOLLOWING, UNBOUNDED_AMOUNT); + } + + @Test + public void testPrecedingUnboundedCurrentRow() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, CURRENT, 0); + } + + @Test + public void testPrecedingUnboundedPreceding2() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, PRECEDING, 2); + } + + @Test + public void testPreceding4Preceding1() throws Exception { + runTest(PRECEDING, 4, PRECEDING, 1); + } + + @Test + public void testPreceding2CurrentRow() throws Exception { + runTest(PRECEDING, 2, CURRENT, 0); + } + + @Test + public void testPreceding2Following2() throws Exception { + runTest(PRECEDING, 2, FOLLOWING, 2); + } + + @Test + public void testCurrentRowFollowing3() throws Exception { + runTest(CURRENT, 0, FOLLOWING, 3); + } + + @Test + public void testCurrentRowFFollowingUnbounded() throws Exception { + runTest(CURRENT, 0, FOLLOWING, UNBOUNDED_AMOUNT); + } + + @Test + public void testFollowing2Following4() throws Exception { + runTest(FOLLOWING, 2, FOLLOWING, 4); + } + + @Test + public void testFollowing2FollowingUnbounded() throws Exception { + runTest(FOLLOWING, 2, FOLLOWING, UNBOUNDED_AMOUNT); + } + + private void runTest(WindowingSpec.Direction startDirection, int startAmount, + WindowingSpec.Direction endDirection, int endAmount) throws Exception { + + BoundaryDef startBoundary = new BoundaryDef(startDirection, startAmount); + BoundaryDef endBoundary = new BoundaryDef(endDirection, endAmount); + + int[] expectedBoundaryStarts = new int[testPartition.size()]; + int[] expectedBoundaryEnds = new int[testPartition.size()]; + + for (PTFInvocationSpec.Order order : orders) { + for (Integer cacheSize : cacheSizes) { + LOG.info(Thread.currentThread().getStackTrace()[2].getMethodName()); + LOG.info("Cache: " + cacheSize + " order: " + order); + BoundaryCache cache = cacheSize == null ? null : new BoundaryCache(cacheSize); + Pair mocks = setupMocks(testPartition, + orderByCol, startBoundary, endBoundary, order, cache); + PTFPartition ptfPartition = mocks.getLeft(); + ValueBoundaryScanner scanner = mocks.getRight(); + for (int i = 0; i < testPartition.size(); ++i) { + scanner.handleCache(i, ptfPartition); + int start = scanner.computeStart(i, ptfPartition); + int end = scanner.computeEnd(i, ptfPartition) - 1; + if (cache == null) { + //Cache-less version should be baseline + expectedBoundaryStarts[i] = start; + expectedBoundaryEnds[i] = end; + } else { + assertEquals(expectedBoundaryStarts[i], start); + assertEquals(expectedBoundaryEnds[i], end); + } + LOG.info(String.format("%d|\t%d\t%d\t%d\t|%d-%d", i, + testPartition.get(i).get(0).get(), testPartition.get(i).get(1).get(), + testPartition.get(i).get(2).get(), start, end)); + } + } + } + } + + private static Pair setupMocks( + List> partition, int orderByCol, BoundaryDef start, BoundaryDef end, + PTFInvocationSpec.Order order, BoundaryCache cache + ) throws Exception { + PTFPartition partitionMock = mock(PTFPartition.class); + doAnswer(invocationOnMock -> { + int idx = invocationOnMock.getArgumentAt(0, Integer.class); + return partition.get(idx); + }).when(partitionMock).getAt(any(Integer.class)); + doAnswer(invocationOnMock -> { + return partition.size(); + }).when(partitionMock).size(); + when(partitionMock.getBoundaryCache()).thenReturn(cache); + + OrderExpressionDef orderDef = mock(OrderExpressionDef.class); + when(orderDef.getOrder()).thenReturn(order); + + ValueBoundaryScanner scanner = new LongValueBoundaryScanner(start, end, orderDef, true); + ValueBoundaryScanner scannerSpy = spy(scanner); + doAnswer(invocationOnMock -> { + List row = invocationOnMock.getArgumentAt(0, List.class); + return row.get(orderByCol); + }).when(scannerSpy).computeValue(any(Object.class)); + doAnswer(invocationOnMock -> { + IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class); + IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class); + return (v1 != null && v2 != null) ? v1.get() == v2.get() : v1 == null && v2 == null; + }).when(scannerSpy).isEqual(any(Object.class), any(Object.class)); + doAnswer(invocationOnMock -> { + IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class); + IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class); + Integer amt = invocationOnMock.getArgumentAt(2, Integer.class); + return (v1 != null && v2 != null) ? (v1.get() - v2.get()) > amt : v1 != null || v2 != null; + }).when(scannerSpy).isDistanceGreater(any(Object.class), any(Object.class), any(Integer.class)); + + setOrderOnTestPartitions(order); + return new ImmutablePair<>(partitionMock, scannerSpy); + + } + + private static void addRow(List> partition, int a, int b, + int c) { + partition.add(Lists.newArrayList( + new IntWritable(a), + new IntWritable(b), + new IntWritable(c) + )); + } + + private static void setOrderOnTestPartitions(PTFInvocationSpec.Order order) { + boolean isAscCurrently = testPartition.getFirst().get(orderByCol).get() < + testPartition.getLast().get(orderByCol).get(); + + if ((ASC.equals(order) && !isAscCurrently) || (DESC.equals(order) && isAscCurrently)) { + Collections.reverse(testPartition); + } + } + +}