diff --git itests/src/test/resources/testconfiguration.properties itests/src/test/resources/testconfiguration.properties index fd6901c..c4ba277 100644 --- itests/src/test/resources/testconfiguration.properties +++ itests/src/test/resources/testconfiguration.properties @@ -270,6 +270,7 @@ minitez.query.files.shared=acid_globallimit.q,\ vector_coalesce_2.q,\ vector_complex_all.q,\ vector_complex_join.q,\ + vector_count.q,\ vector_count_distinct.q,\ vector_data_types.q,\ vector_date_1.q,\ @@ -293,6 +294,8 @@ minitez.query.files.shared=acid_globallimit.q,\ vector_decimal_udf2.q,\ vector_distinct_2.q,\ vector_elt.q,\ + vector_groupby4.q,\ + vector_groupby6.q,\ vector_groupby_3.q,\ vector_groupby_mapjoin.q,\ vector_groupby_reduce.q,\ diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/AggregateDefinition.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/AggregateDefinition.java index 3f15c6f..0334c40 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/AggregateDefinition.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/AggregateDefinition.java @@ -20,19 +20,20 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.plan.GroupByDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; class AggregateDefinition { private String name; private VectorExpressionDescriptor.ArgumentType type; - private GroupByDesc.Mode mode; + private GenericUDAFEvaluator.Mode udafEvaluatorMode; private Class aggClass; - AggregateDefinition(String name, VectorExpressionDescriptor.ArgumentType type, - GroupByDesc.Mode mode, Class aggClass) { + AggregateDefinition(String name, VectorExpressionDescriptor.ArgumentType type, + GenericUDAFEvaluator.Mode udafEvaluatorMode, Class aggClass) { this.name = name; this.type = type; - this.mode = mode; + this.udafEvaluatorMode = udafEvaluatorMode; this.aggClass = aggClass; } @@ -42,8 +43,8 @@ String getName() { VectorExpressionDescriptor.ArgumentType getType() { return type; } - GroupByDesc.Mode getMode() { - return mode; + GenericUDAFEvaluator.Mode getUdafEvaluatorMode() { + return udafEvaluatorMode; } Class getAggClass() { return aggClass; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java index 98a9bf6..6e53526 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorGroupByOperator.java @@ -41,6 +41,7 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -51,6 +52,7 @@ import org.slf4j.LoggerFactory; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; /** * Vectorized GROUP BY operator implementation. Consumes the vectorized input and @@ -542,18 +544,17 @@ private void checkHashModeEfficiency() throws HiveException { if (numEntriesHashTable > sumBatchSize * minReductionHashAggr) { flush(true); - changeToUnsortedStreamingMode(); + changeToStreamingMode(); } } } } /** - * Unsorted streaming processing mode. Each input VectorizedRowBatch may have - * a mix of different keys (hence unsorted). Intermediate values are flushed - * each time key changes. + * Streaming processing mode on ALREADY GROUPED data. Each input VectorizedRowBatch may + * have a mix of different keys. Intermediate values are flushed each time key changes. */ - private class ProcessingModeUnsortedStreaming extends ProcessingModeBase { + private class ProcessingModeStreaming extends ProcessingModeBase { /** * The aggregation buffers used in streaming mode @@ -675,7 +676,7 @@ public void close(boolean aborted) throws HiveException { * writeGroupRow does this and finally increments outputBatch.size. * */ - private class ProcessingModeReduceMergePartialKeys extends ProcessingModeBase { + private class ProcessingModeReduceMergePartial extends ProcessingModeBase { private boolean inGroup; private boolean first; @@ -761,8 +762,7 @@ public VectorGroupByOperator(CompilationOpContext ctx, aggregators = new VectorAggregateExpression[aggrDesc.size()]; for (int i = 0; i < aggrDesc.size(); ++i) { AggregationDesc aggDesc = aggrDesc.get(i); - aggregators[i] = - vContext.getAggregatorExpression(aggDesc, desc.getVectorDesc().isReduceMergePartial()); + aggregators[i] = vContext.getAggregatorExpression(aggDesc); } isVectorOutput = desc.getVectorDesc().isVectorOutput(); @@ -810,12 +810,10 @@ protected void initializeOp(Configuration hconf) throws HiveException { objectInspectors.add(aggregators[i].getOutputObjectInspector()); } - if (outputKeyLength > 0 && !conf.getVectorDesc().isReduceMergePartial()) { - // These data structures are only used by the non Reduce Merge-Partial Keys processing modes. - keyWrappersBatch = VectorHashKeyWrapperBatch.compileKeyWrapperBatch(keyExpressions); - aggregationBatchInfo = new VectorAggregationBufferBatch(); - aggregationBatchInfo.compileAggregationBatchInfo(aggregators); - } + keyWrappersBatch = VectorHashKeyWrapperBatch.compileKeyWrapperBatch(keyExpressions); + aggregationBatchInfo = new VectorAggregationBufferBatch(); + aggregationBatchInfo.compileAggregationBatchInfo(aggregators); + LOG.info("VectorGroupByOperator is vector output {}", isVectorOutput); outputObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector( outputFieldNames, objectInspectors); @@ -835,29 +833,35 @@ protected void initializeOp(Configuration hconf) throws HiveException { forwardCache = new Object[outputKeyLength + aggregators.length]; - if (outputKeyLength == 0) { - // Hash and MergePartial global aggregation are both handled here. + switch (conf.getVectorDesc().getProcessingMode()) { + case GLOBAL: + Preconditions.checkState(outputKeyLength == 0); processingMode = this.new ProcessingModeGlobalAggregate(); - } else if (conf.getVectorDesc().isReduceMergePartial()) { - // Sorted GroupBy of vector batches where an individual batch has the same group key (e.g. reduce). - processingMode = this.new ProcessingModeReduceMergePartialKeys(); - } else if (conf.getVectorDesc().isReduceStreaming()) { - processingMode = this.new ProcessingModeUnsortedStreaming(); - } else { - // We start in hash mode and may dynamically switch to unsorted stream mode. + break; + case HASH: processingMode = this.new ProcessingModeHashAggregate(); + break; + case MERGE_PARTIAL: + processingMode = this.new ProcessingModeReduceMergePartial(); + break; + case STREAMING: + processingMode = this.new ProcessingModeStreaming(); + break; + default: + throw new RuntimeException("Unsupported vector GROUP BY processing mode " + + conf.getVectorDesc().getProcessingMode().name()); } processingMode.initialize(hconf); } /** - * changes the processing mode to unsorted streaming + * changes the processing mode to streaming * This is done at the request of the hash agg mode, if the number of keys * exceeds the minReductionHashAggr factor * @throws HiveException */ - private void changeToUnsortedStreamingMode() throws HiveException { - processingMode = this.new ProcessingModeUnsortedStreaming(); + private void changeToStreamingMode() throws HiveException { + processingMode = this.new ProcessingModeStreaming(); processingMode.initialize(null); LOG.trace("switched to streaming mode"); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 58ce063..1a3299b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -47,8 +47,8 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector.Type; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.ArgumentType; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.InputExpressionType; -import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Mode; import org.apache.hadoop.hive.ql.exec.vector.expressions.*; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFAvgDecimal; @@ -114,6 +114,7 @@ import org.apache.hadoop.hive.ql.udf.UDFToShort; import org.apache.hadoop.hive.ql.udf.UDFToString; import org.apache.hadoop.hive.ql.udf.generic.*; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode; import org.apache.hadoop.hive.serde2.ByteStream.Output; import org.apache.hadoop.hive.serde2.binarysortable.fast.BinarySortableSerializeWrite; import org.apache.hadoop.hive.serde2.io.DateWritable; @@ -404,7 +405,7 @@ public int allocateScratchColumn(String hiveTypeName) throws HiveException { } private VectorExpression getColumnVectorExpression(ExprNodeColumnDesc - exprDesc, Mode mode) throws HiveException { + exprDesc, VectorExpressionDescriptor.Mode mode) throws HiveException { int columnNum = getInputColumnIndex(exprDesc.getColumn()); VectorExpression expr = null; switch (mode) { @@ -425,7 +426,7 @@ private VectorExpression getColumnVectorExpression(ExprNodeColumnDesc // Ok, try the UDF. castToBooleanExpr = getVectorExpressionForUdf(null, UDFToBoolean.class, exprAsList, - Mode.PROJECTION, null); + VectorExpressionDescriptor.Mode.PROJECTION, null); if (castToBooleanExpr == null) { throw new HiveException("Cannot vectorize converting expression " + exprDesc.getExprString() + " to boolean"); @@ -443,10 +444,10 @@ private VectorExpression getColumnVectorExpression(ExprNodeColumnDesc } public VectorExpression[] getVectorExpressions(List exprNodes) throws HiveException { - return getVectorExpressions(exprNodes, Mode.PROJECTION); + return getVectorExpressions(exprNodes, VectorExpressionDescriptor.Mode.PROJECTION); } - public VectorExpression[] getVectorExpressions(List exprNodes, Mode mode) + public VectorExpression[] getVectorExpressions(List exprNodes, VectorExpressionDescriptor.Mode mode) throws HiveException { int i = 0; @@ -461,7 +462,7 @@ private VectorExpression getColumnVectorExpression(ExprNodeColumnDesc } public VectorExpression getVectorExpression(ExprNodeDesc exprDesc) throws HiveException { - return getVectorExpression(exprDesc, Mode.PROJECTION); + return getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); } /** @@ -472,7 +473,7 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc) throws HiveEx * @return {@link VectorExpression} * @throws HiveException */ - public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, Mode mode) throws HiveException { + public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, VectorExpressionDescriptor.Mode mode) throws HiveException { VectorExpression ve = null; if (exprDesc instanceof ExprNodeColumnDesc) { ve = getColumnVectorExpression((ExprNodeColumnDesc) exprDesc, mode); @@ -873,14 +874,14 @@ ExprNodeDesc evaluateCastOnConstants(ExprNodeDesc exprDesc) throws HiveException } private VectorExpression getConstantVectorExpression(Object constantValue, TypeInfo typeInfo, - Mode mode) throws HiveException { + VectorExpressionDescriptor.Mode mode) throws HiveException { String typeName = typeInfo.getTypeName(); VectorExpressionDescriptor.ArgumentType vectorArgType = VectorExpressionDescriptor.ArgumentType.fromHiveTypeName(typeName); if (vectorArgType == VectorExpressionDescriptor.ArgumentType.NONE) { throw new HiveException("No vector argument type for type name " + typeName); } int outCol = -1; - if (mode == Mode.PROJECTION) { + if (mode == VectorExpressionDescriptor.Mode.PROJECTION) { outCol = ocm.allocateOutputColumn(typeName); } if (constantValue == null) { @@ -889,7 +890,7 @@ private VectorExpression getConstantVectorExpression(Object constantValue, TypeI // Boolean is special case. if (typeName.equalsIgnoreCase("boolean")) { - if (mode == Mode.FILTER) { + if (mode == VectorExpressionDescriptor.Mode.FILTER) { if (((Boolean) constantValue).booleanValue()) { return new FilterConstantBooleanVectorExpression(1); } else { @@ -961,7 +962,7 @@ private VectorExpression getIdentityExpression(List childExprList) } private VectorExpression getVectorExpressionForUdf(GenericUDF genericeUdf, - Class udfClass, List childExpr, Mode mode, + Class udfClass, List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { int numChildren = (childExpr == null) ? 0 : childExpr.size(); @@ -973,13 +974,13 @@ private VectorExpression getVectorExpressionForUdf(GenericUDF genericeUdf, Class vclass; if (genericeUdf instanceof GenericUDFOPOr) { - if (mode == Mode.PROJECTION) { + if (mode == VectorExpressionDescriptor.Mode.PROJECTION) { vclass = ColOrCol.class; } else { vclass = FilterExprOrExpr.class; } } else if (genericeUdf instanceof GenericUDFOPAnd) { - if (mode == Mode.PROJECTION) { + if (mode == VectorExpressionDescriptor.Mode.PROJECTION) { vclass = ColAndCol.class; } else { vclass = FilterExprAndExpr.class; @@ -987,8 +988,8 @@ private VectorExpression getVectorExpressionForUdf(GenericUDF genericeUdf, } else { throw new RuntimeException("Unexpected multi-child UDF"); } - Mode childrenMode = getChildrenMode(mode, udfClass); - if (mode == Mode.PROJECTION) { + VectorExpressionDescriptor.Mode childrenMode = getChildrenMode(mode, udfClass); + if (mode == VectorExpressionDescriptor.Mode.PROJECTION) { return createVectorMultiAndOrProjectionExpr(vclass, childExpr, childrenMode, returnType); } else { return createVectorExpression(vclass, childExpr, childrenMode, returnType); @@ -1027,12 +1028,12 @@ private VectorExpression getVectorExpressionForUdf(GenericUDF genericeUdf, } return null; } - Mode childrenMode = getChildrenMode(mode, udfClass); + VectorExpressionDescriptor.Mode childrenMode = getChildrenMode(mode, udfClass); return createVectorExpression(vclass, childExpr, childrenMode, returnType); } private void determineChildrenVectorExprAndArguments(Class vectorClass, - List childExpr, int numChildren, Mode childrenMode, + List childExpr, int numChildren, VectorExpressionDescriptor.Mode childrenMode, VectorExpression.Type [] inputTypes, List children, Object[] arguments) throws HiveException { for (int i = 0; i < numChildren; i++) { @@ -1048,7 +1049,7 @@ private void determineChildrenVectorExprAndArguments(Class vectorClass, arguments[i] = vChild.getOutputColumn(); } else if (child instanceof ExprNodeColumnDesc) { int colIndex = getInputColumnIndex((ExprNodeColumnDesc) child); - if (childrenMode == Mode.FILTER) { + if (childrenMode == VectorExpressionDescriptor.Mode.FILTER) { // In filter mode, the column must be a boolean children.add(new SelectColumnIsTrue(colIndex)); } @@ -1063,7 +1064,7 @@ private void determineChildrenVectorExprAndArguments(Class vectorClass, } private VectorExpression createVectorExpression(Class vectorClass, - List childExpr, Mode childrenMode, TypeInfo returnType) throws HiveException { + List childExpr, VectorExpressionDescriptor.Mode childrenMode, TypeInfo returnType) throws HiveException { int numChildren = childExpr == null ? 0: childExpr.size(); VectorExpression.Type [] inputTypes = new VectorExpression.Type[numChildren]; List children = new ArrayList(); @@ -1087,7 +1088,7 @@ private VectorExpression createVectorExpression(Class vectorClass, } private VectorExpression createVectorMultiAndOrProjectionExpr(Class vectorClass, - List childExpr, Mode childrenMode, TypeInfo returnType) throws HiveException { + List childExpr, VectorExpressionDescriptor.Mode childrenMode, TypeInfo returnType) throws HiveException { int numChildren = childExpr == null ? 0: childExpr.size(); VectorExpression.Type [] inputTypes = new VectorExpression.Type[numChildren]; List children = new ArrayList(); @@ -1119,11 +1120,11 @@ private VectorExpression createVectorMultiAndOrProjectionExpr(Class vectorCla } } - private Mode getChildrenMode(Mode mode, Class udf) { - if (mode.equals(Mode.FILTER) && (udf.equals(GenericUDFOPAnd.class) || udf.equals(GenericUDFOPOr.class))) { - return Mode.FILTER; + private VectorExpressionDescriptor.Mode getChildrenMode(VectorExpressionDescriptor.Mode mode, Class udf) { + if (mode.equals(VectorExpressionDescriptor.Mode.FILTER) && (udf.equals(GenericUDFOPAnd.class) || udf.equals(GenericUDFOPOr.class))) { + return VectorExpressionDescriptor.Mode.FILTER; } - return Mode.PROJECTION; + return VectorExpressionDescriptor.Mode.PROJECTION; } private String getNewInstanceArgumentString(Object [] args) { @@ -1196,7 +1197,7 @@ private VectorExpression instantiateExpression(Class vclass, TypeInfo returnT } private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, - List childExpr, Mode mode, TypeInfo returnType) throws HiveException { + List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { List castedChildren = evaluateCastOnConstants(childExpr); childExpr = castedChildren; @@ -1204,7 +1205,7 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, //First handle special cases. If one of the special case methods cannot handle it, // it returns null. VectorExpression ve = null; - if (udf instanceof GenericUDFBetween && mode == Mode.FILTER) { + if (udf instanceof GenericUDFBetween && mode == VectorExpressionDescriptor.Mode.FILTER) { ve = getBetweenFilterExpression(childExpr, mode, returnType); } else if (udf instanceof GenericUDFIn) { ve = getInExpression(childExpr, mode, returnType); @@ -1249,13 +1250,13 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, } private VectorExpression getCastToTimestamp(GenericUDFTimestamp udf, - List childExpr, Mode mode, TypeInfo returnType) throws HiveException { + List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { VectorExpression ve = getVectorExpressionForUdf(udf, udf.getClass(), childExpr, mode, returnType); // Replace with the milliseconds conversion if (!udf.isIntToTimestampInSeconds() && ve instanceof CastLongToTimestamp) { ve = createVectorExpression(CastMillisecondsLongToTimestamp.class, - childExpr, Mode.PROJECTION, returnType); + childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } return ve; @@ -1266,7 +1267,7 @@ private VectorExpression getCoalesceExpression(List childExpr, Typ int[] inputColumns = new int[childExpr.size()]; VectorExpression[] vectorChildren = null; try { - vectorChildren = getVectorExpressions(childExpr, Mode.PROJECTION); + vectorChildren = getVectorExpressions(childExpr, VectorExpressionDescriptor.Mode.PROJECTION); int i = 0; for (VectorExpression ve : vectorChildren) { @@ -1293,7 +1294,7 @@ private VectorExpression getEltExpression(List childExpr, TypeInfo int[] inputColumns = new int[childExpr.size()]; VectorExpression[] vectorChildren = null; try { - vectorChildren = getVectorExpressions(childExpr, Mode.PROJECTION); + vectorChildren = getVectorExpressions(childExpr, VectorExpressionDescriptor.Mode.PROJECTION); int i = 0; for (VectorExpression ve : vectorChildren) { @@ -1363,7 +1364,7 @@ public static InConstantType getInConstantTypeFromPrimitiveCategory(PrimitiveCat } private VectorExpression getStructInExpression(List childExpr, ExprNodeDesc colExpr, - TypeInfo colTypeInfo, List inChildren, Mode mode, TypeInfo returnType) + TypeInfo colTypeInfo, List inChildren, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { VectorExpression expr = null; @@ -1484,9 +1485,9 @@ private VectorExpression getStructInExpression(List childExpr, Exp // generate the serialized keys of the batch. int scratchBytesCol = ocm.allocateOutputColumn("string"); - Class cl = (mode == Mode.FILTER ? FilterStructColumnInList.class : StructColumnInList.class); + Class cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterStructColumnInList.class : StructColumnInList.class); - expr = createVectorExpression(cl, null, Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, null, VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((IStringInExpr) expr).setInListValues(serializedInChildren); @@ -1500,7 +1501,7 @@ private VectorExpression getStructInExpression(List childExpr, Exp /** * Create a filter or boolean-valued expression for column IN ( ) */ - private VectorExpression getInExpression(List childExpr, Mode mode, TypeInfo returnType) + private VectorExpression getInExpression(List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { ExprNodeDesc colExpr = childExpr.get(0); List inChildren = childExpr.subList(1, childExpr.size()); @@ -1538,53 +1539,53 @@ private VectorExpression getInExpression(List childExpr, Mode mode // determine class Class cl = null; if (isIntFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterLongColumnInList.class : LongColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterLongColumnInList.class : LongColumnInList.class); long[] inVals = new long[childrenForInList.size()]; for (int i = 0; i != inVals.length; i++) { inVals[i] = getIntFamilyScalarAsLong((ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((ILongInExpr) expr).setInListValues(inVals); } else if (isTimestampFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterTimestampColumnInList.class : TimestampColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterTimestampColumnInList.class : TimestampColumnInList.class); Timestamp[] inVals = new Timestamp[childrenForInList.size()]; for (int i = 0; i != inVals.length; i++) { inVals[i] = getTimestampScalar(childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((ITimestampInExpr) expr).setInListValues(inVals); } else if (isStringFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterStringColumnInList.class : StringColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterStringColumnInList.class : StringColumnInList.class); byte[][] inVals = new byte[childrenForInList.size()][]; for (int i = 0; i != inVals.length; i++) { inVals[i] = getStringScalarAsByteArray((ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((IStringInExpr) expr).setInListValues(inVals); } else if (isFloatFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterDoubleColumnInList.class : DoubleColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterDoubleColumnInList.class : DoubleColumnInList.class); double[] inValsD = new double[childrenForInList.size()]; for (int i = 0; i != inValsD.length; i++) { inValsD[i] = getNumericScalarAsDouble(childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((IDoubleInExpr) expr).setInListValues(inValsD); } else if (isDecimalFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterDecimalColumnInList.class : DecimalColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterDecimalColumnInList.class : DecimalColumnInList.class); HiveDecimal[] inValsD = new HiveDecimal[childrenForInList.size()]; for (int i = 0; i != inValsD.length; i++) { inValsD[i] = (HiveDecimal) getVectorTypeScalarValue( (ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((IDecimalInExpr) expr).setInListValues(inValsD); } else if (isDateFamily(colType)) { - cl = (mode == Mode.FILTER ? FilterLongColumnInList.class : LongColumnInList.class); + cl = (mode == VectorExpressionDescriptor.Mode.FILTER ? FilterLongColumnInList.class : LongColumnInList.class); long[] inVals = new long[childrenForInList.size()]; for (int i = 0; i != inVals.length; i++) { inVals[i] = (Integer) getVectorTypeScalarValue((ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, returnType); + expr = createVectorExpression(cl, childExpr.subList(0, 1), VectorExpressionDescriptor.Mode.PROJECTION, returnType); ((ILongInExpr) expr).setInListValues(inVals); } @@ -1607,7 +1608,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode * descriptor based lookup. */ private VectorExpression getGenericUDFBridgeVectorExpression(GenericUDFBridge udf, - List childExpr, Mode mode, TypeInfo returnType) throws HiveException { + List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { Class cl = udf.getUdfClass(); VectorExpression ve = null; if (isCastToIntFamily(cl)) { @@ -1741,21 +1742,21 @@ private VectorExpression getCastToDecimal(List childExpr, TypeInfo // Return a constant vector expression Object constantValue = ((ExprNodeConstantDesc) child).getValue(); HiveDecimal decimalValue = castConstantToDecimal(constantValue, child.getTypeInfo()); - return getConstantVectorExpression(decimalValue, returnType, Mode.PROJECTION); + return getConstantVectorExpression(decimalValue, returnType, VectorExpressionDescriptor.Mode.PROJECTION); } if (isIntFamily(inputType)) { - return createVectorExpression(CastLongToDecimal.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastLongToDecimal.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isFloatFamily(inputType)) { - return createVectorExpression(CastDoubleToDecimal.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDoubleToDecimal.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (decimalTypePattern.matcher(inputType).matches()) { - return createVectorExpression(CastDecimalToDecimal.class, childExpr, Mode.PROJECTION, + return createVectorExpression(CastDecimalToDecimal.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isStringFamily(inputType)) { - return createVectorExpression(CastStringToDecimal.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastStringToDecimal.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (inputType.equals("timestamp")) { - return createVectorExpression(CastTimestampToDecimal.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastTimestampToDecimal.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } - throw null; + return null; } private VectorExpression getCastToString(List childExpr, TypeInfo returnType) @@ -1766,19 +1767,19 @@ private VectorExpression getCastToString(List childExpr, TypeInfo // Return a constant vector expression Object constantValue = ((ExprNodeConstantDesc) child).getValue(); String strValue = castConstantToString(constantValue, child.getTypeInfo()); - return getConstantVectorExpression(strValue, returnType, Mode.PROJECTION); + return getConstantVectorExpression(strValue, returnType, VectorExpressionDescriptor.Mode.PROJECTION); } if (inputType.equals("boolean")) { // Boolean must come before the integer family. It's a special case. - return createVectorExpression(CastBooleanToStringViaLongToString.class, childExpr, Mode.PROJECTION, null); + return createVectorExpression(CastBooleanToStringViaLongToString.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, null); } else if (isIntFamily(inputType)) { - return createVectorExpression(CastLongToString.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastLongToString.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDecimalFamily(inputType)) { - return createVectorExpression(CastDecimalToString.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDecimalToString.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDateFamily(inputType)) { - return createVectorExpression(CastDateToString.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDateToString.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isStringFamily(inputType)) { - return createVectorExpression(CastStringGroupToString.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastStringGroupToString.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } return null; } @@ -1794,15 +1795,15 @@ private VectorExpression getCastToChar(List childExpr, TypeInfo re } if (inputType.equals("boolean")) { // Boolean must come before the integer family. It's a special case. - return createVectorExpression(CastBooleanToCharViaLongToChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastBooleanToCharViaLongToChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isIntFamily(inputType)) { - return createVectorExpression(CastLongToChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastLongToChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDecimalFamily(inputType)) { - return createVectorExpression(CastDecimalToChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDecimalToChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDateFamily(inputType)) { - return createVectorExpression(CastDateToChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDateToChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isStringFamily(inputType)) { - return createVectorExpression(CastStringGroupToChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastStringGroupToChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } return null; } @@ -1818,15 +1819,15 @@ private VectorExpression getCastToVarChar(List childExpr, TypeInfo } if (inputType.equals("boolean")) { // Boolean must come before the integer family. It's a special case. - return createVectorExpression(CastBooleanToVarCharViaLongToVarChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastBooleanToVarCharViaLongToVarChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isIntFamily(inputType)) { - return createVectorExpression(CastLongToVarChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastLongToVarChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDecimalFamily(inputType)) { - return createVectorExpression(CastDecimalToVarChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDecimalToVarChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isDateFamily(inputType)) { - return createVectorExpression(CastDateToVarChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastDateToVarChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } else if (isStringFamily(inputType)) { - return createVectorExpression(CastStringGroupToVarChar.class, childExpr, Mode.PROJECTION, returnType); + return createVectorExpression(CastStringGroupToVarChar.class, childExpr, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } return null; } @@ -1839,17 +1840,17 @@ private VectorExpression getCastToDoubleExpression(Class udf, List childExpr) String inputType = childExpr.get(0).getTypeString(); if (child instanceof ExprNodeConstantDesc) { if (null == ((ExprNodeConstantDesc)child).getValue()) { - return getConstantVectorExpression(null, TypeInfoFactory.booleanTypeInfo, Mode.PROJECTION); + return getConstantVectorExpression(null, TypeInfoFactory.booleanTypeInfo, VectorExpressionDescriptor.Mode.PROJECTION); } // Don't do constant folding here. Wait until the optimizer is changed to do it. // Family of related JIRAs: HIVE-7421, HIVE-7422, and HIVE-7424. @@ -1875,7 +1876,7 @@ private VectorExpression getCastToBoolean(List childExpr) if (isStringFamily(inputType)) { // string casts to false if it is 0 characters long, otherwise true VectorExpression lenExpr = createVectorExpression(StringLength.class, childExpr, - Mode.PROJECTION, null); + VectorExpressionDescriptor.Mode.PROJECTION, null); int outputCol = ocm.allocateOutputColumn("Long"); VectorExpression lenToBoolExpr = @@ -1895,7 +1896,7 @@ private VectorExpression getCastToLongExpression(List childExpr) // Return a constant vector expression Object constantValue = ((ExprNodeConstantDesc) child).getValue(); Long longValue = castConstantToLong(constantValue, child.getTypeInfo()); - return getConstantVectorExpression(longValue, TypeInfoFactory.longTypeInfo, Mode.PROJECTION); + return getConstantVectorExpression(longValue, TypeInfoFactory.longTypeInfo, VectorExpressionDescriptor.Mode.PROJECTION); } // Float family, timestamp are handled via descriptor based lookup, int family needs // special handling. @@ -1912,10 +1913,10 @@ private VectorExpression getCastToLongExpression(List childExpr) * needs to be done differently than the standard way where all arguments are * passed to the VectorExpression constructor. */ - private VectorExpression getBetweenFilterExpression(List childExpr, Mode mode, TypeInfo returnType) + private VectorExpression getBetweenFilterExpression(List childExpr, VectorExpressionDescriptor.Mode mode, TypeInfo returnType) throws HiveException { - if (mode == Mode.PROJECTION) { + if (mode == VectorExpressionDescriptor.Mode.PROJECTION) { // Projection mode is not yet supported for [NOT] BETWEEN. Return null so Vectorizer // knows to revert to row-at-a-time execution. @@ -2000,17 +2001,17 @@ private VectorExpression getBetweenFilterExpression(List childExpr } else if (isDateFamily(colType) && notKeywordPresent) { cl = FilterLongColumnNotBetween.class; } - return createVectorExpression(cl, childrenAfterNot, Mode.PROJECTION, returnType); + return createVectorExpression(cl, childrenAfterNot, VectorExpressionDescriptor.Mode.PROJECTION, returnType); } /* * Return vector expression for a custom (i.e. not built-in) UDF. */ - private VectorExpression getCustomUDFExpression(ExprNodeGenericFuncDesc expr, Mode mode) + private VectorExpression getCustomUDFExpression(ExprNodeGenericFuncDesc expr, VectorExpressionDescriptor.Mode mode) throws HiveException { boolean isFilter = false; // Assume. - if (mode == Mode.FILTER) { + if (mode == VectorExpressionDescriptor.Mode.FILTER) { // Is output type a BOOLEAN? TypeInfo resultTypeInfo = expr.getTypeInfo(); @@ -2043,7 +2044,7 @@ private VectorExpression getCustomUDFExpression(ExprNodeGenericFuncDesc expr, M for (int i = 0; i < childExprList.size(); i++) { ExprNodeDesc child = childExprList.get(i); if (child instanceof ExprNodeGenericFuncDesc) { - VectorExpression e = getVectorExpression(child, Mode.PROJECTION); + VectorExpression e = getVectorExpression(child, VectorExpressionDescriptor.Mode.PROJECTION); vectorExprs.add(e); variableArgPositions.add(i); exprResultColumnNums.add(e.getOutputColumn()); @@ -2384,67 +2385,125 @@ public static String mapTypeNameSynonyms(String typeName) { } } - // TODO: When we support vectorized STRUCTs and can handle more in the reduce-side (MERGEPARTIAL): - // TODO: Write reduce-side versions of AVG. Currently, only map-side (HASH) versions are in table. - // TODO: And, investigate if different reduce-side versions are needed for var* and std*, or if map-side aggregate can be used.. Right now they are conservatively - // marked map-side (HASH). + + /* + * In the aggregatesDefinition table, Mode is GenericUDAFEvaluator.Mode. + * + * It is the different modes for an aggregate UDAF (User Defined Aggregation Function). + * + * (Notice the these names are a subset of GroupByDesc.Mode...) + * + * PARTIAL1 Original data --> Partial aggregation data + * + * PARTIAL2 Partial aggregation data --> Partial aggregation data + * + * FINAL Partial aggregation data --> Full aggregation data + * + * COMPLETE Original data --> Full aggregation data + * + * + * SIMPLEST CASE --> The data type/semantics of original data, partial aggregation + * data, and full aggregation data ARE THE SAME. E.g. MIN, MAX, SUM. The different + * modes can be handled by one aggregation class. + * + * This case has a null for the Mode. + * + * FOR OTHERS --> The data type/semantics of partial aggregation data and full aggregation data + * ARE THE SAME but different than original data. This results in 2 aggregation classes: + * + * 1) A class that takes original rows and outputs partial/full aggregation + * (PARTIAL1/COMPLETE) + * + * and + * + * 2) A class that takes partial aggregation and produces full aggregation + * (PARTIAL2/FINAL). + * + * E.g. COUNT(*) and COUNT(column) + * + * OTHERWISE FULL --> The data type/semantics of partial aggregation data is different than + * original data and full aggregation data. + * + * E.g. AVG uses a STRUCT with count and sum for partial aggregation data. It divides + * sum by count to produce the average for final aggregation. + * + */ static ArrayList aggregatesDefinition = new ArrayList() {{ - add(new AggregateDefinition("min", VectorExpressionDescriptor.ArgumentType.INT_DATE_INTERVAL_YEAR_MONTH, null, VectorUDAFMinLong.class)); - add(new AggregateDefinition("min", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, null, VectorUDAFMinDouble.class)); - add(new AggregateDefinition("min", VectorExpressionDescriptor.ArgumentType.STRING_FAMILY, null, VectorUDAFMinString.class)); - add(new AggregateDefinition("min", VectorExpressionDescriptor.ArgumentType.DECIMAL, null, VectorUDAFMinDecimal.class)); - add(new AggregateDefinition("min", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, null, VectorUDAFMinTimestamp.class)); - add(new AggregateDefinition("max", VectorExpressionDescriptor.ArgumentType.INT_DATE_INTERVAL_YEAR_MONTH, null, VectorUDAFMaxLong.class)); - add(new AggregateDefinition("max", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, null, VectorUDAFMaxDouble.class)); - add(new AggregateDefinition("max", VectorExpressionDescriptor.ArgumentType.STRING_FAMILY, null, VectorUDAFMaxString.class)); - add(new AggregateDefinition("max", VectorExpressionDescriptor.ArgumentType.DECIMAL, null, VectorUDAFMaxDecimal.class)); - add(new AggregateDefinition("max", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, null, VectorUDAFMaxTimestamp.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.NONE, GroupByDesc.Mode.HASH, VectorUDAFCountStar.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.INT_DATE_INTERVAL_YEAR_MONTH, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.MERGEPARTIAL, VectorUDAFCountMerge.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.STRING_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("count", VectorExpressionDescriptor.ArgumentType.INTERVAL_DAY_TIME, GroupByDesc.Mode.HASH, VectorUDAFCount.class)); - add(new AggregateDefinition("sum", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, null, VectorUDAFSumLong.class)); - add(new AggregateDefinition("sum", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, null, VectorUDAFSumDouble.class)); - add(new AggregateDefinition("sum", VectorExpressionDescriptor.ArgumentType.DECIMAL, null, VectorUDAFSumDecimal.class)); - add(new AggregateDefinition("avg", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFAvgLong.class)); - add(new AggregateDefinition("avg", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFAvgDouble.class)); - add(new AggregateDefinition("avg", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFAvgDecimal.class)); - add(new AggregateDefinition("avg", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFAvgTimestamp.class)); - add(new AggregateDefinition("variance", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarPopLong.class)); - add(new AggregateDefinition("var_pop", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarPopLong.class)); - add(new AggregateDefinition("variance", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarPopDouble.class)); - add(new AggregateDefinition("var_pop", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarPopDouble.class)); - add(new AggregateDefinition("variance", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFVarPopDecimal.class)); - add(new AggregateDefinition("var_pop", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFVarPopDecimal.class)); - add(new AggregateDefinition("variance", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFVarPopTimestamp.class)); - add(new AggregateDefinition("var_pop", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFVarPopTimestamp.class)); - add(new AggregateDefinition("var_samp", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarSampLong.class)); - add(new AggregateDefinition("var_samp" , VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFVarSampDouble.class)); - add(new AggregateDefinition("var_samp" , VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFVarSampDecimal.class)); - add(new AggregateDefinition("var_samp" , VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFVarSampTimestamp.class)); - add(new AggregateDefinition("std", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopLong.class)); - add(new AggregateDefinition("stddev", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopLong.class)); - add(new AggregateDefinition("stddev_pop", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopLong.class)); - add(new AggregateDefinition("std", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopDouble.class)); - add(new AggregateDefinition("stddev", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopDouble.class)); - add(new AggregateDefinition("stddev_pop", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdPopDouble.class)); - add(new AggregateDefinition("std", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFStdPopDecimal.class)); - add(new AggregateDefinition("stddev", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFStdPopDecimal.class)); - add(new AggregateDefinition("stddev_pop", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFStdPopDecimal.class)); - add(new AggregateDefinition("std", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFStdPopTimestamp.class)); - add(new AggregateDefinition("stddev", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFStdPopTimestamp.class)); - add(new AggregateDefinition("stddev_pop", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFStdPopTimestamp.class)); - add(new AggregateDefinition("stddev_samp", VectorExpressionDescriptor.ArgumentType.INT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdSampLong.class)); - add(new AggregateDefinition("stddev_samp", VectorExpressionDescriptor.ArgumentType.FLOAT_FAMILY, GroupByDesc.Mode.HASH, VectorUDAFStdSampDouble.class)); - add(new AggregateDefinition("stddev_samp", VectorExpressionDescriptor.ArgumentType.DECIMAL, GroupByDesc.Mode.HASH, VectorUDAFStdSampDecimal.class)); - add(new AggregateDefinition("stddev_samp", VectorExpressionDescriptor.ArgumentType.TIMESTAMP, GroupByDesc.Mode.HASH, VectorUDAFStdSampTimestamp.class)); + + // MIN, MAX, and SUM have the same representation for partial and full aggregation, so the + // same class can be used for all modes (PARTIAL1, PARTIAL2, FINAL, and COMPLETE). + add(new AggregateDefinition("min", ArgumentType.INT_DATE_INTERVAL_YEAR_MONTH, null, VectorUDAFMinLong.class)); + add(new AggregateDefinition("min", ArgumentType.FLOAT_FAMILY, null, VectorUDAFMinDouble.class)); + add(new AggregateDefinition("min", ArgumentType.STRING_FAMILY, null, VectorUDAFMinString.class)); + add(new AggregateDefinition("min", ArgumentType.DECIMAL, null, VectorUDAFMinDecimal.class)); + add(new AggregateDefinition("min", ArgumentType.TIMESTAMP, null, VectorUDAFMinTimestamp.class)); + add(new AggregateDefinition("max", ArgumentType.INT_DATE_INTERVAL_YEAR_MONTH, null, VectorUDAFMaxLong.class)); + add(new AggregateDefinition("max", ArgumentType.FLOAT_FAMILY, null, VectorUDAFMaxDouble.class)); + add(new AggregateDefinition("max", ArgumentType.STRING_FAMILY, null, VectorUDAFMaxString.class)); + add(new AggregateDefinition("max", ArgumentType.DECIMAL, null, VectorUDAFMaxDecimal.class)); + add(new AggregateDefinition("max", ArgumentType.TIMESTAMP, null, VectorUDAFMaxTimestamp.class)); + add(new AggregateDefinition("sum", ArgumentType.INT_FAMILY, null, VectorUDAFSumLong.class)); + add(new AggregateDefinition("sum", ArgumentType.FLOAT_FAMILY, null, VectorUDAFSumDouble.class)); + add(new AggregateDefinition("sum", ArgumentType.DECIMAL, null, VectorUDAFSumDecimal.class)); + + // COUNT(column) doesn't count rows whose column value is NULL. + add(new AggregateDefinition("count", ArgumentType.ALL_FAMILY, Mode.PARTIAL1, VectorUDAFCount.class)); + add(new AggregateDefinition("count", ArgumentType.ALL_FAMILY, Mode.COMPLETE, VectorUDAFCount.class)); + + // COUNT(*) counts all rows regardless of whether the column value(s) are NULL. + add(new AggregateDefinition("count", ArgumentType.NONE, Mode.PARTIAL1, VectorUDAFCountStar.class)); + add(new AggregateDefinition("count", ArgumentType.NONE, Mode.COMPLETE, VectorUDAFCountStar.class)); + + // Merge the counts produced by either COUNT(column) or COUNT(*) modes PARTIAL1 or PARTIAL2. + add(new AggregateDefinition("count", ArgumentType.INT_FAMILY, Mode.PARTIAL2, VectorUDAFCountMerge.class)); + add(new AggregateDefinition("count", ArgumentType.INT_FAMILY, Mode.FINAL, VectorUDAFCountMerge.class)); + + // Since the partial aggregation produced by AVG is a STRUCT with count and sum and the + // STRUCT data type isn't vectorized yet, we currently only support PARTIAL1. When we do + // support STRUCTs for average partial aggregation, we'll need 4 variations: + // + // PARTIAL1 Original data --> STRUCT Average Partial Aggregation + // PARTIAL2 STRUCT Average Partial Aggregation --> STRUCT Average Partial Aggregation + // FINAL STRUCT Average Partial Aggregation --> Full Aggregation + // COMPLETE Original data --> Full Aggregation + // + add(new AggregateDefinition("avg", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFAvgLong.class)); + add(new AggregateDefinition("avg", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFAvgDouble.class)); + add(new AggregateDefinition("avg", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFAvgDecimal.class)); + add(new AggregateDefinition("avg", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFAvgTimestamp.class)); + + // We haven't had a chance to examine the VAR* and STD* area and expand it beyond PARTIAL1. + add(new AggregateDefinition("variance", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFVarPopLong.class)); + add(new AggregateDefinition("var_pop", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFVarPopLong.class)); + add(new AggregateDefinition("variance", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFVarPopDouble.class)); + add(new AggregateDefinition("var_pop", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFVarPopDouble.class)); + add(new AggregateDefinition("variance", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFVarPopDecimal.class)); + add(new AggregateDefinition("var_pop", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFVarPopDecimal.class)); + add(new AggregateDefinition("variance", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFVarPopTimestamp.class)); + add(new AggregateDefinition("var_pop", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFVarPopTimestamp.class)); + add(new AggregateDefinition("var_samp", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFVarSampLong.class)); + add(new AggregateDefinition("var_samp" , ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFVarSampDouble.class)); + add(new AggregateDefinition("var_samp" , ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFVarSampDecimal.class)); + add(new AggregateDefinition("var_samp" , ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFVarSampTimestamp.class)); + add(new AggregateDefinition("std", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopLong.class)); + add(new AggregateDefinition("stddev", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopLong.class)); + add(new AggregateDefinition("stddev_pop", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopLong.class)); + add(new AggregateDefinition("std", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopDouble.class)); + add(new AggregateDefinition("stddev", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopDouble.class)); + add(new AggregateDefinition("stddev_pop", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFStdPopDouble.class)); + add(new AggregateDefinition("std", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFStdPopDecimal.class)); + add(new AggregateDefinition("stddev", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFStdPopDecimal.class)); + add(new AggregateDefinition("stddev_pop", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFStdPopDecimal.class)); + add(new AggregateDefinition("std", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFStdPopTimestamp.class)); + add(new AggregateDefinition("stddev", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFStdPopTimestamp.class)); + add(new AggregateDefinition("stddev_pop", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFStdPopTimestamp.class)); + add(new AggregateDefinition("stddev_samp", ArgumentType.INT_FAMILY, Mode.PARTIAL1, VectorUDAFStdSampLong.class)); + add(new AggregateDefinition("stddev_samp", ArgumentType.FLOAT_FAMILY, Mode.PARTIAL1, VectorUDAFStdSampDouble.class)); + add(new AggregateDefinition("stddev_samp", ArgumentType.DECIMAL, Mode.PARTIAL1, VectorUDAFStdSampDecimal.class)); + add(new AggregateDefinition("stddev_samp", ArgumentType.TIMESTAMP, Mode.PARTIAL1, VectorUDAFStdSampTimestamp.class)); }}; - public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc, boolean isReduceMergePartial) + public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) throws HiveException { ArrayList paramDescList = desc.getParameters(); @@ -2452,7 +2511,7 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc, b for (int i = 0; i< paramDescList.size(); ++i) { ExprNodeDesc exprDesc = paramDescList.get(i); - vectorParams[i] = this.getVectorExpression(exprDesc, Mode.PROJECTION); + vectorParams[i] = this.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); } String aggregateName = desc.getGenericUDAFName(); @@ -2466,15 +2525,16 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc, b } } + GenericUDAFEvaluator.Mode udafEvaluatorMode = desc.getMode(); for (AggregateDefinition aggDef : aggregatesDefinition) { if (aggregateName.equalsIgnoreCase(aggDef.getName()) && ((aggDef.getType() == VectorExpressionDescriptor.ArgumentType.NONE && inputType == VectorExpressionDescriptor.ArgumentType.NONE) || (aggDef.getType().isSameTypeOrFamily(inputType)))) { - if (aggDef.getMode() == GroupByDesc.Mode.HASH && isReduceMergePartial) { - continue; - } else if (aggDef.getMode() == GroupByDesc.Mode.MERGEPARTIAL && !isReduceMergePartial) { + // A null means all modes are ok. + GenericUDAFEvaluator.Mode aggDefUdafEvaluatorMode = aggDef.getUdafEvaluatorMode(); + if (aggDefUdafEvaluatorMode != null && aggDefUdafEvaluatorMode != udafEvaluatorMode) { continue; } @@ -2495,7 +2555,9 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc, b } throw new HiveException("Vector aggregate not implemented: \"" + aggregateName + - "\" for type: \"" + inputType.name() + " (reduce-merge-partial = " + isReduceMergePartial + ")"); + "\" for type: \"" + inputType.name() + + " (UDAF evaluator mode = " + + (udafEvaluatorMode == null ? "NULL" : udafEvaluatorMode.name()) + ")"); } public int firstOutputColumnIndex() { diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index c1d6582..00203ae 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -33,6 +33,7 @@ import java.util.Stack; import java.util.regex.Pattern; +import org.apache.calcite.util.Pair; import org.apache.commons.lang3.tuple.ImmutablePair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -96,6 +97,8 @@ import org.apache.hadoop.hive.ql.plan.MapJoinDesc; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode; +import org.apache.hadoop.hive.ql.plan.VectorPartitionConversion; import org.apache.hadoop.hive.ql.plan.PartitionDesc; import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; import org.apache.hadoop.hive.ql.plan.ReduceWork; @@ -1540,76 +1543,131 @@ private boolean validateGroupByOperator(GroupByOperator op, boolean isReduce, bo LOG.info("Pruning grouping set id not supported in vector mode"); return false; } + if (desc.getMode() != GroupByDesc.Mode.HASH && desc.isDistinct()) { + LOG.info("DISTINCT not supported in vector mode"); + return false; + } boolean ret = validateExprNodeDesc(desc.getKeys()); if (!ret) { - LOG.info("Cannot vectorize groupby key expression"); + LOG.info("Cannot vectorize groupby key expression " + desc.getKeys().toString()); return false; } - if (!isReduce) { - - // MapWork - - ret = validateHashAggregationDesc(desc.getAggregators()); - if (!ret) { - return false; - } - } else { - - // ReduceWork - - boolean isComplete = desc.getMode() == GroupByDesc.Mode.COMPLETE; - if (desc.getMode() != GroupByDesc.Mode.HASH) { - - // Reduce Merge-Partial GROUP BY. - - // A merge-partial GROUP BY is fed by grouping by keys from reduce-shuffle. It is the - // first (or root) operator for its reduce task. - // TODO: Technically, we should also handle FINAL, PARTIAL1, PARTIAL2 and PARTIALS - // that are not hash or complete, but aren't merge-partial, somehow. - - if (desc.isDistinct()) { - LOG.info("Vectorized Reduce MergePartial GROUP BY does not support DISTINCT"); - return false; - } - - boolean hasKeys = (desc.getKeys().size() > 0); - - // Do we support merge-partial aggregation AND the output is primitive? - ret = validateReduceMergePartialAggregationDesc(desc.getAggregators(), hasKeys); - if (!ret) { - return false; - } + /** + * + * GROUP BY DEFINITIONS: + * + * GroupByDesc.Mode enumeration: + * + * The different modes of a GROUP BY operator. + * + * These descriptions are hopefully less cryptic than the comments for GroupByDesc.Mode. + * + * COMPLETE Aggregates original rows into full aggregation row(s). + * + * If the key length is 0, this is also called Global aggregation and + * 1 output row is produced. + * + * When the key length is > 0, the original rows come in ALREADY GROUPED. + * + * An example for key length > 0 is a GROUP BY being applied to the + * ALREADY GROUPED rows coming from an upstream JOIN operator. Or, + * ALREADY GROUPED rows coming from upstream MERGEPARTIAL GROUP BY + * operator. + * + * PARTIAL1 The first of 2 (or more) phases that aggregates ALREADY GROUPED + * original rows into partial aggregations. + * + * Subsequent phases PARTIAL2 (optional) and MERGEPARTIAL will merge + * the partial aggregations and output full aggregations. + * + * PARTIAL2 Accept ALREADY GROUPED partial aggregations and merge them into another + * partial aggregation. Output the merged partial aggregations. + * + * (Haven't seen this one used) + * + * PARTIALS (Behaves for non-distinct the same as PARTIAL2; and behaves for + * distinct the same as PARTIAL1.) + * + * FINAL Accept ALREADY GROUPED original rows and aggregate them into + * full aggregations. + * + * Example is a GROUP BY being applied to rows from a sorted table, where + * the group key is the table sort key (or a prefix). + * + * HASH Accept UNORDERED original rows and aggregate them into a memory table. + * Output the partial aggregations on closeOp (or low memory). + * + * Similar to PARTIAL1 except original rows are UNORDERED. + * + * Commonly used in both Mapper and Reducer nodes. Always followed by + * a Reducer with MERGEPARTIAL GROUP BY. + * + * MERGEPARTIAL Always first operator of a Reducer. Data is grouped by reduce-shuffle. + * + * (Behaves for non-distinct aggregations the same as FINAL; and behaves + * for distinct aggregations the same as COMPLETE.) + * + * The output is full aggregation(s). + * + * Used in Reducers after a stage with a HASH GROUP BY operator. + * + * + * VectorGroupByDesc.ProcessingMode for VectorGroupByOperator: + * + * GLOBAL No key. All rows --> 1 full aggregation on end of input + * + * HASH Rows aggregated in to hash table on group key --> + * 1 partial aggregation per key (normally, unless there is spilling) + * + * MERGE_PARTIAL As first operator in a REDUCER, partial aggregations come grouped from + * reduce-shuffle --> + * aggregate the partial aggregations and emit full aggregation on + * endGroup / closeOp + * + * STREAMING Rows come from PARENT operator ALREADY GROUPED --> + * aggregate the rows and emit full aggregation on key change / closeOp + * + * NOTE: Hash can spill partial result rows prematurely if it runs low on memory. + * NOTE: Streaming has to compare keys where MergePartial gets an endGroup call. + * + * + * DECIDER: Which VectorGroupByDesc.ProcessingMode for VectorGroupByOperator? + * + * Decides using GroupByDesc.Mode and whether there are keys with the + * VectorGroupByDesc.groupByDescModeToVectorProcessingMode method. + * + * Mode.COMPLETE --> (numKeys == 0 ? ProcessingMode.GLOBAL : ProcessingMode.STREAMING) + * + * Mode.HASH --> ProcessingMode.HASH + * + * Mode.MERGEPARTIAL --> (numKeys == 0 ? ProcessingMode.GLOBAL : ProcessingMode.MERGE_PARTIAL) + * + * Mode.PARTIAL1, + * Mode.PARTIAL2, + * Mode.PARTIALS, + * Mode.FINAL --> ProcessingMode.STREAMING + * + */ + boolean hasKeys = (desc.getKeys().size() > 0); - if (hasKeys) { - if (op.getParentOperators().size() > 0 && !isComplete) { - LOG.info("Vectorized Reduce MergePartial GROUP BY keys can only handle a key group when it is fed by reduce-shuffle"); - return false; - } + ProcessingMode processingMode = + VectorGroupByDesc.groupByDescModeToVectorProcessingMode(desc.getMode(), hasKeys); - LOG.info("Vectorized Reduce MergePartial GROUP BY will process key groups"); + Pair retPair = + validateAggregationDescs(desc.getAggregators(), processingMode, hasKeys); + if (!retPair.left) { + return false; + } - // Primitive output validation above means we can output VectorizedRowBatch to the - // children operators. - vectorDesc.setVectorOutput(true); - } else { - LOG.info("Vectorized Reduce MergePartial GROUP BY will do global aggregation"); - } - if (!isComplete) { - vectorDesc.setIsReduceMergePartial(true); - } else { - vectorDesc.setIsReduceStreaming(true); - } - } else { + // If all the aggregation outputs are primitive, we can output VectorizedRowBatch. + // Otherwise, we the rest of the operator tree will be row mode. + vectorDesc.setVectorOutput(retPair.right); - // Reduce Hash GROUP BY or global aggregation. + vectorDesc.setProcessingMode(processingMode); - ret = validateHashAggregationDesc(desc.getAggregators()); - if (!ret) { - return false; - } - } - } + LOG.info("Vector GROUP BY operator will use processing mode " + processingMode.name() + + ", isVectorOutput " + vectorDesc.isVectorOutput()); return true; } @@ -1633,23 +1691,19 @@ private boolean validateExprNodeDesc(List descs, return true; } - - private boolean validateHashAggregationDesc(List descs) { - return validateAggregationDesc(descs, /* isReduceMergePartial */ false, false); - } - - private boolean validateReduceMergePartialAggregationDesc(List descs, boolean hasKeys) { - return validateAggregationDesc(descs, /* isReduceMergePartial */ true, hasKeys); - } - - private boolean validateAggregationDesc(List descs, boolean isReduceMergePartial, boolean hasKeys) { + private Pair validateAggregationDescs(List descs, + ProcessingMode processingMode, boolean hasKeys) { + boolean outputIsPrimitive = true; for (AggregationDesc d : descs) { - boolean ret = validateAggregationDesc(d, isReduceMergePartial, hasKeys); - if (!ret) { - return false; + Pair retPair = validateAggregationDesc(d, processingMode, hasKeys); + if (!retPair.left) { + return retPair; + } + if (!retPair.right) { + outputIsPrimitive = false; } } - return true; + return new Pair(true, outputIsPrimitive); } private boolean validateExprNodeDescRecursive(ExprNodeDesc desc, VectorExpressionDescriptor.Mode mode) { @@ -1787,38 +1841,45 @@ private boolean validateAggregationIsPrimitive(VectorAggregateExpression vectorA return (outputObjInspector.getCategory() == ObjectInspector.Category.PRIMITIVE); } - private boolean validateAggregationDesc(AggregationDesc aggDesc, boolean isReduceMergePartial, - boolean hasKeys) { + private Pair validateAggregationDesc(AggregationDesc aggDesc, ProcessingMode processingMode, + boolean hasKeys) { String udfName = aggDesc.getGenericUDAFName().toLowerCase(); if (!supportedAggregationUdfs.contains(udfName)) { LOG.info("Cannot vectorize groupby aggregate expression: UDF " + udfName + " not supported"); - return false; + return new Pair(false, false); } if (aggDesc.getParameters() != null && !validateExprNodeDesc(aggDesc.getParameters())) { LOG.info("Cannot vectorize groupby aggregate expression: UDF parameters not supported"); - return false; + return new Pair(false, false); } // See if we can vectorize the aggregation. VectorizationContext vc = new ValidatorVectorizationContext(); VectorAggregateExpression vectorAggrExpr; try { - vectorAggrExpr = vc.getAggregatorExpression(aggDesc, isReduceMergePartial); + vectorAggrExpr = vc.getAggregatorExpression(aggDesc); } catch (Exception e) { // We should have already attempted to vectorize in validateAggregationDesc. if (LOG.isDebugEnabled()) { LOG.debug("Vectorization of aggreation should have succeeded ", e); } - return false; + return new Pair(false, false); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Aggregation " + aggDesc.getExprString() + " --> " + + " vector expression " + vectorAggrExpr.toString()); } - if (isReduceMergePartial && hasKeys && !validateAggregationIsPrimitive(vectorAggrExpr)) { + boolean outputIsPrimitive = validateAggregationIsPrimitive(vectorAggrExpr); + if (processingMode == ProcessingMode.MERGE_PARTIAL && + hasKeys && + !outputIsPrimitive) { LOG.info("Vectorized Reduce MergePartial GROUP BY keys can only handle aggregate outputs that are primitive types"); - return false; + return new Pair(false, false); } - return true; + return new Pair(true, outputIsPrimitive); } private boolean validateDataType(String type, VectorExpressionDescriptor.Mode mode) { diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/VectorGroupByDesc.java ql/src/java/org/apache/hadoop/hive/ql/plan/VectorGroupByDesc.java index e613a4e..08f8ebf 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/VectorGroupByDesc.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/VectorGroupByDesc.java @@ -30,23 +30,45 @@ private static long serialVersionUID = 1L; - private boolean isReduceMergePartial; + /** + * GLOBAL No key. All rows --> 1 full aggregation on end of input + * + * HASH Rows aggregated in to hash table on group key --> + * 1 partial aggregation per key (normally, unless there is spilling) + * + * MERGE_PARTIAL As first operator in a REDUCER, partial aggregations come grouped from + * reduce-shuffle --> + * aggregate the partial aggregations and emit full aggregation on + * endGroup / closeOp + * + * STREAMING Rows come from PARENT operator already grouped --> + * aggregate the rows and emit full aggregation on key change / closeOp + * + * NOTE: Hash can spill partial result rows prematurely if it runs low on memory. + * NOTE: Streaming has to compare keys where MergePartial gets an endGroup call. + */ + public static enum ProcessingMode { + NONE, + GLOBAL, + HASH, + MERGE_PARTIAL, + STREAMING + }; - private boolean isVectorOutput; + private ProcessingMode processingMode; - private boolean isReduceStreaming; + private boolean isVectorOutput; public VectorGroupByDesc() { - this.isReduceMergePartial = false; + this.processingMode = ProcessingMode.NONE; this.isVectorOutput = false; } - public boolean isReduceMergePartial() { - return isReduceMergePartial; + public void setProcessingMode(ProcessingMode processingMode) { + this.processingMode = processingMode; } - - public void setIsReduceMergePartial(boolean isReduceMergePartial) { - this.isReduceMergePartial = isReduceMergePartial; + public ProcessingMode getProcessingMode() { + return processingMode; } public boolean isVectorOutput() { @@ -57,11 +79,39 @@ public void setVectorOutput(boolean isVectorOutput) { this.isVectorOutput = isVectorOutput; } - public void setIsReduceStreaming(boolean isReduceStreaming) { - this.isReduceStreaming = isReduceStreaming; - } - - public boolean isReduceStreaming() { - return isReduceStreaming; + /** + * Which ProcessingMode for VectorGroupByOperator? + * + * Decides using GroupByDesc.Mode and whether there are keys. + * + * Mode.COMPLETE --> (numKeys == 0 ? ProcessingMode.GLOBAL : ProcessingMode.STREAMING) + * + * Mode.HASH --> ProcessingMode.HASH + * + * Mode.MERGEPARTIAL --> (numKeys == 0 ? ProcessingMode.GLOBAL : ProcessingMode.MERGE_PARTIAL) + * + * Mode.PARTIAL1, + * Mode.PARTIAL2, + * Mode.PARTIALS, + * Mode.FINAL --> ProcessingMode.STREAMING + * + */ + public static ProcessingMode groupByDescModeToVectorProcessingMode(GroupByDesc.Mode mode, + boolean hasKeys) { + switch (mode) { + case COMPLETE: + return (hasKeys ? ProcessingMode.STREAMING : ProcessingMode.GLOBAL); + case HASH: + return ProcessingMode.HASH; + case MERGEPARTIAL: + return (hasKeys ? ProcessingMode.MERGE_PARTIAL : ProcessingMode.GLOBAL); + case PARTIAL1: + case PARTIAL2: + case PARTIALS: + case FINAL: + return ProcessingMode.STREAMING; + default: + throw new RuntimeException("Unexpected GROUP BY mode " + mode.name()); + } } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index 451947b..f5b5d9d 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -50,6 +50,8 @@ import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc; +import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -84,6 +86,7 @@ private static ExprNodeDesc buildColumnDesc( private static AggregationDesc buildAggregationDesc( VectorizationContext ctx, String aggregate, + GenericUDAFEvaluator.Mode mode, String column, TypeInfo typeInfo) { @@ -94,6 +97,7 @@ private static AggregationDesc buildAggregationDesc( AggregationDesc agg = new AggregationDesc(); agg.setGenericUDAFName(aggregate); + agg.setMode(mode); agg.setParameters(params); return agg; @@ -102,6 +106,7 @@ private static AggregationDesc buildAggregationDescCountStar( VectorizationContext ctx) { AggregationDesc agg = new AggregationDesc(); agg.setGenericUDAFName("COUNT"); + agg.setMode(GenericUDAFEvaluator.Mode.PARTIAL1); agg.setParameters(new ArrayList()); return agg; } @@ -110,10 +115,11 @@ private static AggregationDesc buildAggregationDescCountStar( private static GroupByDesc buildGroupByDescType( VectorizationContext ctx, String aggregate, + GenericUDAFEvaluator.Mode mode, String column, TypeInfo dataType) { - AggregationDesc agg = buildAggregationDesc(ctx, aggregate, + AggregationDesc agg = buildAggregationDesc(ctx, aggregate, mode, column, dataType); ArrayList aggs = new ArrayList(); aggs.add(agg); @@ -124,6 +130,7 @@ private static GroupByDesc buildGroupByDescType( GroupByDesc desc = new GroupByDesc(); desc.setOutputColumnNames(outputColumnNames); desc.setAggregators(aggs); + desc.getVectorDesc().setProcessingMode(ProcessingMode.GLOBAL); return desc; } @@ -154,7 +161,8 @@ private static GroupByDesc buildKeyGroupByDesc( String key, TypeInfo keyTypeInfo) { - GroupByDesc desc = buildGroupByDescType(ctx, aggregate, column, dataTypeInfo); + GroupByDesc desc = buildGroupByDescType(ctx, aggregate, GenericUDAFEvaluator.Mode.PARTIAL1, column, dataTypeInfo); + desc.getVectorDesc().setProcessingMode(ProcessingMode.HASH); ExprNodeDesc keyExp = buildColumnDesc(ctx, key, keyTypeInfo); ArrayList keys = new ArrayList(); @@ -1716,7 +1724,7 @@ private void testMultiKey( ArrayList aggs = new ArrayList(1); aggs.add( - buildAggregationDesc(ctx, aggregateName, + buildAggregationDesc(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "value", TypeInfoFactory.getPrimitiveTypeInfo(columnTypes[i]))); for(i=0; i keys = new HashSet(); - AggregationDesc agg = buildAggregationDesc(ctx, aggregateName, + AggregationDesc agg = buildAggregationDesc(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "Value", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[1])); ArrayList aggs = new ArrayList(); aggs.add(agg); @@ -1839,6 +1848,7 @@ private void testKeyTypeAggregate( GroupByDesc desc = new GroupByDesc(); desc.setOutputColumnNames(outputColumnNames); desc.setAggregators(aggs); + desc.getVectorDesc().setProcessingMode(ProcessingMode.HASH); ExprNodeDesc keyExp = buildColumnDesc(ctx, "Key", TypeInfoFactory.getPrimitiveTypeInfo(data.getTypes()[0])); @@ -2242,6 +2252,7 @@ public void testAggregateCountStarIterable ( VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); GroupByDesc desc = buildGroupByDescCountStar (ctx); + desc.getVectorDesc().setProcessingMode(ProcessingMode.HASH); CompilationOpContext cCtx = new CompilationOpContext(); VectorGroupByOperator vgo = new VectorGroupByOperator(cCtx, ctx, desc); @@ -2271,9 +2282,9 @@ public void testAggregateCountReduceIterable ( mapColumnNames.add("A"); VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); - GroupByDesc desc = buildGroupByDescType(ctx, "count", "A", TypeInfoFactory.longTypeInfo); + GroupByDesc desc = buildGroupByDescType(ctx, "count", GenericUDAFEvaluator.Mode.FINAL, "A", TypeInfoFactory.longTypeInfo); VectorGroupByDesc vectorDesc = desc.getVectorDesc(); - vectorDesc.setIsReduceMergePartial(true); + vectorDesc.setProcessingMode(ProcessingMode.GLOBAL); // Use GLOBAL when no key for Reduce. CompilationOpContext cCtx = new CompilationOpContext(); VectorGroupByOperator vgo = new VectorGroupByOperator(cCtx, ctx, desc); @@ -2303,7 +2314,7 @@ public void testAggregateStringIterable ( mapColumnNames.add("A"); VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); - GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.stringTypeInfo); CompilationOpContext cCtx = new CompilationOpContext(); @@ -2336,7 +2347,7 @@ public void testAggregateDecimalIterable ( VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); GroupByDesc desc = - buildGroupByDescType(ctx, aggregateName, "A", TypeInfoFactory.getDecimalTypeInfo(30, 4)); + buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.getDecimalTypeInfo(30, 4)); CompilationOpContext cCtx = new CompilationOpContext(); VectorGroupByOperator vgo = new VectorGroupByOperator(cCtx, ctx, desc); @@ -2368,7 +2379,7 @@ public void testAggregateDoubleIterable ( mapColumnNames.add("A"); VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); - GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, "A", + GroupByDesc desc = buildGroupByDescType (ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.doubleTypeInfo); CompilationOpContext cCtx = new CompilationOpContext(); @@ -2400,7 +2411,7 @@ public void testAggregateLongIterable ( mapColumnNames.add("A"); VectorizationContext ctx = new VectorizationContext("name", mapColumnNames); - GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", TypeInfoFactory.longTypeInfo); + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, GenericUDAFEvaluator.Mode.PARTIAL1, "A", TypeInfoFactory.longTypeInfo); CompilationOpContext cCtx = new CompilationOpContext(); VectorGroupByOperator vgo = new VectorGroupByOperator(cCtx, ctx, desc); diff --git ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java index 9d4ca76..3295372 100644 --- ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java +++ ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.*; +import org.apache.hadoop.hive.ql.plan.VectorGroupByDesc.ProcessingMode; import org.apache.hadoop.hive.ql.udf.generic.*; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; @@ -107,6 +108,7 @@ public void testAggregateOnUDF() throws HiveException { GroupByOperator gbyOp = new GroupByOperator(new CompilationOpContext()); gbyOp.setConf(desc); + desc.setMode(GroupByDesc.Mode.HASH); Vectorizer v = new Vectorizer(); Assert.assertTrue(v.validateMapWorkOperator(gbyOp, null, false)); @@ -148,9 +150,9 @@ public void testValidateNestedExpressions() { Assert.assertFalse(v.validateExprNodeDesc(andExprDesc, VectorExpressionDescriptor.Mode.FILTER)); Assert.assertFalse(v.validateExprNodeDesc(andExprDesc, VectorExpressionDescriptor.Mode.PROJECTION)); } - + /** - * prepareAbstractMapJoin prepares a join operator descriptor, used as helper by SMB and Map join tests. + * prepareAbstractMapJoin prepares a join operator descriptor, used as helper by SMB and Map join tests. */ private void prepareAbstractMapJoin(AbstractMapJoinOperator map, MapJoinDesc mjdesc) { mjdesc.setPosBigTable(0); @@ -189,15 +191,15 @@ private void prepareAbstractMapJoin(AbstractMapJoinOperator