diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/PTFTopNHash.java ql/src/java/org/apache/hadoop/hive/ql/exec/PTFTopNHash.java new file mode 100644 index 0000000..de7d71b --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/PTFTopNHash.java @@ -0,0 +1,197 @@ +package org.apache.hadoop.hive.ql.exec; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.io.HiveKey; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.BytesWritable; + + +public class PTFTopNHash extends TopNHash { + + protected float memUsage; + protected boolean isMapGroupBy; + private Map partitionHeaps; + private TopNHash largestPartition; + private boolean prevIndexPartIsNull; + private Set indexesWithNullPartKey; + + public void initialize( + int topN, float memUsage, boolean isMapGroupBy, BinaryCollector collector) { + super.initialize(topN, memUsage, isMapGroupBy, collector); + this.isMapGroupBy = isMapGroupBy; + this.memUsage = memUsage; + partitionHeaps = new HashMap(); + indexesWithNullPartKey = new HashSet(); + } + + public int tryStoreKey(HiveKey key, boolean partColsIsNull) throws HiveException, IOException { + prevIndexPartIsNull = partColsIsNull; + return _tryStoreKey(key, partColsIsNull, -1); + } + + private void updateLargest(TopNHash p) { + if ( largestPartition == null || largestPartition.usage < p.usage) { + largestPartition = p; + } + } + + private void findLargest() { + for(TopNHash p : partitionHeaps.values() ) { + updateLargest(p); + } + } + + public int _tryStoreKey(HiveKey key, boolean partColsIsNull, int batchIndex) throws HiveException, IOException { + if (!isEnabled) { + return FORWARD; // short-circuit quickly - forward all rows + } + if (topN == 0) { + return EXCLUDE; // short-circuit quickly - eat all rows + } + Key pk = new Key(partColsIsNull, key.hashCode()); + TopNHash partHeap = partitionHeaps.get(pk); + if ( partHeap == null ) { + partHeap = new TopNHash(); + partHeap.initialize(topN, memUsage, isMapGroupBy, collector); + if ( batchIndex >= 0 ) { + partHeap.startVectorizedBatch(batchSize); + } + partitionHeaps.put(pk, partHeap); + } + usage = usage - partHeap.usage; + int r = 0; + if ( batchIndex >= 0 ) { + partHeap.tryStoreVectorizedKey(key, false, batchIndex); + } else { + r = partHeap.tryStoreKey(key, false); + } + usage = usage + partHeap.usage; + updateLargest(partHeap); + + if ( usage > threshold ) { + usage -= largestPartition.usage; + largestPartition.flush(); + usage += largestPartition.usage; + largestPartition = null; + findLargest(); + } + return r; + } + + public void storeValue(int index, int hashCode, BytesWritable value, boolean vectorized) { + Key pk = new Key(prevIndexPartIsNull, hashCode); + TopNHash partHeap = partitionHeaps.get(pk); + usage = usage - partHeap.usage; + partHeap.storeValue(index, hashCode, value, vectorized); + usage = usage + partHeap.usage; + updateLargest(partHeap); + } + + public void flush() throws HiveException { + if (!isEnabled || (topN == 0)) return; + for(TopNHash partHash : partitionHeaps.values()) { + partHash.flush(); + } + } + + public int startVectorizedBatch(int size) throws IOException, HiveException { + if (!isEnabled) { + return FORWARD; // short-circuit quickly - forward all rows + } else if (topN == 0) { + return EXCLUDE; // short-circuit quickly - eat all rows + } + for(TopNHash partHash : partitionHeaps.values()) { + usage = usage - partHash.usage; + partHash.startVectorizedBatch(size); + usage = usage + partHash.usage; + updateLargest(partHash); + } + batchSize = size; + if (batchIndexToResult == null || batchIndexToResult.length < batchSize) { + batchIndexToResult = new int[Math.max(batchSize, VectorizedRowBatch.DEFAULT_SIZE)]; + } + indexesWithNullPartKey.clear(); + return 0; + } + + public void tryStoreVectorizedKey(HiveKey key, boolean partColsIsNull, int batchIndex) + throws HiveException, IOException { + _tryStoreKey(key, partColsIsNull, batchIndex); + if ( partColsIsNull ) { + indexesWithNullPartKey.add(batchIndex); + } + batchIndexToResult[batchIndex] = key.hashCode(); + } + + public int getVectorizedBatchResult(int batchIndex) { + prevIndexPartIsNull = indexesWithNullPartKey.contains(batchIndex); + Key pk = new Key(prevIndexPartIsNull, batchIndexToResult[batchIndex]); + TopNHash partHeap = partitionHeaps.get(pk); + return partHeap.getVectorizedBatchResult(batchIndex); + } + + public HiveKey getVectorizedKeyToForward(int batchIndex) { + prevIndexPartIsNull = indexesWithNullPartKey.contains(batchIndex); + Key pk = new Key(prevIndexPartIsNull, batchIndexToResult[batchIndex]); + TopNHash partHeap = partitionHeaps.get(pk); + return partHeap.getVectorizedKeyToForward(batchIndex); + } + + public int getVectorizedKeyDistLength(int batchIndex) { + prevIndexPartIsNull = indexesWithNullPartKey.contains(batchIndex); + Key pk = new Key(prevIndexPartIsNull, batchIndexToResult[batchIndex]); + TopNHash partHeap = partitionHeaps.get(pk); + return partHeap.getVectorizedKeyDistLength(batchIndex); + } + + public int getVectorizedKeyHashCode(int batchIndex) { + prevIndexPartIsNull = indexesWithNullPartKey.contains(batchIndex); + Key pk = new Key(prevIndexPartIsNull, batchIndexToResult[batchIndex]); + TopNHash partHeap = partitionHeaps.get(pk); + return partHeap.getVectorizedKeyHashCode(batchIndex); + } + + static class Key { + boolean isNull; + int hashCode; + + public Key(boolean isNull, int hashCode) { + super(); + this.isNull = isNull; + this.hashCode = hashCode; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Key other = (Key) obj; + if (hashCode != other.hashCode) + return false; + if (isNull != other.isNull) + return false; + return true; + } + + @Override + public String toString() { + return "" + hashCode + "," + isNull; + } + + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/ReduceSinkOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/ReduceSinkOperator.java index 03a64e8..2babcba 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/ReduceSinkOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/ReduceSinkOperator.java @@ -118,7 +118,6 @@ public void setOutputCollector(OutputCollector _out) { protected transient TopNHash reducerHash = new TopNHash(); @Override protected void initializeOp(Configuration hconf) throws HiveException { - try { List keys = conf.getKeyCols(); keyEval = new ExprNodeEvaluator[keys.size()]; @@ -174,7 +173,9 @@ protected void initializeOp(Configuration hconf) throws HiveException { int limit = conf.getTopN(); float memUsage = conf.getTopNMemoryUsage(); + if (limit >= 0 && memUsage > 0) { + reducerHash = conf.isPTFReduceSink() ? new PTFTopNHash() : reducerHash; reducerHash.initialize(limit, memUsage, conf.isMapGroupBy(), this); } @@ -316,8 +317,14 @@ public void processOp(Object row, int tag) throws HiveException { firstKey.setHashCode(hashCode); + /* + * in case of TopN for windowing, we need to distinguish between rows with + * null partition keys and rows with value 0 for partition keys. + */ + boolean partKeyNull = conf.isPTFReduceSink() && partitionKeysAreNull(row); + // Try to store the first key. If it's not excluded, we will proceed. - int firstIndex = reducerHash.tryStoreKey(firstKey); + int firstIndex = reducerHash.tryStoreKey(firstKey, partKeyNull); if (firstIndex == TopNHash.EXCLUDE) return; // Nothing to do. // Compute value and hashcode - we'd either store or forward them. BytesWritable value = makeValueWritable(row); @@ -326,7 +333,7 @@ public void processOp(Object row, int tag) throws HiveException { collect(firstKey, value); } else { assert firstIndex >= 0; - reducerHash.storeValue(firstIndex, value, false); + reducerHash.storeValue(firstIndex, firstKey.hashCode(), value, false); } // All other distinct keys will just be forwarded. This could be optimized... @@ -405,6 +412,19 @@ private int computeHashCode(Object row) throws HiveException { } return keyHashCode; } + + private boolean partitionKeysAreNull(Object row) throws HiveException { + if ( partitionEval.length != 0 ) { + for (int i = 0; i < partitionEval.length; i++) { + Object o = partitionEval[i].evaluate(row); + if ( o != null ) { + return false; + } + } + return true; + } + return false; + } private int computeHashCode(Object row, int buckNum) throws HiveException { int keyHashCode = computeHashCode(row); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/TopNHash.java ql/src/java/org/apache/hadoop/hive/ql/exec/TopNHash.java index bc81467..484006a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/TopNHash.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/TopNHash.java @@ -57,11 +57,11 @@ public static final int EXCLUDE = -2; // Discard the row. private static final int MAY_FORWARD = -3; // Vectorized - may forward the row, not sure yet. - private BinaryCollector collector; - private int topN; + protected BinaryCollector collector; + protected int topN; - private long threshold; // max heap size - private long usage; + protected long threshold; // max heap size + protected long usage; // binary keys, values and hashCodes of rows, lined up by index private byte[][] keys; @@ -76,10 +76,10 @@ // temporary single-batch context used for vectorization private int batchNumForwards = 0; // whether current batch has any forwarded keys private int[] indexToBatchIndex; // mapping of index (lined up w/keys) to index in the batch - private int[] batchIndexToResult; // mapping of index in the batch (linear) to hash result - private int batchSize; // Size of the current batch. + protected int[] batchIndexToResult; // mapping of index in the batch (linear) to hash result + protected int batchSize; // Size of the current batch. - private boolean isEnabled = false; + protected boolean isEnabled = false; private final Comparator C = new Comparator() { public int compare(Integer o1, Integer o2) { @@ -124,7 +124,7 @@ public void initialize( * TopNHash.EXCLUDED if the row should be discarded; * any other number if the row is to be stored; the index should be passed to storeValue. */ - public int tryStoreKey(HiveKey key) throws HiveException, IOException { + public int tryStoreKey(HiveKey key, boolean partColsIsNull) throws HiveException, IOException { if (!isEnabled) { return FORWARD; // short-circuit quickly - forward all rows } @@ -191,7 +191,7 @@ public int startVectorizedBatch(int size) throws IOException, HiveException { * @param key the key. * @param batchIndex The index of the key in the vectorized batch (sequential, not .selected). */ - public void tryStoreVectorizedKey(HiveKey key, int batchIndex) + public void tryStoreVectorizedKey(HiveKey key, boolean partColsIsNull, int batchIndex) throws HiveException, IOException { // Assumption - batchIndex is increasing; startVectorizedBatch was called int size = indexes.size(); @@ -201,6 +201,13 @@ public void tryStoreVectorizedKey(HiveKey key, int batchIndex) hashes[index] = key.hashCode(); Integer collisionIndex = indexes.store(index); if (null != collisionIndex) { + /* + * since there is a collision index will be used for the next value + * so have the map point back to original index. + */ + if ( indexes instanceof HashForGroup ) { + indexes.store(collisionIndex); + } // forward conditional on the survival of the corresponding key currently in indexes. ++batchNumForwards; batchIndexToResult[batchIndex] = MAY_FORWARD - collisionIndex; @@ -283,11 +290,12 @@ public int getVectorizedKeyHashCode(int batchIndex) { /** * Stores the value for the key in the heap. * @param index The index, either from tryStoreKey or from tryStoreVectorizedKey result. + * @param hasCode hashCode of key, used by ptfTopNHash. * @param value The value to store. * @param keyHash The key hash to store. * @param vectorized Whether the result is coming from a vectorized batch. */ - public void storeValue(int index, BytesWritable value, boolean vectorized) { + public void storeValue(int index, int hashCode, BytesWritable value, boolean vectorized) { values[index] = Arrays.copyOf(value.getBytes(), value.getLength()); // Vectorized doesn't adjust usage for the keys while processing the batch usage += values[index].length + (vectorized ? keys[index].length : 0); @@ -369,7 +377,7 @@ private void flushInternal() throws IOException, HiveException { } excluded = 0; } - + private interface IndexStore { int size(); /** diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java index 11024da..4d0816d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorReduceSinkOperator.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.exec.PTFTopNHash; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.TopNHash; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; @@ -208,6 +209,7 @@ public void assign(VectorExpressionWriter[] writers, int limit = conf.getTopN(); float memUsage = conf.getTopNMemoryUsage(); if (limit >= 0 && memUsage > 0) { + reducerHash = conf.isPTFReduceSink() ? new PTFTopNHash() : reducerHash; reducerHash.initialize(limit, memUsage, conf.isMapGroupBy(), this); } } catch(Exception e) { @@ -291,7 +293,12 @@ public void processOp(Object row, int tag) throws HiveException { firstKey.setHashCode(hashCode); if (useTopN) { - reducerHash.tryStoreVectorizedKey(firstKey, batchIndex); + /* + * in case of TopN for windowing, we need to distinguish between + * rows with null partition keys and rows with value 0 for partition keys. + */ + boolean partkeysNull = conf.isPTFReduceSink() && partitionKeysAreNull(vrg, rowIndex); + reducerHash.tryStoreVectorizedKey(firstKey, partkeysNull, batchIndex); } else { // No TopN, just forward the first key and all others. BytesWritable value = makeValueWritable(vrg, rowIndex); @@ -320,9 +327,9 @@ public void processOp(Object row, int tag) throws HiveException { hashCode = firstKey.hashCode(); collect(firstKey, value); } else { - reducerHash.storeValue(result, value, true); - distKeyLength = reducerHash.getVectorizedKeyDistLength(batchIndex); hashCode = reducerHash.getVectorizedKeyHashCode(batchIndex); + reducerHash.storeValue(result, hashCode, value, true); + distKeyLength = reducerHash.getVectorizedKeyDistLength(batchIndex); } // Now forward other the rows if there's multi-distinct (but see TODO in forward...). // Unfortunately, that means we will have to rebuild the cachedKeys. Start at 1. @@ -443,6 +450,22 @@ private int computeHashCode(VectorizedRowBatch vrg, int rowIndex) throws HiveExc return keyHashCode; } + private boolean partitionKeysAreNull(VectorizedRowBatch vrg, int rowIndex) + throws HiveException { + if (partitionEval.length != 0) { + for (int p = 0; p < partitionEval.length; p++) { + ColumnVector columnVector = vrg.cols[partitionEval[p].getOutputColumn()]; + Object partitionValue = partitionWriters[p].writeValue(columnVector, + rowIndex); + if (partitionValue != null) { + return false; + } + } + return true; + } + return false; + } + private int computeHashCode(VectorizedRowBatch vrg, int rowIndex, int buckNum) throws HiveException { int keyHashCode = computeHashCode(vrg, rowIndex); keyHashCode = keyHashCode * 31 + buckNum; diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/ReduceSinkDesc.java ql/src/java/org/apache/hadoop/hive/ql/plan/ReduceSinkDesc.java index 8c1d336..d760761 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/ReduceSinkDesc.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/ReduceSinkDesc.java @@ -86,6 +86,8 @@ private int topN = -1; private float topNMemoryUsage = -1; private boolean mapGroupBy; // for group-by, values with same key on top-K should be forwarded + //flag used to control how TopN handled for PTF/Windowing partitions. + private boolean isPTFReduceSink = false; private boolean skipTag; // Skip writing tags when feeding into mapjoin hashtable private boolean autoParallel = false; // Is reducer parallelism automatic or fixed @@ -253,6 +255,14 @@ public void setMapGroupBy(boolean mapGroupBy) { this.mapGroupBy = mapGroupBy; } + public boolean isPTFReduceSink() { + return isPTFReduceSink; + } + + public void setPTFReduceSink(boolean isPTFReduceSink) { + this.isPTFReduceSink = isPTFReduceSink; + } + /** * Returns the number of reducers for the map-reduce job. -1 means to decide * the number of reducers at runtime. This enables Hive to estimate the number diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowTableFunctionDef.java ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowTableFunctionDef.java index c547e62..083aaf2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowTableFunctionDef.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/ptf/WindowTableFunctionDef.java @@ -23,6 +23,9 @@ public class WindowTableFunctionDef extends PartitionedTableFunctionDef { List windowFunctions; + + int rankLimit = -1; + int rankLimitFunction; public List getWindowFunctions() { return windowFunctions; @@ -30,4 +33,16 @@ public void setWindowFunctions(List windowFunctions) { this.windowFunctions = windowFunctions; } + public int getRankLimit() { + return rankLimit; + } + public void setRankLimit(int rankLimit) { + this.rankLimit = rankLimit; + } + public int getRankLimitFunction() { + return rankLimitFunction; + } + public void setRankLimitFunction(int rankLimitFunction) { + this.rankLimitFunction = rankLimitFunction; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java index 7aaf455..ad03e68 100644 --- ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java @@ -37,6 +37,7 @@ import org.apache.hadoop.hive.ql.exec.LateralViewJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.OperatorFactory; +import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.TableScanOperator; @@ -53,7 +54,9 @@ import org.apache.hadoop.hive.ql.parse.OpParseContext; import org.apache.hadoop.hive.ql.parse.RowResolver; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; @@ -61,8 +64,23 @@ import org.apache.hadoop.hive.ql.plan.JoinCondDesc; import org.apache.hadoop.hive.ql.plan.JoinDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.PTFDesc; +import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; +import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.ValueBoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFunctionDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowTableFunctionDef; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFDenseRank.GenericUDAFDenseRankEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFLead.GenericUDAFLeadEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFRank.GenericUDAFRankEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan; +import org.apache.hadoop.hive.ql.udf.ptf.WindowingTableFunction; import org.apache.hadoop.hive.serde2.Deserializer; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.mapred.JobConf; /** @@ -133,6 +151,176 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } + + public static class PTFPPD extends ScriptPPD { + + /* + * For WindowingTableFunction if: + * a. there is a Rank/DenseRank function: if there are unpushedPred of the form + * rnkValue < Constant; then use the smallest Constant val as the 'rankLimit' + * on the WindowingTablFn. + * b. If there are no Wdw Fns with an End Boundary past the current row, the + * condition can be pushed down as a limit pushdown(mapGroupBy=true) + * + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.ppd.OpProcFactory.ScriptPPD#process(org.apache.hadoop.hive.ql.lib.Node, java.util.Stack, org.apache.hadoop.hive.ql.lib.NodeProcessorCtx, java.lang.Object[]) + */ + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + LOG.info("Processing for " + nd.getName() + "(" + + ((Operator) nd).getIdentifier() + ")"); + OpWalkerInfo owi = (OpWalkerInfo) procCtx; + PTFOperator ptfOp = (PTFOperator) nd; + + pushRankLimit(ptfOp, owi); + return super.process(nd, stack, procCtx, nodeOutputs); + } + + private void pushRankLimit(PTFOperator ptfOp, OpWalkerInfo owi) throws SemanticException { + PTFDesc conf = ptfOp.getConf(); + + if ( !conf.forWindowing() ) { + return; + } + + float threshold = owi.getParseContext().getConf().getFloatVar(HiveConf.ConfVars.HIVELIMITPUSHDOWNMEMORYUSAGE); + if (threshold <= 0 || threshold >= 1) { + return; + } + + WindowTableFunctionDef wTFn = (WindowTableFunctionDef) conf.getFuncDef(); + List rFnIdxs = rankingFunctions(wTFn); + + if ( rFnIdxs.size() == 0 ) { + return; + } + + ExprWalkerInfo childInfo = getChildWalkerInfo((Operator) ptfOp, owi); + + List preds = new ArrayList(); + Iterator> iterator = childInfo.getFinalCandidates().values().iterator(); + while (iterator.hasNext()) { + for (ExprNodeDesc pred : iterator.next()) { + preds = ExprNodeDescUtils.split(pred, preds); + } + } + + int rLimit = -1; + int fnIdx = -1; + for(ExprNodeDesc pred : preds) { + int[] pLimit = getLimit(wTFn, rFnIdxs, pred); + if ( pLimit != null ) { + if ( rLimit == -1 || rLimit >= pLimit[0] ) { + rLimit = pLimit[0]; + fnIdx = pLimit[1]; + } + } + } + + if ( rLimit != -1 ) { + wTFn.setRankLimit(rLimit); + wTFn.setRankLimitFunction(fnIdx); + if ( canPushLimitToReduceSink(wTFn)) { + pushRankLimitToRedSink(ptfOp, owi.getParseContext().getConf(), rLimit); + } + } + } + + private List rankingFunctions(WindowTableFunctionDef wTFn) { + List rFns = new ArrayList(); + for(int i=0; i < wTFn.getWindowFunctions().size(); i++ ) { + WindowFunctionDef wFnDef = wTFn.getWindowFunctions().get(i); + if ( (wFnDef.getWFnEval() instanceof GenericUDAFRankEvaluator) || + (wFnDef.getWFnEval() instanceof GenericUDAFDenseRankEvaluator ) ) { + rFns.add(i); + } + } + return rFns; + } + + private int[] getLimit(WindowTableFunctionDef wTFn, List rFnIdxs, ExprNodeDesc expr) { + + if ( !(expr instanceof ExprNodeGenericFuncDesc) ) { + return null; + } + + ExprNodeGenericFuncDesc fExpr = (ExprNodeGenericFuncDesc) expr; + + if ( !(fExpr.getGenericUDF() instanceof GenericUDFOPLessThan) && + !(fExpr.getGenericUDF() instanceof GenericUDFOPEqualOrLessThan) ) { + return null; + } + + if ( !(fExpr.getChildren().get(0) instanceof ExprNodeColumnDesc) ) { + return null; + } + + if ( !(fExpr.getChildren().get(1) instanceof ExprNodeConstantDesc) ) { + return null; + } + + ExprNodeConstantDesc constantExpr = (ExprNodeConstantDesc) fExpr.getChildren().get(1) ; + + if ( constantExpr.getTypeInfo() != TypeInfoFactory.intTypeInfo ) { + return null; + } + + int limit = (Integer) constantExpr.getValue(); + if ( fExpr.getGenericUDF() instanceof GenericUDFOPEqualOrLessThan ) { + limit = limit + 1; + } + String colName = ((ExprNodeColumnDesc)fExpr.getChildren().get(0)).getColumn(); + + for(int i=0; i < rFnIdxs.size(); i++ ) { + String fAlias = wTFn.getWindowFunctions().get(i).getAlias(); + if ( fAlias.equals(colName)) { + return new int[] {limit,i}; + } + } + + return null; + } + + private boolean canPushLimitToReduceSink(WindowTableFunctionDef wTFn) { + + for(WindowFunctionDef wFnDef : wTFn.getWindowFunctions() ) { + if ( (wFnDef.getWFnEval() instanceof GenericUDAFRankEvaluator) || + (wFnDef.getWFnEval() instanceof GenericUDAFDenseRankEvaluator ) || + (wFnDef.getWFnEval() instanceof GenericUDAFLeadEvaluator ) ) { + continue; + } + WindowFrameDef wdwFrame = wFnDef.getWindowFrame(); + BoundaryDef end = wdwFrame.getEnd(); + if ( end instanceof ValueBoundaryDef ) { + return false; + } + if ( end.getDirection() == Direction.FOLLOWING ) { + return false; + } + } + return true; + } + + private void pushRankLimitToRedSink(PTFOperator ptfOp, HiveConf conf, int rLimit) throws SemanticException { + + Operator parent = ptfOp.getParentOperators().get(0); + Operator gP = parent == null ? null : parent.getParentOperators().get(0); + + if ( gP == null || !(gP instanceof ReduceSinkOperator )) { + return; + } + + float threshold = conf.getFloatVar(HiveConf.ConfVars.HIVELIMITPUSHDOWNMEMORYUSAGE); + + ReduceSinkOperator rSink = (ReduceSinkOperator) gP; + ReduceSinkDesc rDesc = rSink.getConf(); + rDesc.setTopN(rLimit); + rDesc.setTopNMemoryUsage(threshold); + rDesc.setMapGroupBy(true); + rDesc.setPTFReduceSink(true); + } + } public static class UDTFPPD extends DefaultPPD implements NodeProcessor { @Override @@ -865,7 +1053,7 @@ public static NodeProcessor getDefaultProc() { } public static NodeProcessor getPTFProc() { - return new ScriptPPD(); + return new PTFPPD(); } public static NodeProcessor getSCRProc() { diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java index 2290766..2fa3a00 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java @@ -62,6 +62,7 @@ public class WindowingTableFunction extends TableFunctionEvaluator { StreamingState streamingState; + RankLimit rnkLimitDef; @SuppressWarnings({ "unchecked", "rawtypes" }) @Override @@ -283,6 +284,12 @@ public void initializeStreaming(Configuration cfg, } } } + + if ( tabDef.getRankLimit() != -1 ) { + rnkLimitDef = new RankLimit(tabDef.getRankLimit(), + tabDef.getRankLimitFunction(), tabDef.getWindowFunctions()); + } + streamingState = new StreamingState(cfg, inputOI, isMapSide, tabDef, span[0], span[1]); } @@ -313,6 +320,13 @@ public void startPartition() throws HiveException { @Override public List processRow(Object row) throws HiveException { + /* + * Once enough rows have been output, there is no need to process input rows. + */ + if ( streamingState.rankLimitReached() ) { + return null; + } + streamingState.rollingPart.append(row); row = streamingState.rollingPart .getAt(streamingState.rollingPart.size() - 1); @@ -381,6 +395,13 @@ public void startPartition() throws HiveException { */ @Override public List finishPartition() throws HiveException { + + /* + * Once enough rows have been output, there is no need to generate more output. + */ + if ( streamingState.rankLimitReached() ) { + return null; + } WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); for (int i = 0; i < tabDef.getWindowFunctions().size(); i++) { @@ -428,15 +449,17 @@ public void startPartition() throws HiveException { List oRows = new ArrayList(); - while (!streamingState.rollingPart.processedAllRows()) { + while (!streamingState.rollingPart.processedAllRows() && + !streamingState.rankLimitReached() ) { boolean hasRow = streamingState.hasOutputRow(); - ; - if (!hasRow) { + if (!hasRow && !streamingState.rankLimitReached() ) { throw new HiveException( "Internal Error: cannot generate all output rows for a Partition"); } - oRows.add(streamingState.nextOutputRow()); + if ( hasRow ) { + oRows.add(streamingState.nextOutputRow()); + } } return oRows.size() == 0 ? null : oRows; @@ -496,6 +519,11 @@ public boolean canIterateOutput() { output.add(null); } + if ( wTFnDef.getRankLimit() != -1 ) { + rnkLimitDef = new RankLimit(wTFnDef.getRankLimit(), + wTFnDef.getRankLimitFunction(), wTFnDef.getWindowFunctions()); + } + return new WindowingIterator(iPart, output, outputFromPivotFunctions, ArrayUtils.toPrimitive(wFnsWithWindows.toArray(new Integer[wFnsWithWindows.size()]))); } @@ -1205,6 +1233,7 @@ public int size() { StructObjectInspector inputOI; AggregationBuffer[] aggBuffers; Object[][] args; + RankLimit rnkLimit; WindowingIterator(PTFPartition iPart, ArrayList output, List[] outputFromPivotFunctions, int[] wFnsToProcess) { @@ -1229,10 +1258,17 @@ public int size() { } catch (HiveException he) { throw new RuntimeException(he); } + if ( WindowingTableFunction.this.rnkLimitDef != null ) { + rnkLimit = new RankLimit(WindowingTableFunction.this.rnkLimitDef); + } } @Override public boolean hasNext() { + + if ( rnkLimit != null && rnkLimit.limitReached() ) { + return false; + } return currIdx < iPart.size(); } @@ -1279,6 +1315,9 @@ public Object next() { throw new RuntimeException(he); } + if ( rnkLimit != null ) { + rnkLimit.updateRank(output); + } currIdx++; return output; } @@ -1296,6 +1335,7 @@ public void remove() { AggregationBuffer[] aggBuffers; Object[][] funcArgs; Order order; + RankLimit rnkLimit; @SuppressWarnings("unchecked") StreamingState(Configuration cfg, StructObjectInspector inputOI, @@ -1321,6 +1361,9 @@ public void remove() { funcArgs[i] = new Object[wFn.getArgs() == null ? 0 : wFn.getArgs().size()]; aggBuffers[i] = wFn.getWFnEval().getNewAggregationBuffer(); } + if ( WindowingTableFunction.this.rnkLimitDef != null ) { + rnkLimit = new RankLimit(WindowingTableFunction.this.rnkLimitDef); + } } void reset(WindowTableFunctionDef tabDef) throws HiveException { @@ -1334,9 +1377,17 @@ void reset(WindowTableFunctionDef tabDef) throws HiveException { WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); aggBuffers[i] = wFn.getWFnEval().getNewAggregationBuffer(); } + + if ( rnkLimit != null ) { + rnkLimit.reset(); + } } boolean hasOutputRow() { + if ( rankLimitReached() ) { + return false; + } + for (int i = 0; i < fnOutputs.length; i++) { if (fnOutputs[i].size() == 0) { return false; @@ -1355,8 +1406,65 @@ boolean hasOutputRow() { for (StructField f : rollingPart.getOutputOI().getAllStructFieldRefs()) { oRow.add(rollingPart.getOutputOI().getStructFieldData(iRow, f)); } + if ( rnkLimit != null ) { + rnkLimit.updateRank(oRow); + } return oRow; } + + boolean rankLimitReached() { + return rnkLimit != null && rnkLimit.limitReached(); + } + } + + static class RankLimit { + + /* + * Rows with a rank <= rankLimit are output. + * Only the first row with rank = rankLimit is output. + */ + final int rankLimit; + + /* + * the rankValue of the last row output. + */ + int currentRank; + + /* + * index of Rank function. + */ + final int rankFnIdx; + + final PrimitiveObjectInspector fnOutOI; + + RankLimit(int rankLimit, int rankFnIdx, List wdwFnDefs) { + this.rankLimit = rankLimit; + this.rankFnIdx = rankFnIdx; + this.fnOutOI = (PrimitiveObjectInspector) wdwFnDefs.get(rankFnIdx).getOI(); + this.currentRank = -1; + } + + RankLimit(RankLimit rl) { + this.rankLimit = rl.rankLimit; + this.rankFnIdx = rl.rankFnIdx; + this.fnOutOI = rl.fnOutOI; + this.currentRank = -1; + } + + void reset() { + this.currentRank = -1; + } + + void updateRank(List oRow) { + int r = (Integer) fnOutOI.getPrimitiveJavaObject(oRow.get(rankFnIdx)); + if ( r > currentRank ) { + currentRank = r; + } + } + + boolean limitReached() { + return currentRank >= rankLimit; + } } } diff --git ql/src/test/queries/clientpositive/windowing_streaming.q ql/src/test/queries/clientpositive/windowing_streaming.q new file mode 100644 index 0000000..69a43d4 --- /dev/null +++ ql/src/test/queries/clientpositive/windowing_streaming.q @@ -0,0 +1,69 @@ +DROP TABLE if exists part; + +-- data setup +CREATE TABLE part( + p_partkey INT, + p_name STRING, + p_mfgr STRING, + p_brand STRING, + p_type STRING, + p_size INT, + p_container STRING, + p_retailprice DOUBLE, + p_comment STRING +); + +LOAD DATA LOCAL INPATH '../../data/files/part_tiny.txt' overwrite into table part; + +drop table over10k; + +create table over10k( + t tinyint, + si smallint, + i int, + b bigint, + f float, + d double, + bo boolean, + s string, + ts timestamp, + dec decimal(4,2), + bin binary) + row format delimited + fields terminated by '|'; + +load data local inpath '../../data/files/over10k' into table over10k; + +set hive.limit.pushdown.memory.usage=.8; + +-- part tests +select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 4; + +select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 2; + +-- over10k tests +select * +from (select t, f, rank() over(partition by t order by f) r from over10k) a +where r < 6 and t < 5; + +set hive.vectorized.execution.enabled=false; +set hive.limit.pushdown.memory.usage=0.8; +drop table if exists sB; +create table sB ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5; + +select * from sB +where ctinyint is null; + +set hive.vectorized.execution.enabled=true; +set hive.limit.pushdown.memory.usage=0.8; +drop table if exists sD; +create table sD ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5; + +select * from sD +where ctinyint is null; \ No newline at end of file diff --git ql/src/test/results/clientpositive/windowing_streaming.q.out ql/src/test/results/clientpositive/windowing_streaming.q.out new file mode 100644 index 0000000..1256077 --- /dev/null +++ ql/src/test/results/clientpositive/windowing_streaming.q.out @@ -0,0 +1,317 @@ +PREHOOK: query: DROP TABLE if exists part +PREHOOK: type: DROPTABLE +POSTHOOK: query: DROP TABLE if exists part +POSTHOOK: type: DROPTABLE +PREHOOK: query: -- data setup +CREATE TABLE part( + p_partkey INT, + p_name STRING, + p_mfgr STRING, + p_brand STRING, + p_type STRING, + p_size INT, + p_container STRING, + p_retailprice DOUBLE, + p_comment STRING +) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +POSTHOOK: query: -- data setup +CREATE TABLE part( + p_partkey INT, + p_name STRING, + p_mfgr STRING, + p_brand STRING, + p_type STRING, + p_size INT, + p_container STRING, + p_retailprice DOUBLE, + p_comment STRING +) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@part +PREHOOK: query: LOAD DATA LOCAL INPATH '../../data/files/part_tiny.txt' overwrite into table part +PREHOOK: type: LOAD +#### A masked pattern was here #### +PREHOOK: Output: default@part +POSTHOOK: query: LOAD DATA LOCAL INPATH '../../data/files/part_tiny.txt' overwrite into table part +POSTHOOK: type: LOAD +#### A masked pattern was here #### +POSTHOOK: Output: default@part +PREHOOK: query: drop table over10k +PREHOOK: type: DROPTABLE +POSTHOOK: query: drop table over10k +POSTHOOK: type: DROPTABLE +PREHOOK: query: create table over10k( + t tinyint, + si smallint, + i int, + b bigint, + f float, + d double, + bo boolean, + s string, + ts timestamp, + dec decimal(4,2), + bin binary) + row format delimited + fields terminated by '|' +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +POSTHOOK: query: create table over10k( + t tinyint, + si smallint, + i int, + b bigint, + f float, + d double, + bo boolean, + s string, + ts timestamp, + dec decimal(4,2), + bin binary) + row format delimited + fields terminated by '|' +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@over10k +PREHOOK: query: load data local inpath '../../data/files/over10k' into table over10k +PREHOOK: type: LOAD +#### A masked pattern was here #### +PREHOOK: Output: default@over10k +POSTHOOK: query: load data local inpath '../../data/files/over10k' into table over10k +POSTHOOK: type: LOAD +#### A masked pattern was here #### +POSTHOOK: Output: default@over10k +PREHOOK: query: -- part tests +select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 4 +PREHOOK: type: QUERY +PREHOOK: Input: default@part +#### A masked pattern was here #### +POSTHOOK: query: -- part tests +select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 4 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@part +#### A masked pattern was here #### +Manufacturer#1 1 +Manufacturer#1 1 +Manufacturer#1 3 +Manufacturer#2 1 +Manufacturer#2 2 +Manufacturer#2 3 +Manufacturer#3 1 +Manufacturer#3 2 +Manufacturer#3 3 +Manufacturer#4 1 +Manufacturer#4 2 +Manufacturer#4 3 +Manufacturer#5 1 +Manufacturer#5 2 +Manufacturer#5 3 +PREHOOK: query: select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 2 +PREHOOK: type: QUERY +PREHOOK: Input: default@part +#### A masked pattern was here #### +POSTHOOK: query: select * +from ( select p_mfgr, rank() over(partition by p_mfgr order by p_name) r from part) a +where r < 2 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@part +#### A masked pattern was here #### +Manufacturer#1 1 +Manufacturer#1 1 +Manufacturer#2 1 +Manufacturer#3 1 +Manufacturer#4 1 +Manufacturer#5 1 +PREHOOK: query: -- over10k tests +select * +from (select t, f, rank() over(partition by t order by f) r from over10k) a +where r < 6 and t < 5 +PREHOOK: type: QUERY +PREHOOK: Input: default@over10k +#### A masked pattern was here #### +POSTHOOK: query: -- over10k tests +select * +from (select t, f, rank() over(partition by t order by f) r from over10k) a +where r < 6 and t < 5 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@over10k +#### A masked pattern was here #### +-3 0.56 1 +-3 0.83 2 +-3 2.26 3 +-3 2.48 4 +-3 3.82 5 +-2 1.55 1 +-2 1.65 2 +-2 1.79 3 +-2 4.06 4 +-2 4.4 5 +-1 0.79 1 +-1 0.95 2 +-1 1.27 3 +-1 1.49 4 +-1 2.8 5 +0 0.08 1 +0 0.94 2 +0 1.44 3 +0 2.0 4 +0 2.12 5 +1 0.13 1 +1 0.44 2 +1 1.04 3 +1 3.41 4 +1 3.45 5 +2 2.21 1 +2 3.1 2 +2 9.93 3 +2 11.43 4 +2 15.45 5 +3 0.12 1 +3 0.19 2 +3 7.14 3 +3 7.97 4 +3 8.95 5 +4 2.26 1 +4 5.51 2 +4 5.53 3 +4 5.76 4 +4 7.26 5 +PREHOOK: query: drop table if exists sB +PREHOOK: type: DROPTABLE +POSTHOOK: query: drop table if exists sB +POSTHOOK: type: DROPTABLE +PREHOOK: query: create table sB ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5 +PREHOOK: type: CREATETABLE_AS_SELECT +PREHOOK: Input: default@alltypesorc +POSTHOOK: query: create table sB ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5 +POSTHOOK: type: CREATETABLE_AS_SELECT +POSTHOOK: Input: default@alltypesorc +POSTHOOK: Output: default@sB +PREHOOK: query: select * from sB +where ctinyint is null +PREHOOK: type: QUERY +PREHOOK: Input: default@sb +#### A masked pattern was here #### +POSTHOOK: query: select * from sB +where ctinyint is null +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sb +#### A masked pattern was here #### +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +PREHOOK: query: drop table if exists sD +PREHOOK: type: DROPTABLE +POSTHOOK: query: drop table if exists sD +POSTHOOK: type: DROPTABLE +PREHOOK: query: create table sD ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5 +PREHOOK: type: CREATETABLE_AS_SELECT +PREHOOK: Input: default@alltypesorc +POSTHOOK: query: create table sD ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE as +select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5 +POSTHOOK: type: CREATETABLE_AS_SELECT +POSTHOOK: Input: default@alltypesorc +POSTHOOK: Output: default@sD +PREHOOK: query: select * from sD +where ctinyint is null +PREHOOK: type: QUERY +PREHOOK: Input: default@sd +#### A masked pattern was here #### +POSTHOOK: query: select * from sD +where ctinyint is null +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sd +#### A masked pattern was here #### +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1 +NULL NULL 1