diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 8c8603f..2fe8192 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -41,6 +41,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFCountDouble; @@ -942,6 +943,7 @@ private String getOutputColType(String inputType, String method) { {"min", "Double", VectorUDAFMinDouble.class}, {"max", "Long", VectorUDAFMaxLong.class}, {"max", "Double", VectorUDAFMaxDouble.class}, + {"count", null, VectorUDAFCountStar.class}, {"count", "Long", VectorUDAFCountLong.class}, {"count", "Double", VectorUDAFCountDouble.class}, {"sum", "Long", VectorUDAFSumLong.class}, @@ -966,6 +968,7 @@ private String getOutputColType(String inputType, String method) { public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) throws HiveException { + ArrayList paramDescList = desc.getParameters(); VectorExpression[] vectorParams = new VectorExpression[paramDescList.size()]; @@ -975,22 +978,25 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) } String aggregateName = desc.getGenericUDAFName(); - List params = desc.getParameters(); - //TODO: handle length != 1 - assert (params.size() == 1); - ExprNodeDesc inputExpr = params.get(0); - String inputType = getNormalizedTypeName(inputExpr.getTypeString()); + String inputType = null; + + if (paramDescList.size() > 0) { + ExprNodeDesc inputExpr = paramDescList.get(0); + inputType = getNormalizedTypeName(inputExpr.getTypeString()); + } for (Object[] aggDef : aggregatesDefinition) { - if (aggDef[0].equals (aggregateName) && - aggDef[1].equals(inputType)) { + if (aggregateName.equalsIgnoreCase((String) aggDef[0]) && + ((aggDef[1] == null && inputType == null) || + (aggDef[1] != null && aggDef[1].equals(inputType)))) { Class aggClass = (Class) (aggDef[2]); try { Constructor ctor = aggClass.getConstructor(VectorExpression.class); - VectorAggregateExpression aggExpr = ctor.newInstance(vectorParams[0]); + VectorAggregateExpression aggExpr = ctor.newInstance( + vectorParams.length > 0 ? vectorParams[0] : null); return aggExpr; } // TODO: change to 1.7 syntax when possible diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java new file mode 100644 index 0000000..607e3ad --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountStar.java @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.LongWritable; + +/** +* VectorUDAFCountStar. Vectorized implementation for COUNT(*) aggregates. +*/ +@Description(name = "count", value = "_FUNC_(expr) - Returns count(*) (vectorized)") +public class VectorUDAFCountStar extends VectorAggregateExpression { + + /** + /* class for storing the current aggregate value. + */ + static class Aggregation implements AggregationBuffer { + long value; + boolean isNull; + } + + private final LongWritable result; + + public VectorUDAFCountStar(VectorExpression inputExpression) { + super(); + result = new LongWritable(0); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + // count(*) cares not about NULLs nor selection + for (int i=0; i < batchSize; ++i) { + Aggregation myAgg = getCurrentAggregationBuffer( + aggregationBufferSets, aggregateIndex, i); + myAgg.isNull = false; + ++myAgg.value; + } + } + + @Override + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Aggregation myagg = (Aggregation)agg; + myagg.isNull = false; + myagg.value += batchSize; + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + return new Aggregation(); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + Aggregation myAgg = (Aggregation) agg; + myAgg.isNull = true; + } + + @Override + public Object evaluateOutput(AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + if (myagg.isNull) { + return null; + } + else { + result.set (myagg.value); + return result; + } + } + + @Override + public ObjectInspector getOutputObjectInspector() { + return PrimitiveObjectInspectorFactory.writableLongObjectInspector; + } +} + diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index f9c05cf..6fc230f 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -79,6 +79,14 @@ private static AggregationDesc buildAggregationDesc( return agg; } + private static AggregationDesc buildAggregationDescCountStar( + VectorizationContext ctx) { + AggregationDesc agg = new AggregationDesc(); + agg.setGenericUDAFName("COUNT"); + agg.setParameters(new ArrayList()); + return agg; + } + private static GroupByDesc buildGroupByDesc( VectorizationContext ctx, @@ -98,6 +106,23 @@ private static GroupByDesc buildGroupByDesc( return desc; } + private static GroupByDesc buildGroupByDescCountStar( + VectorizationContext ctx) { + + AggregationDesc agg = buildAggregationDescCountStar(ctx); + ArrayList aggs = new ArrayList(); + aggs.add(agg); + + ArrayList outputColumnNames = new ArrayList(); + outputColumnNames.add("_col0"); + + GroupByDesc desc = new GroupByDesc(); + desc.setOutputColumnNames(outputColumnNames); + desc.setAggregators(aggs); + + return desc; + } + private static GroupByDesc buildKeyGroupByDesc( VectorizationContext ctx, @@ -117,6 +142,14 @@ private static GroupByDesc buildKeyGroupByDesc( } @Test + public void testCountStar () throws HiveException { + testAggregateCountStar( + 2, + Arrays.asList(new Long[]{13L,null,7L,19L}), + 4L); + } + + @Test public void testMinLongNullStringKeys() throws HiveException { testAggregateStringKeyAggregate( "min", @@ -947,6 +980,17 @@ public void testAggregateLongAggregate ( testAggregateLongIterable (aggregateName, fdr, expected); } + public void testAggregateCountStar ( + int batchSize, + Iterable values, + Object expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromLongIterables fdr = new FakeVectorRowBatchFromLongIterables(batchSize, values); + testAggregateCountStarIterable (fdr, expected); + } + + public static interface Validator { void validate (Object expected, Object result); }; @@ -1086,6 +1130,35 @@ public static Validator getValidator(String aggregate) throws HiveException { throw new HiveException("Missing validator for aggregate: " + aggregate); } + public void testAggregateCountStarIterable ( + Iterable data, + Object expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("A", 0); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); + + GroupByDesc desc = buildGroupByDescCountStar (ctx); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(1, outBatchList.size()); + + Object result = outBatchList.get(0); + + Validator validator = getValidator("count"); + validator.validate(expected, result); + } + public void testAggregateLongIterable ( String aggregateName, Iterable data,