diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java index bec3cb7..d3b170d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/OperatorUtils.java @@ -18,18 +18,26 @@ package org.apache.hadoop.hive.ql.exec; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.Stack; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.NodeUtils.Function; +import org.apache.hadoop.hive.ql.plan.OpTraits; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.mapred.OutputCollector; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + public class OperatorUtils { private static final Log LOG = LogFactory.getLog(OperatorUtils.class); @@ -140,4 +148,64 @@ private static void iterateParents(Operator operator, Function> f } } } + + public static List> flatten(Collection> top) { + Set visited = new HashSet(); + List> ops = new ArrayList>(); + Stack> stack = new Stack>(); + for(Operator op : top) { + stack.push(op); + } + while(!stack.isEmpty()) { + Operator op = stack.pop(); + String name = op.toString(); + if ( !visited.contains(name)) { + visited.add(name); + ops.add(op); + List> children = op.getChildOperators(); + for (int i = children.size() -1; i >= 0; i--) { + stack.push(children.get(i)); + } + } + } + return ops; + } + + public static List convert(Collection> ops, com.google.common.base.Function, String> fn ) { + return Lists.transform(ImmutableList.copyOf(ops), fn); + } + + public static List traits(Collection> ops) { + com.google.common.base.Function, String> fn = new com.google.common.base.Function, String>() { + public String apply(Operator op) { + String s = String.format("%s: None", op.toString()); + if (op.getOpTraits() != null) { + OpTraits t = op.getOpTraits(); + s = String.format("%s: bucketCols=%s,numBuckets=%d", op.toString(), + t.getBucketColNames(), t.getNumBuckets()); + } + return s; + } + }; + + return convert(ops, fn); + } + public static List columnExprMap(Collection> ops) { + com.google.common.base.Function, String> fn = new com.google.common.base.Function, String>() { + public String apply(Operator op) { + String s = String.format("%s: None", op.toString()); + if (op.getColumnExprMap() != null) { + s = String.format("%s: %s", op.toString(), + op.getColumnExprMap()); + } + return s; + } + }; + + return convert(ops, fn); + } + + public static String print(List l, String sep) { + return Joiner.on(sep).join(l); + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/PTFOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/PTFOperator.java index d3800c2..2705ab9 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/PTFOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/PTFOperator.java @@ -83,6 +83,7 @@ protected void initializeOp(Configuration jobConf) throws HiveException { setupKeysWrapper(inputObjInspectors[0]); ptfInvocation = setupChain(); + ptfInvocation.initializeStreaming(jobConf, isMapOperator); firstMapRow = true; super.initializeOp(jobConf); @@ -282,6 +283,20 @@ boolean isStreaming() { return tabFn.canAcceptInputAsStream(); } + void initializeStreaming(Configuration cfg, boolean isMapSide) throws HiveException { + PartitionedTableFunctionDef tabDef = tabFn.getTableDef(); + PTFInputDef inputDef = tabDef.getInput(); + ObjectInspector inputOI = conf.getStartOfChain() == tabDef ? + inputObjInspectors[0] : inputDef.getOutputShape().getOI(); + if (isStreaming()) { + tabFn.initializeStreaming(cfg, (StructObjectInspector) inputOI, isMapSide); + } + + if ( next != null ) { + next.initializeStreaming(cfg, isMapSide); + } + } + void startPartition() throws HiveException { if ( isStreaming() ) { tabFn.startPartition(); @@ -301,15 +316,6 @@ void startPartition() throws HiveException { void processRow(Object row) throws HiveException { if ( isStreaming() ) { - if ( prev == null ) { - /* - * this is needed because during Translation we are still assuming that rows - * are collected into a PTFPartition. - * @Todo make translation handle the case when the first PTF is Streaming. - */ - row = ObjectInspectorUtils.copyToStandardObject(row, inputObjInspectors[0], - ObjectInspectorCopyOption.WRITABLE); - } handleOutputRows(tabFn.processRow(row)); } else { inputPart.append(row); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java index b5adb11..21d85f1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java @@ -51,13 +51,25 @@ protected PTFPartition(Configuration cfg, SerDe serDe, StructObjectInspector inputOI, StructObjectInspector outputOI) throws HiveException { + this(cfg, serDe, inputOI, outputOI, true); + } + + protected PTFPartition(Configuration cfg, + SerDe serDe, StructObjectInspector inputOI, + StructObjectInspector outputOI, + boolean createElemContainer) + throws HiveException { this.serDe = serDe; this.inputOI = inputOI; this.outputOI = outputOI; - int containerNumRows = HiveConf.getIntVar(cfg, ConfVars.HIVEJOINCACHESIZE); - elems = new PTFRowContainer>(containerNumRows, cfg, null); - elems.setSerDe(serDe, outputOI); - elems.setTableDesc(PTFRowContainer.createTableDesc(inputOI)); + if ( createElemContainer ) { + int containerNumRows = HiveConf.getIntVar(cfg, ConfVars.HIVEJOINCACHESIZE); + elems = new PTFRowContainer>(containerNumRows, cfg, null); + elems.setSerDe(serDe, outputOI); + elems.setTableDesc(PTFRowContainer.createTableDesc(inputOI)); + } else { + elems = null; + } } public void reset() throws HiveException { @@ -233,6 +245,16 @@ public static PTFPartition create(Configuration cfg, throws HiveException { return new PTFPartition(cfg, serDe, inputOI, outputOI); } + + public static PTFRollingPartition createRolling(Configuration cfg, + SerDe serDe, + StructObjectInspector inputOI, + StructObjectInspector outputOI, + int precedingSpan, + int followingSpan) + throws HiveException { + return new PTFRollingPartition(cfg, serDe, inputOI, outputOI, precedingSpan, followingSpan); + } public static StructObjectInspector setupPartitionOutputOI(SerDe serDe, StructObjectInspector tblFnOI) throws SerDeException { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/PTFRollingPartition.java ql/src/java/org/apache/hadoop/hive/ql/exec/PTFRollingPartition.java new file mode 100644 index 0000000..8d2f667 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/PTFRollingPartition.java @@ -0,0 +1,161 @@ +package org.apache.hadoop.hive.ql.exec; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFunctionDef; +import org.apache.hadoop.hive.serde2.SerDe; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +public class PTFRollingPartition extends PTFPartition { + + /* + * num rows whose output is evaluated. + */ + int numRowsProcessed; + + /* + * number rows to maintain before nextRowToProcess + */ + int precedingSpan; + + /* + * number rows to maintain after nextRowToProcess + */ + int followingSpan; + + /* + * number of rows received. + */ + int numRowsReceived; + + /* + * cache of rows; guaranteed to contain precedingSpan rows before nextRowToProcess. + */ + List currWindow; + + protected PTFRollingPartition(Configuration cfg, + SerDe serDe, StructObjectInspector inputOI, + StructObjectInspector outputOI, + int precedingSpan, + int succeedingSpan) + throws HiveException { + super(cfg, serDe, inputOI, outputOI, false); + this.precedingSpan = precedingSpan; + this.followingSpan = succeedingSpan; + currWindow = new ArrayList(); + } + + public void reset() throws HiveException { + currWindow.clear(); + numRowsProcessed = 0; + numRowsReceived = 0; + } + + public Object getAt(int i) throws HiveException { + int rangeStart = numRowsReceived - currWindow.size(); + return currWindow.get(i - rangeStart); + } + + public void append(Object o) throws HiveException { + @SuppressWarnings("unchecked") + List l = (List) + ObjectInspectorUtils.copyToStandardObject(o, inputOI, ObjectInspectorCopyOption.WRITABLE); + currWindow.add(l); + numRowsReceived++; + } + + public Object nextOutputRow() throws HiveException { + Object row = getAt(numRowsProcessed); + numRowsProcessed++; + if ( numRowsProcessed > precedingSpan ) { + currWindow.remove(0); + } + return row; + } + + public boolean processedAllRows() { + return numRowsProcessed >= numRowsReceived; + } + + public int getRowsReceived() { + return numRowsReceived; + } + + public int rowToProcess(WindowFunctionDef wFn) { + int rowToProcess = numRowsReceived - wFn.getWindowFrame().getEnd().getAmt() - 1; + return rowToProcess >= 0 ? rowToProcess : -1; + } + + public int size() { + return currWindow.size(); + } + + public PTFPartitionIterator iterator() throws HiveException { + return new RollingPItr(); + } + + public void close() { + } + + class RollingPItr implements PTFPartitionIterator { + + @Override + public boolean hasNext() { + throw new UnsupportedOperationException(); + } + + @Override + public Object next() { + throw new UnsupportedOperationException(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public int getIndex() { + return PTFRollingPartition.this.numRowsProcessed; + } + + @Override + public Object lead(int amt) throws HiveException { + int i = PTFRollingPartition.this.numRowsProcessed + amt; + i = i >= PTFRollingPartition.this.numRowsReceived ? + PTFRollingPartition.this.numRowsReceived - 1 : i; + return PTFRollingPartition.this.getAt(i); + } + + @Override + public Object lag(int amt) throws HiveException { + int i = PTFRollingPartition.this.numRowsProcessed - amt; + int start = PTFRollingPartition.this.numRowsReceived - + PTFRollingPartition.this.currWindow.size(); + + i = i < start ? + start : i; + return PTFRollingPartition.this.getAt(i); + } + + @Override + public Object resetToIndex(int idx) throws HiveException { + return PTFRollingPartition.this.getAt(idx); + } + + @Override + public PTFPartition getPartition() { + return PTFRollingPartition.this; + } + + @Override + public void reset() throws HiveException { + } + + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java index 326654f..2fede2f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapRedTask.java @@ -130,6 +130,7 @@ public int execute(DriverContext driverContext) { runningViaChild = ShimLoader.getHadoopShims().isLocalMode(conf) || conf.getBoolVar(HiveConf.ConfVars.SUBMITVIACHILD); + runningViaChild=false; if(!runningViaChild) { // we are not running this mapred task via child jvm // so directly invoke ExecDriver diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapredLocalTask.java ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapredLocalTask.java index 34b063d..3c605d3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapredLocalTask.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/mr/MapredLocalTask.java @@ -130,6 +130,12 @@ public boolean requireLock() { @Override public int execute(DriverContext driverContext) { + + boolean a = 1 > 0; + if(a) { + return executeFromChildJVM(driverContext); + } + try { // generate the cmd line to run in the child jvm Context ctx = driverContext.getCtx(); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCumeDist.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCumeDist.java index 18c8c8d..fbadb91 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCumeDist.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCumeDist.java @@ -53,12 +53,12 @@ static final Log LOG = LogFactory.getLog(GenericUDAFCumeDist.class.getName()); @Override - protected GenericUDAFRankEvaluator createEvaluator() + protected GenericUDAFAbstractRankEvaluator createEvaluator() { return new GenericUDAFCumeDistEvaluator(); } - public static class GenericUDAFCumeDistEvaluator extends GenericUDAFRankEvaluator + public static class GenericUDAFCumeDistEvaluator extends GenericUDAFAbstractRankEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFDenseRank.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFDenseRank.java index c1d43d8..8856fb7 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFDenseRank.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFDenseRank.java @@ -43,7 +43,7 @@ static final Log LOG = LogFactory.getLog(GenericUDAFDenseRank.class.getName()); @Override - protected GenericUDAFRankEvaluator createEvaluator() + protected GenericUDAFAbstractRankEvaluator createEvaluator() { return new GenericUDAFDenseRankEvaluator(); } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFEvaluator.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFEvaluator.java index 5668a3b..2a9cc9a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFEvaluator.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFEvaluator.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.exec.MapredContext; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -232,5 +233,23 @@ public Object evaluate(AggregationBuffer agg) throws HiveException { * @return final aggregation result. */ public abstract Object terminate(AggregationBuffer agg) throws HiveException; + + + /** + * When evaluating an aggregates over a fixed Window, the naive way to compute results + * is to compute the aggregate for each row. But often there is a way to compute results + * in a more efficient manner. This method enables the basic evaluator to provide a function + * object that does the job in a more efficient manner. + *

+ * This method is called after this Evaluator is initialized. The returned Function must be + * initialized. It is passed the 'window' of aggregation for each row. + * + * @param wFrmDef the Window definition in play for this evaluation. + * @return null implies that this fn cannot be processed in Streaming mode. So each row is evaluated + * independently. + */ + public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrmDef) { + return null; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFPercentRank.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFPercentRank.java index aab1922..1cca03e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFPercentRank.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFPercentRank.java @@ -49,12 +49,12 @@ static final Log LOG = LogFactory.getLog(GenericUDAFPercentRank.class.getName()); @Override - protected GenericUDAFRankEvaluator createEvaluator() + protected GenericUDAFAbstractRankEvaluator createEvaluator() { return new GenericUDAFPercentRankEvaluator(); } - public static class GenericUDAFPercentRankEvaluator extends GenericUDAFRankEvaluator + public static class GenericUDAFPercentRankEvaluator extends GenericUDAFAbstractRankEvaluator { @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRank.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRank.java index 5c8f1e0..090e7a8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRank.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFRank.java @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.ql.exec.WindowFunctionDescription; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -71,7 +72,7 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticE return createEvaluator(); } - protected GenericUDAFRankEvaluator createEvaluator() + protected GenericUDAFAbstractRankEvaluator createEvaluator() { return new GenericUDAFRankEvaluator(); } @@ -83,10 +84,12 @@ protected GenericUDAFRankEvaluator createEvaluator() Object[] currVal; int currentRank; int numParams; + boolean supportsStreaming; - RankBuffer(int numParams) + RankBuffer(int numParams, boolean supportsStreaming) { this.numParams = numParams; + this.supportsStreaming = supportsStreaming; init(); } @@ -96,20 +99,33 @@ void init() currentRowNum = 0; currentRank = 0; currVal = new Object[numParams]; + if ( supportsStreaming ) { + /* initialize rowNums to have 1 row */ + rowNums.add(null); + } } - + void incrRowNum() { currentRowNum++; } void addRank() { - rowNums.add(new IntWritable(currentRank)); + if ( supportsStreaming ) { + rowNums.set(0, new IntWritable(currentRank)); + } else { + rowNums.add(new IntWritable(currentRank)); + } } } - public static class GenericUDAFRankEvaluator extends GenericUDAFEvaluator + public static abstract class GenericUDAFAbstractRankEvaluator extends GenericUDAFEvaluator { ObjectInspector[] inputOI; ObjectInspector[] outputOI; + boolean isStreamingMode = false; + + protected boolean isStreaming() { + return isStreamingMode; + } @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException @@ -132,7 +148,7 @@ public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveExc @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { - return new RankBuffer(inputOI.length); + return new RankBuffer(inputOI.length, isStreamingMode); } @Override @@ -182,6 +198,22 @@ public Object terminate(AggregationBuffer agg) throws HiveException } } + + public static class GenericUDAFRankEvaluator extends GenericUDAFAbstractRankEvaluator + implements ISupportStreamingModeForWindowing { + + @Override + public Object getNextResult(AggregationBuffer agg) throws HiveException { + return ((RankBuffer) agg).rowNums.get(0); + } + + @Override + public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrmDef) { + isStreamingMode = true; + return this; + } + + } public static int compare(Object[] o1, ObjectInspector[] oi1, Object[] o2, ObjectInspector[] oi2) diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFStreamingSum.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFStreamingSum.java new file mode 100644 index 0000000..1630e6a --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFStreamingSum.java @@ -0,0 +1,184 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.util.StringUtils; + +/* + * Set of classes that support summation over a 'fixed' Window. + * Aggregation for each row is computed based on the current state plus + * the row's value. + */ +public class GenericUDAFStreamingSum { + + static final Log LOG = LogFactory.getLog(GenericUDAFStreamingSum.class.getName()); + + + /** + * GenericUDAFSumDouble. + * + */ + public static class GenericUDAFSumDouble extends GenericUDAFEvaluator implements ISupportStreamingModeForWindowing { + private PrimitiveObjectInspector inputOI; + private DoubleWritable result; + private final int numPreceding; + private final int numFollowing; + + public GenericUDAFSumDouble(PrimitiveObjectInspector inputOI, DoubleWritable result, int numPreceding, int numFollowing) { + this.mode = Mode.COMPLETE; + this.inputOI = inputOI; + this.result = result; + this.numPreceding = numPreceding; + this.numFollowing = numFollowing; + } + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 1); + super.init(m, parameters); + result = new DoubleWritable(0); + inputOI = (PrimitiveObjectInspector) parameters[0]; + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + /** class for storing double sum value. */ + @AggregationType(estimable = true) + static class SumDoubleAgg extends AbstractAggregationBuffer { + boolean empty; + double sum; + final int numPreceding; + final int numFollowing; + final List results; + final List cumeSum; + final int windowSize; + int numRows; + + SumDoubleAgg(int numPreceding, int numFollowing) { + this.numPreceding = numPreceding; + this.numFollowing = numFollowing; + windowSize = numPreceding + numFollowing + 1; + results = new ArrayList(); + cumeSum = new ArrayList(); + } + + @Override + public int estimate() { return (4 * JavaDataModel.PRIMITIVES1) + ( (windowSize + 2) * JavaDataModel.PRIMITIVES2); } + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + SumDoubleAgg result = new SumDoubleAgg(numPreceding, numFollowing); + reset(result); + return result; + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + SumDoubleAgg myagg = (SumDoubleAgg) agg; + myagg.empty = true; + myagg.sum = 0; + myagg.results.clear(); + myagg.cumeSum.clear(); + myagg.numRows = 0; + } + + boolean warned = false; + + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + assert (parameters.length == 1); + try { + SumDoubleAgg myagg = (SumDoubleAgg) agg; + Object o = parameters[0]; + Double d = o == null ? null : PrimitiveObjectInspectorUtils.getDouble(o, inputOI); + if ( o != null ) { + myagg.empty = false; + myagg.sum += d; + } + + if ( myagg.numRows >= myagg.numFollowing ) { + double r = myagg.sum; + if ( myagg.numPreceding != BoundarySpec.UNBOUNDED_AMOUNT && + (myagg.numRows - myagg.numFollowing) >= (myagg.numPreceding + 1) ) { + r = r - myagg.cumeSum.remove(0); + } + myagg.results.add(r); + } + + if ( myagg.numPreceding != BoundarySpec.UNBOUNDED_AMOUNT ) { + myagg.cumeSum.add(myagg.sum); + } + //System.out.println(myagg.numRows + ": cumeSum=" + myagg.cumeSum + ", res = " + myagg.results); + myagg.numRows++; + + } catch (NumberFormatException e) { + if (!warned) { + warned = true; + LOG.warn(getClass().getSimpleName() + " " + + StringUtils.stringifyException(e)); + LOG + .warn(getClass().getSimpleName() + + " ignoring similar exceptions."); + } + } + } + + @Override + public Object terminate(AggregationBuffer agg) throws HiveException { + SumDoubleAgg myagg = (SumDoubleAgg) agg; + + for(int i=0; i < myagg.numFollowing; i++ ) { + double r = myagg.sum; + if ( myagg.numPreceding != BoundarySpec.UNBOUNDED_AMOUNT && + (myagg.numRows - myagg.numFollowing) >= (myagg.numPreceding + 1) ) { + r = r - myagg.cumeSum.remove(0); + } + myagg.results.add(r); + } + + if (myagg.empty) { + return null; + } + result.set(myagg.sum); + return result; + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + throw new HiveException(getClass().getSimpleName() + ": terminatePartial not supported"); + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + throw new HiveException(getClass().getSimpleName() + ": merge not supported"); + } + + @Override + public Object getNextResult(AggregationBuffer agg) throws HiveException { + SumDoubleAgg myagg = (SumDoubleAgg) agg; + if ( !myagg.results.isEmpty() ) { + Double d = myagg.results.remove(0); + if ( d == null ) { + return ISupportStreamingModeForWindowing.NULL_RESULT; + } + result.set(d); + return result; + } + return null; + } + + } + + +} diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index 8508ffb..5458736 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -24,6 +24,10 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec; +import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.ValueBoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; @@ -263,6 +267,23 @@ public Object terminate(AggregationBuffer agg) throws HiveException { result.set(myagg.sum); return result; } + + @Override + public GenericUDAFEvaluator getWindowingEvaluator(WindowFrameDef wFrmDef) { + + BoundaryDef start = wFrmDef.getStart(); + BoundaryDef end = wFrmDef.getEnd(); + + if ( start instanceof ValueBoundaryDef || end instanceof ValueBoundaryDef ) { + return null; + } + + if ( end.getAmt() == BoundarySpec.UNBOUNDED_AMOUNT ) { + return null; + } + + return new GenericUDAFStreamingSum.GenericUDAFSumDouble(inputOI, result, start.getAmt(), end.getAmt()); + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/ISupportStreamingModeForWindowing.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/ISupportStreamingModeForWindowing.java new file mode 100644 index 0000000..4aa56a1 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/ISupportStreamingModeForWindowing.java @@ -0,0 +1,19 @@ +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.exec.WindowFunctionInfo; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.ptf.WindowingTableFunction; + + +/** + * A GenericUDAF mode that provides it results as a List to the {@link WindowingTableFunction} + * (so it is a {@link WindowFunctionInfo#isPivotResult()} return true) may support this interface. + * If it does then the WindowingTableFunction will ask it for the next Result after every aggregate call. + */ +public interface ISupportStreamingModeForWindowing { + + Object getNextResult(AggregationBuffer agg) throws HiveException; + + public static Object NULL_RESULT = new Object(); +} diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopStreaming.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopStreaming.java index d50a542..ec7b1cc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopStreaming.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopStreaming.java @@ -21,13 +21,18 @@ import java.util.ArrayList; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.PTFDesc; import org.apache.hadoop.hive.ql.plan.ptf.PartitionedTableFunctionDef; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; public class NoopStreaming extends Noop { List rows; + StructObjectInspector inputOI; NoopStreaming() { rows = new ArrayList(); @@ -37,6 +42,11 @@ public boolean canAcceptInputAsStream() { return true; } + public void initializeStreaming(Configuration cfg, StructObjectInspector inputOI, boolean isMapSide) + throws HiveException { + this.inputOI = inputOI; + } + public List processRow(Object row) throws HiveException { if (!canAcceptInputAsStream() ) { throw new HiveException(String.format( @@ -44,6 +54,8 @@ public boolean canAcceptInputAsStream() { getClass().getName())); } rows.clear(); + row = ObjectInspectorUtils.copyToStandardObject(row, inputOI, + ObjectInspectorCopyOption.WRITABLE); rows.add(row); return rows; } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopWithMapStreaming.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopWithMapStreaming.java index be1f9ab..80ea631 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopWithMapStreaming.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/NoopWithMapStreaming.java @@ -21,12 +21,17 @@ import java.util.ArrayList; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.PTFDesc; import org.apache.hadoop.hive.ql.plan.ptf.PartitionedTableFunctionDef; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; public class NoopWithMapStreaming extends NoopWithMap { List rows; + StructObjectInspector inputOI; NoopWithMapStreaming() { rows = new ArrayList(); @@ -36,6 +41,11 @@ public boolean canAcceptInputAsStream() { return true; } + public void initializeStreaming(Configuration cfg, StructObjectInspector inputOI, boolean isMapSide) + throws HiveException { + this.inputOI = inputOI; + } + public List processRow(Object row) throws HiveException { if (!canAcceptInputAsStream() ) { throw new HiveException(String.format( @@ -43,6 +53,8 @@ public boolean canAcceptInputAsStream() { getClass().getName())); } rows.clear(); + row = ObjectInspectorUtils.copyToStandardObject(row, inputOI, + ObjectInspectorCopyOption.WRITABLE); rows.add(row); return rows; } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/TableFunctionEvaluator.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/TableFunctionEvaluator.java index 8a1e085..d62314a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/TableFunctionEvaluator.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/TableFunctionEvaluator.java @@ -21,6 +21,7 @@ import java.util.Iterator; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator; @@ -29,6 +30,7 @@ import org.apache.hadoop.hive.ql.plan.PTFDesc; import org.apache.hadoop.hive.ql.plan.ptf.PartitionedTableFunctionDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; /* @@ -218,6 +220,15 @@ public boolean canAcceptInputAsStream() { return false; } + public void initializeStreaming(Configuration cfg, StructObjectInspector inputOI, boolean isMapSide) + throws HiveException { + if (!canAcceptInputAsStream() ) { + throw new HiveException(String.format( + "Internal error: PTF %s, doesn't support Streaming", + getClass().getName())); + } + } + public void startPartition() throws HiveException { if (!canAcceptInputAsStream() ) { throw new HiveException(String.format( diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java index de511f4..03689c4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java @@ -24,9 +24,12 @@ import java.util.List; import org.apache.commons.lang.ArrayUtils; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator; +import org.apache.hadoop.hive.ql.exec.Utilities.StreamStatus; +import org.apache.hadoop.hive.ql.exec.PTFRollingPartition; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.PTFInvocationSpec.Order; import org.apache.hadoop.hive.ql.parse.SemanticException; @@ -42,6 +45,9 @@ import org.apache.hadoop.hive.ql.plan.ptf.WindowTableFunctionDef; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.ISupportStreamingModeForWindowing; +import org.apache.hadoop.hive.serde2.SerDe; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; @@ -51,6 +57,9 @@ @SuppressWarnings("deprecation") public class WindowingTableFunction extends TableFunctionEvaluator { + StreamingState streamingState; + Boolean canAcceptStreamingFlag; + @SuppressWarnings({ "unchecked", "rawtypes" }) @Override public void execute(PTFPartitionIterator pItr, PTFPartition outP) throws HiveException { @@ -132,6 +141,243 @@ private boolean processWindow(WindowFunctionDef wFn) { return true; } + /* + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.udf.ptf.TableFunctionEvaluator#canAcceptInputAsStream() + * + * WindowTableFunction supports streaming if all functions meet one of these conditions: + * 1. The Function implements ISupportStreamingModeForWindowing + * 2. Or returns a non null Object for the getWindowingEvaluator, that + * implements ISupportStreamingModeForWindowing. + * 3. Is an invocation on a 'fixed' window. So no Unbounded Preceding or Following. + */ + @Override + public boolean canAcceptInputAsStream() { + + if ( canAcceptStreamingFlag != null ) { + return canAcceptStreamingFlag; + } + + canAcceptStreamingFlag = false; + + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + + for(int i=0; i < tabDef.getWindowFunctions().size(); i++) { + WindowFunctionDef wFnDef = tabDef.getWindowFunctions().get(i); + WindowFrameDef wdwFrame = wFnDef.getWindowFrame(); + GenericUDAFEvaluator fnEval = wFnDef.getWFnEval(); + if ( fnEval instanceof ISupportStreamingModeForWindowing ) { + continue; + } + GenericUDAFEvaluator streamingEval = fnEval.getWindowingEvaluator(wdwFrame); + if ( streamingEval != null && streamingEval instanceof ISupportStreamingModeForWindowing ) { + continue; + } + BoundaryDef start = wdwFrame.getStart(); + BoundaryDef end = wdwFrame.getEnd(); + if ( !(end instanceof ValueBoundaryDef) ) { + if ( end.getAmt() != BoundarySpec.UNBOUNDED_AMOUNT && + start.getAmt() != BoundarySpec.UNBOUNDED_AMOUNT && + end.getDirection() != Direction.PRECEDING && + start.getDirection() != Direction.FOLLOWING ) { + continue; + } + } + return canAcceptStreamingFlag; + } + canAcceptStreamingFlag = true; + return canAcceptStreamingFlag; + } + + @Override + public void initializeStreaming(Configuration cfg, StructObjectInspector inputOI, boolean isMapSide) + throws HiveException { + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + int precedingSpan = 0; + int followingSpan = 0; + + for(int i=0; i < tabDef.getWindowFunctions().size(); i++) { + WindowFunctionDef wFnDef = tabDef.getWindowFunctions().get(i); + WindowFrameDef wdwFrame = wFnDef.getWindowFrame(); + GenericUDAFEvaluator fnEval = wFnDef.getWFnEval(); + GenericUDAFEvaluator streamingEval = fnEval.getWindowingEvaluator(wdwFrame); + if ( streamingEval != null ) { + wFnDef.setWFnEval(streamingEval); + if ( wFnDef.isPivotResult() ) { + ListObjectInspector listOI = (ListObjectInspector) wFnDef.getOI(); + wFnDef.setOI(listOI.getListElementObjectInspector()); + } + } else { + int amt = wdwFrame.getStart().getAmt(); + if ( amt > precedingSpan ) { + precedingSpan = amt; + } + + amt = wdwFrame.getEnd().getAmt(); + if ( amt > followingSpan ) { + followingSpan = amt; + } + } + } + streamingState = new StreamingState(cfg, inputOI, isMapSide, tabDef, precedingSpan, followingSpan); + } + + /* + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.udf.ptf.TableFunctionEvaluator#startPartition() + * + */ + @Override + public void startPartition() throws HiveException { + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + streamingState.reset(tabDef); + } + + /* + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.udf.ptf.TableFunctionEvaluator#processRow(java.lang.Object) + * + * - hand row to each Function, provided there are enough rows for Function's window. + * - call getNextObject on each Function. + * - output as many rows as possible, based on minimum sz of Output List + */ + @Override + public List processRow(Object row) throws HiveException { + + streamingState.rollingPart.append(row); + row = streamingState.rollingPart.getAt(streamingState.rollingPart.getRowsReceived() - 1); + + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + + for(int i=0; i < tabDef.getWindowFunctions().size(); i++) { + WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); + GenericUDAFEvaluator fnEval = wFn.getWFnEval(); + + int a = 0; + if (wFn.getArgs() != null) { + for (PTFExpressionDef arg : wFn.getArgs()) { + streamingState.args[i][a++] = arg.getExprEvaluator().evaluate(row); + } + } + + if ( fnEval instanceof ISupportStreamingModeForWindowing ) { + fnEval.aggregate(streamingState.aggBuffers[i], streamingState.args[i]); + Object out = ((ISupportStreamingModeForWindowing) fnEval).getNextResult(streamingState.aggBuffers[i]); + if ( out != null ) { + streamingState.fnOutputs[i].add(out == ISupportStreamingModeForWindowing.NULL_RESULT ? null : out); + } + } else { + int rowToProcess = streamingState.rollingPart.rowToProcess(wFn); + if ( rowToProcess >= 0 ) { + Range rng = getRange(wFn, rowToProcess, streamingState.rollingPart, streamingState.order); + PTFPartitionIterator rItr = rng.iterator(); + PTFOperator.connectLeadLagFunctionsToPartition(ptfDesc, rItr); + Object out = evaluateWindowFunction(wFn, rItr); + streamingState.fnOutputs[i].add(out); + } + } + } + + List oRows = new ArrayList(); + while(true) { + boolean hasRow = true; + + for(int i=0; i < streamingState.fnOutputs.length; i++) { + if (streamingState.fnOutputs[i].size() == 0 ) { + hasRow = false; + break; + } + } + + if (!hasRow) { + break; + } + + List oRow = new ArrayList(); + Object iRow = streamingState.rollingPart.nextOutputRow(); + int i = 0; + for(; i < streamingState.fnOutputs.length; i++) { + oRow.add(streamingState.fnOutputs[i].remove(0)); + } + for (StructField f : streamingState.rollingPart.getOutputOI().getAllStructFieldRefs()) { + oRow.add(streamingState.rollingPart.getOutputOI().getStructFieldData(iRow, f)); + } + oRows.add(oRow); + } + + + return oRows.size() == 0 ? null : oRows; + } + + /* + * (non-Javadoc) + * @see org.apache.hadoop.hive.ql.udf.ptf.TableFunctionEvaluator#finishPartition() + * + * for fns that are not ISupportStreamingModeForWindowing + * give them the remaining rows (rows whose span went beyond the end of the partition) + * for rest of the functions invoke terminate. + * + * while numOutputRows < numInputRows + * for each Fn that doesn't have enough o/p + * invoke getNextObj + * if there is no O/p then flag this as an error. + */ + @Override + public List finishPartition() throws HiveException { + + WindowTableFunctionDef tabDef = (WindowTableFunctionDef) getTableDef(); + for(int i=0; i < tabDef.getWindowFunctions().size(); i++) { + WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); + GenericUDAFEvaluator fnEval = wFn.getWFnEval(); + + if ( fnEval instanceof ISupportStreamingModeForWindowing ) { + fnEval.terminate(streamingState.aggBuffers[i]); + } else { + int numRowsRemaining = streamingState.rollingPart.size() - wFn.getWindowFrame().getEnd().getAmt(); + while(numRowsRemaining > 0 ) { + int rowToProcess = streamingState.rollingPart.size() - numRowsRemaining; + Range rng = getRange(wFn, rowToProcess, streamingState.rollingPart, streamingState.order); + PTFPartitionIterator rItr = rng.iterator(); + PTFOperator.connectLeadLagFunctionsToPartition(ptfDesc, rItr); + Object out = evaluateWindowFunction(wFn, rItr); + streamingState.fnOutputs[i].add(out); + numRowsRemaining--; + } + } + + } + + List oRows = new ArrayList(); + + while(!streamingState.rollingPart.processedAllRows()) { + boolean hasRow = true; + + for(int i=0; i < streamingState.fnOutputs.length; i++) { + if (streamingState.fnOutputs[i].size() == 0 ) { + hasRow = false; + break; + } + } + + if (!hasRow) { + throw new HiveException("Internal Error: cannot generate all output rows for a Partition"); + } + + List oRow = new ArrayList(); + Object iRow = streamingState.rollingPart.nextOutputRow(); + int i = 0; + for(; i < streamingState.fnOutputs.length; i++) { + oRow.add(streamingState.fnOutputs[i].remove(0)); + } + for (StructField f : streamingState.rollingPart.getOutputOI().getAllStructFieldRefs()) { + oRow.set(i++, streamingState.rollingPart.getOutputOI().getStructFieldData(iRow, f)); + } + oRows.add(oRow); + } + + return oRows.size() == 0 ? null : oRows; + } + @Override public boolean canIterateOutput() { return true; @@ -154,16 +400,19 @@ public boolean canIterateOutput() { Object out = evaluateWindowFunction(wFn, pItr); output.add(out); } else if (wFn.isPivotResult()) { - /* - * for functions that currently return the output as a List, - * for e.g. the ranking functions, lead/lag, ntile, cume_dist - * - for now continue to execute them here. The functions need to provide a way to get - * each output row as we are iterating through the input. This is relative - * easy to do for ranking functions; not possible for lead, ntile, cume_dist. - * - */ - outputFromPivotFunctions[i] = (List) evaluateWindowFunction(wFn, pItr); - output.add(null); + GenericUDAFEvaluator streamingEval = wFn.getWFnEval().getWindowingEvaluator(wFn.getWindowFrame()); + if ( streamingEval != null && streamingEval instanceof ISupportStreamingModeForWindowing ) { + wFn.setWFnEval(streamingEval); + if ( wFn.getOI() instanceof ListObjectInspector ) { + ListObjectInspector listOI = (ListObjectInspector) wFn.getOI(); + wFn.setOI(listOI.getListElementObjectInspector()); + } + output.add(null); + wFnsWithWindows.add(i); + } else { + outputFromPivotFunctions[i] = (List) evaluateWindowFunction(wFn, pItr); + output.add(null); + } } else { output.add(null); wFnsWithWindows.add(i); @@ -850,8 +1099,10 @@ public int size() { Order order; PTFDesc ptfDesc; StructObjectInspector inputOI; + AggregationBuffer[] aggBuffers; + Object[][] args; - WindowingIterator(PTFPartition iPart, ArrayList output, + WindowingIterator(PTFPartition iPart, ArrayList output, List[] outputFromPivotFunctions, int[] wFnsToProcess) { this.iPart = iPart; this.output = output; @@ -862,6 +1113,18 @@ public int size() { order = wTFnDef.getOrder().getExpressions().get(0).getOrder(); ptfDesc = getQueryDef(); inputOI = iPart.getOutputOI(); + + aggBuffers = new AggregationBuffer[wTFnDef.getWindowFunctions().size()]; + args = new Object[wTFnDef.getWindowFunctions().size()][]; + try { + for (int j : wFnsToProcess) { + WindowFunctionDef wFn = wTFnDef.getWindowFunctions().get(j); + aggBuffers[j] = wFn.getWFnEval().getNewAggregationBuffer(); + args[j] = new Object[wFn.getArgs() == null ? 0 : wFn.getArgs().size()]; + } + } catch (HiveException he) { + throw new RuntimeException(he); + } } @Override @@ -881,10 +1144,25 @@ public Object next() { try { for (int j : wFnsToProcess) { WindowFunctionDef wFn = wTFnDef.getWindowFunctions().get(j); - Range rng = getRange(wFn, currIdx, iPart, order); - PTFPartitionIterator rItr = rng.iterator(); - PTFOperator.connectLeadLagFunctionsToPartition(ptfDesc, rItr); - output.set(j, evaluateWindowFunction(wFn, rItr)); + if (wFn.getWFnEval() instanceof ISupportStreamingModeForWindowing) { + Object iRow = iPart.getAt(currIdx); + int a = 0; + if (wFn.getArgs() != null) { + for (PTFExpressionDef arg : wFn.getArgs()) { + args[j][a++] = arg.getExprEvaluator().evaluate(iRow); + } + } + wFn.getWFnEval().aggregate(aggBuffers[j], args[j]); + Object out = ((ISupportStreamingModeForWindowing) wFn.getWFnEval()) + .getNextResult(aggBuffers[j]); + out = ObjectInspectorUtils.copyToStandardObject(out, wFn.getOI()); + output.set(j, out); + } else { + Range rng = getRange(wFn, currIdx, iPart, order); + PTFPartitionIterator rItr = rng.iterator(); + PTFOperator.connectLeadLagFunctionsToPartition(ptfDesc, rItr); + output.set(j, evaluateWindowFunction(wFn, rItr)); + } } Object iRow = iPart.getAt(currIdx); @@ -908,4 +1186,54 @@ public void remove() { } + class StreamingState { + PTFRollingPartition rollingPart; + List[] fnOutputs; + AggregationBuffer[] aggBuffers; + Object[][] args; + Order order; + + @SuppressWarnings("unchecked") + StreamingState(Configuration cfg, StructObjectInspector inputOI, + boolean isMapSide, WindowTableFunctionDef tabDef, + int precedingSpan, int followingSpan) + throws HiveException { + SerDe serde = isMapSide ? tabDef.getInput().getOutputShape().getSerde() : + tabDef.getRawInputShape().getSerde(); + StructObjectInspector outputOI = isMapSide ? tabDef.getInput().getOutputShape().getOI() : + tabDef.getRawInputShape().getOI(); + rollingPart = PTFPartition.createRolling(cfg, + serde, + inputOI, + outputOI, precedingSpan, + followingSpan); + + order = tabDef.getOrder().getExpressions().get(0).getOrder(); + + int numFns = tabDef.getWindowFunctions().size(); + fnOutputs = new ArrayList[numFns]; + + aggBuffers = new AggregationBuffer[numFns]; + args = new Object[numFns][]; + for (int i=0; i< numFns; i++) { + fnOutputs[i] = new ArrayList(); + WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); + args[i] = new Object[wFn.getArgs() == null ? 0 : wFn.getArgs().size()]; + } + } + + void reset(WindowTableFunctionDef tabDef) throws HiveException { + int numFns = tabDef.getWindowFunctions().size(); + rollingPart.reset(); + for(int i=0; i < fnOutputs.length; i++) { + fnOutputs[i].clear(); + } + + for (int i=0; i< numFns; i++) { + WindowFunctionDef wFn = tabDef.getWindowFunctions().get(i); + aggBuffers[i] = wFn.getWFnEval().getNewAggregationBuffer(); + } + } + } + } diff --git ql/src/test/org/apache/hadoop/hive/ql/udaf/TestStreamingSum.java ql/src/test/org/apache/hadoop/hive/ql/udaf/TestStreamingSum.java new file mode 100644 index 0000000..7b0a4da --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udaf/TestStreamingSum.java @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udaf; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import junit.framework.Assert; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.WindowingSpec.BoundarySpec; +import org.apache.hadoop.hive.ql.parse.WindowingSpec.Direction; +import org.apache.hadoop.hive.ql.plan.ptf.BoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.CurrentRowDef; +import org.apache.hadoop.hive.ql.plan.ptf.RangeBoundaryDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFrameDef; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFSum; +import org.apache.hadoop.hive.ql.udf.generic.ISupportStreamingModeForWindowing; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.junit.Test; + + +public class TestStreamingSum { + + public WindowFrameDef wdwFrame(int p, int f) { + WindowFrameDef wFrmDef = new WindowFrameDef(); + BoundaryDef start, end; + if ( p == 0 ) { + start = new CurrentRowDef(); + } else { + RangeBoundaryDef startR = new RangeBoundaryDef(); + startR.setDirection(Direction.PRECEDING); + startR.setAmt(p); + start = startR; + } + + if ( f == 0 ) { + end = new CurrentRowDef(); + } else { + RangeBoundaryDef endR = new RangeBoundaryDef(); + endR.setDirection(Direction.FOLLOWING); + endR.setAmt(f); + end = endR; + } + wFrmDef.setStart(start); + wFrmDef.setEnd(end); + return wFrmDef; + } + + public void sum(Iterator inVals, + int inSz, + int numPreceding, + int numFollowing, + Iterator outVals) throws HiveException { + + GenericUDAFSum fnR = new GenericUDAFSum(); + TypeInfo[] inputTypes = { + TypeInfoFactory.doubleTypeInfo + }; + ObjectInspector[] inputOIs = { + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector + }; + + GenericUDAFEvaluator fn = fnR.getEvaluator(inputTypes); + fn.init(Mode.COMPLETE, inputOIs); + fn = fn.getWindowingEvaluator(wdwFrame(numPreceding, numFollowing)); + AggregationBuffer agg = fn.getNewAggregationBuffer(); + ISupportStreamingModeForWindowing oS = (ISupportStreamingModeForWindowing) fn; + + DoubleWritable[] in = new DoubleWritable[1]; + in[0] = new DoubleWritable(); + int outSz = 0; + while(inVals.hasNext()) { + in[0].set(inVals.next()); + fn.aggregate(agg, in); + Object out = oS.getNextResult(agg); + if ( out != null ) { + out = out == ISupportStreamingModeForWindowing.NULL_RESULT ? null : ((DoubleWritable)out).get(); + Assert.assertEquals(out, outVals.next()); + outSz++; + } + } + + fn.terminate(agg); + + while(outSz < inSz ) { + Object out = oS.getNextResult(agg); + out = out == ISupportStreamingModeForWindowing.NULL_RESULT ? null : ((DoubleWritable)out).get(); + Assert.assertEquals(out, outVals.next()); + outSz++; + } + } + + @Test + public void test_3_4() throws HiveException { + + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(15.0, 21.0, 28.0, 36.0, 44.0, 52.0, 49.0, 45.0, 40.0, 34.0); + sum(inVals.iterator(),10,3,4, outVals.iterator()); + } + + @Test + public void test_3_0() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(1.0, 3.0, 6.0, 10.0, 14.0, 18.0, 22.0, 26.0, 30.0, 34.0); + sum(inVals.iterator(),10,3,0, outVals.iterator()); + } + + @Test + public void test_unb_0() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0); + sum(inVals.iterator(),10,BoundarySpec.UNBOUNDED_AMOUNT,0, outVals.iterator()); + } + + @Test + public void test_0_5() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(21.0, 27.0, 33.0, 39.0, 45.0, 40.0, 34.0, 27.0, 19.0, 10.0); + sum(inVals.iterator(),10,0,5, outVals.iterator()); + } + + @Test + public void test_unb_5() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(21.0, 28.0, 36.0, 45.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0); + sum(inVals.iterator(),10,BoundarySpec.UNBOUNDED_AMOUNT,5, outVals.iterator()); + } + + @Test + public void test_7_2() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 54.0, 52.0); + sum(inVals.iterator(),10,7,2, outVals.iterator()); + } + + @Test + public void test_15_15() throws HiveException { + List inVals = Arrays.asList(1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0); + List outVals = Arrays.asList(55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0, 55.0); + sum(inVals.iterator(),10,15,15, outVals.iterator()); + } + +}