diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 4a86b0ae01299c80accaa5ab39f07dfda6e4d5e3..11f165a085971769479a74569c26c25ab49d431c 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1634,6 +1634,10 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal + "the evaluation of certain joins, since we will not be emitting rows which are thrown away by " + "a Filter operator straight away. However, currently vectorization does not support them, thus " + "enabling it is only recommended when vectorization is disabled."), + HIVE_PTF_RANGECACHE_SIZE("hive.ptf.rangecache.size", 10000, + "Size of the cache used on reducer side, that stores boundaries of ranges within a PTF " + + "partition. Used if a query specifies a RANGE type window including an orderby clause." + + "Set this to 0 to disable this cache."), // CBO related HIVE_CBO_ENABLED("hive.cbo.enable", true, "Flag to control enabling Cost Based Optimizations using Calcite framework."), diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java new file mode 100644 index 0000000000000000000000000000000000000000..7cf278ca05ec24256ad7624f5e1540e18f41ccbd --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/BoundaryCache.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import java.util.LinkedList; +import java.util.Map; +import java.util.TreeMap; + +/** + * Cache for storing boundaries found within a partition - used for PTF functions. + * Stores key-value pairs where key is the row index in the partition from which a range begins, + * value is the corresponding row value (based on what the user specified in the orderby column). + */ +public class BoundaryCache extends TreeMap { + + private boolean isComplete = false; + private final int maxSize; + private final LinkedList queue = new LinkedList<>(); + + public BoundaryCache(int maxSize) { + if (maxSize <= 1) { + throw new IllegalArgumentException("Cache size of 1 and below it doesn't make sense."); + } + this.maxSize = maxSize; + } + + /** + * True if the last range(s) of the partition are loaded into the cache. + * @return + */ + public boolean isComplete() { + return isComplete; + } + + public void setComplete(boolean complete) { + isComplete = complete; + } + + @Override + public Object put(Integer key, Object value) { + Object result = super.put(key, value); + //Every new element is added to FIFO too. + if (result == null) { + queue.add(key); + } + //If FIFO size reaches maxSize we evict the eldest entry. + if (queue.size() > maxSize) { + evictOne(); + } + return result; + } + + /** + * Puts new key-value pair in cache. + * @param key + * @param value + * @return false if queue was full and put failed. True otherwise. + */ + public Boolean putIfNotFull(Integer key, Object value) { + if (isFull()) { + return false; + } else { + put(key, value); + return true; + } + } + + /** + * Checks if cache is full. + * @return true if full, false otherwise. + */ + public Boolean isFull() { + return queue.size() >= maxSize; + } + + @Override + public void clear() { + this.isComplete = false; + this.queue.clear(); + super.clear(); + } + + /** + * Returns entry corresponding to highest row index. + * @return max entry. + */ + public Map.Entry getMaxEntry() { + return floorEntry(Integer.MAX_VALUE); + } + + /** + * Removes eldest entry from the boundary cache. + */ + public void evictOne() { + if (queue.isEmpty()) { + return; + } + Integer elementToDelete = queue.poll(); + this.remove(elementToDelete); + } + + public void evictThisAndAllBefore(int rowIdx) { + while (queue.peek() <= rowIdx) { + evictOne(); + } + } + +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java index f125f9bcfa324364613d29e01a557f1b91cb9ff9..e17068e4d895ab46768550688fdf91a9407e51ae 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java @@ -46,6 +46,7 @@ StructObjectInspector inputOI; StructObjectInspector outputOI; private final PTFRowContainer> elems; + private final BoundaryCache boundaryCache; protected PTFPartition(Configuration cfg, AbstractSerDe serDe, StructObjectInspector inputOI, @@ -70,6 +71,8 @@ protected PTFPartition(Configuration cfg, } else { elems = null; } + int boundaryCacheSize = HiveConf.getIntVar(cfg, ConfVars.HIVE_PTF_RANGECACHE_SIZE); + boundaryCache = boundaryCacheSize > 1 ? new BoundaryCache(boundaryCacheSize) : null; } public void reset() throws HiveException { @@ -262,4 +265,8 @@ public static StructObjectInspector setupPartitionOutputOI(AbstractSerDe serDe, ObjectInspectorCopyOption.WRITABLE); } + public BoundaryCache getBoundaryCache() { + return boundaryCache; + } + } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java index d44604d2eced7e2a1bedb82ac8669e513fb5881c..20dc8628d6a6c0c86d964806b747474b568f25d2 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/BasePartitionEvaluator.java @@ -256,6 +256,7 @@ protected static Range getRange(WindowFrameDef winFrame, int currRow, PTFPartiti end = getRowBoundaryEnd(endB, currRow, p); } else { ValueBoundaryScanner vbs = ValueBoundaryScanner.getScanner(winFrame, nullsLast); + vbs.handleCache(currRow, p); start = vbs.computeStart(currRow, p); end = vbs.computeEnd(currRow, p); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java index e633edb96e54af770086616003df6ea63233b993..524812fd15cf470d7ddfc742b5439456bc805841 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/ValueBoundaryScanner.java @@ -18,10 +18,15 @@ package org.apache.hadoop.hive.ql.udf.ptf; +import java.util.Map; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.Timestamp; import org.apache.hadoop.hive.common.type.TimestampTZ; +import org.apache.hadoop.hive.ql.exec.BoundaryCache; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order; @@ -44,10 +49,258 @@ public ValueBoundaryScanner(BoundaryDef start, BoundaryDef end, boolean nullsLas this.nullsLast = nullsLast; } + public abstract Object computeValue(Object row) throws HiveException; + + /** + * Checks if the distance of v2 to v1 is greater than the given amt. + * @return True if the value of v1 - v2 is greater than amt or either value is null. + */ + public abstract boolean isDistanceGreater(Object v1, Object v2, int amt); + + /** + * Checks if the values of v1 or v2 are the same. + * @return True if both values are the same or both are nulls. + */ + public abstract boolean isEqual(Object v1, Object v2); + public abstract int computeStart(int rowIdx, PTFPartition p) throws HiveException; public abstract int computeEnd(int rowIdx, PTFPartition p) throws HiveException; + /** + * Checks and maintains cache content - optimizes cache window to always be around current row + * thereby makes it follow the current progress. + * @param rowIdx current row + * @param p current partition for the PTF operator + * @throws HiveException + */ + public void handleCache(int rowIdx, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + if (cache == null) { + return; + } + + //No need to setup/fill cache. + if (start.isUnbounded() && end.isUnbounded()) { + return; + } + + //Start of partition. + if (rowIdx == 0) { + cache.clear(); + } + if (cache.isComplete()) { + return; + } + if (cache.isEmpty()) { + fillCacheUntilEndOrFull(rowIdx, p); + return; + } + + if (start.isPreceding()) { + if (start.isUnbounded()) { + if (end.isPreceding()) { + //We can wait with cache eviction until we're at the end of currently known ranges. + Map.Entry maxEntry = cache.getMaxEntry(); + if (maxEntry != null && maxEntry.getKey() <= rowIdx) { + cache.evictOne(); + } + } else { + //Starting from current row, all previous ranges can be evicted. + checkIfCacheCanEvict(rowIdx, p, true); + } + } else { + //We either evict when we're at the end of currently known ranges, or if not there yet and + // END is of FOLLOWING type: we should remove ranges preceding the current range beginning. + Map.Entry maxEntry = cache.getMaxEntry(); + if (maxEntry != null && maxEntry.getKey() <= rowIdx) { + cache.evictOne(); + } else if (end.isFollowing()) { + int startIdx = computeStart(rowIdx, p); + checkIfCacheCanEvict(startIdx - 1, p, true); + } + } + } + + if (start.isCurrentRow()) { + //Starting from current row, all previous ranges before the previous range can be evicted. + checkIfCacheCanEvict(rowIdx, p, false); + } + if (start.isFollowing()) { + //Starting from current row, all previous ranges can be evicted. + checkIfCacheCanEvict(rowIdx, p, true); + } + + fillCacheUntilEndOrFull(rowIdx, p); + } + + /** + * Retrieves the range for rowIdx, then removes all previous range entries before it. + * @param rowIdx row index. + * @param p partition. + * @param willScanFwd false: removal is started only from the previous previous range. + */ + private void checkIfCacheCanEvict(int rowIdx, PTFPartition p, boolean willScanFwd) { + BoundaryCache cache = p.getBoundaryCache(); + if (cache == null) { + return; + } + Map.Entry floorEntry = cache.floorEntry(rowIdx); + if (floorEntry != null) { + floorEntry = cache.floorEntry(floorEntry.getKey() - 1); + if (floorEntry != null) { + if (willScanFwd) { + cache.evictThisAndAllBefore(floorEntry.getKey()); + } else { + floorEntry = cache.floorEntry(floorEntry.getKey() - 1); + if (floorEntry != null) { + cache.evictThisAndAllBefore(floorEntry.getKey()); + } + } + } + } + } + + /** + * Inserts values into cache starting from rowIdx in the current partition p. Stops if cache + * reaches its maximum size or we get out of rows in p. + * @param rowIdx + * @param p + * @throws HiveException + */ + private void fillCacheUntilEndOrFull(int rowIdx, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + if (cache == null || p.size() <= 0) { + return; + } + + Object rowVal = null; + + //If we continue building cache + Map.Entry ceilingEntry = cache.getMaxEntry(); + if (ceilingEntry != null) { + rowIdx = ceilingEntry.getKey(); + rowVal = ceilingEntry.getValue(); + ++rowIdx; + } + + Object lastRowVal = rowVal; + + while (rowIdx < p.size() && !cache.isFull()) { + rowVal = computeValue(p.getAt(rowIdx)); + if (!isEqual(rowVal, lastRowVal)){ + cache.put(rowIdx, rowVal); + } + lastRowVal = rowVal; + ++rowIdx; + + } + //Signaling end of all rows in a partition + if (cache.putIfNotFull(rowIdx, null)) { + cache.setComplete(true); + } + } + + /** + * Uses cache content to jump backwards if possible. If not, it steps one back. + * @param r + * @param p + * @return pair of (row we stepped/jumped onto ; row value at this position) + * @throws HiveException + */ + protected Pair skipOrStepBack(int r, PTFPartition p) + throws HiveException { + Object rowVal = null; + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (floorEntry != null && ceilingEntry != null) { + r = floorEntry.getKey() - 1; + floorEntry = cache.floorEntry(r); + if (floorEntry != null) { + rowVal = floorEntry.getValue(); + } else if (r >= 0){ + rowVal = computeValue(p.getAt(r)); + } + } else { + r--; + if (r >= 0) { + rowVal = computeValue(p.getAt(r)); + } + } + return new ImmutablePair<>(r, rowVal); + } + + /** + * Uses cache content to jump forward if possible. If not, it steps one forward. + * @param r + * @param p + * @return pair of (row we stepped/jumped onto ; row value at this position) + * @throws HiveException + */ + protected Pair skipOrStepForward(int r, PTFPartition p) + throws HiveException { + Object rowVal = null; + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){ + ceilingEntry = cache.ceilingEntry(r + 1); + } + if (floorEntry != null && ceilingEntry != null) { + r = ceilingEntry.getKey(); + rowVal = ceilingEntry.getValue(); + } else { + r++; + if (r < p.size()) { + rowVal = computeValue(p.getAt(r)); + } + } + return new ImmutablePair<>(r, rowVal); + } + + /** + * Uses cache to lookup row value. Computes it on the fly on cache miss. + * @param r + * @param p + * @return row value. + * @throws HiveException + */ + protected Object computeValueUseCache(int r, PTFPartition p) throws HiveException { + BoundaryCache cache = p.getBoundaryCache(); + + Map.Entry floorEntry = null; + Map.Entry ceilingEntry = null; + + if (cache != null) { + floorEntry = cache.floorEntry(r); + ceilingEntry = cache.ceilingEntry(r); + } + + if (ceilingEntry != null && ceilingEntry.getKey().equals(r)){ + return ceilingEntry.getValue(); + } + if (floorEntry != null && ceilingEntry != null) { + return floorEntry.getValue(); + } else { + return computeValue(p.getAt(r)); + } + } + public static ValueBoundaryScanner getScanner(WindowFrameDef winFrameDef, boolean nullsLast) throws HiveException { OrderDef orderDef = winFrameDef.getOrderDef(); @@ -108,6 +361,7 @@ public SingleValueBoundaryScanner(BoundaryDef start, BoundaryDef end, | | | | | | such that R2.sk - R.sk > amt | |------+----------------+----------------+----------+-------+-----------------------------------| */ + @Override public int computeStart(int rowIdx, PTFPartition p) throws HiveException { switch(start.getDirection()) { @@ -127,18 +381,17 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) { return 0; } - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); if ( sortKey == null ) { // Use Case 3. if (nullsLast || expressionDef.getOrder() == Order.DESC) { while ( sortKey == null && rowIdx >= 0 ) { - --rowIdx; - if ( rowIdx >= 0 ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepBack(rowIdx, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } - return rowIdx+1; + return rowIdx + 1; } else { // Use Case 2. if ( expressionDef.getOrder() == Order.ASC ) { @@ -153,36 +406,34 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep // Use Case 4. if ( expressionDef.getOrder() == Order.DESC ) { while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } else { // Use Case 5. while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } + return r + 1; } } protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); // Use Case 6. if ( sortKey == null ) { while ( sortKey == null && rowIdx >= 0 ) { - --rowIdx; - if ( rowIdx >= 0 ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepBack(rowIdx, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } - return rowIdx+1; + return rowIdx + 1; } Object rowVal = sortKey; @@ -190,17 +441,16 @@ protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveExce // Use Case 7. while (r >= 0 && isEqual(rowVal, sortKey) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveException { int amt = start.getAmt(); - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); Object rowVal = sortKey; int r = rowIdx; @@ -212,10 +462,9 @@ protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveExcep } else { // Use Case 10. while (r < p.size() && rowVal == null ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -224,19 +473,17 @@ protected int computeStartFollowing(int rowIdx, PTFPartition p) throws HiveExcep // Use Case 11. if ( expressionDef.getOrder() == Order.DESC) { while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } else { // Use Case 12. while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -292,7 +539,7 @@ protected int computeEndPreceding(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 1. // amt == UNBOUNDED, is caught during translation - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); if ( sortKey == null ) { // Use Case 2. @@ -310,34 +557,31 @@ protected int computeEndPreceding(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 4. if ( expressionDef.getOrder() == Order.DESC ) { while (r >= 0 && !isDistanceGreater(rowVal, sortKey, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } else { // Use Case 5. while (r >= 0 && !isDistanceGreater(sortKey, rowVal, amt) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } } protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); // Use Case 6. if ( sortKey == null ) { while ( sortKey == null && rowIdx < p.size() ) { - ++rowIdx; - if ( rowIdx < p.size() ) { - sortKey = computeValue(p.getAt(rowIdx)); - } + Pair stepResult = skipOrStepForward(rowIdx, p); + rowIdx = stepResult.getLeft(); + sortKey = stepResult.getRight(); } return rowIdx; } @@ -347,10 +591,9 @@ protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveExcept // Use Case 7. while (r < p.size() && isEqual(sortKey, rowVal) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -362,7 +605,7 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti if ( amt == BoundarySpec.UNBOUNDED_AMOUNT ) { return p.size(); } - Object sortKey = computeValue(p.getAt(rowIdx)); + Object sortKey = computeValueUseCache(rowIdx, p); Object rowVal = sortKey; int r = rowIdx; @@ -374,10 +617,9 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti } else { // Use Case 10. while (r < p.size() && rowVal == null ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -386,19 +628,17 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti // Use Case 11. if ( expressionDef.getOrder() == Order.DESC) { while (r < p.size() && !isDistanceGreater(sortKey, rowVal, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } else { // Use Case 12. while (r < p.size() && !isDistanceGreater(rowVal, sortKey, amt) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValue(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -717,15 +957,14 @@ protected int computeStartPreceding(int rowIdx, PTFPartition p) throws HiveExcep } protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object[] sortKey = computeValues(p.getAt(rowIdx)); - Object[] rowVal = sortKey; + Object sortKey = computeValueUseCache(rowIdx, p); + Object rowVal = sortKey; int r = rowIdx; while (r >= 0 && isEqual(rowVal, sortKey) ) { - r--; - if ( r >= 0 ) { - rowVal = computeValues(p.getAt(r)); - } + Pair stepResult = skipOrStepBack(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r + 1; } @@ -741,6 +980,7 @@ protected int computeStartCurrentRow(int rowIdx, PTFPartition p) throws HiveExce | 2. | FOLLOWING | UNB | ANY | ANY | end = partition.size() | |------+----------------+---------------+----------+-------+-----------------------------------| */ + @Override public int computeEnd(int rowIdx, PTFPartition p) throws HiveException { switch(end.getDirection()) { @@ -756,15 +996,14 @@ public int computeEnd(int rowIdx, PTFPartition p) throws HiveException { } protected int computeEndCurrentRow(int rowIdx, PTFPartition p) throws HiveException { - Object[] sortKey = computeValues(p.getAt(rowIdx)); - Object[] rowVal = sortKey; + Object sortKey = computeValueUseCache(rowIdx, p); + Object rowVal = sortKey; int r = rowIdx; while (r < p.size() && isEqual(sortKey, rowVal) ) { - r++; - if ( r < p.size() ) { - rowVal = computeValues(p.getAt(r)); - } + Pair stepResult = skipOrStepForward(r, p); + r = stepResult.getLeft(); + rowVal = stepResult.getRight(); } return r; } @@ -778,7 +1017,8 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti "FOLLOWING needs UNBOUNDED for RANGE with multiple expressions in ORDER BY"); } - public Object[] computeValues(Object row) throws HiveException { + @Override + public Object computeValue(Object row) throws HiveException { Object[] objs = new Object[orderDef.getExpressions().size()]; for (int i = 0; i < objs.length; i++) { Object o = orderDef.getExpressions().get(i).getExprEvaluator().evaluate(row); @@ -787,7 +1027,14 @@ protected int computeEndFollowing(int rowIdx, PTFPartition p) throws HiveExcepti return objs; } - public boolean isEqual(Object[] v1, Object[] v2) { + @Override + public boolean isEqual(Object val1, Object val2) { + if (val1 == null || val2 == null) { + return (val1 == null && val2 == null); + } + Object[] v1 = (Object[]) val1; + Object[] v2 = (Object[]) val2; + assert v1.length == v2.length; for (int i = 0; i < v1.length; i++) { if (v1[i] == null && v2[i] == null) { @@ -804,5 +1051,10 @@ public boolean isEqual(Object[] v1, Object[] v2) { } return true; } + + @Override + public boolean isDistanceGreater(Object v1, Object v2, int amt) { + throw new UnsupportedOperationException("Only unbounded ranges supported"); + } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java new file mode 100644 index 0000000000000000000000000000000000000000..714c51badcdf6144f49b28b0d043471276340e54 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/udf/ptf/TestBoundaryCache.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.ptf; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hadoop.hive.ql.exec.BoundaryCache; +import org.apache.hadoop.hive.ql.exec.PTFPartition; +import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec; +import org.apache.hadoop.hive.ql.parse.WindowingSpec; +import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.OrderExpressionDef; +import org.apache.hadoop.io.IntWritable; + +import com.google.common.collect.Lists; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.util.Optional.ofNullable; +import static java.util.stream.Collectors.toCollection; +import static java.util.stream.Collectors.toList; +import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.ASC; +import static org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order.DESC; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec.UNBOUNDED_AMOUNT; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.CURRENT; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.FOLLOWING; +import static org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction.PRECEDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +/** + * Tests BoundaryCache used for RANGE windows in PTF functions. + */ +public class TestBoundaryCache { + + private static final Logger LOG = LoggerFactory.getLogger(TestBoundaryCache.class); + private static final LinkedList> TEST_PARTITION = new LinkedList<>(); + //Null for using no cache at all, 2 is minimum cache length, 5-9-15 for checking with smaller, + // exactly equal and larger cache than needed. + private static final List CACHE_SIZES = Lists.newArrayList(null, 2, 5, 9, 15); + private static final List ORDERS = Lists.newArrayList(ASC, DESC); + private static final int ORDER_BY_COL = 2; + + @BeforeClass + public static void setupTests() throws Exception { + //8 ranges, max cache content is 8+1=9 entries + addRow(TEST_PARTITION, 1, 1, -7); + addRow(TEST_PARTITION, 2, 1, -1); + addRow(TEST_PARTITION, 3, 1, -1); + addRow(TEST_PARTITION, 4, 1, 1); + addRow(TEST_PARTITION, 5, 1, 1); + addRow(TEST_PARTITION, 6, 1, 1); + addRow(TEST_PARTITION, 7, 1, 1); + addRow(TEST_PARTITION, 8, 1, 2); + addRow(TEST_PARTITION, 9, 1, 2); + addRow(TEST_PARTITION, 10, 1, 2); + addRow(TEST_PARTITION, 11, 1, 2); + addRow(TEST_PARTITION, 12, 1, 3); + addRow(TEST_PARTITION, 13, 1, 5); + addRow(TEST_PARTITION, 14, 1, 5); + addRow(TEST_PARTITION, 15, 1, 5); + addRow(TEST_PARTITION, 16, 1, 5); + addRow(TEST_PARTITION, 17, 1, 6); + addRow(TEST_PARTITION, 18, 1, 6); + addRow(TEST_PARTITION, 19, 1, 9); + addRow(TEST_PARTITION, 20, 1, null); + addRow(TEST_PARTITION, 21, 1, null); + + } + + @Test + public void testPrecedingUnboundedFollowingUnbounded() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, FOLLOWING, UNBOUNDED_AMOUNT); + } + + @Test + public void testPrecedingUnboundedCurrentRow() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, CURRENT, 0); + } + + @Test + public void testPrecedingUnboundedPreceding2() throws Exception { + runTest(PRECEDING, UNBOUNDED_AMOUNT, PRECEDING, 2); + } + + @Test + public void testPreceding4Preceding1() throws Exception { + runTest(PRECEDING, 4, PRECEDING, 1); + } + + @Test + public void testPreceding2CurrentRow() throws Exception { + runTest(PRECEDING, 2, CURRENT, 0); + } + + @Test + public void testPreceding2Following100() throws Exception { + runTest(PRECEDING, 1, FOLLOWING, 100); + } + + @Test + public void testCurrentRowFollowing3() throws Exception { + runTest(CURRENT, 0, FOLLOWING, 3); + } + + @Test + public void testCurrentRowFFollowingUnbounded() throws Exception { + runTest(CURRENT, 0, FOLLOWING, UNBOUNDED_AMOUNT); + } + + @Test + public void testFollowing2Following4() throws Exception { + runTest(FOLLOWING, 2, FOLLOWING, 4); + } + + @Test + public void testFollowing2FollowingUnbounded() throws Exception { + runTest(FOLLOWING, 2, FOLLOWING, UNBOUNDED_AMOUNT); + } + + /** + * Executes test on a given window definition. Such a test will be executed against the values set + * in ORDERS and CACHE_SIZES, validating ORDERS X CACHE_SIZES test cases. Cache size of null will + * be used to setup baseline. + * @param startDirection + * @param startAmount + * @param endDirection + * @param endAmount + * @throws Exception + */ + private void runTest(WindowingSpec.Direction startDirection, int startAmount, + WindowingSpec.Direction endDirection, int endAmount) throws Exception { + + BoundaryDef startBoundary = new BoundaryDef(startDirection, startAmount); + BoundaryDef endBoundary = new BoundaryDef(endDirection, endAmount); + AtomicInteger readCounter = new AtomicInteger(0); + + int[] expectedBoundaryStarts = new int[TEST_PARTITION.size()]; + int[] expectedBoundaryEnds = new int[TEST_PARTITION.size()]; + int expectedReadCountWithoutCache = -1; + + for (PTFInvocationSpec.Order order : ORDERS) { + for (Integer cacheSize : CACHE_SIZES) { + LOG.info(Thread.currentThread().getStackTrace()[2].getMethodName()); + LOG.info("Cache: " + cacheSize + " order: " + order); + BoundaryCache cache = cacheSize == null ? null : new BoundaryCache(cacheSize); + Pair mocks = setupMocks(TEST_PARTITION, + ORDER_BY_COL, startBoundary, endBoundary, order, cache, readCounter); + PTFPartition ptfPartition = mocks.getLeft(); + ValueBoundaryScanner scanner = mocks.getRight(); + for (int i = 0; i < TEST_PARTITION.size(); ++i) { + scanner.handleCache(i, ptfPartition); + int start = scanner.computeStart(i, ptfPartition); + int end = scanner.computeEnd(i, ptfPartition) - 1; + if (cache == null) { + //Cache-less version should be baseline + expectedBoundaryStarts[i] = start; + expectedBoundaryEnds[i] = end; + } else { + assertEquals(expectedBoundaryStarts[i], start); + assertEquals(expectedBoundaryEnds[i], end); + } + Integer col0 = ofNullable(TEST_PARTITION.get(i).get(0)).map(v -> v.get()).orElse(null); + Integer col1 = ofNullable(TEST_PARTITION.get(i).get(1)).map(v -> v.get()).orElse(null); + Integer col2 = ofNullable(TEST_PARTITION.get(i).get(2)).map(v -> v.get()).orElse(null); + LOG.info(String.format("%d|\t%d\t%d\t%d\t|%d-%d", i, col0, col1, col2, start, end)); + } + if (cache == null) { + expectedReadCountWithoutCache = readCounter.get(); + } else { + //Read count should be smaller with cache being used, but larger than the minimum of + // reading every row once. + assertTrue(expectedReadCountWithoutCache >= readCounter.get()); + if (startAmount != UNBOUNDED_AMOUNT || endAmount != UNBOUNDED_AMOUNT) { + assertTrue(TEST_PARTITION.size() <= readCounter.get()); + } + } + readCounter.set(0); + } + } + } + + /** + * Sets up mock and spy objects used for testing. + * @param partition The real partition containing row values. + * @param orderByCol Index of column in the row used for separating ranges. + * @param start Window definition. + * @param end Window definition. + * @param order Window definition. + * @param cache BoundaryCache instance, it may come in various sizes. + * @param readCounter counts how many times reading was invoked + * @return Mocked PTFPartition instance and ValueBoundaryScanner spy. + * @throws Exception + */ + private static Pair setupMocks( + List> partition, int orderByCol, BoundaryDef start, BoundaryDef end, + PTFInvocationSpec.Order order, BoundaryCache cache, + AtomicInteger readCounter) throws Exception { + PTFPartition partitionMock = mock(PTFPartition.class); + doAnswer(invocationOnMock -> { + int idx = invocationOnMock.getArgumentAt(0, Integer.class); + return partition.get(idx); + }).when(partitionMock).getAt(any(Integer.class)); + doAnswer(invocationOnMock -> { + return partition.size(); + }).when(partitionMock).size(); + when(partitionMock.getBoundaryCache()).thenReturn(cache); + + OrderExpressionDef orderDef = mock(OrderExpressionDef.class); + when(orderDef.getOrder()).thenReturn(order); + + ValueBoundaryScanner scan = new LongValueBoundaryScanner(start, end, orderDef, order == ASC); + ValueBoundaryScanner scannerSpy = spy(scan); + doAnswer(invocationOnMock -> { + readCounter.incrementAndGet(); + List row = invocationOnMock.getArgumentAt(0, List.class); + return row.get(orderByCol); + }).when(scannerSpy).computeValue(any(Object.class)); + doAnswer(invocationOnMock -> { + IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class); + IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class); + return (v1 != null && v2 != null) ? v1.get() == v2.get() : v1 == null && v2 == null; + }).when(scannerSpy).isEqual(any(Object.class), any(Object.class)); + doAnswer(invocationOnMock -> { + IntWritable v1 = invocationOnMock.getArgumentAt(0, IntWritable.class); + IntWritable v2 = invocationOnMock.getArgumentAt(1, IntWritable.class); + Integer amt = invocationOnMock.getArgumentAt(2, Integer.class); + return (v1 != null && v2 != null) ? (v1.get() - v2.get()) > amt : v1 != null || v2 != null; + }).when(scannerSpy).isDistanceGreater(any(Object.class), any(Object.class), any(Integer.class)); + + setOrderOnTestPartitions(order); + return new ImmutablePair<>(partitionMock, scannerSpy); + + } + + private static void addRow(List> partition, Integer col0, Integer col1, + Integer col2) { + partition.add(Lists.newArrayList( + col0 != null ? new IntWritable(col0) : null, + col1 != null ? new IntWritable(col1) : null, + col2 != null ? new IntWritable(col2) : null + )); + } + + /** + * Reverses order on actual data if needed, based on order parameter. + * @param order + */ + private static void setOrderOnTestPartitions(PTFInvocationSpec.Order order) { + LinkedList> notNulls = TEST_PARTITION.stream().filter( + r -> r.get(ORDER_BY_COL) != null).collect(toCollection(LinkedList::new)); + List> nulls = TEST_PARTITION.stream().filter( + r -> r.get(ORDER_BY_COL) == null).collect(toList()); + + boolean isAscCurrently = notNulls.getFirst().get(ORDER_BY_COL).get() < + notNulls.getLast().get(ORDER_BY_COL).get(); + + if ((ASC.equals(order) && !isAscCurrently) || (DESC.equals(order) && isAscCurrently)) { + Collections.reverse(notNulls); + TEST_PARTITION.clear(); + TEST_PARTITION.addAll(notNulls); + TEST_PARTITION.addAll(nulls); + } + } + +}