diff --git ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java index 7aaf455..88f59dc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/ppd/OpProcFactory.java @@ -133,6 +133,26 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, } } + + public class PTFPPD extends ScriptPPD { + + /* + * For WindowingTableFunction if: + * a. there is a Rank/DenseRank function: if there are unpushedPred of the form + * rnkValue < Constant; then use the smallest Constant val as the 'rankLimit' + * on the WindowingTablFn. + * b. If there are no Wdw Fns with an End Boundary past the current row, the + * condition can be pushed down as a limit pushdown(mapGroupBy=true) + * + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.ppd.OpProcFactory.ScriptPPD#process(org.apache.hadoop.hive.ql.lib.Node, java.util.Stack, org.apache.hadoop.hive.ql.lib.NodeProcessorCtx, java.lang.Object[]) + */ + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + return super.process(nd, stack, procCtx, nodeOutputs); + } + } public static class UDTFPPD extends DefaultPPD implements NodeProcessor { @Override diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java index 0a67fea..df9ad43 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java @@ -60,6 +60,7 @@ public class WindowingTableFunction extends TableFunctionEvaluator { StreamingState streamingState; + RankLimit rnkLimitDef; @SuppressWarnings({ "unchecked", "rawtypes" }) @Override @@ -141,6 +142,11 @@ private boolean processWindow(WindowFunctionDef wFn) { } return true; } + + public void setRankLimit(int rankLimit, int rnkFnIdx) { + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + rnkLimitDef = new RankLimit(rankLimit, rnkFnIdx, tabDef.getWindowFunctions()); + } /* * (non-Javadoc) @@ -236,6 +242,9 @@ public void initializeStreaming(Configuration cfg, } } } + //@TODO remove + rnkLimitDef = new RankLimit(2, 0, tabDef.getWindowFunctions()); + streamingState = new StreamingState(cfg, inputOI, isMapSide, tabDef, span[0], span[1]); } @@ -266,6 +275,13 @@ public void startPartition() throws HiveException { @Override public List processRow(Object row) throws HiveException { + /* + * Once enough rows have been output, there is no need to process input rows. + */ + if ( streamingState.rankLimitReached() ) { + return null; + } + streamingState.rollingPart.append(row); row = streamingState.rollingPart .getAt(streamingState.rollingPart.size() - 1); @@ -334,6 +350,13 @@ public void startPartition() throws HiveException { */ @Override public List finishPartition() throws HiveException { + + /* + * Once enough rows have been output, there is no need to generate more output. + */ + if ( streamingState.rankLimitReached() ) { + return null; + } WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); for (int i = 0; i < tabDef.getWindowFunctions().size(); i++) { @@ -373,15 +396,17 @@ public void startPartition() throws HiveException { List oRows = new ArrayList(); - while (!streamingState.rollingPart.processedAllRows()) { + while (!streamingState.rollingPart.processedAllRows() && + !streamingState.rankLimitReached() ) { boolean hasRow = streamingState.hasOutputRow(); - ; - if (!hasRow) { + if (!hasRow && !streamingState.rankLimitReached() ) { throw new HiveException( "Internal Error: cannot generate all output rows for a Partition"); } - oRows.add(streamingState.nextOutputRow()); + if ( hasRow ) { + oRows.add(streamingState.nextOutputRow()); + } } return oRows.size() == 0 ? null : oRows; @@ -434,6 +459,9 @@ public boolean canIterateOutput() { output.add(null); } + //@TODO remove + rnkLimitDef = new RankLimit(2, 0, wTFnDef.getWindowFunctions()); + return new WindowingIterator(iPart, output, outputFromPivotFunctions, ArrayUtils.toPrimitive(wFnsWithWindows.toArray(new Integer[wFnsWithWindows.size()]))); } @@ -1143,6 +1171,7 @@ public int size() { StructObjectInspector inputOI; AggregationBuffer[] aggBuffers; Object[][] args; + RankLimit rnkLimit; WindowingIterator(PTFPartition iPart, ArrayList output, List[] outputFromPivotFunctions, int[] wFnsToProcess) { @@ -1167,10 +1196,17 @@ public int size() { } catch (HiveException he) { throw new RuntimeException(he); } + if ( WindowingTableFunction.this.rnkLimitDef != null ) { + rnkLimit = new RankLimit(WindowingTableFunction.this.rnkLimitDef); + } } @Override public boolean hasNext() { + + if ( rnkLimit != null && rnkLimit.limitReached() ) { + return false; + } return currIdx < iPart.size(); } @@ -1217,6 +1253,9 @@ public Object next() { throw new RuntimeException(he); } + if ( rnkLimit != null ) { + rnkLimit.updateRank(output); + } currIdx++; return output; } @@ -1234,6 +1273,7 @@ public void remove() { AggregationBuffer[] aggBuffers; Object[][] funcArgs; Order order; + RankLimit rnkLimit; @SuppressWarnings("unchecked") StreamingState(Configuration cfg, StructObjectInspector inputOI, @@ -1258,6 +1298,9 @@ public void remove() { WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); funcArgs[i] = new Object[wFn.getArgs() == null ? 0 : wFn.getArgs().size()]; } + if ( WindowingTableFunction.this.rnkLimitDef != null ) { + rnkLimit = new RankLimit(WindowingTableFunction.this.rnkLimitDef); + } } void reset(WindowTableFunctionDef tabDef) throws HiveException { @@ -1271,9 +1314,17 @@ void reset(WindowTableFunctionDef tabDef) throws HiveException { WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); aggBuffers[i] = wFn.getWFnEval().getNewAggregationBuffer(); } + + if ( rnkLimit != null ) { + rnkLimit.reset(); + } } boolean hasOutputRow() { + if ( rankLimitReached() ) { + return false; + } + for (int i = 0; i < fnOutputs.length; i++) { if (fnOutputs[i].size() == 0) { return false; @@ -1292,8 +1343,65 @@ boolean hasOutputRow() { for (StructField f : rollingPart.getOutputOI().getAllStructFieldRefs()) { oRow.add(rollingPart.getOutputOI().getStructFieldData(iRow, f)); } + if ( rnkLimit != null ) { + rnkLimit.updateRank(oRow); + } return oRow; } + + boolean rankLimitReached() { + return rnkLimit != null && rnkLimit.limitReached(); + } + } + + static class RankLimit { + + /* + * Rows with a rank <= rankLimit are output. + * Only the first row with rank = rankLimit is output. + */ + final int rankLimit; + + /* + * the rankValue of the last row output. + */ + int currentRank; + + /* + * index of Rank function. + */ + final int rankFnIdx; + + final PrimitiveObjectInspector fnOutOI; + + RankLimit(int rankLimit, int rankFnIdx, List wdwFnDefs) { + this.rankLimit = rankLimit; + this.rankFnIdx = rankFnIdx; + this.fnOutOI = (PrimitiveObjectInspector) wdwFnDefs.get(rankFnIdx).getOI(); + this.currentRank = -1; + } + + RankLimit(RankLimit rl) { + this.rankLimit = rl.rankLimit; + this.rankFnIdx = rl.rankFnIdx; + this.fnOutOI = rl.fnOutOI; + this.currentRank = -1; + } + + void reset() { + this.currentRank = -1; + } + + void updateRank(List oRow) { + int r = (Integer) fnOutOI.getPrimitiveJavaObject(oRow.get(rankFnIdx)); + if ( r > currentRank ) { + currentRank = r; + } + } + + boolean limitReached() { + return currentRank >= rankLimit; + } } }