diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt index fc3d01f..cf5cc69 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt @@ -56,20 +56,9 @@ public class extends VectorAggregateExpression { transient private double sum; transient private long count; - /** - * Value is explicitly (re)initialized in reset() - */ - transient private boolean isNull = true; - public void avgValue( value) { - if (isNull) { - sum = value; - count = 1; - isNull = false; - } else { - sum += value; - count++; - } + sum += value; + count++; } @Override @@ -79,7 +68,6 @@ public class extends VectorAggregateExpression { @Override public void reset () { - isNull = true; sum = 0; count = 0L; } @@ -151,15 +139,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -219,28 +201,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.avgValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, @@ -321,11 +281,6 @@ public class extends VectorAggregateExpression { if (inputVector.isRepeating) { if (inputVector.noNulls || !inputVector.isNull[0]) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } myagg.sum += vector[0]*batchSize; myagg.count += batchSize; } @@ -353,14 +308,8 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { - value = vector[i]; - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } - myagg.sum += value; - myagg.count += 1; + myagg.sum += vector[i]; + myagg.count++; } } } @@ -371,16 +320,9 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } - for (int i=0; i< batchSize; ++i) { - value = vector[selected[i]]; - myagg.sum += value; - myagg.count += 1; + myagg.sum += vector[selected[i]]; + myagg.count++; } } @@ -392,13 +334,7 @@ public class extends VectorAggregateExpression { for(int i=0;i value = vector[i]; - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } - myagg.sum += value; + myagg.sum += vector[i]; myagg.count += 1; } } @@ -408,15 +344,9 @@ public class extends VectorAggregateExpression { Aggregation myagg, [] vector, int batchSize) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } for (int i=0;i value = vector[i]; - myagg.sum += value; + myagg.sum += vector[i]; myagg.count += 1; } } @@ -483,15 +413,11 @@ public class extends VectorAggregateExpression { #ENDIF COMPLETE Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } - Preconditions.checkState(myagg.count > 0); - outputColVector.isNull[batchIndex] = false; #IF PARTIAL1 + // For AVG, we do not mark NULL if all inputs were NULL. + outputColVector.isNull[batchIndex] = false; + ColumnVector[] fields = outputColVector.fields; fields[AVERAGE_COUNT_FIELD_INDEX].isNull[batchIndex] = false; ((LongColumnVector) fields[AVERAGE_COUNT_FIELD_INDEX]).vector[batchIndex] = myagg.count; @@ -506,6 +432,12 @@ public class extends VectorAggregateExpression { #ENDIF PARTIAL1 #IF COMPLETE + if (myagg.count == 0) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } + outputColVector.isNull[batchIndex] = false; outputColVector.vector[batchIndex] = myagg.sum / myagg.count; #ENDIF COMPLETE } diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt index f512639..3caeecd 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt @@ -61,23 +61,11 @@ public class extends VectorAggregateExpression { transient private final HiveDecimalWritable sum = new HiveDecimalWritable(); transient private long count; - transient private boolean isNull; public void avgValue(HiveDecimalWritable writable) { - if (isNull) { - // Make a copy since we intend to mutate sum. - sum.set(writable); - count = 1; - isNull = false; - } else { - // Note that if sum is out of range, mutateAdd will ignore the call. - // At the end, sum.isSet() can be checked for null. - sum.mutateAdd(writable); - count++; - } - } - public void avgValueNoNullCheck(HiveDecimalWritable writable) { + // Note that if sum is out of range, mutateAdd will ignore the call. + // At the end, sum.isSet() can be checked for null. sum.mutateAdd(writable); count++; } @@ -89,7 +77,6 @@ public class extends VectorAggregateExpression { @Override public void reset() { - isNull = true; sum.setFromLong(0L); count = 0; } @@ -189,15 +176,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -257,28 +238,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - HiveDecimalWritable value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.avgValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, @@ -360,11 +319,6 @@ public class extends VectorAggregateExpression { if (inputVector.isRepeating) { if (inputVector.noNulls || !inputVector.isNull[0]) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum.setFromLong(0L); - myagg.count = 0; - } HiveDecimal value = vector[0].getHiveDecimal(); HiveDecimal multiple = value.multiply(HiveDecimal.create(batchSize)); myagg.sum.mutateAdd(multiple); @@ -408,14 +362,8 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum.setFromLong(0L); - myagg.count = 0; - } - for (int i=0; i< batchSize; ++i) { - myagg.avgValueNoNullCheck(vector[selected[i]]); + myagg.avgValue(vector[selected[i]]); } } @@ -436,14 +384,9 @@ public class extends VectorAggregateExpression { Aggregation myagg, HiveDecimalWritable[] vector, int batchSize) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum.setFromLong(0L); - myagg.count = 0; - } for (int i=0;i extends VectorAggregateExpression { #ENDIF COMPLETE Aggregation myagg = (Aggregation) agg; - if (myagg.isNull || !myagg.sum.isSet()) { + + // For AVG, we only mark NULL on actual overflow. + if (!myagg.sum.isSet()) { outputColVector.noNulls = false; outputColVector.isNull[batchIndex] = true; return; } - Preconditions.checkState(myagg.count > 0); + outputColVector.isNull[batchIndex] = false; #IF PARTIAL1 @@ -532,6 +477,12 @@ public class extends VectorAggregateExpression { #ENDIF PARTIAL1 #IF COMPLETE + // For AVG, we mark NULL on count 0 or on overflow. + if (myagg.count == 0 || !myagg.sum.isSet()) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } tempDecWritable.setFromLong (myagg.count); HiveDecimalWritable result = outputColVector.vector[batchIndex]; result.set(myagg.sum); diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal64ToDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal64ToDecimal.txt index 53dceeb..39e0562 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal64ToDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal64ToDecimal.txt @@ -81,7 +81,6 @@ public class extends VectorAggregateExpression { /** * Value is explicitly (re)initialized in reset() */ - private boolean isNull = true; private boolean usingRegularDecimal = false; public Aggregation(int inputScale, HiveDecimalWritable temp) { @@ -90,26 +89,21 @@ public class extends VectorAggregateExpression { } public void avgValue(long value) { - if (isNull) { - sum = value; - count = 1; - isNull = false; - } else { - if (Math.abs(sum) > nearDecimal64Max) { - if (!usingRegularDecimal) { - usingRegularDecimal = true; - regularDecimalSum.deserialize64(sum, inputScale); - } else { - temp.deserialize64(sum, inputScale); - regularDecimalSum.mutateAdd(temp); - } - sum = value; + + if (Math.abs(sum) > nearDecimal64Max) { + if (!usingRegularDecimal) { + usingRegularDecimal = true; + regularDecimalSum.deserialize64(sum, inputScale); } else { - sum += value; + temp.deserialize64(sum, inputScale); + regularDecimalSum.mutateAdd(temp); } - - count++; + sum = value; + } else { + sum += value; } + + count++; } @Override @@ -119,7 +113,6 @@ public class extends VectorAggregateExpression { @Override public void reset () { - isNull = true; usingRegularDecimal = false; sum = 0; regularDecimalSum.setFromLong(0); @@ -202,15 +195,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -270,28 +257,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - long value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.avgValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, @@ -502,8 +467,10 @@ public class extends VectorAggregateExpression { #ENDIF COMPLETE Aggregation myagg = (Aggregation) agg; - final boolean isNull; - if (!myagg.isNull) { + +#IF PARTIAL1 + if (myagg.count > 0) { + if (!myagg.usingRegularDecimal) { myagg.regularDecimalSum.deserialize64(myagg.sum, inputScale); } else { @@ -511,19 +478,15 @@ public class extends VectorAggregateExpression { myagg.regularDecimalSum.mutateAdd(myagg.temp); } - isNull = !myagg.regularDecimalSum.isSet(); - } else { - isNull = true; - } - if (isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; + // For AVG, we only mark NULL on actual overflow. + if (!myagg.regularDecimalSum.isSet()) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } } - Preconditions.checkState(myagg.count > 0); - outputColVector.isNull[batchIndex] = false; -#IF PARTIAL1 + outputColVector.isNull[batchIndex] = false; ColumnVector[] fields = outputColVector.fields; fields[AVERAGE_COUNT_FIELD_INDEX].isNull[batchIndex] = false; ((LongColumnVector) fields[AVERAGE_COUNT_FIELD_INDEX]).vector[batchIndex] = myagg.count; @@ -539,6 +502,27 @@ public class extends VectorAggregateExpression { #ENDIF PARTIAL1 #IF COMPLETE + final boolean isNull; + if (myagg.count > 0) { + if (!myagg.usingRegularDecimal) { + myagg.regularDecimalSum.deserialize64(myagg.sum, inputScale); + } else { + myagg.temp.deserialize64(myagg.sum, inputScale); + myagg.regularDecimalSum.mutateAdd(myagg.temp); + } + + isNull = !myagg.regularDecimalSum.isSet(); + } else { + isNull = true; + } + if (isNull) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } + Preconditions.checkState(myagg.count > 0); + outputColVector.isNull[batchIndex] = false; + temp.setFromLong (myagg.count); HiveDecimalWritable result = outputColVector.vector[batchIndex]; result.set(myagg.regularDecimalSum); diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt index 5fe9256..3691c05 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt @@ -188,15 +188,9 @@ public class extends VectorAggregateExpression { } } else { if (inputStructColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], batchSize, batch.selected, inputStructColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], batchSize, inputStructColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + countVector[0], sumVector[0], batchSize, inputStructColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -260,29 +254,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - long count, - HiveDecimalWritable sum, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.merge(count, sum); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgMerge.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgMerge.txt index 162d1ba..2e93efd 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgMerge.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgMerge.txt @@ -154,15 +154,9 @@ public class extends VectorAggregateExpression { } } else { if (inputStructColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], batchSize, batch.selected, inputStructColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], batchSize, inputStructColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + countVector[0], sumVector[0], batchSize, inputStructColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -226,29 +220,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - long count, - double sum, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.merge(count, sum); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgTimestamp.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgTimestamp.txt index 810f31f..358d108 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgTimestamp.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgTimestamp.txt @@ -59,20 +59,9 @@ public class extends VectorAggregateExpression { transient private double sum; transient private long count; - /** - * Value is explicitly (re)initialized in reset() - */ - transient private boolean isNull = true; - - public void sumValue(double value) { - if (isNull) { - sum = value; - count = 1; - isNull = false; - } else { - sum += value; - count++; - } + public void avgValue(double value) { + sum += value; + count++; } @Override @@ -82,7 +71,6 @@ public class extends VectorAggregateExpression { @Override public void reset() { - isNull = true; sum = 0; count = 0L; } @@ -153,15 +141,9 @@ public class extends VectorAggregateExpression { } } else { if (inputColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - inputColVector.getDouble(0), batchSize, batch.selected, inputColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - inputColVector.getDouble(0), batchSize, inputColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + inputColVector.getDouble(0), batchSize, inputColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -187,7 +169,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, bufferIndex, i); - myagg.sumValue(value); + myagg.avgValue(value); } } @@ -203,7 +185,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, bufferIndex, i); - myagg.sumValue( + myagg.avgValue( inputColVector.getDouble(selection[i])); } } @@ -218,45 +200,27 @@ public class extends VectorAggregateExpression { aggregationBufferSets, bufferIndex, i); - myagg.sumValue(inputColVector.getDouble(i)); + myagg.avgValue(inputColVector.getDouble(i)); } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, double value, int batchSize, - int[] selection, boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[selection[i]]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.sumValue(value); - } + if (isNull[0]) { + return; } - } - - private void iterateHasNullsRepeatingWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - double value, - int batchSize, - boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[i]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.sumValue(value); - } + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.avgValue(value); } } @@ -275,7 +239,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, bufferIndex, j); - myagg.sumValue(inputColVector.getDouble(i)); + myagg.avgValue(inputColVector.getDouble(i)); } } } @@ -293,7 +257,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, bufferIndex, i); - myagg.sumValue(inputColVector.getDouble(i)); + myagg.avgValue(inputColVector.getDouble(i)); } } } @@ -318,11 +282,6 @@ public class extends VectorAggregateExpression { if (inputColVector.isRepeating) { if (inputColVector.noNulls || !inputColVector.isNull[0]) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } myagg.sum += inputColVector.getDouble(0)*batchSize; myagg.count += batchSize; } @@ -353,13 +312,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { - double value = inputColVector.getDouble(i); - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } - myagg.sum += value; + myagg.sum += inputColVector.getDouble(i); myagg.count += 1; } } @@ -371,15 +324,8 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } - for (int i=0; i< batchSize; ++i) { - double value = inputColVector.getDouble(selected[i]); - myagg.sum += value; + myagg.sum += inputColVector.getDouble(selected[i]); myagg.count += 1; } } @@ -392,13 +338,7 @@ public class extends VectorAggregateExpression { for(int i=0;i extends VectorAggregateExpression { Aggregation myagg, TimestampColumnVector inputColVector, int batchSize) { - if (myagg.isNull) { - myagg.isNull = false; - myagg.sum = 0; - myagg.count = 0; - } for (int i=0;i extends VectorAggregateExpression { public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int columnNum, AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + #IF PARTIAL1 StructColumnVector outputColVector = (StructColumnVector) batch.cols[columnNum]; -#ENDIF PARTIAL1 -#IF COMPLETE - DoubleColumnVector outputColVector = (DoubleColumnVector) batch.cols[columnNum]; -#ENDIF COMPLETE - Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } - Preconditions.checkState(myagg.count > 0); + // For AVG, we do not mark NULL if all inputs were NULL. outputColVector.isNull[batchIndex] = false; -#IF PARTIAL1 ColumnVector[] fields = outputColVector.fields; fields[AVERAGE_COUNT_FIELD_INDEX].isNull[batchIndex] = false; ((LongColumnVector) fields[AVERAGE_COUNT_FIELD_INDEX]).vector[batchIndex] = myagg.count; @@ -506,6 +431,15 @@ public class extends VectorAggregateExpression { #ENDIF PARTIAL1 #IF COMPLETE + DoubleColumnVector outputColVector = (DoubleColumnVector) batch.cols[columnNum]; + + if (myagg.count == 0) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } + outputColVector.isNull[batchIndex] = false; + outputColVector.vector[batchIndex] = myagg.sum / myagg.count; #ENDIF COMPLETE } diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt index 2df45bb..3569d51 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt @@ -145,15 +145,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -213,28 +207,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregrateIndex, - value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - myagg.minmaxValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, @@ -363,7 +335,7 @@ public class extends VectorAggregateExpression { for (int i=0; i< batchSize; ++i) { value = vector[selected[i]]; - myagg.minmaxValueNoCheck(value); + myagg.minmaxValue(value); } } @@ -437,7 +409,7 @@ public class extends VectorAggregateExpression { outputColVector = () batch.cols[columnNum]; Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { + if (myagg.isNull) { outputColVector.noNulls = false; outputColVector.isNull[batchIndex] = true; return; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt index 9c8ebcc..eb63301 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt @@ -63,7 +63,7 @@ public class extends VectorAggregateExpression { value = new HiveDecimalWritable(); } - public void checkValue(HiveDecimalWritable writable, short scale) { + public void minmaxValue(HiveDecimalWritable writable, short scale) { if (isNull) { isNull = false; this.value.set(writable); @@ -144,15 +144,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - vector[0], inputVector.scale, batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - vector[0], inputVector.scale, batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], inputVector.scale, batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -179,14 +173,14 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(value, scale); + myagg.minmaxValue(value, scale); } } private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, - HiveDecimalWritable[] values, + HiveDecimalWritable[] vector, short scale, int[] selection, int batchSize) { @@ -196,14 +190,14 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[selection[i]], scale); + myagg.minmaxValue(vector[selection[i]], scale); } } private void iterateNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, - HiveDecimalWritable[] values, + HiveDecimalWritable[] vector, short scale, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -211,31 +205,10 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[i], scale); + myagg.minmaxValue(vector[i], scale); } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregrateIndex, - HiveDecimalWritable value, - short scale, - int batchSize, - int[] selection, - boolean[] isNull) { - - for (int i=0; i < batchSize; ++i) { - if (!isNull[selection[i]]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - myagg.checkValue(value, scale); - } - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, @@ -253,14 +226,14 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(value, scale); + myagg.minmaxValue(value, scale); } } private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, - HiveDecimalWritable[] values, + HiveDecimalWritable[] vector, short scale, int batchSize, int[] selection, @@ -273,7 +246,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, j); - myagg.checkValue(values[i], scale); + myagg.minmaxValue(vector[i], scale); } } } @@ -281,7 +254,7 @@ public class extends VectorAggregateExpression { private void iterateHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, - HiveDecimalWritable[] values, + HiveDecimalWritable[] vector, short scale, int batchSize, boolean[] isNull) { @@ -292,7 +265,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[i], scale); + myagg.minmaxValue(vector[i], scale); } } } @@ -318,10 +291,8 @@ public class extends VectorAggregateExpression { HiveDecimalWritable[] vector = inputVector.vector; if (inputVector.isRepeating) { - if ((inputVector.noNulls || !inputVector.isNull[0]) && - (myagg.isNull || (myagg.value.compareTo(vector[0]) 0))) { - myagg.isNull = false; - myagg.value.set(vector[0]); + if (inputVector.noNulls || !inputVector.isNull[0]) { + myagg.minmaxValue(vector[0], inputVector.scale); } return; } @@ -353,14 +324,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { - HiveDecimalWritable writable = vector[i]; - if (myagg.isNull) { - myagg.isNull = false; - myagg.value.set(writable); - } - else if (myagg.value.compareTo(writable) 0) { - myagg.value.set(writable); - } + myagg.minmaxValue(vector[i], scale); } } } @@ -372,16 +336,8 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - myagg.value.set(vector[selected[0]]); - myagg.isNull = false; - } - for (int i=0; i< batchSize; ++i) { - HiveDecimalWritable writable = vector[selected[i]]; - if (myagg.value.compareTo(writable) 0) { - myagg.value.set(writable); - } + myagg.minmaxValue(vector[selected[i]], scale); } } @@ -394,14 +350,7 @@ public class extends VectorAggregateExpression { for(int i=0;i 0) { - myagg.value.set(writable); - } + myagg.minmaxValue(vector[i], scale); } } } @@ -411,16 +360,9 @@ public class extends VectorAggregateExpression { HiveDecimalWritable[] vector, short scale, int batchSize) { - if (myagg.isNull) { - myagg.value.set(vector[0]); - myagg.isNull = false; - } for (int i=0;i 0) { - myagg.value.set(writable); - } + myagg.minmaxValue(vector[i], scale); } } diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxIntervalDayTime.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxIntervalDayTime.txt index 9a0a6e7..9fdf77c 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxIntervalDayTime.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxIntervalDayTime.txt @@ -62,7 +62,7 @@ public class extends VectorAggregateExpression { value = new HiveIntervalDayTime(); } - public void checkValue(IntervalDayTimeColumnVector colVector, int index) { + public void minmaxValue(IntervalDayTimeColumnVector colVector, int index) { if (isNull) { isNull = false; colVector.intervalDayTimeUpdate(this.value, index); @@ -141,15 +141,9 @@ public class extends VectorAggregateExpression { } } else { if (inputColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - inputColVector, batchSize, batch.selected, inputColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - inputColVector, batchSize, inputColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + inputColVector, batchSize, inputColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -176,7 +170,7 @@ public class extends VectorAggregateExpression { aggregrateIndex, i); // Repeating use index 0. - myagg.checkValue(inputColVector, 0); + myagg.minmaxValue(inputColVector, 0); } } @@ -192,7 +186,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, selection[i]); + myagg.minmaxValue(inputColVector, selection[i]); } } @@ -206,47 +200,28 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, IntervalDayTimeColumnVector inputColVector, int batchSize, - int[] selection, boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[selection[i]]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - // Repeating use index 0. - myagg.checkValue(inputColVector, 0); - } + if (isNull[0]) { + return; } - } - - private void iterateHasNullsRepeatingWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregrateIndex, - IntervalDayTimeColumnVector inputColVector, - int batchSize, - boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[i]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - // Repeating use index 0. - myagg.checkValue(inputColVector, 0); - } + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + // Repeating use index 0. + myagg.minmaxValue(inputColVector, 0); } } @@ -265,7 +240,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, j); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } } @@ -283,7 +258,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } } @@ -307,10 +282,8 @@ public class extends VectorAggregateExpression { Aggregation myagg = (Aggregation)agg; if (inputColVector.isRepeating) { - if ((inputColVector.noNulls || !inputColVector.isNull[0]) && - (myagg.isNull || (inputColVector.compareTo(myagg.value, 0) 0))) { - myagg.isNull = false; - inputColVector.intervalDayTimeUpdate(myagg.value, 0); + if (inputColVector.noNulls || !inputColVector.isNull[0]) { + myagg.minmaxValue(inputColVector, 0); } return; } @@ -341,13 +314,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { - if (myagg.isNull) { - myagg.isNull = false; - inputColVector.intervalDayTimeUpdate(myagg.value, i); - } - else if (inputColVector.compareTo(myagg.value, i) 0) { - inputColVector.intervalDayTimeUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } } @@ -358,16 +325,9 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - inputColVector.intervalDayTimeUpdate(myagg.value, selected[0]); - myagg.isNull = false; - } - for (int i=0; i< batchSize; ++i) { int sel = selected[i]; - if (inputColVector.compareTo(myagg.value, sel) 0) { - inputColVector.intervalDayTimeUpdate(myagg.value, sel); - } + myagg.minmaxValue(inputColVector, sel); } } @@ -379,13 +339,7 @@ public class extends VectorAggregateExpression { for(int i=0;i 0) { - inputColVector.intervalDayTimeUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } } @@ -394,15 +348,9 @@ public class extends VectorAggregateExpression { Aggregation myagg, IntervalDayTimeColumnVector inputColVector, int batchSize) { - if (myagg.isNull) { - inputColVector.intervalDayTimeUpdate(myagg.value, 0); - myagg.isNull = false; - } for (int i=0;i 0) { - inputColVector.intervalDayTimeUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } @@ -447,7 +395,7 @@ public class extends VectorAggregateExpression { IntervalDayTimeColumnVector outputColVector = (IntervalDayTimeColumnVector) batch.cols[columnNum]; Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { + if (myagg.isNull) { outputColVector.noNulls = false; outputColVector.isNull[batchIndex] = true; return; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt index 4f0b5a5..3387c0d 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt @@ -60,7 +60,7 @@ public class extends VectorAggregateExpression { */ transient private boolean isNull = true; - public void checkValue(byte[] bytes, int start, int length) { + public void minmaxValue(byte[] bytes, int start, int length) { if (isNull) { isNull = false; assign(bytes, start, length); @@ -151,7 +151,9 @@ public class extends VectorAggregateExpression { } } else { if (inputColumn.isRepeating) { - // All nulls, no-op for min/max + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + inputColumn, batchSize, inputColumn.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -180,7 +182,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(bytes, start, length); + myagg.minmaxValue(bytes, start, length); } } @@ -197,7 +199,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColumn.vector[row], + myagg.minmaxValue(inputColumn.vector[row], inputColumn.start[row], inputColumn.length[row]); } @@ -213,12 +215,36 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } } + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + BytesColumnVector inputColumn, + int batchSize, + boolean[] isNull) { + + if (isNull[0]) { + return; + } + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + // Repeating use index 0. + myagg.minmaxValue(inputColumn.vector[0], + inputColumn.start[0], + inputColumn.length[0]); + } + + } + private void iterateHasNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, @@ -233,7 +259,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColumn.vector[row], + myagg.minmaxValue(inputColumn.vector[row], inputColumn.start[row], inputColumn.length[row]); } @@ -252,7 +278,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } @@ -279,7 +305,7 @@ public class extends VectorAggregateExpression { if (inputColumn.isRepeating) { if (inputColumn.noNulls || !inputColumn.isNull[0]) { - myagg.checkValue(inputColumn.vector[0], + myagg.minmaxValue(inputColumn.vector[0], inputColumn.start[0], inputColumn.length[0]); } @@ -309,7 +335,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!inputColumn.isNull[i]) { - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } @@ -324,7 +350,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } @@ -337,7 +363,7 @@ public class extends VectorAggregateExpression { for (int i=0; i< batchSize; ++i) { if (!inputColumn.isNull[i]) { - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } @@ -349,7 +375,7 @@ public class extends VectorAggregateExpression { BytesColumnVector inputColumn, int batchSize) { for (int i=0; i< batchSize; ++i) { - myagg.checkValue(inputColumn.vector[i], + myagg.minmaxValue(inputColumn.vector[i], inputColumn.start[i], inputColumn.length[i]); } diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxTimestamp.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxTimestamp.txt index 5114cda..b8d71d6 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxTimestamp.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxTimestamp.txt @@ -64,7 +64,7 @@ public class extends VectorAggregateExpression { value = new Timestamp(0); } - public void checkValue(TimestampColumnVector colVector, int index) { + public void minmaxValue(TimestampColumnVector colVector, int index) { if (isNull) { isNull = false; colVector.timestampUpdate(this.value, index); @@ -143,15 +143,9 @@ public class extends VectorAggregateExpression { } } else { if (inputColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - inputColVector, batchSize, batch.selected, inputColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregrateIndex, - inputColVector, batchSize, inputColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + inputColVector, batchSize, inputColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -178,7 +172,7 @@ public class extends VectorAggregateExpression { aggregrateIndex, i); // Repeating use index 0. - myagg.checkValue(inputColVector, 0); + myagg.minmaxValue(inputColVector, 0); } } @@ -194,7 +188,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, selection[i]); + myagg.minmaxValue(inputColVector, selection[i]); } } @@ -208,47 +202,28 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, TimestampColumnVector inputColVector, int batchSize, - int[] selection, boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[selection[i]]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - // Repeating use index 0. - myagg.checkValue(inputColVector, 0); - } + if (isNull[0]) { + return; } - } - - private void iterateHasNullsRepeatingWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregrateIndex, - TimestampColumnVector inputColVector, - int batchSize, - boolean[] isNull) { - for (int i=0; i < batchSize; ++i) { - if (!isNull[i]) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregrateIndex, - i); - // Repeating use index 0. - myagg.checkValue(inputColVector, 0); - } + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + // Repeating use index 0. + myagg.minmaxValue(inputColVector, 0); } } @@ -267,7 +242,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, j); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } } @@ -285,7 +260,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(inputColVector, i); + myagg.minmaxValue(inputColVector, i); } } } @@ -309,10 +284,8 @@ public class extends VectorAggregateExpression { Aggregation myagg = (Aggregation)agg; if (inputColVector.isRepeating) { - if ((inputColVector.noNulls || !inputColVector.isNull[0]) && - (myagg.isNull || (inputColVector.compareTo(myagg.value, 0) 0))) { - myagg.isNull = false; - inputColVector.timestampUpdate(myagg.value, 0); + if (inputColVector.noNulls || !inputColVector.isNull[0]) { + myagg.minmaxValue(inputColVector, 0); } return; } @@ -343,13 +316,7 @@ public class extends VectorAggregateExpression { for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { - if (myagg.isNull) { - myagg.isNull = false; - inputColVector.timestampUpdate(myagg.value, i); - } - else if (inputColVector.compareTo(myagg.value, i) 0) { - inputColVector.timestampUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } } @@ -360,16 +327,9 @@ public class extends VectorAggregateExpression { int batchSize, int[] selected) { - if (myagg.isNull) { - inputColVector.timestampUpdate(myagg.value, selected[0]); - myagg.isNull = false; - } - - for (int i=0; i< batchSize; ++i) { + for (int i=0; i< batchSize; ++i) { int sel = selected[i]; - if (inputColVector.compareTo(myagg.value, sel) 0) { - inputColVector.timestampUpdate(myagg.value, sel); - } + myagg.minmaxValue(inputColVector, sel); } } @@ -381,13 +341,7 @@ public class extends VectorAggregateExpression { for(int i=0;i 0) { - inputColVector.timestampUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } } @@ -396,15 +350,9 @@ public class extends VectorAggregateExpression { Aggregation myagg, TimestampColumnVector inputColVector, int batchSize) { - if (myagg.isNull) { - inputColVector.timestampUpdate(myagg.value, 0); - myagg.isNull = false; - } for (int i=0;i 0) { - inputColVector.timestampUpdate(myagg.value, i); - } + myagg.minmaxValue(inputColVector, i); } } @@ -449,7 +397,7 @@ public class extends VectorAggregateExpression { TimestampColumnVector outputColVector = (TimestampColumnVector) batch.cols[columnNum]; Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { + if (myagg.isNull) { outputColVector.noNulls = false; outputColVector.isNull[batchIndex] = true; return; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt index c731869..548125e 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt @@ -142,15 +142,9 @@ public class extends VectorAggregateExpression { } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -210,28 +204,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.sumValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt index 876ead5..995190f 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt @@ -127,7 +127,7 @@ public class extends VectorAggregateExpression { private void init() { #IF COMPLETE - String aggregateName = vecAggrDesc.getAggrDesc().getGenericUDAFName(); + String aggregateName = vecAggrDesc.getAggregationName(); varianceKind = VarianceKind.nameMap.get(aggregateName); #ENDIF COMPLETE } @@ -490,11 +490,8 @@ public class extends VectorAggregateExpression { StructColumnVector outputColVector = (StructColumnVector) batch.cols[columnNum]; Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } + + // For Variance Family, we do not mark NULL if all inputs were NULL. outputColVector.isNull[batchIndex] = false; ColumnVector[] fields = outputColVector.fields; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt index cf19b14..a831610 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt @@ -138,7 +138,7 @@ public class extends VectorAggregateExpression { private void init() { #IF COMPLETE - String aggregateName = vecAggrDesc.getAggrDesc().getGenericUDAFName(); + String aggregateName = vecAggrDesc.getAggregationName(); varianceKind = VarianceKind.nameMap.get(aggregateName); #ENDIF COMPLETE } @@ -450,15 +450,12 @@ public class extends VectorAggregateExpression { public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int columnNum, AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + #IF PARTIAL1 StructColumnVector outputColVector = (StructColumnVector) batch.cols[columnNum]; - Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } + // For Variance Family, we do not mark NULL if all inputs were NULL. outputColVector.isNull[batchIndex] = false; ColumnVector[] fields = outputColVector.fields; @@ -469,7 +466,13 @@ public class extends VectorAggregateExpression { #IF COMPLETE DoubleColumnVector outputColVector = (DoubleColumnVector) batch.cols[columnNum]; - Aggregation myagg = (Aggregation) agg; + if (myagg.isNull) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } + outputColVector.isNull[batchIndex] = false; + if (GenericUDAFVariance.isVarianceNull(myagg.count, varianceKind)) { // SQL standard - return null for zero (or 1 for sample) elements diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarMerge.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarMerge.txt index ccc5a22..a80c261 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarMerge.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarMerge.txt @@ -60,27 +60,24 @@ public class extends VectorAggregateExpression { transient private double mergeSum; transient private double mergeVariance; - /** - * Value is explicitly (re)initialized in reset() - */ - transient private boolean isNull = true; - public void merge(long partialCount, double partialSum, double partialVariance) { - if (isNull || mergeCount == 0) { + if (mergeCount == 0) { // Just copy the information since there is nothing so far. mergeCount = partialCount; mergeSum = partialSum; mergeVariance = partialVariance; - isNull = false; + if (Double.toString(mergeVariance).equals("Infinity")) { + System.out.println("here"); + } return; } if (partialCount > 0 && mergeCount > 0) { // Merge the two partials. - mergeVariance += + mergeVariance = GenericUDAFVariance.calculateMerge( partialCount, mergeCount, partialSum, mergeSum, partialVariance, mergeVariance); @@ -98,7 +95,6 @@ public class extends VectorAggregateExpression { @Override public void reset () { - isNull = true; mergeCount = 0L; mergeSum = 0; mergeVariance = 0; @@ -127,7 +123,7 @@ public class extends VectorAggregateExpression { private void init() { #IF FINAL - String aggregateName = vecAggrDesc.getAggrDesc().getGenericUDAFName(); + String aggregateName = vecAggrDesc.getAggregationName(); varianceKind = VarianceKind.nameMap.get(aggregateName); #ENDIF FINAL } @@ -183,15 +179,9 @@ public class extends VectorAggregateExpression { } } else { if (inputStructColVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], varianceVector[0], batchSize, batch.selected, inputStructColVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, bufferIndex, - countVector[0], sumVector[0], varianceVector[0], batchSize, inputStructColVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + countVector[0], sumVector[0], varianceVector[0], batchSize, inputStructColVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -258,30 +248,6 @@ public class extends VectorAggregateExpression { } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int bufferIndex, - long count, - double sum, - double variance, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - bufferIndex, - i); - myagg.merge(count, sum, variance); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int bufferIndex, @@ -488,8 +454,6 @@ public class extends VectorAggregateExpression { #ENDIF FINAL */ -/* - There seems to be a Wrong Results bug in VectorUDAFVarFinal -- disabling vectorization for now... return GenericUDAFVariance.isVarianceFamilyName(name) && inputColVectorType == ColumnVector.Type.STRUCT && @@ -501,8 +465,6 @@ public class extends VectorAggregateExpression { outputColVectorType == ColumnVector.Type.DOUBLE && mode == Mode.FINAL; #ENDIF FINAL -*/ - return false; } @Override @@ -513,11 +475,8 @@ public class extends VectorAggregateExpression { StructColumnVector outputColVector = (StructColumnVector) batch.cols[columnNum]; Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } + + // For Variance Family, we do not mark NULL if all inputs were NULL. outputColVector.isNull[batchIndex] = false; ColumnVector[] fields = outputColVector.fields; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarTimestamp.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarTimestamp.txt index 1dd5ab4..4e79f22 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarTimestamp.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarTimestamp.txt @@ -128,7 +128,7 @@ public class extends VectorAggregateExpression { private void init() { #IF COMPLETE - String aggregateName = vecAggrDesc.getAggrDesc().getGenericUDAFName(); + String aggregateName = vecAggrDesc.getAggregationName(); varianceKind = VarianceKind.nameMap.get(aggregateName); #ENDIF COMPLETE } @@ -422,15 +422,12 @@ public class extends VectorAggregateExpression { public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int columnNum, AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + #IF PARTIAL1 StructColumnVector outputColVector = (StructColumnVector) batch.cols[columnNum]; - Aggregation myagg = (Aggregation) agg; - if (myagg.isNull) { - outputColVector.noNulls = false; - outputColVector.isNull[batchIndex] = true; - return; - } + // For Variance Family, we do not mark NULL if all inputs were NULL. outputColVector.isNull[batchIndex] = false; ColumnVector[] fields = outputColVector.fields; @@ -441,7 +438,13 @@ public class extends VectorAggregateExpression { #IF COMPLETE DoubleColumnVector outputColVector = (DoubleColumnVector) batch.cols[columnNum]; - Aggregation myagg = (Aggregation) agg; + if (myagg.isNull) { + outputColVector.noNulls = false; + outputColVector.isNull[batchIndex] = true; + return; + } + outputColVector.isNull[batchIndex] = false; + if (GenericUDAFVariance.isVarianceNull(myagg.count, varianceKind)) { // SQL standard - return null for zero (or 1 for sample) elements diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAggregationDesc.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAggregationDesc.java index 5736399..417beec 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAggregationDesc.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorAggregationDesc.java @@ -86,7 +86,7 @@ private static final long serialVersionUID = 1L; - private final AggregationDesc aggrDesc; + private final String aggregationName; private final TypeInfo inputTypeInfo; private final ColumnVector.Type inputColVectorType; @@ -99,15 +99,19 @@ private final Class vecAggrClass; private GenericUDAFEvaluator evaluator; + private GenericUDAFEvaluator.Mode udafEvaluatorMode; - public VectorAggregationDesc(AggregationDesc aggrDesc, GenericUDAFEvaluator evaluator, + public VectorAggregationDesc(String aggregationName, GenericUDAFEvaluator evaluator, + GenericUDAFEvaluator.Mode udafEvaluatorMode, TypeInfo inputTypeInfo, ColumnVector.Type inputColVectorType, VectorExpression inputExpression, TypeInfo outputTypeInfo, ColumnVector.Type outputColVectorType, Class vecAggrClass) { - this.aggrDesc = aggrDesc; + this.aggregationName = aggregationName; + this.evaluator = evaluator; + this.udafEvaluatorMode = udafEvaluatorMode; this.inputTypeInfo = inputTypeInfo; this.inputColVectorType = inputColVectorType; @@ -122,8 +126,12 @@ public VectorAggregationDesc(AggregationDesc aggrDesc, GenericUDAFEvaluator eval this.vecAggrClass = vecAggrClass; } - public AggregationDesc getAggrDesc() { - return aggrDesc; + public String getAggregationName() { + return aggregationName; + } + + public GenericUDAFEvaluator.Mode getUdafEvaluatorMode() { + return udafEvaluatorMode; } public TypeInfo getInputTypeInfo() { @@ -174,7 +182,6 @@ public String toString() { sb.append("/"); sb.append(outputDataTypePhysicalVariation); } - String aggregationName = aggrDesc.getGenericUDAFName(); if (GenericUDAFVariance.isVarianceFamilyName(aggregationName)) { sb.append(" aggregation: "); sb.append(aggregationName); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java index 3224557..2499f09 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorAggregateExpression.java @@ -84,7 +84,7 @@ public VectorAggregateExpression(VectorAggregationDesc vecAggrDesc) { outputTypeInfo = vecAggrDesc.getOutputTypeInfo(); outputDataTypePhysicalVariation = vecAggrDesc.getOutputDataTypePhysicalVariation(); - mode = vecAggrDesc.getAggrDesc().getMode(); + mode = vecAggrDesc.getUdafEvaluatorMode(); } public VectorExpression getInputExpression() { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountMerge.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountMerge.java index 0463de5..bd781af 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountMerge.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFCountMerge.java @@ -117,15 +117,9 @@ public void aggregateInputSelection( } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -185,28 +179,6 @@ private void iterateNoNullsWithAggregationSelection( } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - long value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.value += value; - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java index 315b72b..469f610 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java @@ -139,17 +139,10 @@ public void aggregateInputSelection( } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], - batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], - batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], + batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -211,28 +204,6 @@ private void iterateNoNullsWithAggregationSelection( } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - HiveDecimalWritable value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.sumValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java index a503445..7f2a18a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64.java @@ -164,15 +164,9 @@ public void aggregateInputSelection( } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -232,28 +226,6 @@ private void iterateNoNullsWithAggregationSelection( } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - long value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.sumValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java index 117611e..a02bdf3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java @@ -189,15 +189,9 @@ public void aggregateInputSelection( } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - vector[0], batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -257,28 +251,6 @@ private void iterateNoNullsWithAggregationSelection( } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - long value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.sumValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java index e542033..731a143 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumTimestamp.java @@ -131,15 +131,9 @@ public void aggregateInputSelection( } } else { if (inputVector.isRepeating) { - if (batch.selectedInUse) { - iterateHasNullsRepeatingSelectionWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - inputVector.getDouble(0), batchSize, batch.selected, inputVector.isNull); - } else { - iterateHasNullsRepeatingWithAggregationSelection( - aggregationBufferSets, aggregateIndex, - inputVector.getDouble(0), batchSize, inputVector.isNull); - } + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + inputVector.getDouble(0), batchSize, inputVector.isNull); } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( @@ -199,28 +193,6 @@ private void iterateNoNullsWithAggregationSelection( } } - private void iterateHasNullsRepeatingSelectionWithAggregationSelection( - VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - double value, - int batchSize, - int[] selection, - boolean[] isNull) { - - if (isNull[0]) { - return; - } - - for (int i=0; i < batchSize; ++i) { - Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, - aggregateIndex, - i); - myagg.sumValue(value); - } - - } - private void iterateHasNullsRepeatingWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index 7afbf04..7ec80e6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -4183,7 +4183,7 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { AggregationDesc aggrDesc, VectorizationContext vContext) throws HiveException { String aggregateName = aggrDesc.getGenericUDAFName(); - ArrayList parameterList = aggrDesc.getParameters(); + List parameterList = aggrDesc.getParameters(); final int parameterCount = parameterList.size(); final GenericUDAFEvaluator.Mode udafEvaluatorMode = aggrDesc.getMode(); @@ -4192,10 +4192,9 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { */ GenericUDAFEvaluator evaluator = aggrDesc.getGenericUDAFEvaluator(); - ArrayList parameters = aggrDesc.getParameters(); ObjectInspector[] parameterObjectInspectors = new ObjectInspector[parameterCount]; for (int i = 0; i < parameterCount; i++) { - TypeInfo typeInfo = parameters.get(i).getTypeInfo(); + TypeInfo typeInfo = parameterList.get(i).getTypeInfo(); parameterObjectInspectors[i] = TypeInfoUtils .getStandardWritableObjectInspectorFromTypeInfo(typeInfo); } @@ -4207,18 +4206,30 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { aggrDesc.getMode(), parameterObjectInspectors); + final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName()); + + return getVectorAggregationDesc( + aggregateName, parameterList, evaluator, outputTypeInfo, udafEvaluatorMode, vContext); + } + + public static ImmutablePair getVectorAggregationDesc( + String aggregationName, List parameterList, + GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo, + GenericUDAFEvaluator.Mode udafEvaluatorMode, + VectorizationContext vContext) + throws HiveException { + VectorizedUDAFs annotation = AnnotationUtils.getAnnotation(evaluator.getClass(), VectorizedUDAFs.class); if (annotation == null) { String issue = "Evaluator " + evaluator.getClass().getSimpleName() + " does not have a " + - "vectorized UDAF annotation (aggregation: \"" + aggregateName + "\"). " + + "vectorized UDAF annotation (aggregation: \"" + aggregationName + "\"). " + "Vectorization not supported"; return new ImmutablePair(null, issue); } final Class[] vecAggrClasses = annotation.value(); - final TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(returnOI.getTypeName()); // Not final since it may change later due to DECIMAL_64. ColumnVector.Type outputColVectorType = @@ -4233,6 +4244,7 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { VectorExpression inputExpression; ColumnVector.Type inputColVectorType; + final int parameterCount = parameterList.size(); if (parameterCount == 0) { // COUNT(*) @@ -4246,7 +4258,7 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { inputTypeInfo = exprNodeDesc.getTypeInfo(); if (inputTypeInfo == null) { String issue ="Aggregations with null parameter type not supported " + - aggregateName + "(" + parameterList.toString() + ")"; + aggregationName + "(" + parameterList.toString() + ")"; return new ImmutablePair(null, issue); } @@ -4260,12 +4272,12 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { exprNodeDesc, VectorExpressionDescriptor.Mode.PROJECTION); if (inputExpression == null) { String issue ="Parameter expression " + exprNodeDesc.toString() + " not supported " + - aggregateName + "(" + parameterList.toString() + ")"; + aggregationName + "(" + parameterList.toString() + ")"; return new ImmutablePair(null, issue); } if (inputExpression.getOutputTypeInfo() == null) { String issue ="Parameter expression " + exprNodeDesc.toString() + " with null type not supported " + - aggregateName + "(" + parameterList.toString() + ")"; + aggregationName + "(" + parameterList.toString() + ")"; return new ImmutablePair(null, issue); } inputColVectorType = inputExpression.getOutputColumnVectorType(); @@ -4273,7 +4285,7 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { // No multi-parameter aggregations supported. String issue ="Aggregations with > 1 parameter are not supported " + - aggregateName + "(" + parameterList.toString() + ")"; + aggregationName + "(" + parameterList.toString() + ")"; return new ImmutablePair(null, issue); } @@ -4291,12 +4303,13 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { // Try with DECIMAL_64 input and DECIMAL_64 output. final Class vecAggrClass = findVecAggrClass( - vecAggrClasses, aggregateName, inputColVectorType, + vecAggrClasses, aggregationName, inputColVectorType, ColumnVector.Type.DECIMAL_64, udafEvaluatorMode); if (vecAggrClass != null) { final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc( - aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, + aggregationName, evaluator, udafEvaluatorMode, + inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, ColumnVector.Type.DECIMAL_64, vecAggrClass); return new ImmutablePair(vecAggrDesc, null); } @@ -4305,12 +4318,13 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { // Try with regular DECIMAL output type. final Class vecAggrClass = findVecAggrClass( - vecAggrClasses, aggregateName, inputColVectorType, + vecAggrClasses, aggregationName, inputColVectorType, outputColVectorType, udafEvaluatorMode); if (vecAggrClass != null) { final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc( - aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, + aggregationName, evaluator, udafEvaluatorMode, + inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass); return new ImmutablePair(vecAggrDesc, null); } @@ -4325,19 +4339,20 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { // Try with with DECIMAL_64 input and desired output type. final Class vecAggrClass = findVecAggrClass( - vecAggrClasses, aggregateName, inputColVectorType, + vecAggrClasses, aggregationName, inputColVectorType, outputColVectorType, udafEvaluatorMode); if (vecAggrClass != null) { // for now, disable operating on decimal64 column vectors for semijoin reduction as // we have to make sure same decimal type should be used during bloom filter creation // and bloom filter probing - if (aggregateName.equals("bloom_filter")) { + if (aggregationName.equals("bloom_filter")) { inputExpression = vContext.wrapWithDecimal64ToDecimalConversion(inputExpression); inputColVectorType = ColumnVector.Type.DECIMAL; } final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc( - aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, + aggregationName, evaluator, udafEvaluatorMode, + inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass); return new ImmutablePair(vecAggrDesc, null); } @@ -4355,19 +4370,20 @@ private boolean usesVectorUDFAdaptor(VectorExpression[] vecExprs) { */ Class vecAggrClass = findVecAggrClass( - vecAggrClasses, aggregateName, inputColVectorType, + vecAggrClasses, aggregationName, inputColVectorType, outputColVectorType, udafEvaluatorMode); if (vecAggrClass != null) { final VectorAggregationDesc vecAggrDesc = new VectorAggregationDesc( - aggrDesc, evaluator, inputTypeInfo, inputColVectorType, inputExpression, + aggregationName, evaluator, udafEvaluatorMode, + inputTypeInfo, inputColVectorType, inputExpression, outputTypeInfo, outputColVectorType, vecAggrClass); return new ImmutablePair(vecAggrDesc, null); } // No match? String issue = - "Vector aggregation : \"" + aggregateName + "\" " + + "Vector aggregation : \"" + aggregationName + "\" " + "for input type: " + (inputColVectorType == null ? "any" : "\"" + inputColVectorType) + "\" " + "and output type: \"" + outputColVectorType + "\" " + diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java index d170d86..5cb7061 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java @@ -56,6 +56,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.util.StringUtils; @@ -250,6 +251,25 @@ protected BasePartitionEvaluator createPartitionEvaluator( VectorUDAFAvgDecimalPartial2.class, VectorUDAFAvgDecimalFinal.class}) public static class GenericUDAFAverageEvaluatorDecimal extends AbstractGenericUDAFAverageEvaluator { + private int resultPrecision = -1; + private int resultScale = -1; + + @Override + public ObjectInspector init(Mode m, ObjectInspector[] parameters) + throws HiveException { + + // Intercept result ObjectInspector so we can extract the DECIMAL precision and scale. + ObjectInspector resultOI = super.init(m, parameters); + if (m == Mode.COMPLETE || m == Mode.FINAL) { + DecimalTypeInfo decimalTypeInfo = + (DecimalTypeInfo) + TypeInfoUtils.getTypeInfoFromObjectInspector(resultOI); + resultPrecision = decimalTypeInfo.getPrecision(); + resultScale = decimalTypeInfo.getScale(); + } + return resultOI; + } + @Override public void doReset(AverageAggregationBuffer aggregation) throws HiveException { aggregation.count = 0; @@ -336,6 +356,7 @@ protected Object doTerminate(AverageAggregationBuffer aggregation) } else { HiveDecimalWritable result = new HiveDecimalWritable(HiveDecimal.ZERO); result.set(aggregation.sum.divide(HiveDecimal.create(aggregation.count))); + result.mutateEnforcePrecisionScale(resultPrecision, resultScale); return result; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java index c9fb3df..bb55d88 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java @@ -132,23 +132,29 @@ public static double calculateMerge( /* * Calculate the variance family {VARIANCE, VARIANCE_SAMPLE, STANDARD_DEVIATION, or - * STANDARD_DEVIATION_STAMPLE) result when count > 1. Public so vectorization code can + * STANDARD_DEVIATION_SAMPLE) result when count > 1. Public so vectorization code can * use it, etc. */ public static double calculateVarianceFamilyResult(double variance, long count, VarianceKind varianceKind) { + final double result; switch (varianceKind) { case VARIANCE: - return GenericUDAFVarianceEvaluator.calculateVarianceResult(variance, count); + result = GenericUDAFVarianceEvaluator.calculateVarianceResult(variance, count); + break; case VARIANCE_SAMPLE: - return GenericUDAFVarianceSampleEvaluator.calculateVarianceSampleResult(variance, count); + result = GenericUDAFVarianceSampleEvaluator.calculateVarianceSampleResult(variance, count); + break; case STANDARD_DEVIATION: - return GenericUDAFStdEvaluator.calculateStdResult(variance, count); + result = GenericUDAFStdEvaluator.calculateStdResult(variance, count); + break; case STANDARD_DEVIATION_SAMPLE: - return GenericUDAFStdSampleEvaluator.calculateStdSampleResult(variance, count); + result = GenericUDAFStdSampleEvaluator.calculateStdSampleResult(variance, count); + break; default: throw new RuntimeException("Unexpected variance kind " + varianceKind); } + return result; } @Override @@ -381,7 +387,8 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { * Calculate the variance result when count > 1. Public so vectorization code can use it, etc. */ public static double calculateVarianceResult(double variance, long count) { - return variance / count; + final double result = variance / count; + return result; } @Override diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index ffdc410..fe1375b 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -216,7 +216,10 @@ private static AggregationDesc buildAggregationDescCountStar( vectorDesc.setVecAggrDescs( new VectorAggregationDesc[] { new VectorAggregationDesc( - agg, new GenericUDAFCount.GenericUDAFCountEvaluator(), null, ColumnVector.Type.NONE, null, + agg.getGenericUDAFName(), + new GenericUDAFCount.GenericUDAFCountEvaluator(), + agg.getMode(), + null, ColumnVector.Type.NONE, null, TypeInfoFactory.longTypeInfo, ColumnVector.Type.LONG, VectorUDAFCountStar.class)}); vectorDesc.setProcessingMode(VectorGroupByDesc.ProcessingMode.HASH); @@ -1555,7 +1558,7 @@ public void testAvgLongEmpty () throws HiveException { "avg", 2, Arrays.asList(new Long[]{}), - null); + 0.0); } @Test @@ -1564,12 +1567,12 @@ public void testAvgLongNulls () throws HiveException { "avg", 2, Arrays.asList(new Long[]{null}), - null); + 0.0); testAggregateLongAggregate( "avg", 2, Arrays.asList(new Long[]{null, null, null}), - null); + 0.0); testAggregateLongAggregate( "avg", 2, @@ -1601,7 +1604,7 @@ public void testAvgLongRepeatNulls () throws HiveException { null, 4096, 1024, - null); + 0.0); } @SuppressWarnings("unchecked") @@ -1632,7 +1635,7 @@ public void testVarianceLongEmpty () throws HiveException { "variance", 2, Arrays.asList(new Long[]{}), - null); + 0.0); } @Test @@ -1650,12 +1653,12 @@ public void testVarianceLongNulls () throws HiveException { "variance", 2, Arrays.asList(new Long[]{null}), - null); + 0.0); testAggregateLongAggregate( "variance", 2, Arrays.asList(new Long[]{null, null, null}), - null); + 0.0); testAggregateLongAggregate( "variance", 2, @@ -1680,7 +1683,7 @@ public void testVarPopLongRepeatNulls () throws HiveException { null, 4096, 1024, - null); + 0.0); } @Test @@ -1708,7 +1711,7 @@ public void testVarSampLongEmpty () throws HiveException { "var_samp", 2, Arrays.asList(new Long[]{}), - null); + 0.0); } @@ -1737,7 +1740,7 @@ public void testStdLongEmpty () throws HiveException { "std", 2, Arrays.asList(new Long[]{}), - null); + 0.0); } @@ -1758,7 +1761,7 @@ public void testStdDevLongRepeatNulls () throws HiveException { null, 4096, 1024, - null); + 0.0); } @@ -2236,14 +2239,21 @@ public void validate(String key, Object expected, Object result) { assertEquals (true, vals[0] instanceof LongWritable); LongWritable lw = (LongWritable) vals[0]; - assertFalse (lw.get() == 0L); if (vals[1] instanceof DoubleWritable) { DoubleWritable dw = (DoubleWritable) vals[1]; - assertEquals (key, expected, dw.get() / lw.get()); + if (lw.get() != 0L) { + assertEquals (key, expected, dw.get() / lw.get()); + } else { + assertEquals(key, expected, 0.0); + } } else if (vals[1] instanceof HiveDecimalWritable) { HiveDecimalWritable hdw = (HiveDecimalWritable) vals[1]; - assertEquals (key, expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get()))); + if (lw.get() != 0L) { + assertEquals (key, expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get()))); + } else { + assertEquals(key, expected, HiveDecimal.ZERO); + } } } } @@ -2271,10 +2281,14 @@ public void validate(String key, Object expected, Object result) { assertEquals (true, vals[1] instanceof DoubleWritable); assertEquals (true, vals[2] instanceof DoubleWritable); LongWritable cnt = (LongWritable) vals[0]; - DoubleWritable sum = (DoubleWritable) vals[1]; - DoubleWritable var = (DoubleWritable) vals[2]; - assertTrue (1 <= cnt.get()); - validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get()); + if (cnt.get() == 0) { + assertEquals(key, expected, 0.0); + } else { + DoubleWritable sum = (DoubleWritable) vals[1]; + DoubleWritable var = (DoubleWritable) vals[2]; + assertTrue (1 <= cnt.get()); + validateVariance (key, (Double) expected, cnt.get(), sum.get(), var.get()); + } } } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java index 4c2f872..dd2f8e3 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomBatchSource.java @@ -167,6 +167,8 @@ private static VectorBatchPatterns chooseBatchPatterns( VectorRandomRowSource vectorRandomRowSource, Object[][] randomRows) { + final boolean allowNull = vectorRandomRowSource.getAllowNull(); + List vectorBatchPatternList = new ArrayList(); final int rowCount = randomRows.length; int rowIndex = 0; @@ -201,35 +203,38 @@ private static VectorBatchPatterns chooseBatchPatterns( */ while (true) { - // Repeated NULL permutations. long columnPermutation = 1; - while (true) { - if (columnPermutation > columnPermutationLimit) { - break; - } - final int maximumRowCount = Math.min(rowCount - rowIndex, VectorizedRowBatch.DEFAULT_SIZE); - if (maximumRowCount == 0) { - break; - } - int randomRowCount = 1 + random.nextInt(maximumRowCount); - final int rowLimit = rowIndex + randomRowCount; + if (allowNull) { - BitSet bitSet = BitSet.valueOf(new long[]{columnPermutation}); + // Repeated NULL permutations. + while (true) { + if (columnPermutation > columnPermutationLimit) { + break; + } + final int maximumRowCount = Math.min(rowCount - rowIndex, VectorizedRowBatch.DEFAULT_SIZE); + if (maximumRowCount == 0) { + break; + } + int randomRowCount = 1 + random.nextInt(maximumRowCount); + final int rowLimit = rowIndex + randomRowCount; - for (int columnNum = bitSet.nextSetBit(0); - columnNum >= 0; - columnNum = bitSet.nextSetBit(columnNum + 1)) { + BitSet bitSet = BitSet.valueOf(new long[]{columnPermutation}); - // Repeated NULL fill down column. - for (int r = rowIndex; r < rowLimit; r++) { - randomRows[r][columnNum] = null; + for (int columnNum = bitSet.nextSetBit(0); + columnNum >= 0; + columnNum = bitSet.nextSetBit(columnNum + 1)) { + + // Repeated NULL fill down column. + for (int r = rowIndex; r < rowLimit; r++) { + randomRows[r][columnNum] = null; + } } + vectorBatchPatternList.add( + VectorBatchPattern.createRepeatedBatch( + random, randomRowCount, bitSet, asSelected)); + columnPermutation++; + rowIndex = rowLimit; } - vectorBatchPatternList.add( - VectorBatchPattern.createRepeatedBatch( - random, randomRowCount, bitSet, asSelected)); - columnPermutation++; - rowIndex = rowLimit; } // Repeated non-NULL permutations. diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index 6181ae8..a1cefaa 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -21,7 +21,6 @@ import java.text.DateFormat; import java.text.SimpleDateFormat; import java.text.ParseException; - import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -29,7 +28,6 @@ import java.util.Set; import org.apache.commons.lang.StringUtils; - import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.HiveChar; @@ -86,6 +84,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; import org.apache.hive.common.util.DateUtils; import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.LongWritable; @@ -130,6 +129,10 @@ private boolean addEscapables; private String needsEscapeStr; + public boolean getAllowNull() { + return allowNull; + } + public static class StringGenerationOption { private boolean generateSentences; @@ -1021,43 +1024,141 @@ public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeI switch (primitiveTypeInfo.getPrimitiveCategory()) { case BOOLEAN: - return ((WritableBooleanObjectInspector) objectInspector).create((boolean) object); + { + WritableBooleanObjectInspector writableOI = (WritableBooleanObjectInspector) objectInspector; + if (object instanceof Boolean) { + return writableOI.create((boolean) object); + } else { + return writableOI.copyObject(object); + } + } case BYTE: - return ((WritableByteObjectInspector) objectInspector).create((byte) object); + { + WritableByteObjectInspector writableOI = (WritableByteObjectInspector) objectInspector; + if (object instanceof Byte) { + return writableOI.create((byte) object); + } else { + return writableOI.copyObject(object); + } + } case SHORT: - return ((WritableShortObjectInspector) objectInspector).create((short) object); + { + WritableShortObjectInspector writableOI = (WritableShortObjectInspector) objectInspector; + if (object instanceof Short) { + return writableOI.create((short) object); + } else { + return writableOI.copyObject(object); + } + } case INT: - return ((WritableIntObjectInspector) objectInspector).create((int) object); + { + WritableIntObjectInspector writableOI = (WritableIntObjectInspector) objectInspector; + if (object instanceof Integer) { + return writableOI.create((int) object); + } else { + return writableOI.copyObject(object); + } + } case LONG: - return ((WritableLongObjectInspector) objectInspector).create((long) object); + { + WritableLongObjectInspector writableOI = (WritableLongObjectInspector) objectInspector; + if (object instanceof Long) { + return writableOI.create((long) object); + } else { + return writableOI.copyObject(object); + } + } case DATE: - return ((WritableDateObjectInspector) objectInspector).create((Date) object); + { + WritableDateObjectInspector writableOI = (WritableDateObjectInspector) objectInspector; + if (object instanceof Date) { + return writableOI.create((Date) object); + } else { + return writableOI.copyObject(object); + } + } case FLOAT: - return ((WritableFloatObjectInspector) objectInspector).create((float) object); + { + WritableFloatObjectInspector writableOI = (WritableFloatObjectInspector) objectInspector; + if (object instanceof Float) { + return writableOI.create((float) object); + } else { + return writableOI.copyObject(object); + } + } case DOUBLE: - return ((WritableDoubleObjectInspector) objectInspector).create((double) object); + { + WritableDoubleObjectInspector writableOI = (WritableDoubleObjectInspector) objectInspector; + if (object instanceof Double) { + return writableOI.create((double) object); + } else { + return writableOI.copyObject(object); + } + } case STRING: - return ((WritableStringObjectInspector) objectInspector).create((String) object); + { + WritableStringObjectInspector writableOI = (WritableStringObjectInspector) objectInspector; + if (object instanceof String) { + return writableOI.create((String) object); + } else { + return writableOI.copyObject(object); + } + } case CHAR: { WritableHiveCharObjectInspector writableCharObjectInspector = new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); - return writableCharObjectInspector.create((HiveChar) object); + if (object instanceof HiveChar) { + return writableCharObjectInspector.create((HiveChar) object); + } else { + return writableCharObjectInspector.copyObject(object); + } } case VARCHAR: { WritableHiveVarcharObjectInspector writableVarcharObjectInspector = new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); - return writableVarcharObjectInspector.create((HiveVarchar) object); + if (object instanceof HiveVarchar) { + return writableVarcharObjectInspector.create((HiveVarchar) object); + } else { + return writableVarcharObjectInspector.copyObject(object); + } } case BINARY: - return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object); + { + if (object instanceof byte[]) { + return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.create((byte[]) object); + } else { + return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector.copyObject(object); + } + } case TIMESTAMP: - return ((WritableTimestampObjectInspector) objectInspector).create((Timestamp) object); + { + WritableTimestampObjectInspector writableOI = (WritableTimestampObjectInspector) objectInspector; + if (object instanceof Timestamp) { + return writableOI.create((Timestamp) object); + } else { + return writableOI.copyObject(object); + } + } case INTERVAL_YEAR_MONTH: - return ((WritableHiveIntervalYearMonthObjectInspector) objectInspector).create((HiveIntervalYearMonth) object); + { + WritableHiveIntervalYearMonthObjectInspector writableOI = (WritableHiveIntervalYearMonthObjectInspector) objectInspector; + if (object instanceof HiveIntervalYearMonth) { + return writableOI.create((HiveIntervalYearMonth) object); + } else { + return writableOI.copyObject(object); + } + } case INTERVAL_DAY_TIME: - return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).create((HiveIntervalDayTime) object); + { + WritableHiveIntervalDayTimeObjectInspector writableOI = (WritableHiveIntervalDayTimeObjectInspector) objectInspector; + if (object instanceof HiveIntervalDayTime) { + return writableOI.create((HiveIntervalDayTime) object); + } else { + return writableOI.copyObject(object); + } + } case DECIMAL: { if (dataTypePhysicalVariation == dataTypePhysicalVariation.DECIMAL_64) { @@ -1071,9 +1172,13 @@ public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeI } return ((WritableLongObjectInspector) objectInspector).create(value); } else { - WritableHiveDecimalObjectInspector writableDecimalObjectInspector = + WritableHiveDecimalObjectInspector writableOI = new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); - return writableDecimalObjectInspector.create((HiveDecimal) object); + if (object instanceof HiveDecimal) { + return writableOI.create((HiveDecimal) object); + } else { + return writableOI.copyObject(object); + } } } default: @@ -1081,6 +1186,116 @@ public static Object getWritablePrimitiveObject(PrimitiveTypeInfo primitiveTypeI } } + public static Object getNonWritablePrimitiveObject(Object object, TypeInfo typeInfo, + ObjectInspector objectInspector) { + + PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + if (object instanceof Boolean) { + return object; + } else { + return ((WritableBooleanObjectInspector) objectInspector).get(object); + } + case BYTE: + if (object instanceof Byte) { + return object; + } else { + return ((WritableByteObjectInspector) objectInspector).get(object); + } + case SHORT: + if (object instanceof Short) { + return object; + } else { + return ((WritableShortObjectInspector) objectInspector).get(object); + } + case INT: + if (object instanceof Integer) { + return object; + } else { + return ((WritableIntObjectInspector) objectInspector).get(object); + } + case LONG: + if (object instanceof Long) { + return object; + } else { + return ((WritableLongObjectInspector) objectInspector).get(object); + } + case FLOAT: + if (object instanceof Float) { + return object; + } else { + return ((WritableFloatObjectInspector) objectInspector).get(object); + } + case DOUBLE: + if (object instanceof Double) { + return object; + } else { + return ((WritableDoubleObjectInspector) objectInspector).get(object); + } + case STRING: + if (object instanceof String) { + return object; + } else { + return ((WritableStringObjectInspector) objectInspector).getPrimitiveJavaObject(object); + } + case DATE: + if (object instanceof Date) { + return object; + } else { + return ((WritableDateObjectInspector) objectInspector).getPrimitiveJavaObject(object); + } + case TIMESTAMP: + if (object instanceof Timestamp) { + return object; + } else if (object instanceof org.apache.hadoop.hive.common.type.Timestamp) { + return object; + } else { + return ((WritableTimestampObjectInspector) objectInspector).getPrimitiveJavaObject(object); + } + case DECIMAL: + if (object instanceof HiveDecimal) { + return object; + } else { + WritableHiveDecimalObjectInspector writableDecimalObjectInspector = + new WritableHiveDecimalObjectInspector((DecimalTypeInfo) primitiveTypeInfo); + return writableDecimalObjectInspector.getPrimitiveJavaObject(object); + } + case VARCHAR: + if (object instanceof HiveVarchar) { + return object; + } else { + WritableHiveVarcharObjectInspector writableVarcharObjectInspector = + new WritableHiveVarcharObjectInspector( (VarcharTypeInfo) primitiveTypeInfo); + return writableVarcharObjectInspector.getPrimitiveJavaObject(object); + } + case CHAR: + if (object instanceof HiveChar) { + return object; + } else { + WritableHiveCharObjectInspector writableCharObjectInspector = + new WritableHiveCharObjectInspector( (CharTypeInfo) primitiveTypeInfo); + return writableCharObjectInspector.getPrimitiveJavaObject(object); + } + case INTERVAL_YEAR_MONTH: + if (object instanceof HiveIntervalYearMonth) { + return object; + } else { + return ((WritableHiveIntervalYearMonthObjectInspector) objectInspector).getPrimitiveJavaObject(object); + } + case INTERVAL_DAY_TIME: + if (object instanceof HiveIntervalDayTime) { + return object; + } else { + return ((WritableHiveIntervalDayTimeObjectInspector) objectInspector).getPrimitiveJavaObject(object); + } + case BINARY: + default: + throw new RuntimeException( + "Unexpected primitive category " + primitiveTypeInfo.getPrimitiveCategory()); + } + } + public Object randomWritable(int column) { return randomWritable( typeInfos[column], objectInspectorList.get(column), dataTypePhysicalVariations[column], diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java new file mode 100644 index 0000000..583241c --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/AggregationBase.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.aggregation; + +import java.lang.reflect.Constructor; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationDesc; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.optimizer.physical.Vectorizer; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; + +import junit.framework.Assert; + +public class AggregationBase { + + public enum AggregationTestMode { + ROW_MODE, + VECTOR_EXPRESSION; + + static final int count = values().length; + } + + public static GenericUDAFEvaluator getEvaluator(String aggregationFunctionName, + TypeInfo typeInfo) + throws SemanticException { + + GenericUDAFResolver resolver = + FunctionRegistry.getGenericUDAFResolver(aggregationFunctionName); + TypeInfo[] parameters = new TypeInfo[] { typeInfo }; + GenericUDAFEvaluator evaluator = resolver.getEvaluator(parameters); + return evaluator; + } + + protected static boolean doRowTest(TypeInfo typeInfo, + GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo, + GenericUDAFEvaluator.Mode udafEvaluatorMode, int maxKeyCount, + List columns, List children, + Object[][] randomRows, ObjectInspector rowInspector, + Object[] results) + throws Exception { + + /* + System.out.println( + "*DEBUG* typeInfo " + typeInfo.toString() + + " aggregationTestMode ROW_MODE" + + " outputTypeInfo " + outputTypeInfo.toString()); + */ + + // Last entry is for a NULL key. + AggregationBuffer[] aggregationBuffers = new AggregationBuffer[maxKeyCount + 1]; + + ObjectInspector objectInspector = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo); + + Object[] parameterArray = new Object[1]; + final int rowCount = randomRows.length; + for (int i = 0; i < rowCount; i++) { + Object[] row = randomRows[i]; + ShortWritable shortWritable = (ShortWritable) row[0]; + + final int key; + if (shortWritable == null) { + key = maxKeyCount; + } else { + key = shortWritable.get(); + } + AggregationBuffer aggregationBuffer = aggregationBuffers[key]; + if (aggregationBuffer == null) { + aggregationBuffer = evaluator.getNewAggregationBuffer(); + aggregationBuffers[key] = aggregationBuffer; + } + parameterArray[0] = row[1]; + evaluator.aggregate(aggregationBuffer, parameterArray); + } + + final boolean isPrimitive = (outputTypeInfo instanceof PrimitiveTypeInfo); + final boolean isPartial = + (udafEvaluatorMode == GenericUDAFEvaluator.Mode.PARTIAL1 || + udafEvaluatorMode == GenericUDAFEvaluator.Mode.PARTIAL2); + + for (short key = 0; key < maxKeyCount + 1; key++) { + AggregationBuffer aggregationBuffer = aggregationBuffers[key]; + if (aggregationBuffer != null) { + final Object result; + if (isPartial) { + result = evaluator.terminatePartial(aggregationBuffer); + } else { + result = evaluator.terminate(aggregationBuffer); + } + Object copyResult; + if (result == null) { + copyResult = null; + } else if (isPrimitive) { + copyResult = + VectorRandomRowSource.getWritablePrimitiveObject( + (PrimitiveTypeInfo) outputTypeInfo, objectInspector, result); + } else { + copyResult = + ObjectInspectorUtils.copyToStandardObject( + result, objectInspector, ObjectInspectorCopyOption.WRITABLE); + } + results[key] = copyResult; + } + } + + return true; + } + + private static void extractResultObjects(VectorizedRowBatch outputBatch, short[] keys, + VectorExtractRow resultVectorExtractRow, TypeInfo outputTypeInfo, Object[] scrqtchRow, + Object[] results) { + + final boolean isPrimitive = (outputTypeInfo instanceof PrimitiveTypeInfo); + ObjectInspector objectInspector; + if (isPrimitive) { + objectInspector = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo); + } else { + objectInspector = null; + } + + for (int batchIndex = 0; batchIndex < outputBatch.size; batchIndex++) { + resultVectorExtractRow.extractRow(outputBatch, batchIndex, scrqtchRow); + if (isPrimitive) { + Object copyResult = + ObjectInspectorUtils.copyToStandardObject( + scrqtchRow[0], objectInspector, ObjectInspectorCopyOption.WRITABLE); + results[keys[batchIndex]] = copyResult; + } else { + results[keys[batchIndex]] = scrqtchRow[0]; + } + } + } + + protected static boolean doVectorTest(String aggregationName, TypeInfo typeInfo, + GenericUDAFEvaluator evaluator, TypeInfo outputTypeInfo, + GenericUDAFEvaluator.Mode udafEvaluatorMode, int maxKeyCount, + List columns, String[] columnNames, + TypeInfo[] typeInfos, DataTypePhysicalVariation[] dataTypePhysicalVariations, + List parameterList, + VectorRandomBatchSource batchSource, + Object[] results) + throws Exception { + + HiveConf hiveConf = new HiveConf(); + + VectorizationContext vectorizationContext = + new VectorizationContext( + "name", + columns, + Arrays.asList(typeInfos), + Arrays.asList(dataTypePhysicalVariations), + hiveConf); + + ImmutablePair pair = + Vectorizer.getVectorAggregationDesc( + aggregationName, + parameterList, + evaluator, + outputTypeInfo, + udafEvaluatorMode, + vectorizationContext); + VectorAggregationDesc vecAggrDesc = pair.left; + if (vecAggrDesc == null) { + Assert.fail( + "No vector aggregation expression found for aggregationName " + aggregationName + + " udafEvaluatorMode " + udafEvaluatorMode + + " parameterList " + parameterList + + " outputTypeInfo " + outputTypeInfo); + } + + Class vecAggrClass = vecAggrDesc.getVecAggrClass(); + + Constructor ctor = null; + try { + ctor = vecAggrClass.getConstructor(VectorAggregationDesc.class); + } catch (Exception e) { + throw new HiveException("Constructor " + vecAggrClass.getSimpleName() + + "(VectorAggregationDesc) not available"); + } + VectorAggregateExpression vecAggrExpr = null; + try { + vecAggrExpr = ctor.newInstance(vecAggrDesc); + } catch (Exception e) { + + throw new HiveException("Failed to create " + vecAggrClass.getSimpleName() + + "(VectorAggregationDesc) object ", e); + } + VectorExpression.doTransientInit(vecAggrExpr.getInputExpression()); + + /* + System.out.println( + "*DEBUG* typeInfo " + typeInfo.toString() + + " aggregationTestMode VECTOR_MODE" + + " vecAggrExpr " + vecAggrExpr.getClass().getSimpleName()); + */ + + VectorRandomRowSource rowSource = batchSource.getRowSource(); + VectorizedRowBatchCtx batchContext = + new VectorizedRowBatchCtx( + columnNames, + rowSource.typeInfos(), + rowSource.dataTypePhysicalVariations(), + /* dataColumnNums */ null, + /* partitionColumnCount */ 0, + /* virtualColumnCount */ 0, + /* neededVirtualColumns */ null, + vectorizationContext.getScratchColumnTypeNames(), + vectorizationContext.getScratchDataTypePhysicalVariations()); + + VectorizedRowBatch batch = batchContext.createVectorizedRowBatch(); + + // Last entry is for a NULL key. + VectorAggregationBufferRow[] vectorAggregationBufferRows = + new VectorAggregationBufferRow[maxKeyCount + 1]; + + VectorAggregationBufferRow[] batchBufferRows; + + batchSource.resetBatchIteration(); + int rowIndex = 0; + while (true) { + if (!batchSource.fillNextBatch(batch)) { + break; + } + LongColumnVector keyLongColVector = (LongColumnVector) batch.cols[0]; + + batchBufferRows = + new VectorAggregationBufferRow[VectorizedRowBatch.DEFAULT_SIZE]; + + final int size = batch.size; + boolean selectedInUse = batch.selectedInUse; + int[] selected = batch.selected; + for (int logical = 0; logical < size; logical++) { + final int batchIndex = (selectedInUse ? selected[logical] : logical); + final int keyAdjustedBatchIndex; + if (keyLongColVector.isRepeating) { + keyAdjustedBatchIndex = 0; + } else { + keyAdjustedBatchIndex = batchIndex; + } + final short key; + if (keyLongColVector.noNulls || !keyLongColVector.isNull[keyAdjustedBatchIndex]) { + key = (short) keyLongColVector.vector[keyAdjustedBatchIndex]; + } else { + key = (short) maxKeyCount; + } + + VectorAggregationBufferRow bufferRow = vectorAggregationBufferRows[key]; + if (bufferRow == null) { + VectorAggregateExpression.AggregationBuffer aggregationBuffer = + vecAggrExpr.getNewAggregationBuffer(); + aggregationBuffer.reset(); + VectorAggregateExpression.AggregationBuffer[] aggregationBuffers = + new VectorAggregateExpression.AggregationBuffer[] { aggregationBuffer }; + bufferRow = new VectorAggregationBufferRow(aggregationBuffers); + vectorAggregationBufferRows[key] = bufferRow; + } + batchBufferRows[logical] = bufferRow; + } + + vecAggrExpr.aggregateInputSelection( + batchBufferRows, + 0, + batch); + + rowIndex += batch.size; + } + + String[] outputColumnNames = new String[] { "output" }; + + TypeInfo[] outputTypeInfos = new TypeInfo[] { outputTypeInfo }; + VectorizedRowBatchCtx outputBatchContext = + new VectorizedRowBatchCtx( + outputColumnNames, + outputTypeInfos, + null, + /* dataColumnNums */ null, + /* partitionColumnCount */ 0, + /* virtualColumnCount */ 0, + /* neededVirtualColumns */ null, + new String[0], + new DataTypePhysicalVariation[0]); + + VectorizedRowBatch outputBatch = outputBatchContext.createVectorizedRowBatch(); + + short[] keys = new short[VectorizedRowBatch.DEFAULT_SIZE]; + + VectorExtractRow resultVectorExtractRow = new VectorExtractRow(); + resultVectorExtractRow.init( + new TypeInfo[] { outputTypeInfo }, new int[] { 0 }); + Object[] scrqtchRow = new Object[1]; + + for (short key = 0; key < maxKeyCount + 1; key++) { + VectorAggregationBufferRow vectorAggregationBufferRow = vectorAggregationBufferRows[key]; + if (vectorAggregationBufferRow != null) { + if (outputBatch.size == VectorizedRowBatch.DEFAULT_SIZE) { + extractResultObjects(outputBatch, keys, resultVectorExtractRow, outputTypeInfo, + scrqtchRow, results); + outputBatch.reset(); + } + keys[outputBatch.size] = key; + VectorAggregateExpression.AggregationBuffer aggregationBuffer = + vectorAggregationBufferRow.getAggregationBuffer(0); + vecAggrExpr.assignRowColumn(outputBatch, outputBatch.size++, 0, aggregationBuffer); + } + } + if (outputBatch.size > 0) { + extractResultObjects(outputBatch, keys, resultVectorExtractRow, outputTypeInfo, + scrqtchRow, results); + } + + return true; + } + + private boolean compareObjects(Object object1, Object object2, TypeInfo typeInfo, + ObjectInspector objectInspector) { + if (typeInfo instanceof PrimitiveTypeInfo) { + return + VectorRandomRowSource.getWritablePrimitiveObject( + (PrimitiveTypeInfo) typeInfo, objectInspector, object1).equals( + VectorRandomRowSource.getWritablePrimitiveObject( + (PrimitiveTypeInfo) typeInfo, objectInspector, object2)); + } else { + return object1.equals(object2); + } + } + + protected void executeAggregationTests(String aggregationName, TypeInfo typeInfo, + GenericUDAFEvaluator evaluator, + TypeInfo outputTypeInfo, GenericUDAFEvaluator.Mode udafEvaluatorMode, + int maxKeyCount, List columns, String[] columnNames, + List parameters, Object[][] randomRows, + VectorRandomRowSource rowSource, VectorRandomBatchSource batchSource, + Object[] resultsArray) + throws Exception { + + for (int i = 0; i < AggregationTestMode.count; i++) { + + // Last entry is for a NULL key. + Object[] results = new Object[maxKeyCount + 1]; + resultsArray[i] = results; + + AggregationTestMode aggregationTestMode = AggregationTestMode.values()[i]; + switch (aggregationTestMode) { + case ROW_MODE: + if (!doRowTest( + typeInfo, + evaluator, + outputTypeInfo, + udafEvaluatorMode, + maxKeyCount, + columns, + parameters, + randomRows, + rowSource.rowStructObjectInspector(), + results)) { + return; + } + break; + case VECTOR_EXPRESSION: + if (!doVectorTest( + aggregationName, + typeInfo, + evaluator, + outputTypeInfo, + udafEvaluatorMode, + maxKeyCount, + columns, + columnNames, + rowSource.typeInfos(), + rowSource.dataTypePhysicalVariations(), + parameters, + batchSource, + results)) { + return; + } + break; + default: + throw new RuntimeException( + "Unexpected Hash Aggregation test mode " + aggregationTestMode); + } + } + } + + protected void verifyAggregationResults(TypeInfo typeInfo, TypeInfo outputTypeInfo, + int maxKeyCount, GenericUDAFEvaluator.Mode udafEvaluatorMode, + Object[] resultsArray) { + + // Row-mode is the expected results. + Object[] expectedResults = (Object[]) resultsArray[0]; + + ObjectInspector objectInspector = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(outputTypeInfo); + + for (int v = 1; v < AggregationTestMode.count; v++) { + Object[] vectorResults = (Object[]) resultsArray[v]; + + for (short key = 0; key < maxKeyCount + 1; key++) { + Object expectedResult = expectedResults[key]; + Object vectorResult = vectorResults[key]; + if (expectedResult == null || vectorResult == null) { + if (expectedResult != null || vectorResult != null) { + Assert.fail( + "Key " + key + + " typeName " + typeInfo.getTypeName() + + " outputTypeName " + outputTypeInfo.getTypeName() + + " " + AggregationTestMode.values()[v] + + " result is NULL " + (vectorResult == null ? "YES" : "NO result " + vectorResult.toString()) + + " does not match row-mode expected result is NULL " + + (expectedResult == null ? "YES" : "NO result " + expectedResult.toString()) + + " udafEvaluatorMode " + udafEvaluatorMode); + } + } else { + if (!compareObjects(expectedResult, vectorResult, outputTypeInfo, objectInspector)) { + Assert.fail( + "Key " + key + + " typeName " + typeInfo.getTypeName() + + " outputTypeName " + outputTypeInfo.getTypeName() + + " " + AggregationTestMode.values()[v] + + " result " + vectorResult.toString() + + " (" + vectorResult.getClass().getSimpleName() + ")" + + " does not match row-mode expected result " + expectedResult.toString() + + " (" + expectedResult.getClass().getSimpleName() + ")" + + " udafEvaluatorMode " + udafEvaluatorMode); + } + } + } + } + } +} \ No newline at end of file diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java new file mode 100644 index 0000000..c5f0483 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/aggregation/TestVectorAggregation.java @@ -0,0 +1,664 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.aggregation; + +import java.lang.reflect.Constructor; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.sql.Timestamp; + +import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource.GenerationSpec; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFVariance; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableShortObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import org.apache.hadoop.hive.serde2.io.ShortWritable; + +import junit.framework.Assert; + +import org.junit.Ignore; +import org.junit.Test; + +public class TestVectorAggregation extends AggregationBase { + + @Test + public void testAvgIntegers() throws Exception { + Random random = new Random(7743); + + doIntegerTests("avg", random); + } + + @Test + public void testAvgFloating() throws Exception { + Random random = new Random(7743); + + doFloatingTests("avg", random); + } + + @Test + public void testAvgDecimal() throws Exception { + Random random = new Random(7743); + + doDecimalTests("avg", random); + } + + @Test + public void testAvgTimestamp() throws Exception { + Random random = new Random(7743); + + doTests( + random, "avg", TypeInfoFactory.timestampTypeInfo); + } + + @Test + public void testCount() throws Exception { + Random random = new Random(7743); + + doTests( + random, "count", TypeInfoFactory.shortTypeInfo); + doTests( + random, "count", TypeInfoFactory.longTypeInfo); + doTests( + random, "count", TypeInfoFactory.doubleTypeInfo); + doTests( + random, "count", new DecimalTypeInfo(18, 10)); + doTests( + random, "count", TypeInfoFactory.stringTypeInfo); + } + + @Test + public void testMax() throws Exception { + Random random = new Random(7743); + + doIntegerTests("max", random); + doFloatingTests("max", random); + doDecimalTests("max", random); + + doTests( + random, "max", TypeInfoFactory.timestampTypeInfo); + doTests( + random, "max", TypeInfoFactory.intervalDayTimeTypeInfo); + + doStringFamilyTests("max", random); + } + + @Test + public void testMin() throws Exception { + Random random = new Random(7743); + + doIntegerTests("min", random); + doFloatingTests("min", random); + doDecimalTests("min", random); + + doTests( + random, "min", TypeInfoFactory.timestampTypeInfo); + doTests( + random, "min", TypeInfoFactory.intervalDayTimeTypeInfo); + + doStringFamilyTests("min", random); + } + + @Test + public void testSum() throws Exception { + Random random = new Random(7743); + + doTests( + random, "sum", TypeInfoFactory.shortTypeInfo); + doTests( + random, "sum", TypeInfoFactory.longTypeInfo); + doTests( + random, "sum", TypeInfoFactory.doubleTypeInfo); + + doDecimalTests("sum", random); + } + + private final static Set varianceNames = + GenericUDAFVariance.VarianceKind.nameMap.keySet(); + + @Test + public void testVarianceIntegers() throws Exception { + Random random = new Random(7743); + + for (String aggregationName : varianceNames) { + doIntegerTests(aggregationName, random); + } + } + + @Test + public void testVarianceFloating() throws Exception { + Random random = new Random(7743); + + for (String aggregationName : varianceNames) { + doFloatingTests(aggregationName, random); + } + } + + @Test + public void testVarianceDecimal() throws Exception { + Random random = new Random(7743); + + for (String aggregationName : varianceNames) { + doDecimalTests(aggregationName, random); + } + } + + private static TypeInfo[] integerTypeInfos = new TypeInfo[] { + TypeInfoFactory.byteTypeInfo, + TypeInfoFactory.shortTypeInfo, + TypeInfoFactory.intTypeInfo, + TypeInfoFactory.longTypeInfo + }; + + // We have test failures with FLOAT. Ignoring this issue for now. + private static TypeInfo[] floatingTypeInfos = new TypeInfo[] { + // TypeInfoFactory.floatTypeInfo, + TypeInfoFactory.doubleTypeInfo + }; + + private void doIntegerTests(String aggregationName, Random random) + throws Exception { + for (TypeInfo typeInfo : integerTypeInfos) { + doTests( + random, aggregationName, typeInfo); + } + } + + private void doFloatingTests(String aggregationName, Random random) + throws Exception { + for (TypeInfo typeInfo : floatingTypeInfos) { + doTests( + random, aggregationName, typeInfo); + } + } + + private static TypeInfo[] decimalTypeInfos = new TypeInfo[] { + new DecimalTypeInfo(38, 18), + new DecimalTypeInfo(25, 2), + new DecimalTypeInfo(19, 4), + new DecimalTypeInfo(18, 10), + new DecimalTypeInfo(17, 3), + new DecimalTypeInfo(12, 2), + new DecimalTypeInfo(7, 1) + }; + + private void doDecimalTests(String aggregationName, Random random) + throws Exception { + for (TypeInfo typeInfo : decimalTypeInfos) { + doTests( + random, aggregationName, typeInfo); + } + } + + private static TypeInfo[] stringFamilyTypeInfos = new TypeInfo[] { + TypeInfoFactory.stringTypeInfo, + new CharTypeInfo(25), + new CharTypeInfo(10), + new VarcharTypeInfo(20), + new VarcharTypeInfo(15) + }; + + private void doStringFamilyTests(String aggregationName, Random random) + throws Exception { + for (TypeInfo typeInfo : stringFamilyTypeInfos) { + doTests( + random, aggregationName, typeInfo); + } + } + + public static int getLinearRandomNumber(Random random, int maxSize) { + //Get a linearly multiplied random number + int randomMultiplier = maxSize * (maxSize + 1) / 2; + int randomInt = random.nextInt(randomMultiplier); + + //Linearly iterate through the possible values to find the correct one + int linearRandomNumber = 0; + for(int i=maxSize; randomInt >= 0; i--){ + randomInt -= i; + linearRandomNumber++; + } + + return linearRandomNumber; + } + + private static final int TEST_ROW_COUNT = 100000; + + private void doMerge(GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode, + Random random, + String aggregationName, + TypeInfo typeInfo, + GenerationSpec keyGenerationSpec, + List columns, String[] columnNames, + int dataAggrMaxKeyCount, int reductionFactor, + TypeInfo partial1OutputTypeInfo, + Object[] partial1ResultsArray) + throws Exception { + + List mergeAggrGenerationSpecList = new ArrayList(); + List mergeDataTypePhysicalVariationList = + new ArrayList(); + + mergeAggrGenerationSpecList.add(keyGenerationSpec); + mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + // Use OMIT for both. We will fill in the data from the PARTIAL1 results. + GenerationSpec mergeGenerationSpec = + GenerationSpec.createOmitGeneration(partial1OutputTypeInfo); + mergeAggrGenerationSpecList.add(mergeGenerationSpec); + mergeDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + ExprNodeColumnDesc mergeCol1Expr = + new ExprNodeColumnDesc(partial1OutputTypeInfo, "col1", "table", false); + List mergeParameters = new ArrayList(); + mergeParameters.add(mergeCol1Expr); + final int mergeParameterCount = mergeParameters.size(); + ObjectInspector[] mergeParameterObjectInspectors = + new ObjectInspector[mergeParameterCount]; + for (int i = 0; i < mergeParameterCount; i++) { + TypeInfo paramTypeInfo = mergeParameters.get(i).getTypeInfo(); + mergeParameterObjectInspectors[i] = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(paramTypeInfo); + } + + VectorRandomRowSource mergeRowSource = new VectorRandomRowSource(); + + mergeRowSource.initGenerationSpecSchema( + random, mergeAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ false, + mergeDataTypePhysicalVariationList); + + Object[][] mergeRandomRows = mergeRowSource.randomRows(TEST_ROW_COUNT); + + // Reduce the key range to cause there to be work for each PARTIAL2 key. + final int mergeMaxKeyCount = dataAggrMaxKeyCount / reductionFactor; + + Object[] partial1Results = (Object[]) partial1ResultsArray[0]; + + short partial1Key = 0; + for (int i = 0; i < mergeRandomRows.length; i++) { + // Find a non-NULL entry... + while (true) { + if (partial1Key >= dataAggrMaxKeyCount) { + partial1Key = 0; + } + if (partial1Results[partial1Key] != null) { + break; + } + partial1Key++; + } + final short mergeKey = (short) (partial1Key % mergeMaxKeyCount); + mergeRandomRows[i][0] = new ShortWritable(mergeKey); + mergeRandomRows[i][1] = partial1Results[partial1Key]; + partial1Key++; + } + + VectorRandomBatchSource mergeBatchSource = + VectorRandomBatchSource.createInterestingBatches( + random, + mergeRowSource, + mergeRandomRows, + null); + + // We need to pass the original TypeInfo in for initializing the evaluator. + GenericUDAFEvaluator mergeEvaluator = + getEvaluator(aggregationName, typeInfo); + + /* + System.out.println( + "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " + + mergeEvaluator.getClass().getSimpleName()); + */ + + // The only way to get the return object inspector (and its return type) is to + // initialize it... + + ObjectInspector mergeReturnOI = + mergeEvaluator.init( + mergeUdafEvaluatorMode, + mergeParameterObjectInspectors); + TypeInfo mergeOutputTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector(mergeReturnOI); + + Object[] mergeResultsArray = new Object[AggregationTestMode.count]; + + executeAggregationTests( + aggregationName, + partial1OutputTypeInfo, + mergeEvaluator, + mergeOutputTypeInfo, + mergeUdafEvaluatorMode, + mergeMaxKeyCount, + columns, + columnNames, + mergeParameters, + mergeRandomRows, + mergeRowSource, + mergeBatchSource, + mergeResultsArray); + + verifyAggregationResults( + partial1OutputTypeInfo, + mergeOutputTypeInfo, + mergeMaxKeyCount, + mergeUdafEvaluatorMode, + mergeResultsArray); + } + + private void doTests(Random random, String aggregationName, TypeInfo typeInfo) + throws Exception { + + List dataAggrGenerationSpecList = new ArrayList(); + List explicitDataTypePhysicalVariationList = + new ArrayList(); + + TypeInfo keyTypeInfo = TypeInfoFactory.shortTypeInfo; + GenerationSpec keyGenerationSpec = GenerationSpec.createOmitGeneration(keyTypeInfo); + dataAggrGenerationSpecList.add(keyGenerationSpec); + explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + GenerationSpec generationSpec = GenerationSpec.createSameType(typeInfo); + dataAggrGenerationSpecList.add(generationSpec); + explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + List columns = new ArrayList(); + columns.add("col0"); + columns.add("col1"); + + ExprNodeColumnDesc dataAggrCol1Expr = new ExprNodeColumnDesc(typeInfo, "col1", "table", false); + List dataAggrParameters = new ArrayList(); + dataAggrParameters.add(dataAggrCol1Expr); + final int dataAggrParameterCount = dataAggrParameters.size(); + ObjectInspector[] dataAggrParameterObjectInspectors = new ObjectInspector[dataAggrParameterCount]; + for (int i = 0; i < dataAggrParameterCount; i++) { + TypeInfo paramTypeInfo = dataAggrParameters.get(i).getTypeInfo(); + dataAggrParameterObjectInspectors[i] = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(paramTypeInfo); + } + + String[] columnNames = columns.toArray(new String[0]); + + final int dataAggrMaxKeyCount = 20000; + final int reductionFactor = 16; + + ObjectInspector keyObjectInspector = VectorRandomRowSource.getObjectInspector(keyTypeInfo); + + /* + * PARTIAL1. + */ + + VectorRandomRowSource partial1RowSource = new VectorRandomRowSource(); + + partial1RowSource.initGenerationSpecSchema( + random, dataAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ true, + explicitDataTypePhysicalVariationList); + + Object[][] partial1RandomRows = partial1RowSource.randomRows(TEST_ROW_COUNT); + + final int partial1RowCount = partial1RandomRows.length; + for (int i = 0; i < partial1RowCount; i++) { + final short shortKey = (short) getLinearRandomNumber(random, dataAggrMaxKeyCount); + partial1RandomRows[i][0] = + ((WritableShortObjectInspector) keyObjectInspector).create((short) shortKey); + } + + VectorRandomBatchSource partial1BatchSource = + VectorRandomBatchSource.createInterestingBatches( + random, + partial1RowSource, + partial1RandomRows, + null); + + GenericUDAFEvaluator partial1Evaluator = getEvaluator(aggregationName, typeInfo); + + /* + System.out.println( + "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " + + partial1Evaluator.getClass().getSimpleName()); + */ + + // The only way to get the return object inspector (and its return type) is to + // initialize it... + final GenericUDAFEvaluator.Mode partial1UdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL1; + ObjectInspector partial1ReturnOI = + partial1Evaluator.init( + partial1UdafEvaluatorMode, + dataAggrParameterObjectInspectors); + TypeInfo partial1OutputTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector(partial1ReturnOI); + + Object[] partial1ResultsArray = new Object[AggregationTestMode.count]; + + executeAggregationTests( + aggregationName, + typeInfo, + partial1Evaluator, + partial1OutputTypeInfo, + partial1UdafEvaluatorMode, + dataAggrMaxKeyCount, + columns, + columnNames, + dataAggrParameters, + partial1RandomRows, + partial1RowSource, + partial1BatchSource, + partial1ResultsArray); + + verifyAggregationResults( + typeInfo, + partial1OutputTypeInfo, + dataAggrMaxKeyCount, + partial1UdafEvaluatorMode, + partial1ResultsArray); + + final boolean hasDifferentCompleteExpr; + if (varianceNames.contains(aggregationName)) { + hasDifferentCompleteExpr = true; + } else { + switch (aggregationName) { + case "avg": + /* + if (typeInfo instanceof DecimalTypeInfo) { + // UNDONE: Row-mode GenericUDAFAverage does not call enforcePrecisionScale... + hasDifferentCompleteExpr = false; + } else { + hasDifferentCompleteExpr = true; + } + */ + hasDifferentCompleteExpr = true; + break; + case "count": + case "max": + case "min": + case "sum": + hasDifferentCompleteExpr = false; + break; + default: + throw new RuntimeException("Unexpected aggregation name " + aggregationName); + } + } + + if (hasDifferentCompleteExpr) { + + /* + * COMPLETE. + */ + + VectorRandomRowSource completeRowSource = new VectorRandomRowSource(); + + completeRowSource.initGenerationSpecSchema( + random, dataAggrGenerationSpecList, /* maxComplexDepth */ 0, /* allowNull */ true, + explicitDataTypePhysicalVariationList); + + Object[][] completeRandomRows = completeRowSource.randomRows(TEST_ROW_COUNT); + + final int completeRowCount = completeRandomRows.length; + for (int i = 0; i < completeRowCount; i++) { + final short shortKey = (short) getLinearRandomNumber(random, dataAggrMaxKeyCount); + completeRandomRows[i][0] = + ((WritableShortObjectInspector) keyObjectInspector).create((short) shortKey); + } + + VectorRandomBatchSource completeBatchSource = + VectorRandomBatchSource.createInterestingBatches( + random, + completeRowSource, + completeRandomRows, + null); + + GenericUDAFEvaluator completeEvaluator = getEvaluator(aggregationName, typeInfo); + + /* + System.out.println( + "*DEBUG* GenericUDAFEvaluator for " + aggregationName + ", " + typeInfo.getTypeName() + ": " + + completeEvaluator.getClass().getSimpleName()); + */ + + // The only way to get the return object inspector (and its return type) is to + // initialize it... + final GenericUDAFEvaluator.Mode completeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.COMPLETE; + ObjectInspector completeReturnOI = + completeEvaluator.init( + completeUdafEvaluatorMode, + dataAggrParameterObjectInspectors); + TypeInfo completeOutputTypeInfo = + TypeInfoUtils.getTypeInfoFromObjectInspector(completeReturnOI); + + Object[] completeResultsArray = new Object[AggregationTestMode.count]; + + executeAggregationTests( + aggregationName, + typeInfo, + completeEvaluator, + completeOutputTypeInfo, + completeUdafEvaluatorMode, + dataAggrMaxKeyCount, + columns, + columnNames, + dataAggrParameters, + completeRandomRows, + completeRowSource, + completeBatchSource, + completeResultsArray); + + verifyAggregationResults( + typeInfo, + completeOutputTypeInfo, + dataAggrMaxKeyCount, + completeUdafEvaluatorMode, + completeResultsArray); + } + + final boolean hasDifferentPartial2Expr; + if (varianceNames.contains(aggregationName)) { + hasDifferentPartial2Expr = true; + } else { + switch (aggregationName) { + case "avg": + hasDifferentPartial2Expr = true; + break; + case "count": + case "max": + case "min": + case "sum": + hasDifferentPartial2Expr = false; + break; + default: + throw new RuntimeException("Unexpected aggregation name " + aggregationName); + } + } + + if (hasDifferentPartial2Expr && false) { + + /* + * PARTIAL2. + */ + + final GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.PARTIAL2; + + doMerge( + mergeUdafEvaluatorMode, + random, + aggregationName, + typeInfo, + keyGenerationSpec, + columns, columnNames, + dataAggrMaxKeyCount, + reductionFactor, + partial1OutputTypeInfo, + partial1ResultsArray); + } + + final boolean hasDifferentFinalExpr; + if (varianceNames.contains(aggregationName)) { + hasDifferentFinalExpr = true; + } else { + switch (aggregationName) { + case "avg": + hasDifferentFinalExpr = true; + break; + case "count": + hasDifferentFinalExpr = true; + break; + case "max": + case "min": + case "sum": + hasDifferentFinalExpr = false; + break; + default: + throw new RuntimeException("Unexpected aggregation name " + aggregationName); + } + } + if (hasDifferentFinalExpr) { + + /* + * FINAL. + */ + + final GenericUDAFEvaluator.Mode mergeUdafEvaluatorMode = GenericUDAFEvaluator.Mode.FINAL; + + doMerge( + mergeUdafEvaluatorMode, + random, + aggregationName, + typeInfo, + keyGenerationSpec, + columns, columnNames, + dataAggrMaxKeyCount, + reductionFactor, + partial1OutputTypeInfo, + partial1ResultsArray); + } + } +} \ No newline at end of file diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java index f5deca5..c4146be 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateAddSub.java @@ -370,6 +370,7 @@ private void doRowDateAddSubTest(TypeInfo dateTimeStringTypeInfo, TypeInfo integ Object[][] randomRows, ColumnScalarMode columnScalarMode, ObjectInspector rowInspector, Object[] resultObjects) throws Exception { + /* System.out.println( "*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo.toString() + " integerTypeInfo " + integerTypeInfo + @@ -377,6 +378,7 @@ private void doRowDateAddSubTest(TypeInfo dateTimeStringTypeInfo, TypeInfo integ " dateAddSubTestMode ROW_MODE" + " columnScalarMode " + columnScalarMode + " exprDesc " + exprDesc.toString()); + */ HiveConf hiveConf = new HiveConf(); ExprNodeEvaluator evaluator = diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java index dce7ccf..b382c2a 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorDateDiff.java @@ -362,12 +362,14 @@ private void doRowDateAddSubTest(TypeInfo dateTimeStringTypeInfo1, Object[][] randomRows, ColumnScalarMode columnScalarMode, ObjectInspector rowInspector, Object[] resultObjects) throws Exception { + /* System.out.println( "*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo1.toString() + " dateTimeStringTypeInfo2 " + dateTimeStringTypeInfo2 + " dateDiffTestMode ROW_MODE" + " columnScalarMode " + columnScalarMode + " exprDesc " + exprDesc.toString()); + */ HiveConf hiveConf = new HiveConf(); ExprNodeEvaluator evaluator = diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement.java index e54ccaa..9020016 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement.java @@ -199,10 +199,12 @@ private void doIfTestsWithDiffColumnScalar(Random random, String typeName, boolean isNullScalar1, boolean isNullScalar2) throws Exception { + /* System.out.println("*DEBUG* typeName " + typeName + " columnScalarMode " + columnScalarMode + " isNullScalar1 " + isNullScalar1 + " isNullScalar2 " + isNullScalar2); + */ TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeName); @@ -457,11 +459,13 @@ private void doVectorIfTest(TypeInfo typeInfo, resultVectorExtractRow.init(new TypeInfo[] { typeInfo }, new int[] { columns.size() }); Object[] scrqtchRow = new Object[1]; + /* System.out.println( "*DEBUG* typeInfo " + typeInfo.toString() + " ifStmtTestMode " + ifStmtTestMode + " columnScalarMode " + columnScalarMode + " vectorExpression " + vectorExpression.getClass().getSimpleName()); + */ batchSource.resetBatchIteration(); int rowIndex = 0; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorNegative.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorNegative.java index ce20f28..d43249e 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorNegative.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorNegative.java @@ -330,10 +330,12 @@ private void doRowArithmeticTest(TypeInfo typeInfo, ObjectInspector rowInspector, TypeInfo outputTypeInfo, Object[] resultObjects) throws Exception { + /* System.out.println( "*DEBUG* typeInfo " + typeInfo.toString() + " negativeTestMode ROW_MODE" + " exprDesc " + exprDesc.toString()); + */ HiveConf hiveConf = new HiveConf(); ExprNodeEvaluator evaluator = @@ -425,10 +427,13 @@ private void doVectorArithmeticTest(TypeInfo typeInfo, new TypeInfo[] { outputTypeInfo }, new int[] { vectorExpression.getOutputColumnNum() }); Object[] scrqtchRow = new Object[1]; + /* System.out.println( "*DEBUG* typeInfo " + typeInfo.toString() + " negativeTestMode " + negativeTestMode + " vectorExpression " + vectorExpression.toString()); + */ + batchSource.resetBatchIteration(); int rowIndex = 0; while (true) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringConcat.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringConcat.java index a87a8b4..f3050c2 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringConcat.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringConcat.java @@ -305,12 +305,14 @@ private void doRowStringConcatTest(TypeInfo stringTypeInfo, TypeInfo integerType ObjectInspector rowInspector, GenericUDF genericUdf, Object[] resultObjects) throws Exception { + /* System.out.println( "*DEBUG* stringTypeInfo " + stringTypeInfo.toString() + " integerTypeInfo " + integerTypeInfo + " stringConcatTestMode ROW_MODE" + " columnScalarMode " + columnScalarMode + " genericUdf " + genericUdf.toString()); + */ ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(TypeInfoFactory.stringTypeInfo, genericUdf, children); @@ -405,12 +407,14 @@ private void doVectorStringConcatTest(TypeInfo stringTypeInfo1, TypeInfo stringT new TypeInfo[] { outputTypeInfo }, new int[] { columns.size() }); Object[] scrqtchRow = new Object[1]; + /* System.out.println( "*DEBUG* stringTypeInfo1 " + stringTypeInfo1.toString() + " stringTypeInfo2 " + stringTypeInfo2.toString() + " stringConcatTestMode " + stringConcatTestMode + " columnScalarMode " + columnScalarMode + " vectorExpression " + vectorExpression.toString()); + */ batchSource.resetBatchIteration(); int rowIndex = 0; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringUnary.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringUnary.java index 90f7992..8df5595 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringUnary.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorStringUnary.java @@ -347,11 +347,13 @@ private void doVectorIfTest(TypeInfo typeInfo, TypeInfo targetTypeInfo, resultVectorExtractRow.init(new TypeInfo[] { targetTypeInfo }, new int[] { columns.size() }); Object[] scrqtchRow = new Object[1]; + /* System.out.println( "*DEBUG* typeInfo " + typeInfo.toString() + " targetTypeInfo " + targetTypeInfo.toString() + " stringUnaryTestMode " + stringUnaryTestMode + " vectorExpression " + vectorExpression.getClass().getSimpleName()); + */ batchSource.resetBatchIteration(); int rowIndex = 0; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorSubStr.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorSubStr.java index 284a47a..b1344ab 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorSubStr.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorSubStr.java @@ -326,11 +326,13 @@ private void doVectorIfTest(TypeInfo typeInfo, TypeInfo targetTypeInfo, resultVectorExtractRow.init(new TypeInfo[] { targetTypeInfo }, new int[] { columns.size() }); Object[] scrqtchRow = new Object[1]; + /* System.out.println( "*DEBUG* typeInfo " + typeInfo.toString() + " targetTypeInfo " + targetTypeInfo.toString() + " subStrTestMode " + subStrTestMode + " vectorExpression " + vectorExpression.getClass().getSimpleName()); + */ batchSource.resetBatchIteration(); int rowIndex = 0; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTimestampExtract.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTimestampExtract.java index 58e3fa3..e56a6c3 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTimestampExtract.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTimestampExtract.java @@ -283,10 +283,12 @@ private boolean doRowCastTest(TypeInfo dateTimeStringTypeInfo, Object[][] randomRows, ObjectInspector rowInspector, Object[] resultObjects) throws Exception { + /* System.out.println( "*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo.toString() + " timestampExtractTestMode ROW_MODE" + " exprDesc " + exprDesc.toString()); + */ HiveConf hiveConf = new HiveConf(); ExprNodeEvaluator evaluator = @@ -392,10 +394,12 @@ private boolean doVectorCastTest(TypeInfo dateTimeStringTypeInfo, VectorExpression vectorExpression = vectorizationContext.getVectorExpression(exprDesc); vectorExpression.transientInit(); + /* System.out.println( "*DEBUG* dateTimeStringTypeInfo " + dateTimeStringTypeInfo.toString() + " timestampExtractTestMode " + timestampExtractTestMode + " vectorExpression " + vectorExpression.getClass().getSimpleName()); + */ VectorRandomRowSource rowSource = batchSource.getRowSource(); VectorizedRowBatchCtx batchContext = diff --git ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java index d9fc060..2a2bbe1 100644 --- ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java +++ ql/src/test/org/apache/hadoop/hive/ql/optimizer/physical/TestVectorizer.java @@ -112,7 +112,10 @@ public void testAggregateOnUDF() throws HiveException, VectorizerCannotVectorize vectorDesc.setVecAggrDescs( new VectorAggregationDesc[] { new VectorAggregationDesc( - aggDesc, new GenericUDAFSum.GenericUDAFSumLong(), TypeInfoFactory.longTypeInfo, ColumnVector.Type.LONG, null, + aggDesc.getGenericUDAFName(), + new GenericUDAFSum.GenericUDAFSumLong(), + aggDesc.getMode(), + TypeInfoFactory.longTypeInfo, ColumnVector.Type.LONG, null, TypeInfoFactory.longTypeInfo, ColumnVector.Type.LONG, VectorUDAFCountStar.class)}); desc.setOutputColumnNames(outputColumnNames);