diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRowNumber.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRowNumber.java index 8e672e64759db40df53a56fb17bef6bb80dbdad1..e56aeea0dabe70dcd497df815fc292cb673bc25a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRowNumber.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRowNumber.java @@ -27,7 +27,10 @@ import org.apache.hadoop.hive.ql.exec.WindowFunctionDescription; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFRank.GenericUDAFAbstractRankEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFRank.RankBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; @@ -59,22 +62,36 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticE ArrayList rowNums; int nextRow; + boolean supportsStreaming; void init() { rowNums = new ArrayList(); + nextRow = 1; + if (supportsStreaming) { + rowNums.add(null); + } } - RowNumberBuffer() { + RowNumberBuffer(boolean supportsStreaming) { + this.supportsStreaming = supportsStreaming; init(); - nextRow = 1; } void incr() { - rowNums.add(new IntWritable(nextRow++)); + if (supportsStreaming) { + rowNums.set(0,new IntWritable(nextRow++)); + } else { + rowNums.add(new IntWritable(nextRow++)); + } } } - public static class GenericUDAFRowNumberEvaluator extends GenericUDAFEvaluator { + public static class GenericUDAFAbstractRowNumberEvaluator extends GenericUDAFEvaluator { + boolean isStreamingMode = false; + + protected boolean isStreaming() { + return isStreamingMode; + } @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { @@ -89,7 +106,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { - return new RowNumberBuffer(); + return new RowNumberBuffer(isStreamingMode); } @Override @@ -118,5 +135,26 @@ public Object terminate(AggregationBuffer agg) throws HiveException { } } + + public static class GenericUDAFRowNumberEvaluator extends GenericUDAFAbstractRowNumberEvaluator + implements ISupportStreamingModeForWindowing { + + @Override + public Object getNextResult(AggregationBuffer agg) throws HiveException { + return ((RowNumberBuffer) agg).rowNums.get(0); + } + + @Override + public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrmDef) { + isStreamingMode = true; + return this; + } + + @Override + public int getRowsRemainingAfterTerminate() throws HiveException { + return 0; + } + } + } diff --git a/ql/src/test/queries/clientpositive/windowing_streaming.q b/ql/src/test/queries/clientpositive/windowing_streaming.q index 294fe091c16a4d3c6cfd92e4c3b75cbc3ba2d912..b8442e68bdd4f5564a259e5853b33d65f020134b 100644 --- a/ql/src/test/queries/clientpositive/windowing_streaming.q +++ b/ql/src/test/queries/clientpositive/windowing_streaming.q @@ -43,6 +43,10 @@ select * from (select t, f, rank() over(partition by t order by f) r from over10k) a where r < 6 and t < 5; +select * +from (select t, f, row_number() over(partition by t order by f) r from over10k) a +where r < 8 and t < 0; + set hive.vectorized.execution.enabled=false; set hive.limit.pushdown.memory.usage=0.8; diff --git a/ql/src/test/results/clientpositive/windowing_streaming.q.out b/ql/src/test/results/clientpositive/windowing_streaming.q.out index a74ddb3d05851c7b7cb09dd1bd2e1a212f4ba182..701ae40768459c6df39afdc27a98fbb5083ed28b 100644 --- a/ql/src/test/results/clientpositive/windowing_streaming.q.out +++ b/ql/src/test/results/clientpositive/windowing_streaming.q.out @@ -287,6 +287,39 @@ POSTHOOK: Input: default@over10k 4 5.53 3 4 5.76 4 4 7.26 5 +PREHOOK: query: select * +from (select t, f, row_number() over(partition by t order by f) r from over10k) a +where r < 8 and t < 0 +PREHOOK: type: QUERY +PREHOOK: Input: default@over10k +#### A masked pattern was here #### +POSTHOOK: query: select * +from (select t, f, row_number() over(partition by t order by f) r from over10k) a +where r < 8 and t < 0 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@over10k +#### A masked pattern was here #### +-3 0.56 1 +-3 0.83 2 +-3 2.26 3 +-3 2.48 4 +-3 3.82 5 +-3 6.8 6 +-3 6.83 7 +-2 1.55 1 +-2 1.65 2 +-2 1.79 3 +-2 4.06 4 +-2 4.4 5 +-2 5.43 6 +-2 5.59 7 +-1 0.79 1 +-1 0.95 2 +-1 1.27 3 +-1 1.49 4 +-1 2.8 5 +-1 4.08 6 +-1 4.31 7 PREHOOK: query: explain select * from (select ctinyint, cdouble, rank() over(partition by ctinyint order by cdouble) r from alltypesorc) a where r < 5 PREHOOK: type: QUERY