diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/LLRowResolver.java ql/src/java/org/apache/hadoop/hive/ql/exec/LLRowResolver.java new file mode 100644 index 0000000..2c05b50 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/LLRowResolver.java @@ -0,0 +1,104 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec; + +import java.util.List; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ptf.PTFExpressionDef; +import org.apache.hadoop.hive.ql.plan.ptf.WindowFunctionDef; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFLead; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFLeadLag; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; + +public class LLRowResolver { + + private final PTFExpressionDef valueExpr; + private final PTFExpressionDef defaultExpr; + private final ObjectInspectorConverters.Converter defaultValueConverter; + private final int amount; + + public LLRowResolver(List argExprs, int amount) { + this.valueExpr = argExprs.get(0); + this.defaultExpr = argExprs.size() > 2 ? argExprs.get(2) : null; + this.defaultValueConverter = defaultExpr != null ? + ObjectInspectorConverters.getConverter(defaultExpr.getOI(), valueExpr.getOI()) : null; + this.amount = amount; + } + + public Object evaluateAndCopy(PTFPartition.PTFPartitionIterator iterator) throws HiveException { + Object evaluate = evaluate(iterator); + return evaluate != null ? + ObjectInspectorUtils.copyToStandardObject(evaluate, valueExpr.getOI()) : null; + } + + private Object evaluate(PTFPartition.PTFPartitionIterator iterator) throws HiveException { + boolean hasValue = amount < 0 ? iterator.hasLead(-amount) : iterator.hasLag(amount); + if (hasValue) { + Object row = amount < 0 ? iterator.lead(-amount) : iterator.lag(amount); + return valueExpr.getExprEvaluator().evaluate(row); + } + if (defaultExpr != null) { + Object evaluate = defaultExpr.getExprEvaluator().evaluate(iterator.current()); + return defaultValueConverter.convert(evaluate); + } + return null; + } + + public static boolean isLL(WindowFunctionDef function) { + return function.getWFnEval() instanceof GenericUDAFLeadLag.GenericUDAFLeadLagEvaluator; + } + + public static LLRowResolver[] toResolver(List functions) { + LLRowResolver[] resolvers = new LLRowResolver[functions.size()]; + for (int i = 0; i < resolvers.length; i++) { + WindowFunctionDef function = functions.get(i); + GenericUDAFEvaluator eval = function.getWFnEval(); + if (!(eval instanceof GenericUDAFLeadLag.GenericUDAFLeadLagEvaluator)) { + continue; + } + if (resolvers == null) { + resolvers = new LLRowResolver[functions.size()]; + } + int amt = ((GenericUDAFLeadLag.GenericUDAFLeadLagEvaluator) eval).getAmt(); + if (eval instanceof GenericUDAFLead.GenericUDAFLeadEvaluator) { + amt = -amt; + } + resolvers[i] = new LLRowResolver(function.getArgs(), amt); + } + return resolvers; + } + + public static int[] toRange(LLRowResolver[] resolvers) { + int[] range = new int[2]; + for (LLRowResolver window : resolvers) { + if (window == null) { + continue; + } + if (window.amount > 0) { + range[0] = Math.max(range[0], window.amount); // lag + } else if (window.amount < 0) { + range[1] = Math.min(range[1], window.amount); // lead + } + } + return range; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java index 21d85f1..132fae8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/PTFPartition.java @@ -115,6 +115,10 @@ public int size() { return new PItr(0, size()); } + public PTFPartitionIterator iterator(LLRowResolver[] resolvers) { + return new RowResolvingPItr(0, size(), resolvers); + } + public PTFPartitionIterator range(int start, int end) { assert (start >= 0); assert (end <= size()); @@ -130,6 +134,33 @@ public void close() { } } + class RowResolvingPItr extends PItr { + + final LLRowResolver[] resolvers; + final Object[] buffer; + + public RowResolvingPItr(int start, int end, LLRowResolver[] resolvers) { + super(start, end); + this.resolvers = resolvers; + this.buffer = new Object[resolvers.length + 1]; + } + + @Override + public Object next() { + for (int i = 0; i < resolvers.length; i++) { + if (resolvers[i] != null) { + try { + buffer[i] = resolvers[i].evaluateAndCopy(this); + } catch (HiveException e) { + throw new RuntimeException(e); + } + } + } + buffer[resolvers.length] = super.next(); + return buffer; + } + } + class PItr implements PTFPartitionIterator { int idx; final int start; @@ -180,6 +211,11 @@ private Object getAt(int i) throws HiveException { } @Override + public final boolean hasLead(int amt) { + return idx + amt < end; + } + + @Override public Object lead(int amt) throws HiveException { int i = idx + amt; i = i >= end ? end - 1 : i; @@ -187,6 +223,16 @@ public Object lead(int amt) throws HiveException { } @Override + public final boolean hasLag(int amt) { + return idx - amt >= start; + } + + @Override + public Object current() throws HiveException { + return getAt(idx); + } + + @Override public Object lag(int amt) throws HiveException { int i = idx - amt; i = i < start ? start : i; @@ -212,7 +258,7 @@ public PTFPartition getPartition() { public void reset() { idx = start; } - }; + } /* * provide an Iterator on the rows in a Partiton. @@ -222,10 +268,16 @@ public void reset() { public static interface PTFPartitionIterator extends Iterator { int getIndex(); + boolean hasLead(int amt); + + boolean hasLag(int amt); + T lead(int amt) throws HiveException; T lag(int amt) throws HiveException; + T current() throws HiveException; + /* * after a lead and lag call, allow Object associated with SerDe and writable associated with * partition to be reset diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java index 903a9b0..aa5dbcc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/ptf/WindowingTableFunction.java @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.exec.LLRowResolver; import org.apache.hadoop.hive.ql.exec.PTFOperator; import org.apache.hadoop.hive.ql.exec.PTFPartition; import org.apache.hadoop.hive.ql.exec.WindowFunctionInfo; @@ -69,13 +70,20 @@ public void execute(PTFPartitionIterator pItr, PTFPartition outP) throws HiveException { ArrayList> oColumns = new ArrayList>(); PTFPartition iPart = pItr.getPartition(); - StructObjectInspector inputOI; - inputOI = (StructObjectInspector) iPart.getOutputOI(); + StructObjectInspector inputOI = iPart.getOutputOI(); WindowTableFunctionDef wTFnDef = (WindowTableFunctionDef) getTableDef(); Order order = wTFnDef.getOrder().getExpressions().get(0).getOrder(); - for(WindowFunctionDef wFn : wTFnDef.getWindowFunctions()) { + final List functions = wTFnDef.getWindowFunctions(); + final LLRowResolver[] resolvers = LLRowResolver.toResolver(functions); + + for (int wi = 0; wi < functions.size(); wi++) { + if (resolvers != null && resolvers[wi] != null) { + oColumns.add(null); + continue; + } + WindowFunctionDef wFn = functions.get(wi); boolean processWindow = processWindow(wFn); pItr.reset(); if ( !processWindow ) { @@ -89,18 +97,41 @@ public void execute(PTFPartitionIterator pItr, PTFPartition outP) throws } } + PTFPartitionIterator iterator; + if (resolvers != null) { + iterator = iPart.iterator(resolvers); + } else { + iterator = iPart.iterator(); + } + /* * Output Columns in the following order * - the columns representing the output from Window Fns * - the input Rows columns */ - for(int i=0; i < iPart.size(); i++) { - ArrayList oRow = new ArrayList(); - Object iRow = iPart.getAt(i); + ArrayList oRow = new ArrayList(); + while (iterator.hasNext()) { - for(int j=0; j < oColumns.size(); j++) { - oRow.add(oColumns.get(j).get(i)); + int ir = iterator.getIndex(); + + for (List columns : oColumns) { + if (columns != null) { + oRow.add(columns.get(ir)); + } else { + oRow.add(null); + } + } + Object iRow = iterator.next(); + + if (resolvers != null) { + Object[] cached = (Object[])iRow; + for (int iw = 0; iw < resolvers.length; iw++) { + if (resolvers[iw] != null) { + oRow.set(iw, cached[iw]); + } + } + iRow = cached[cached.length - 1]; } for(StructField f : inputOI.getAllStructFieldRefs()) { @@ -108,6 +139,7 @@ public void execute(PTFPartitionIterator pItr, PTFPartition outP) throws } outP.append(oRow); + oRow.clear(); } } @@ -514,7 +546,6 @@ public boolean canIterateOutput() { i++; } - i=0; for(i=0; i < iPart.getOutputOI().getAllStructFieldRefs().size(); i++) { output.add(null); }