diff --git a/common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java b/common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java index a5d7399..8578d2f 100644 --- a/common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java +++ b/common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java @@ -261,8 +261,19 @@ public boolean isZero() { * object to copy from */ public Decimal128 update(Decimal128 o) { + update(o, o.scale); + return this; + } + + /** + * Copy the value of given object and assigns a custom scale. + * + * @param o + * object to copy from + */ + public Decimal128 update(Decimal128 o, short scale) { this.unscaledValue.update(o.unscaledValue); - this.scale = o.scale; + this.scale = scale; this.signum = o.signum; return this; } diff --git a/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt b/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt index de9a84c..37ce103 100644 --- a/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt +++ b/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt @@ -60,12 +60,12 @@ public class extends VectorAggregateExpression { value = new Decimal128(); } - public void checkValue(Decimal128 value) { + public void checkValue(Decimal128 value, short scale) { if (isNull) { isNull = false; this.value.update(value); } else if (this.value.compareTo(value) 0) { - this.value.update(value); + this.value.update(value, scale); } } @@ -124,16 +124,16 @@ public class extends VectorAggregateExpression { if (inputVector.isRepeating) { iterateNoNullsRepeatingWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector[0], batchSize); + vector[0], inputVector.scale, batchSize); } else { if (batch.selectedInUse) { iterateNoNullsSelectionWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector, batch.selected, batchSize); + vector, inputVector.scale, batch.selected, batchSize); } else { iterateNoNullsWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector, batchSize); + vector, inputVector.scale, batchSize); } } } else { @@ -141,21 +141,21 @@ public class extends VectorAggregateExpression { if (batch.selectedInUse) { iterateHasNullsRepeatingSelectionWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector[0], batchSize, batch.selected, inputVector.isNull); + vector[0], inputVector.scale, batchSize, batch.selected, inputVector.isNull); } else { iterateHasNullsRepeatingWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector[0], batchSize, inputVector.isNull); + vector[0], inputVector.scale, batchSize, inputVector.isNull); } } else { if (batch.selectedInUse) { iterateHasNullsSelectionWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector, batchSize, batch.selected, inputVector.isNull); + vector, inputVector.scale, batchSize, batch.selected, inputVector.isNull); } else { iterateHasNullsWithAggregationSelection( aggregationBufferSets, aggregrateIndex, - vector, batchSize, inputVector.isNull); + vector, inputVector.scale, batchSize, inputVector.isNull); } } } @@ -165,6 +165,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128 value, + short scale, int batchSize) { for (int i=0; i < batchSize; ++i) { @@ -172,7 +173,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(value); + myagg.checkValue(value, scale); } } @@ -180,6 +181,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128[] values, + short scale, int[] selection, int batchSize) { @@ -188,7 +190,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[selection[i]]); + myagg.checkValue(values[selection[i]], scale); } } @@ -196,13 +198,14 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128[] values, + short scale, int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[i]); + myagg.checkValue(values[i], scale); } } @@ -210,6 +213,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128 value, + short scale, int batchSize, int[] selection, boolean[] isNull) { @@ -220,7 +224,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(value); + myagg.checkValue(value, scale); } } @@ -230,6 +234,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128 value, + short scale, int batchSize, boolean[] isNull) { @@ -239,7 +244,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(value); + myagg.checkValue(value, scale); } } } @@ -248,6 +253,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128[] values, + short scale, int batchSize, int[] selection, boolean[] isNull) { @@ -259,7 +265,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, j); - myagg.checkValue(values[i]); + myagg.checkValue(values[i], scale); } } } @@ -268,6 +274,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, Decimal128[] values, + short scale, int batchSize, boolean[] isNull) { @@ -277,7 +284,7 @@ public class extends VectorAggregateExpression { aggregationBufferSets, aggregrateIndex, i); - myagg.checkValue(values[i]); + myagg.checkValue(values[i], scale); } } } @@ -305,28 +312,31 @@ public class extends VectorAggregateExpression { if (inputVector.noNulls && (myagg.isNull || (myagg.value.compareTo(vector[0]) 0))) { myagg.isNull = false; - myagg.value.update(vector[0]); + myagg.value.update(vector[0], inputVector.scale); } return; } if (!batch.selectedInUse && inputVector.noNulls) { - iterateNoSelectionNoNulls(myagg, vector, batchSize); + iterateNoSelectionNoNulls(myagg, vector, inputVector.scale, batchSize); } else if (!batch.selectedInUse) { - iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + iterateNoSelectionHasNulls(myagg, vector, inputVector.scale, + batchSize, inputVector.isNull); } else if (inputVector.noNulls){ - iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + iterateSelectionNoNulls(myagg, vector, inputVector.scale, batchSize, batch.selected); } else { - iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + iterateSelectionHasNulls(myagg, vector, inputVector.scale, + batchSize, inputVector.isNull, batch.selected); } } private void iterateSelectionHasNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize, boolean[] isNull, int[] selected) { @@ -340,7 +350,7 @@ public class extends VectorAggregateExpression { myagg.value.update(value); } else if (myagg.value.compareTo(value) 0) { - myagg.value.update(value); + myagg.value.update(value, scale); } } } @@ -349,6 +359,7 @@ public class extends VectorAggregateExpression { private void iterateSelectionNoNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize, int[] selected) { @@ -360,7 +371,7 @@ public class extends VectorAggregateExpression { for (int i=0; i< batchSize; ++i) { Decimal128 value = vector[selected[i]]; if (myagg.value.compareTo(value) 0) { - myagg.value.update(value); + myagg.value.update(value, scale); } } } @@ -368,6 +379,7 @@ public class extends VectorAggregateExpression { private void iterateNoSelectionHasNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize, boolean[] isNull) { @@ -375,11 +387,11 @@ public class extends VectorAggregateExpression { if (!isNull[i]) { Decimal128 value = vector[i]; if (myagg.isNull) { - myagg.value.update(value); + myagg.value.update(value, scale); myagg.isNull = false; } else if (myagg.value.compareTo(value) 0) { - myagg.value.update(value); + myagg.value.update(value, scale); } } } @@ -388,6 +400,7 @@ public class extends VectorAggregateExpression { private void iterateNoSelectionNoNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize) { if (myagg.isNull) { myagg.value.update(vector[0]); @@ -397,7 +410,7 @@ public class extends VectorAggregateExpression { for (int i=0;i 0) { - myagg.value.update(value); + myagg.value.update(value, scale); } } } diff --git a/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt b/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt index e626161..c5af930 100644 --- a/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt +++ b/ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt @@ -79,27 +79,27 @@ public class extends VectorAggregateExpression { throw new UnsupportedOperationException(); } - public void updateValueWithCheckAndInit(Decimal128 scratch, Decimal128 value) { + public void updateValueWithCheckAndInit(Decimal128 scratch, Decimal128 value, short scale) { if (this.isNull) { this.init(); } - this.sum.addDestructive(value, value.getScale()); + this.sum.addDestructive(value, scale); this.count += 1; if(this.count > 1) { scratch.update(count); - scratch.multiplyDestructive(value, value.getScale()); - scratch.subtractDestructive(sum, sum.getScale()); + scratch.multiplyDestructive(value, scale); + scratch.subtractDestructive(sum, scale); double t = scratch.doubleValue(); this.variance += (t*t) / ((double)this.count*(this.count-1)); } } - public void updateValueNoCheck(Decimal128 scratch, Decimal128 value) { - this.sum.addDestructive(value, value.getScale()); + public void updateValueNoCheck(Decimal128 scratch, Decimal128 value, short scale) { + this.sum.addDestructive(value, scale); this.count += 1; scratch.update(count); - scratch.multiplyDestructive(value, value.getScale()); - scratch.subtractDestructive(sum, sum.getScale()); + scratch.multiplyDestructive(value, scale); + scratch.subtractDestructive(sum, scale); double t = scratch.doubleValue(); this.variance += (t*t) / ((double)this.count*(this.count-1)); } @@ -181,24 +181,26 @@ public class extends VectorAggregateExpression { if (inputVector.isRepeating) { if (inputVector.noNulls || !inputVector.isNull[0]) { iterateRepeatingNoNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector[0], batchSize); + aggregationBufferSets, aggregateIndex, vector[0], inputVector.scale, batchSize); } } else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector, batchSize); + aggregationBufferSets, aggregateIndex, vector, inputVector.scale, batchSize); } else if (!batch.selectedInUse) { iterateNoSelectionHasNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull); + aggregationBufferSets, aggregateIndex, vector, inputVector.scale, + batchSize, inputVector.isNull); } else if (inputVector.noNulls){ iterateSelectionNoNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected); + aggregationBufferSets, aggregateIndex, vector, inputVector.scale, + batchSize, batch.selected); } else { iterateSelectionHasNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector, batchSize, + aggregationBufferSets, aggregateIndex, vector, inputVector.scale, batchSize, inputVector.isNull, batch.selected); } @@ -208,6 +210,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, Decimal128 value, + short scale, int batchSize) { for (int i=0; i extends VectorAggregateExpression { aggregationBufferSets, aggregateIndex, i); - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } @@ -223,6 +226,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, Decimal128[] vector, + short scale, int batchSize, boolean[] isNull, int[] selected) { @@ -235,7 +239,7 @@ public class extends VectorAggregateExpression { int i = selected[j]; if (!isNull[i]) { Decimal128 value = vector[i]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } } @@ -244,6 +248,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, Decimal128[] vector, + short scale, int batchSize, int[] selected) { @@ -253,7 +258,7 @@ public class extends VectorAggregateExpression { aggregateIndex, i); Decimal128 value = vector[selected[i]]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } @@ -261,6 +266,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, Decimal128[] vector, + short scale, int batchSize, boolean[] isNull) { @@ -271,7 +277,7 @@ public class extends VectorAggregateExpression { aggregateIndex, i); Decimal128 value = vector[i]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } } @@ -280,6 +286,7 @@ public class extends VectorAggregateExpression { VectorAggregationBufferRow[] aggregationBufferSets, int aggregateIndex, Decimal128[] vector, + short scale, int batchSize) { for (int i=0; i extends VectorAggregateExpression { aggregateIndex, i); Decimal128 value = vector[i]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } @@ -313,42 +320,45 @@ public class extends VectorAggregateExpression { if (inputVector.isRepeating) { if (inputVector.noNulls) { - iterateRepeatingNoNulls(myagg, vector[0], batchSize); + iterateRepeatingNoNulls(myagg, vector[0], inputVector.scale, batchSize); } } else if (!batch.selectedInUse && inputVector.noNulls) { - iterateNoSelectionNoNulls(myagg, vector, batchSize); + iterateNoSelectionNoNulls(myagg, vector, inputVector.scale, batchSize); } else if (!batch.selectedInUse) { - iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + iterateNoSelectionHasNulls(myagg, vector, inputVector.scale, batchSize, inputVector.isNull); } else if (inputVector.noNulls){ - iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + iterateSelectionNoNulls(myagg, vector, inputVector.scale, batchSize, batch.selected); } else { - iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + iterateSelectionHasNulls(myagg, vector, inputVector.scale, + batchSize, inputVector.isNull, batch.selected); } } private void iterateRepeatingNoNulls( Aggregation myagg, Decimal128 value, + short scale, int batchSize) { // TODO: conjure a formula w/o iterating // - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); // We pulled out i=0 so we can remove the count > 1 check in the loop for (int i=1; i extends VectorAggregateExpression { int i = selected[j]; if (!isNull[i]) { Decimal128 value = vector[i]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); } } } @@ -365,6 +375,7 @@ public class extends VectorAggregateExpression { private void iterateSelectionNoNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize, int[] selected) { @@ -373,26 +384,27 @@ public class extends VectorAggregateExpression { } Decimal128 value = vector[selected[0]]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); // i=0 was pulled out to remove the count > 1 check in the loop // for (int i=1; i< batchSize; ++i) { value = vector[selected[i]]; - myagg.updateValueNoCheck(scratchDecimal, value); + myagg.updateValueNoCheck(scratchDecimal, value, scale); } } private void iterateNoSelectionHasNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize, boolean[] isNull) { for(int i=0;i extends VectorAggregateExpression { private void iterateNoSelectionNoNulls( Aggregation myagg, Decimal128[] vector, + short scale, int batchSize) { if (myagg.isNull) { @@ -407,12 +420,12 @@ public class extends VectorAggregateExpression { } Decimal128 value = vector[0]; - myagg.updateValueWithCheckAndInit(scratchDecimal, value); + myagg.updateValueWithCheckAndInit(scratchDecimal, value, scale); // i=0 was pulled out to remove count > 1 check for (int i=1; i values, - Object expected) throws HiveException { + String typeName, + String aggregateName, + int batchSize, + Iterable values, + Object expected) throws HiveException { @SuppressWarnings("unchecked") FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( - batchSize, new String[] {"Decimal"}, values); + batchSize, new String[] {typeName}, values); testAggregateDecimalIterable (aggregateName, fdr, expected); } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java index ba7b0f9..eab051e 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java @@ -22,6 +22,9 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.regex.MatchResult; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; @@ -141,7 +144,23 @@ public void assign( } }; } else if (types[i].toLowerCase().startsWith("decimal")) { - batch.cols[i] = new DecimalColumnVector(batchSize, 38, 0); + Pattern decimalPattern = Pattern.compile( + "decimal(?:\\((\\d+)(?:\\,(\\d+))?\\))?", Pattern.CASE_INSENSITIVE); + Matcher mr = decimalPattern.matcher(types[i]); + int precission = 38; + int scale = 0; + if (mr.matches()) { + String typePrecission = mr.group(1); + if (typePrecission != null) { + precission = Integer.parseInt(typePrecission); + } + String typeScale = mr.group(2); + if (typeScale != null) { + scale = Integer.parseInt(typeScale); + } + } + + batch.cols[i] = new DecimalColumnVector(batchSize, precission, scale); columnAssign[i] = new ColumnVectorAssign() { @Override public void assign(