diff --git ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java index 1b76fc9..330e91d 100644 --- ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java +++ ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java @@ -459,6 +459,11 @@ "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: long)"}, {"VectorUDAFMinMax", "VectorUDAFMaxDouble", "double", ">", "max", "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: double)"}, + + {"VectorUDAFMinMaxDecimal", "VectorUDAFMaxDecimal", "<", "max", + "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: decimal)"}, + {"VectorUDAFMinMaxDecimal", "VectorUDAFMinDecimal", ">", "min", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: decimal)"}, {"VectorUDAFMinMaxString", "VectorUDAFMinString", "<", "min", "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: string)"}, @@ -479,24 +484,36 @@ {"VectorUDAFVar", "VectorUDAFVarPopDouble", "double", "myagg.variance / myagg.count", "variance, var_pop", "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFVarPopDecimal", "myagg.variance / myagg.count", + "variance, var_pop", + "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFVarSampLong", "long", "myagg.variance / (myagg.count-1.0)", "var_samp", "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFVarSampDouble", "double", "myagg.variance / (myagg.count-1.0)", "var_samp", "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFVarSampDecimal", "myagg.variance / (myagg.count-1.0)", + "var_samp", + "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFStdPopLong", "long", "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFStdPopDouble", "double", "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFStdPopDecimal", + "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", + "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFStdSampLong", "long", "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFStdSampDouble", "double", "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFStdSampDecimal", + "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", + "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, decimal)"}, }; @@ -619,12 +636,16 @@ private void generate() throws Exception { generateVectorUDAFMinMax(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMaxString")) { generateVectorUDAFMinMaxString(tdesc); + } else if (tdesc[0].equals("VectorUDAFMinMaxDecimal")) { + generateVectorUDAFMinMaxDecimal(tdesc); } else if (tdesc[0].equals("VectorUDAFSum")) { generateVectorUDAFSum(tdesc); } else if (tdesc[0].equals("VectorUDAFAvg")) { generateVectorUDAFAvg(tdesc); } else if (tdesc[0].equals("VectorUDAFVar")) { generateVectorUDAFVar(tdesc); + } else if (tdesc[0].equals("VectorUDAFVarDecimal")) { + generateVectorUDAFVarDecimal(tdesc); } else if (tdesc[0].equals("FilterStringColumnCompareScalar")) { generateFilterStringColumnCompareScalar(tdesc); } else if (tdesc[0].equals("FilterStringColumnBetween")) { @@ -675,7 +696,7 @@ private void generateFilterStringColumnBetween(String[] tdesc) throws IOExceptio className, templateString); } - private void generateFilterColumnBetween(String[] tdesc) throws IOException { + private void generateFilterColumnBetween(String[] tdesc) throws Exception { String operandType = tdesc[1]; String optionalNot = tdesc[2]; @@ -695,7 +716,7 @@ private void generateFilterColumnBetween(String[] tdesc) throws IOException { className, templateString); } - private void generateColumnCompareColumn(String[] tdesc) throws IOException { + private void generateColumnCompareColumn(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -748,6 +769,24 @@ private void generateVectorUDAFMinMaxString(String[] tdesc) throws Exception { className, templateString); } + private void generateVectorUDAFMinMaxDecimal(String[] tdesc) throws Exception { + String className = tdesc[1]; + String operatorSymbol = tdesc[2]; + String descName = tdesc[3]; + String descValue = tdesc[4]; + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", descName); + templateString = templateString.replaceAll("", descValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateVectorUDAFSum(String[] tdesc) throws Exception { //template, , , , String className = tdesc[1]; @@ -768,7 +807,7 @@ private void generateVectorUDAFSum(String[] tdesc) throws Exception { className, templateString); } - private void generateVectorUDAFAvg(String[] tdesc) throws IOException { + private void generateVectorUDAFAvg(String[] tdesc) throws Exception { String className = tdesc[1]; String valueType = tdesc[2]; String columnType = getColumnVectorType(valueType); @@ -783,7 +822,7 @@ private void generateVectorUDAFAvg(String[] tdesc) throws IOException { className, templateString); } - private void generateVectorUDAFVar(String[] tdesc) throws IOException { + private void generateVectorUDAFVar(String[] tdesc) throws Exception { String className = tdesc[1]; String valueType = tdesc[2]; String varianceFormula = tdesc[3]; @@ -804,6 +843,24 @@ private void generateVectorUDAFVar(String[] tdesc) throws IOException { className, templateString); } + private void generateVectorUDAFVarDecimal(String[] tdesc) throws Exception { + String className = tdesc[1]; + String varianceFormula = tdesc[2]; + String descriptionName = tdesc[3]; + String descriptionValue = tdesc[4]; + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", varianceFormula); + templateString = templateString.replaceAll("", descriptionName); + templateString = templateString.replaceAll("", descriptionValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateFilterStringScalarCompareColumn(String[] tdesc) throws IOException { String operatorName = tdesc[1]; String className = "FilterStringScalar" + operatorName + "StringColumn"; @@ -857,7 +914,7 @@ private void generateStringColumnCompareScalar(String[] tdesc, String className) className, templateString); } - private void generateFilterColumnCompareColumn(String[] tdesc) throws IOException { + private void generateFilterColumnCompareColumn(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -868,7 +925,7 @@ private void generateFilterColumnCompareColumn(String[] tdesc) throws IOExceptio generateColumnBinaryOperatorColumn(tdesc, null, className); } - private void generateColumnUnaryMinus(String[] tdesc) throws IOException { + private void generateColumnUnaryMinus(String[] tdesc) throws Exception { String operandType = tdesc[1]; String inputColumnVectorType = this.getColumnVectorType(operandType); String outputColumnVectorType = inputColumnVectorType; @@ -886,7 +943,7 @@ private void generateColumnUnaryMinus(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprColumnColumn(String[] tdesc) throws IOException { + private void generateIfExprColumnColumn(String[] tdesc) throws Exception { String operandType = tdesc[1]; String inputColumnVectorType = this.getColumnVectorType(operandType); String outputColumnVectorType = inputColumnVectorType; @@ -904,7 +961,7 @@ private void generateIfExprColumnColumn(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprColumnScalar(String[] tdesc) throws IOException { + private void generateIfExprColumnScalar(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg2ColumnVectorType = this.getColumnVectorType(operandType2); @@ -926,7 +983,7 @@ private void generateIfExprColumnScalar(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprScalarColumn(String[] tdesc) throws IOException { + private void generateIfExprScalarColumn(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg3ColumnVectorType = this.getColumnVectorType(operandType3); @@ -948,7 +1005,7 @@ private void generateIfExprScalarColumn(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprScalarScalar(String[] tdesc) throws IOException { + private void generateIfExprScalarScalar(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg3ColumnVectorType = this.getColumnVectorType(operandType3); @@ -970,7 +1027,7 @@ private void generateIfExprScalarScalar(String[] tdesc) throws IOException { } // template, , , , , , - private void generateColumnUnaryFunc(String[] tdesc) throws IOException { + private void generateColumnUnaryFunc(String[] tdesc) throws Exception { String classNamePrefix = tdesc[1]; String operandType = tdesc[3]; String inputColumnVectorType = this.getColumnVectorType(operandType); @@ -998,7 +1055,7 @@ private void generateColumnUnaryFunc(String[] tdesc) throws IOException { className, templateString); } - private void generateColumnArithmeticColumn(String [] tdesc) throws IOException { + private void generateColumnArithmeticColumn(String [] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1008,7 +1065,7 @@ private void generateColumnArithmeticColumn(String [] tdesc) throws IOException generateColumnBinaryOperatorColumn(tdesc, returnType, className); } - private void generateFilterColumnCompareScalar(String[] tdesc) throws IOException { + private void generateFilterColumnCompareScalar(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -1019,7 +1076,7 @@ private void generateFilterColumnCompareScalar(String[] tdesc) throws IOExceptio generateColumnBinaryOperatorScalar(tdesc, null, className); } - private void generateFilterScalarCompareColumn(String[] tdesc) throws IOException { + private void generateFilterScalarCompareColumn(String[] tdesc) throws Exception { //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; String operandType1 = tdesc[2]; @@ -1029,7 +1086,7 @@ private void generateFilterScalarCompareColumn(String[] tdesc) throws IOExceptio generateScalarBinaryOperatorColumn(tdesc, null, className); } - private void generateColumnCompareScalar(String[] tdesc) throws IOException { + private void generateColumnCompareScalar(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1039,7 +1096,7 @@ private void generateColumnCompareScalar(String[] tdesc) throws IOException { generateColumnBinaryOperatorScalar(tdesc, returnType, className); } - private void generateScalarCompareColumn(String[] tdesc) throws IOException { + private void generateScalarCompareColumn(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1050,10 +1107,11 @@ private void generateScalarCompareColumn(String[] tdesc) throws IOException { } private void generateColumnBinaryOperatorColumn(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType1 = this.getColumnVectorType(operandType1); String inputColumnVectorType2 = this.getColumnVectorType(operandType2); String operatorSymbol = tdesc[4]; @@ -1089,10 +1147,11 @@ private void generateColumnBinaryOperatorColumn(String[] tdesc, String returnTyp } private void generateColumnBinaryOperatorScalar(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType = this.getColumnVectorType(operandType1); String operatorSymbol = tdesc[4]; @@ -1127,10 +1186,11 @@ private void generateColumnBinaryOperatorScalar(String[] tdesc, String returnTyp } private void generateScalarBinaryOperatorColumn(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType = this.getColumnVectorType(operandType2); String operatorSymbol = tdesc[4]; @@ -1166,7 +1226,7 @@ private void generateScalarBinaryOperatorColumn(String[] tdesc, String returnTyp } //Binary arithmetic operator - private void generateColumnArithmeticScalar(String[] tdesc) throws IOException { + private void generateColumnArithmeticScalar(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1260,7 +1320,7 @@ private void generateColumnDivideColumnDecimal(String[] tdesc) throws IOExceptio className, templateString); } - private void generateScalarArithmeticColumn(String[] tdesc) throws IOException { + private void generateScalarArithmeticColumn(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1369,11 +1429,15 @@ private String getArithmeticReturnType(String operandType1, } } - private String getColumnVectorType(String primitiveType) { + private String getColumnVectorType(String primitiveType) throws Exception { if(primitiveType!=null && primitiveType.equals("double")) { return "DoubleColumnVector"; + } else if (primitiveType.equals("long")) { + return "LongColumnVector"; + } else if (primitiveType.equals("decimal")) { + return "DecimalColumnVector"; } - return "LongColumnVector"; + throw new Exception("Unimplemented primitive column vector type: " + primitiveType); } private String getOutputWritableType(String primitiveType) throws Exception { @@ -1381,6 +1445,8 @@ private String getOutputWritableType(String primitiveType) throws Exception { return "LongWritable"; } else if (primitiveType.equals("double")) { return "DoubleWritable"; + } else if (primitiveType.equals("decimal")) { + return "HiveDecimalWritable"; } throw new Exception("Unimplemented primitive output writable: " + primitiveType); } @@ -1390,6 +1456,8 @@ private String getOutputObjectInspector(String primitiveType) throws Exception { return "PrimitiveObjectInspectorFactory.writableLongObjectInspector"; } else if (primitiveType.equals("double")) { return "PrimitiveObjectInspectorFactory.writableDoubleObjectInspector"; + } else if (primitiveType.equals("decimal")) { + return "PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector"; } throw new Exception("Unimplemented primitive output inspector: " + primitiveType); } diff --git common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java index 2e0f058..a5bf24b 100644 --- common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java +++ common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java @@ -17,8 +17,11 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.nio.IntBuffer; +import org.apache.hive.common.util.Decimal128FastBuffer; + /** * This code was originally written for Microsoft PolyBase. *

@@ -64,14 +67,18 @@ /** Minimum value for #scale. */ public static final short MIN_SCALE = 0; + + public static final Decimal128 ONE = new Decimal128().update(1); /** Maximum value that can be represented in this class. */ public static final Decimal128 MAX_VALUE = new Decimal128( - UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, false); + UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, false).subtractDestructive( + Decimal128.ONE, (short) 0); /** Minimum value that can be represented in this class. */ public static final Decimal128 MIN_VALUE = new Decimal128( - UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, true); + UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, true).addDestructive( + Decimal128.ONE, (short) 0); /** For Serializable. */ private static final long serialVersionUID = 1L; @@ -234,9 +241,10 @@ public Decimal128(char[] str, int offset, int length, short scale) { } /** Reset the value of this object to zero. */ - public void zeroClear() { + public Decimal128 zeroClear() { this.unscaledValue.zeroClear(); this.signum = 0; + return this; } /** @return whether this value represents zero. */ @@ -252,10 +260,11 @@ public boolean isZero() { * @param o * object to copy from */ - public void update(Decimal128 o) { + public Decimal128 update(Decimal128 o) { this.unscaledValue.update(o.unscaledValue); this.scale = o.scale; this.signum = o.signum; + return this; } /** @@ -265,8 +274,8 @@ public void update(Decimal128 o) { * @param val * {@code long} value to be set to {@code Decimal128}. */ - public void update(long val) { - update(val, (short) 0); + public Decimal128 update(long val) { + return update(val, (short) 0); } /** @@ -278,7 +287,7 @@ public void update(long val) { * @param scale * scale of the {@code Decimal128}. */ - public void update(long val, short scale) { + public Decimal128 update(long val, short scale) { this.scale = 0; if (val < 0L) { this.unscaledValue.update(-val); @@ -293,6 +302,7 @@ public void update(long val, short scale) { if (scale != 0) { changeScaleDestructive(scale); } + return this; } /** @@ -311,7 +321,7 @@ public void update(long val, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(double val, short scale) { + public Decimal128 update(double val, short scale) { if (Double.isInfinite(val) || Double.isNaN(val)) { throw new NumberFormatException("Infinite or NaN"); } @@ -331,7 +341,7 @@ public void update(double val, short scale) { // zero check if (significand == 0) { zeroClear(); - return; + return this; } this.signum = sign; @@ -389,6 +399,7 @@ public void update(double val, short scale) { (short) (twoScaleDown - scale)); } } + return this; } /** @@ -400,12 +411,13 @@ public void update(double val, short scale) { * @param precision * 0 to 38. Decimal digits. */ - public void update(IntBuffer buf, int precision) { + public Decimal128 update(IntBuffer buf, int precision) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update(buf, precision); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -415,12 +427,13 @@ public void update(IntBuffer buf, int precision) { * @param buf * ByteBuffer to read values from */ - public void update128(IntBuffer buf) { + public Decimal128 update128(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update128(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -430,12 +443,13 @@ public void update128(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update96(IntBuffer buf) { + public Decimal128 update96(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update96(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -445,12 +459,13 @@ public void update96(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update64(IntBuffer buf) { + public Decimal128 update64(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update64(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -460,12 +475,13 @@ public void update64(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update32(IntBuffer buf) { + public Decimal128 update32(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update32(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -479,11 +495,12 @@ public void update32(IntBuffer buf) { * @param precision * 0 to 38. Decimal digits. */ - public void update(int[] array, int offset, int precision) { + public Decimal128 update(int[] array, int offset, int precision) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update(array, offset + 1, precision); + return this; } /** @@ -495,11 +512,12 @@ public void update(int[] array, int offset, int precision) { * @param offset * offset of the int array */ - public void update128(int[] array, int offset) { + public Decimal128 update128(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update128(array, offset + 1); + return this; } /** @@ -511,11 +529,12 @@ public void update128(int[] array, int offset) { * @param offset * offset of the int array */ - public void update96(int[] array, int offset) { + public Decimal128 update96(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update96(array, offset + 1); + return this; } /** @@ -527,11 +546,12 @@ public void update96(int[] array, int offset) { * @param offset * offset of the int array */ - public void update64(int[] array, int offset) { + public Decimal128 update64(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update64(array, offset + 1); + return this; } /** @@ -543,12 +563,22 @@ public void update64(int[] array, int offset) { * @param offset * offset of the int array */ - public void update32(int[] array, int offset) { + public Decimal128 update32(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update32(array, offset + 1); + return this; } + + /** + * Updates the value of this object with the given {@link BigDecimal}. + * @param bigDecimal + * {@link java.math.BigDecimal} + */ + public Decimal128 update(BigDecimal bigDecimal) { + return update(bigDecimal.unscaledValue(), (short) bigDecimal.scale()); + } /** * Updates the value of this object with the given {@link BigInteger} and scale. @@ -557,7 +587,7 @@ public void update32(int[] array, int offset) { * {@link java.math.BigInteger} * @param scale */ - public void update(BigInteger bigInt, short scale) { + public Decimal128 update(BigInteger bigInt, short scale) { this.scale = scale; this.signum = (byte) bigInt.compareTo(BigInteger.ZERO); if (signum == 0) { @@ -567,6 +597,7 @@ public void update(BigInteger bigInt, short scale) { } else { unscaledValue.update(bigInt); } + return this; } /** @@ -577,8 +608,8 @@ public void update(BigInteger bigInt, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(String str, short scale) { - update(str.toCharArray(), 0, str.length(), scale); + public Decimal128 update(String str, short scale) { + return update(str.toCharArray(), 0, str.length(), scale); } /** @@ -594,7 +625,7 @@ public void update(String str, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(char[] str, int offset, int length, short scale) { + public Decimal128 update(char[] str, int offset, int length, short scale) { final int end = offset + length; assert (end <= str.length); int cursor = offset; @@ -616,7 +647,7 @@ public void update(char[] str, int offset, int length, short scale) { this.scale = scale; zeroClear(); if (cursor == end) { - return; + return this; } // "1234567" => unscaledValue=1234567, negative=false, @@ -695,9 +726,20 @@ public void update(char[] str, int offset, int length, short scale) { this.unscaledValue.scaleDownTenDestructive((short) -scaleAdjust); } this.signum = (byte) (this.unscaledValue.isZero() ? 0 : (negative ? -1 : 1)); + return this; } + /** + * Serializes the value in a format compatible with the BigDecimal's own representation + * @param bytes + * @param offset + */ + public int fastSerializeForHiveDecimal( Decimal128FastBuffer scratch) { + return this.unscaledValue.fastSerializeForHiveDecimal(scratch, this.signum); + } + + /** * Serialize this object to the given array, putting the required number of * ints for the given precision. * @@ -914,15 +956,15 @@ public static void add(Decimal128 left, Decimal128 right, Decimal128 result, * @param scale * scale of the result. must be 0 or positive. */ - public void addDestructive(Decimal128 right, short scale) { + public Decimal128 addDestructive(Decimal128 right, short scale) { this.changeScaleDestructive(scale); if (right.signum == 0) { - return; + return this; } if (this.signum == 0) { this.update(right); this.changeScaleDestructive(scale); - return; + return this; } short rightScaleTen = (short) (scale - right.scale); @@ -946,6 +988,7 @@ public void addDestructive(Decimal128 right, short scale) { } this.unscaledValue.throwIfExceedsTenToThirtyEight(); + return this; } /** @@ -986,16 +1029,16 @@ public static void subtract(Decimal128 left, Decimal128 right, * @param scale * scale of the result. must be 0 or positive. */ - public void subtractDestructive(Decimal128 right, short scale) { + public Decimal128 subtractDestructive(Decimal128 right, short scale) { this.changeScaleDestructive(scale); if (right.signum == 0) { - return; + return this; } if (this.signum == 0) { this.update(right); this.changeScaleDestructive(scale); this.negateDestructive(); - return; + return this; } short rightScaleTen = (short) (scale - right.scale); @@ -1020,6 +1063,7 @@ public void subtractDestructive(Decimal128 right, short scale) { } this.unscaledValue.throwIfExceedsTenToThirtyEight(); + return this; } /** @@ -1759,4 +1803,44 @@ public void zeroFractionPart() { */ this.getUnscaledValue().scaleUpTenDestructive(placesToRemove); } + + /** + * Multiplies this with this, updating this + * + * @return self + */ + public Decimal128 squareDestructive() { + this.multiplyDestructive(this, this.getScale()); + return this; + } + + + /** + * For UDAF variance we use the algorithm described by Chan, Golub, and LeVeque in + * "Algorithms for computing the sample variance: analysis and recommendations" + * The American Statistician, 37 (1983) pp. 242--247. + * + * variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2),2) + * + * where: - variance is sum[x-avg^2] (this is actually n times the variance) + * and is updated at every step. - n is the count of elements in chunk1 - m is + * the count of elements in chunk2 - t1 = sum of elements in chunk1, t2 = + * sum of elements in chunk2. + * + * This is a helper function doing the intermediate computation: + * t = myagg.count*value - myagg.sum; + * myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + * + * @return self + */ + public Decimal128 updateVarianceDestructive( + Decimal128 scratch, Decimal128 value, Decimal128 sum, long count) { + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + scratch.squareDestructive(); + scratch.unscaledValue.divideDestructive(count * (count-1)); + this.addDestructive(scratch, getScale()); + return this; + } } diff --git common/src/java/org/apache/hadoop/hive/common/type/HiveDecimal.java common/src/java/org/apache/hadoop/hive/common/type/HiveDecimal.java index 29c5168..8101be9 100644 --- common/src/java/org/apache/hadoop/hive/common/type/HiveDecimal.java +++ common/src/java/org/apache/hadoop/hive/common/type/HiveDecimal.java @@ -274,4 +274,12 @@ public static BigDecimal enforcePrecisionScale(BigDecimal bd, int maxPrecision, return bd; } + /** + * Sets the {@link BigDecimal} value in this object. + * @param bigDecimal + */ + public void setNormalize(BigDecimal bigDecimal) { + BigDecimal value = normalize(bigDecimal, true); + this.bd = value; + } } diff --git common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java index fb3c346..28daa21 100644 --- common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java +++ common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java @@ -16,9 +16,12 @@ package org.apache.hadoop.hive.common.type; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.nio.IntBuffer; import java.util.Arrays; +import org.apache.hive.common.util.Decimal128FastBuffer; + /** * This code was originally written for Microsoft PolyBase. * @@ -1373,6 +1376,31 @@ public int divideDestructive(int right) { } /** + * Divides this value with the given value. This version is destructive, + * meaning it modifies this object. + * + * @param right + * the value to divide + * @return remainder + */ + public long divideDestructive(long right) { + assert (right >= 0); + + long quotient; + long remainder = 0; + + for (int i = INT_COUNT - 1; i >= 0; --i) { + remainder = ((this.v[i] & SqlMathUtil.LONG_MASK) + (remainder << 32)); + quotient = remainder / right; + remainder %= right; + this.v[i] = (int) quotient; + } + updateCount(); + return remainder; + } + + + /** * Right-shift for the given number of bits. This version is destructive, * meaning it modifies this object. * @@ -2405,4 +2433,60 @@ private void updateCount() { this.count = (byte) 0; } } + + /** + * Serializes one int part into the given @{link ByteBuffer} + * considering two's complement for negatives + * @param buf + * @param pos + * @param value + * @param signum + */ + private static void fastSerializeIntPartForHiveDecimal(ByteBuffer buf, + int pos, int value, byte signum, long c2) { + if ((signum == 1) ) { + // Because of the extra 0x00 at the beginning, the value will be read as positive + buf.putInt(pos, value); + } + else if (signum == -1) { + // For negatives there will be an 0xff in the beginning + // and this part must be written in two's complement + // the special negative 0x80000000 has no compliment... + if (value != 0x80000000) { + value = (int)(c2 - ((long)value & 0xFFFFFFFF)); + } + buf.putInt(pos, value); + } + } + + /** + * Serializes this value into the format used by @{link java.math.BigInteger} + * This is used for fast assignment of a Decimal128 to a HiveDecimalWritable internal storage. + * @param scratch + * @param signum + * @return + */ + public int fastSerializeForHiveDecimal(Decimal128FastBuffer scratch, byte signum) { + int bufferUsed = this.count; + ByteBuffer buf = scratch.getByteBuffer(bufferUsed); + buf.put(0, (byte) (signum == 1 ? 0 : signum)); + int pos = 1; + switch(this.count) { + case 4: + fastSerializeIntPartForHiveDecimal(buf, pos, v[3], signum, 0xFFFFFFFFL); + pos+=4; + // intentional fall through + case 3: + fastSerializeIntPartForHiveDecimal(buf, pos, v[2], signum, 0xFFFFFFFFL); + pos+=4; + // intentional fall through + case 2: + fastSerializeIntPartForHiveDecimal(buf, pos, v[1], signum, 0xFFFFFFFFL); + pos+=4; + // intentional fall through + case 1: + fastSerializeIntPartForHiveDecimal(buf, pos, v[0], signum, 0x100000000L); + } + return bufferUsed; + } } diff --git common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java new file mode 100644 index 0000000..72c5f53 --- /dev/null +++ common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java @@ -0,0 +1,51 @@ +/** + * + */ +package org.apache.hive.common.util; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A helper class for fast serialization of decimal128 in the BigDecimal byte[] representation + * + */ +public class Decimal128FastBuffer { + + /** + * Preallocated byte[] for each Decimal128 size (0-4) + */ + private final byte[][] sumBytes; + + /** + * Preallocated ByteBuffer wrappers around each sumBytes + */ + private final ByteBuffer[] sumBuffer; + + public Decimal128FastBuffer() { + sumBytes = new byte[5][]; + sumBuffer = new ByteBuffer[5]; + sumBytes[0] = new byte[1]; + sumBuffer[0] = ByteBuffer.wrap(sumBytes[0]); + sumBytes[1] = new byte[5]; + sumBuffer[1] = ByteBuffer.wrap(sumBytes[1]); + sumBuffer[1].order(ByteOrder.BIG_ENDIAN); + sumBytes[2] = new byte[9]; + sumBuffer[2] = ByteBuffer.wrap(sumBytes[2]); + sumBuffer[2].order(ByteOrder.BIG_ENDIAN); + sumBytes[3] = new byte[13]; + sumBuffer[3] = ByteBuffer.wrap(sumBytes[3]); + sumBuffer[3].order(ByteOrder.BIG_ENDIAN); + sumBytes[4] = new byte[17]; + sumBuffer[4] = ByteBuffer.wrap(sumBytes[4]); + sumBuffer[4].order(ByteOrder.BIG_ENDIAN); + } + + public ByteBuffer getByteBuffer(int index) { + return sumBuffer[index]; + } + + public byte[] getBytes(int index) { + return sumBytes[index]; + } +} diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt new file mode 100644 index 0000000..f5ca641 --- /dev/null +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt @@ -0,0 +1,446 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; + +/** +* . Vectorized implementation for MIN/MAX aggregates. +*/ +@Description(name = "", + value = "") +public class extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** + * class for storing the current aggregate value. + */ + static private final class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 value; + transient private boolean isNull; + + public Aggregation() { + value = new Decimal128(); + } + + public void checkValue(Decimal128 value) { + if (isNull) { + isNull = false; + this.value.update(value); + } else if (this.value.compareTo(value) 0) { + this.value.update(value); + } + } + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + } + + private VectorExpression inputExpression; + private transient VectorExpressionWriter resultWriter; + + public (VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public () { + super(); + } + + @Override + public void init(AggregationDesc desc) throws HiveException { + resultWriter = VectorExpressionWriterFactory.genVectorExpressionWritable( + desc.getParameters().get(0)); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregrateIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + Decimal128[] vector = inputVector.vector; + + if (inputVector.noNulls) { + if (inputVector.isRepeating) { + iterateNoNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize); + } else { + if (batch.selectedInUse) { + iterateNoNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batch.selected, batchSize); + } else { + iterateNoNullsWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize); + } + } + } else { + if (inputVector.isRepeating) { + if (batch.selectedInUse) { + iterateHasNullsRepeatingSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize, inputVector.isNull); + } + } else { + if (batch.selectedInUse) { + iterateHasNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize, inputVector.isNull); + } + } + } + } + + private void iterateNoNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + + private void iterateNoNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int[] selection, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[selection[i]]); + } + } + + private void iterateNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize) { + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[i]); + } + } + + private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[selection[i]]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + + } + + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + } + + private void iterateHasNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int j=0; j < batchSize; ++j) { + int i = selection[j]; + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + j); + myagg.checkValue(values[i]); + } + } + } + + private void iterateHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[i]); + } + } + } + + @Override + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Aggregation myagg = (Aggregation)agg; + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls && + (myagg.isNull || (myagg.value.compareTo(vector[0]) 0))) { + myagg.isNull = false; + myagg.value.update(vector[0]); + } + return; + } + + if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNulls(myagg, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + } + } + + private void iterateSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull, + int[] selected) { + + for (int j=0; j< batchSize; ++j) { + int i = selected[j]; + if (!isNull[i]) { + Decimal128 value = vector[i]; + if (myagg.isNull) { + myagg.isNull = false; + myagg.value.update(value); + } + else if (myagg.value.compareTo(value) 0) { + myagg.value.update(value); + } + } + } + } + + private void iterateSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + int[] selected) { + + if (myagg.isNull) { + myagg.value.update(vector[selected[0]]); + myagg.isNull = false; + } + + for (int i=0; i< batchSize; ++i) { + Decimal128 value = vector[selected[i]]; + if (myagg.value.compareTo(value) 0) { + myagg.value.update(value); + } + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i 0) { + myagg.value.update(value); + } + } + } + } + + private void iterateNoSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize) { + if (myagg.isNull) { + myagg.value.update(vector[0]); + myagg.isNull = false; + } + + for (int i=0;i 0) { + myagg.value.update(value); + } + } + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + return new Aggregation(); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + Aggregation myAgg = (Aggregation) agg; + myAgg.isNull = true; + } + + @Override + public Object evaluateOutput( + AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + if (myagg.isNull) { + return null; + } + else { + return resultWriter.writeValue(myagg.value); + } + } + + @Override + public ObjectInspector getOutputObjectInspector() { + return resultWriter.getObjectInspector(); + } + + @Override + public int getAggregationBufferFixedSize() { + JavaDataModel model = JavaDataModel.get(); + return JavaDataModel.alignUp( + model.object() + + model.primitive2(), + model.memoryAlign()); + } + + public VectorExpression getInputExpression() { + return inputExpression; + } + + public void setInputExpression(VectorExpression inputExpression) { + this.inputExpression = inputExpression; + } +} + diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt new file mode 100644 index 0000000..995d5c9 --- /dev/null +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt @@ -0,0 +1,468 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** +* . Vectorized implementation for VARIANCE aggregates. +*/ +@Description(name = "", + value = "") +public class extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** + /* class for storing the current aggregate value. + */ + private static final class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 sum; + transient private long count; + transient private double variance; + transient private boolean isNull; + + public Aggregation() { + sum = new Decimal128(); + } + + public void init() { + isNull = false; + sum.zeroClear(); + count = 0; + variance = 0f; + } + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + + public void updateValueWithCheckAndInit(Decimal128 scratch, Decimal128 value) { + if (this.isNull) { + this.init(); + } + this.sum.addDestructive(value, value.getScale()); + this.count += 1; + if(this.count > 1) { + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + double t = scratch.doubleValue(); + this.variance += (t*t) / ((double)this.count*(this.count-1)); + } + } + + public void updateValueNoCheck(Decimal128 scratch, Decimal128 value) { + this.sum.addDestructive(value, value.getScale()); + this.count += 1; + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + double t = scratch.doubleValue(); + this.variance += (t*t) / ((double)this.count*(this.count-1)); + } + + } + + private VectorExpression inputExpression; + transient private LongWritable resultCount; + transient private DoubleWritable resultSum; + transient private DoubleWritable resultVariance; + transient private Object[] partialResult; + + transient private ObjectInspector soi; + + transient private final Decimal128 scratchDecimal; + + + public (VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public () { + super(); + partialResult = new Object[3]; + resultCount = new LongWritable(); + resultSum = new DoubleWritable(); + resultVariance = new DoubleWritable(); + partialResult[0] = resultCount; + partialResult[1] = resultSum; + partialResult[2] = resultVariance; + initPartialResultInspector(); + scratchDecimal = new Decimal128(); + } + + private void initPartialResultInspector() { + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + List fname = new ArrayList(); + fname.add("count"); + fname.add("sum"); + fname.add("variance"); + + soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); + return myagg; + } + + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + VectorizedRowBatch batch) throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls || !inputVector.isNull[0]) { + iterateRepeatingNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector[0], batchSize); + } + } + else if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, + inputVector.isNull, batch.selected); + } + + } + + private void iterateRepeatingNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i 1 check in the loop + for (int i=1; i 1 check in the loop + // + for (int i=1; i< batchSize; ++i) { + value = vector[selected[i]]; + myagg.updateValueNoCheck(scratchDecimal, value); + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i 1 check + for (int i=1; i= 0) { + return kw.getIsDecimalNull(klh.decimalIndex)? null : + keyOutputWriter.writeValue( + kw.getDecimal(klh.decimalIndex)); + } + else { throw new HiveException(String.format( "Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d", i, klh.longIndex, klh.doubleIndex, klh.stringIndex)); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index f5ab731..82a6c97 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -40,24 +40,32 @@ import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Mode; import org.apache.hadoop.hive.ql.exec.vector.expressions.*; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFAvgDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCount; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFSumDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastLongToBooleanViaLongToLong; @@ -901,6 +909,10 @@ public static boolean isIntFamily(String resultType) { || resultType.equalsIgnoreCase("long"); } + public static boolean isDecimalFamily(String outputType) { + return outputType.startsWith("decimal"); + } + public static String mapJavaTypeToVectorType(String javaType) throws HiveException { if (isStringFamily(javaType)) { @@ -1023,12 +1035,14 @@ private long evaluateCastToTimestamp(ExprNodeDesc expr) throws HiveException { } } - static String getNormalizedTypeName(String colType) { + static String getNormalizedTypeName(String colType){ String normalizedType = null; if (colType.equalsIgnoreCase("Double") || colType.equalsIgnoreCase("Float")) { normalizedType = "Double"; } else if (colType.equalsIgnoreCase("String")) { normalizedType = "String"; + } else if (colType.toLowerCase().startsWith("decimal")) { + normalizedType = "Decimal"; } else { normalizedType = "Long"; } @@ -1039,31 +1053,43 @@ static String getNormalizedTypeName(String colType) { {"min", "Long", VectorUDAFMinLong.class}, {"min", "Double", VectorUDAFMinDouble.class}, {"min", "String", VectorUDAFMinString.class}, + {"min", "Decimal",VectorUDAFMinDecimal.class}, {"max", "Long", VectorUDAFMaxLong.class}, {"max", "Double", VectorUDAFMaxDouble.class}, {"max", "String", VectorUDAFMaxString.class}, + {"max", "Decimal",VectorUDAFMaxDecimal.class}, {"count", null, VectorUDAFCountStar.class}, {"count", "Long", VectorUDAFCount.class}, {"count", "Double", VectorUDAFCount.class}, {"count", "String", VectorUDAFCount.class}, + {"count", "Decimal",VectorUDAFCount.class}, {"sum", "Long", VectorUDAFSumLong.class}, {"sum", "Double", VectorUDAFSumDouble.class}, + {"sum", "Decimal",VectorUDAFSumDecimal.class}, {"avg", "Long", VectorUDAFAvgLong.class}, {"avg", "Double", VectorUDAFAvgDouble.class}, + {"avg", "Decimal",VectorUDAFAvgDecimal.class}, {"variance", "Long", VectorUDAFVarPopLong.class}, {"var_pop", "Long", VectorUDAFVarPopLong.class}, {"variance", "Double", VectorUDAFVarPopDouble.class}, {"var_pop", "Double", VectorUDAFVarPopDouble.class}, + {"variance", "Decimal",VectorUDAFVarPopDecimal.class}, + {"var_pop", "Decimal",VectorUDAFVarPopDecimal.class}, {"var_samp", "Long", VectorUDAFVarSampLong.class}, {"var_samp" , "Double", VectorUDAFVarSampDouble.class}, + {"var_samp" , "Decimal",VectorUDAFVarSampDecimal.class}, {"std", "Long", VectorUDAFStdPopLong.class}, {"stddev", "Long", VectorUDAFStdPopLong.class}, {"stddev_pop","Long", VectorUDAFStdPopLong.class}, {"std", "Double", VectorUDAFStdPopDouble.class}, {"stddev", "Double", VectorUDAFStdPopDouble.class}, {"stddev_pop","Double", VectorUDAFStdPopDouble.class}, + {"std", "Decimal",VectorUDAFStdPopDecimal.class}, + {"stddev", "Decimal",VectorUDAFStdPopDecimal.class}, + {"stddev_pop","Decimal",VectorUDAFStdPopDecimal.class}, {"stddev_samp","Long", VectorUDAFStdSampLong.class}, {"stddev_samp","Double",VectorUDAFStdSampDouble.class}, + {"stddev_samp","Decimal",VectorUDAFStdSampDecimal.class}, }; public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java index f513188..6e79979 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java @@ -41,6 +41,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.FileSplit; @@ -233,7 +234,8 @@ public VectorizedRowBatch createVectorizedRowBatch() throws HiveException case PRIMITIVE: { PrimitiveObjectInspector poi = (PrimitiveObjectInspector) foi; // Vectorization currently only supports the following data types: - // BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING and TIMESTAMP + // BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING, TIMESTAMP, + // DATE and DECIMAL switch (poi.getPrimitiveCategory()) { case BOOLEAN: case BYTE: @@ -241,6 +243,7 @@ public VectorizedRowBatch createVectorizedRowBatch() throws HiveException case INT: case LONG: case TIMESTAMP: + case DATE: result.cols[j] = new LongColumnVector(VectorizedRowBatch.DEFAULT_SIZE); break; case FLOAT: @@ -250,6 +253,11 @@ public VectorizedRowBatch createVectorizedRowBatch() throws HiveException case STRING: result.cols[j] = new BytesColumnVector(VectorizedRowBatch.DEFAULT_SIZE); break; + case DECIMAL: + DecimalTypeInfo tInfo = (DecimalTypeInfo) poi.getTypeInfo(); + result.cols[j] = new DecimalColumnVector(VectorizedRowBatch.DEFAULT_SIZE, + tInfo.precision(), tInfo.scale()); + break; default: throw new RuntimeException("Vectorizaton is not supported for datatype:" + poi.getPrimitiveCategory()); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java index e5c3aa4..be5cea8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriter.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.exec.vector.expressions; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -33,6 +34,7 @@ Object writeValue(long value) throws HiveException; Object writeValue(double value) throws HiveException; Object writeValue(byte[] value, int start, int length) throws HiveException; + Object writeValue(Decimal128 value) throws HiveException; Object setValue(Object row, ColumnVector column, int columnRow) throws HiveException; Object initValue(Object ost) throws HiveException; } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java index a242fef..868f13e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpressionWriterFactory.java @@ -18,6 +18,8 @@ package org.apache.hadoop.hive.ql.exec.vector.expressions; +import java.math.BigDecimal; +import java.sql.Date; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; @@ -25,18 +27,13 @@ import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.common.type.HiveVarchar; -import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; +import org.apache.hadoop.hive.ql.exec.vector.*; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; -import org.apache.hadoop.hive.serde2.io.ByteWritable; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.io.ShortWritable; -import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.io.*; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -143,6 +140,21 @@ public Object writeValue(byte[] value, int start, int length) throws HiveExcepti public Object setValue(Object field, byte[] value, int start, int length) throws HiveException { throw new HiveException("Internal error: should not reach here"); } + + /** + * The base implementation must be overridden by the Decimal specialization + */ + @Override + public Object writeValue(Decimal128 value) throws HiveException { + throw new HiveException("Internal error: should not reach here"); + } + + /** + * The base implementation must be overridden by the Decimal specialization + */ + public Object setValue(Object field, Decimal128 value) throws HiveException { + throw new HiveException("Internal error: should not reach here"); + } } /** @@ -272,7 +284,7 @@ public Object writeValue(ColumnVector column, int row) throws HiveException { "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", row, bcv.noNulls, bcv.isRepeating, bcv.isNull[row], bcv.isNull[0])); } - + @Override public Object setValue(Object field, ColumnVector column, int row) throws HiveException { BytesColumnVector bcv = (BytesColumnVector) column; @@ -294,7 +306,58 @@ public Object setValue(Object field, ColumnVector column, int row) throws HiveEx "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", row, bcv.noNulls, bcv.isRepeating, bcv.isNull[row], bcv.isNull[0])); } - } + } + + + /** + * Specialized writer for DecimalColumnVector. Will throw cast exception + * if the wrong vector column is used. + */ + private static abstract class VectorExpressionWriterDecimal extends VectorExpressionWriterBase { + @Override + public Object writeValue(ColumnVector column, int row) throws HiveException { + DecimalColumnVector dcv = (DecimalColumnVector) column; + if (dcv.noNulls && !dcv.isRepeating) { + return writeValue(dcv.vector[row]); + } else if (dcv.noNulls && dcv.isRepeating) { + return writeValue(dcv.vector[0]); + } else if (!dcv.noNulls && !dcv.isRepeating && !dcv.isNull[row]) { + return writeValue(dcv.vector[row]); + } else if (!dcv.noNulls && dcv.isRepeating && !dcv.isNull[0]) { + return writeValue(dcv.vector[0]); + } else if (!dcv.noNulls && dcv.isRepeating && dcv.isNull[0]) { + return null; + } else if (!dcv.noNulls && !dcv.isRepeating && dcv.isNull[row]) { + return null; + } + throw new HiveException( + String.format( + "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", + row, dcv.noNulls, dcv.isRepeating, dcv.isNull[row], dcv.isNull[0])); + } + + @Override + public Object setValue(Object field, ColumnVector column, int row) throws HiveException { + DecimalColumnVector dcv = (DecimalColumnVector) column; + if (dcv.noNulls && !dcv.isRepeating) { + return setValue(field, dcv.vector[row]); + } else if (dcv.noNulls && dcv.isRepeating) { + return setValue(field, dcv.vector[0]); + } else if (!dcv.noNulls && !dcv.isRepeating && !dcv.isNull[row]) { + return setValue(field, dcv.vector[row]); + } else if (!dcv.noNulls && !dcv.isRepeating && dcv.isNull[row]) { + return null; + } else if (!dcv.noNulls && dcv.isRepeating && !dcv.isNull[0]) { + return setValue(field, dcv.vector[0]); + } else if (!dcv.noNulls && dcv.isRepeating && dcv.isNull[0]) { + return null; + } + throw new HiveException( + String.format( + "Incorrect null/repeating: row:%d noNulls:%b isRepeating:%b isNull[row]:%b isNull[0]:%b", + row, dcv.noNulls, dcv.isRepeating, dcv.isNull[row], dcv.isNull[0])); + } + } /** * Compiles the appropriate vector expression writer based on an expression info (ExprNodeDesc) @@ -381,17 +444,78 @@ public static VectorExpressionWriter genVectorExpressionWritable( } private static VectorExpressionWriter genVectorExpressionWritableDecimal( - SettableHiveDecimalObjectInspector fieldObjInspector) throws HiveException { - - // We should never reach this, the compile validation should guard us - throw new HiveException("DECIMAL primitive type not supported in vectorization."); - } + SettableHiveDecimalObjectInspector fieldObjInspector) throws HiveException { + + return new VectorExpressionWriterDecimal() { + private HiveDecimal hd; + private Object obj; + + public VectorExpressionWriter init(SettableHiveDecimalObjectInspector objInspector) throws HiveException { + super.init(objInspector); + hd = HiveDecimal.create(BigDecimal.ZERO); + obj = initValue(null); + return this; + } + + @Override + public Object writeValue(Decimal128 value) throws HiveException { + hd.setNormalize(value.toBigDecimal()); + ((SettableHiveDecimalObjectInspector) this.objectInspector).set(obj, hd); + return obj; + } + + @Override + public Object setValue(Object field, Decimal128 value) { + hd.setNormalize(value.toBigDecimal()); + ((SettableHiveDecimalObjectInspector) this.objectInspector).set(field, hd); + return field; + } + + @Override + public Object initValue(Object ignored) throws HiveException { + return ((SettableHiveDecimalObjectInspector) this.objectInspector).create( + HiveDecimal.create(BigDecimal.ZERO)); + } + }.init(fieldObjInspector); + } private static VectorExpressionWriter genVectorExpressionWritableDate( - SettableDateObjectInspector fieldObjInspector) throws HiveException { - // We should never reach this, the compile validation should guard us - throw new HiveException("DATE primitive type not supported in vectorization."); - } + SettableDateObjectInspector fieldObjInspector) throws HiveException { + return new VectorExpressionWriterLong() { + private Date dt; + private Object obj; + + public VectorExpressionWriter init(SettableDateObjectInspector objInspector) throws HiveException { + super.init(objInspector); + dt = new Date(0); + obj = initValue(null); + return this; + } + + @Override + public Object writeValue(long value) { + dt.setTime(DateWritable.daysToMillis((int) value)); + ((SettableDateObjectInspector) this.objectInspector).set(obj, dt); + return obj; + } + + @Override + public Object setValue(Object field, long value) { + if (null == field) { + field = initValue(null); + } + dt.setTime(DateWritable.daysToMillis((int) value)); + ((SettableDateObjectInspector) this.objectInspector).set(field, dt); + return field; + } + + @Override + public Object initValue(Object ignored) { + return ((SettableDateObjectInspector) this.objectInspector).create(new Date(0)); + } + + }.init(fieldObjInspector); + } private static VectorExpressionWriter genVectorExpressionWritableTimestamp( SettableTimestampObjectInspector fieldObjInspector) throws HiveException { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java new file mode 100644 index 0000000..09f71f6 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java @@ -0,0 +1,520 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.DecimalUtil; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hive.common.util.Decimal128FastBuffer; + +/** + * Generated from template VectorUDAFAvg.txt. + */ +@Description(name = "avg", + value = "_FUNC_(AVG) - Returns the average value of expr (vectorized, type: decimal)") +public class VectorUDAFAvgDecimal extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** class for storing the current aggregate value. */ + static class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 sum = new Decimal128(); + transient private long count; + transient private boolean isNull; + + public void sumValueWithCheck(Decimal128 value, short scale) { + if (isNull) { + sum.update(value); + sum.changeScaleDestructive(scale); + count = 1; + isNull = false; + } else { + sum.addDestructive(value, scale); + count++; + } + } + + public void sumValueNoCheck(Decimal128 value, short scale) { + sum.addDestructive(value, scale); + count++; + } + + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + } + + private VectorExpression inputExpression; + transient private Object[] partialResult; + transient private LongWritable resultCount; + transient private HiveDecimalWritable resultSum; + transient private StructObjectInspector soi; + + transient private final Decimal128FastBuffer scratch; + + /** + * The scale of the SUM in the partial output + */ + private short sumScale; + + /** + * The precision of the SUM in the partial output + */ + private short sumPrecision; + + /** + * the scale of the input expression + */ + private short inputScale; + + /** + * the precision of the input expression + */ + private short inputPrecision; + + /** + * A value used as scratch to avoid allocating at runtime. + * Needed by computations like vector[0] * batchSize + */ + transient private Decimal128 scratchDecimal = new Decimal128(); + + public VectorUDAFAvgDecimal(VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public VectorUDAFAvgDecimal() { + super(); + partialResult = new Object[2]; + resultCount = new LongWritable(); + resultSum = new HiveDecimalWritable(); + partialResult[0] = resultCount; + partialResult[1] = resultSum; + scratch = new Decimal128FastBuffer(); + + } + + private void initPartialResultInspector() { + // the output type of the vectorized partial aggregate must match the + // expected type for the row-mode aggregation + // For decimal, the type is "same number of integer digits and 4 more decimal digits" + + DecimalTypeInfo dtiSum = GenericUDAFAverage.deriveSumTypeInfo(inputScale, inputPrecision); + this.sumScale = (short) dtiSum.scale(); + this.sumPrecision = (short) dtiSum.precision(); + + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(dtiSum)); + List fname = new ArrayList(); + fname.add("count"); + fname.add("sum"); + soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(bufferIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = ( DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + Decimal128[] vector = inputVector.vector; + + if (inputVector.noNulls) { + if (inputVector.isRepeating) { + iterateNoNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize); + } else { + if (batch.selectedInUse) { + iterateNoNullsSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batch.selected, batchSize); + } else { + iterateNoNullsWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize); + } + } + } else { + if (inputVector.isRepeating) { + if (batch.selectedInUse) { + iterateHasNullsRepeatingSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, inputVector.isNull); + } + } else { + if (batch.selectedInUse) { + iterateHasNullsSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize, inputVector.isNull); + } + } + } + } + + private void iterateNoNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + + private void iterateNoNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int[] selection, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[selection[i]], this.sumScale); + } + } + + private void iterateNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize) { + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + + private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[selection[i]]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + + } + + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + } + + private void iterateHasNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int j=0; j < batchSize; ++j) { + int i = selection[j]; + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + j); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + } + + private void iterateHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + } + + + @Override + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = + (DecimalColumnVector)batch.cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Aggregation myagg = (Aggregation)agg; + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls) { + if (myagg.isNull) { + myagg.isNull = false; + myagg.sum.zeroClear(); + myagg.count = 0; + } + scratchDecimal.update(batchSize); + scratchDecimal.multiplyDestructive(vector[0], vector[0].getScale()); + myagg.sum.update(scratchDecimal); + myagg.count += batchSize; + } + return; + } + + if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNulls(myagg, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + } + } + + private void iterateSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull, + int[] selected) { + + for (int j=0; j< batchSize; ++j) { + int i = selected[j]; + if (!isNull[i]) { + Decimal128 value = vector[i]; + myagg.sumValueWithCheck(value, this.sumScale); + } + } + } + + private void iterateSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + int[] selected) { + + if (myagg.isNull) { + myagg.isNull = false; + myagg.sum.zeroClear(); + myagg.count = 0; + } + + for (int i=0; i< batchSize; ++i) { + Decimal128 value = vector[selected[i]]; + myagg.sumValueNoCheck(value, this.sumScale); + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i supportedDataTypes = new HashSet(); + Pattern supportedDataTypesPattern; List> vectorizableTasks = new ArrayList>(); Set> supportedGenericUDFs = new HashSet>(); @@ -175,19 +176,25 @@ private PhysicalContext physicalContext = null;; public Vectorizer() { - supportedDataTypes.add("int"); - supportedDataTypes.add("smallint"); - supportedDataTypes.add("tinyint"); - supportedDataTypes.add("bigint"); - supportedDataTypes.add("integer"); - supportedDataTypes.add("long"); - supportedDataTypes.add("short"); - supportedDataTypes.add("timestamp"); - supportedDataTypes.add("boolean"); - supportedDataTypes.add("string"); - supportedDataTypes.add("byte"); - supportedDataTypes.add("float"); - supportedDataTypes.add("double"); + + StringBuilder patternBuilder = new StringBuilder(); + patternBuilder.append("int"); + patternBuilder.append("|smallint"); + patternBuilder.append("|tinyint"); + patternBuilder.append("|bigint"); + patternBuilder.append("|integer"); + patternBuilder.append("|long"); + patternBuilder.append("|short"); + patternBuilder.append("|timestamp"); + patternBuilder.append("|boolean"); + patternBuilder.append("|string"); + patternBuilder.append("|byte"); + patternBuilder.append("|float"); + patternBuilder.append("|double"); + patternBuilder.append("|date"); + patternBuilder.append("|decimal.*"); + + supportedDataTypesPattern = Pattern.compile(patternBuilder.toString()); supportedGenericUDFs.add(GenericUDFOPPlus.class); supportedGenericUDFs.add(GenericUDFOPMinus.class); @@ -747,7 +754,7 @@ private boolean validateAggregationDesc(AggregationDesc aggDesc) { } private boolean validateDataType(String type) { - return supportedDataTypes.contains(type.toLowerCase()); + return supportedDataTypesPattern.matcher(type.toLowerCase()).matches(); } private VectorizationContext getVectorizationContext(Operator op, diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java index 1a00800..72e9d7a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java @@ -174,15 +174,9 @@ protected ObjectInspector getSumFieldWritableObjectInspector() { return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo); } - /** - * The result type has the same number of integer digits and 4 more decimal digits. - */ private DecimalTypeInfo deriveResultDecimalTypeInfo() { if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) { - int scale = inputOI.scale(); - int intPart = inputOI.precision() - scale; - scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart); - return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale); + return GenericUDAFAverage.deriveSumTypeInfo(inputOI.scale(), inputOI.precision()); } else { PrimitiveObjectInspector sfOI = (PrimitiveObjectInspector) sumFieldOI; return (DecimalTypeInfo) sfOI.getTypeInfo(); @@ -367,4 +361,17 @@ public Object terminate(AggregationBuffer aggregation) throws HiveException { return doTerminate((AverageAggregationBuffer)aggregation); } } + + /** + * The result type has the same number of integer digits and 4 more decimal digits + * This is exposed as static so that the vectorized AVG operator use the same logic + * @param scale + * @param precision + * @return + */ + public static DecimalTypeInfo deriveSumTypeInfo(int scale, int precision) { + int intPart = precision - scale; + scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart); + return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale); + } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index a2b45f8..e9f1652 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -26,6 +26,8 @@ import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.reflect.Constructor; +import java.math.BigDecimal; +import java.math.BigInteger; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +38,8 @@ import java.util.Map; import java.util.Set; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureOutputOperator; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromLongIterables; @@ -48,6 +52,7 @@ import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -596,6 +601,178 @@ public void testCountStar() throws HiveException { } @Test + public void testCountDecimal() throws HiveException { + testAggregateDecimal( + "count", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + 3L); + } + + @Test + public void testMaxDecimal() throws HiveException { + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(3)); + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(3), + new Decimal128(2), + new Decimal128(1)}), + new Decimal128(3)); + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(2), + new Decimal128(3), + new Decimal128(1)}), + new Decimal128(3)); + } + + @Test + public void testMinDecimal() throws HiveException { + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(1)); + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(3), + new Decimal128(2), + new Decimal128(1)}), + new Decimal128(1)); + + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(2), + new Decimal128(1), + new Decimal128(3)}), + new Decimal128(1)); + } + + @Test + public void testSumDecimal() throws HiveException { + testAggregateDecimal( + "sum", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(1+2+3)); + } + + @Test + public void testAvgDecimal() throws HiveException { + testAggregateDecimal( + "avg", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + HiveDecimal.create((1+2+3)/3)); + } + + @Test + public void testAvgDecimalNegative() throws HiveException { + testAggregateDecimal( + "avg", + 2, + Arrays.asList(new Object[]{ + new Decimal128(-1), + new Decimal128(-2), + new Decimal128(-3)}), + HiveDecimal.create((-1-2-3)/3)); + } + + @Test + public void testVarianceDecimal () throws HiveException { + testAggregateDecimal( + "variance", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) 30); + } + + @Test + public void testVarSampDecimal () throws HiveException { + testAggregateDecimal( + "var_samp", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) 40); + } + + @Test + public void testStdPopDecimal () throws HiveException { + testAggregateDecimal( + "stddev_pop", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) Math.sqrt(30)); + } + + @Test + public void testStdSampDecimal () throws HiveException { + testAggregateDecimal( + "stddev_samp", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) Math.sqrt(40)); + } + + @Test + public void testDecimalKeyTypeAggregate() throws HiveException { + testKeyTypeAggregate( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"decimal(38,0)", "bigint"}, + Arrays.asList(new Object[]{ + new Decimal128(1),null, + new Decimal128(1), null}), + Arrays.asList(new Object[]{13L,null,7L, 19L})), + buildHashMap(HiveDecimal.create(1), 20L, null, 19L)); + } + + + @Test public void testCountString() throws HiveException { testAggregateString( "count", @@ -1655,6 +1832,9 @@ public void inspectRow(Object row, int tag) throws HiveException { } else if (key instanceof BooleanWritable) { BooleanWritable bwKey = (BooleanWritable)key; keyValue = bwKey.get(); + } else if (key instanceof HiveDecimalWritable) { + HiveDecimalWritable hdwKey = (HiveDecimalWritable)key; + keyValue = hdwKey.getHiveDecimal(); } else { Assert.fail(String.format("Not implemented key output type %s: %s", key.getClass().getName(), key)); @@ -1755,6 +1935,19 @@ public void testAggregateLongKeyAggregate ( testAggregateLongKeyIterable (aggregateName, fdr, expected); } + public void testAggregateDecimal ( + String aggregateName, + int batchSize, + Iterable values, + Object expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, new String[] {"Decimal"}, values); + testAggregateDecimalIterable (aggregateName, fdr, expected); + } + + public void testAggregateString ( String aggregateName, int batchSize, @@ -1832,6 +2025,15 @@ public void validate(String key, Object expected, Object result) { assertEquals (key, (Double) expected, (Double) arr[0]); } else if (arr[0] instanceof Long) { assertEquals (key, (Long) expected, (Long) arr[0]); + } else if (arr[0] instanceof HiveDecimalWritable) { + HiveDecimalWritable hdw = (HiveDecimalWritable) arr[0]; + HiveDecimal hd = hdw.getHiveDecimal(); + Decimal128 d128 = (Decimal128)expected; + assertEquals (key, d128.toBigDecimal(), hd.bigDecimalValue()); + } else if (arr[0] instanceof HiveDecimal) { + HiveDecimal hd = (HiveDecimal) arr[0]; + Decimal128 d128 = (Decimal128)expected; + assertEquals (key, d128.toBigDecimal(), hd.bigDecimalValue()); } else { Assert.fail("Unsupported result type: " + arr[0].getClass().getName()); } @@ -1853,11 +2055,16 @@ public void validate(String key, Object expected, Object result) { assertEquals (2, vals.length); assertEquals (true, vals[0] instanceof LongWritable); - assertEquals (true, vals[1] instanceof DoubleWritable); LongWritable lw = (LongWritable) vals[0]; - DoubleWritable dw = (DoubleWritable) vals[1]; assertFalse (lw.get() == 0L); - assertEquals (key, (Double) expected, (Double) (dw.get() / lw.get())); + + if (vals[1] instanceof DoubleWritable) { + DoubleWritable dw = (DoubleWritable) vals[1]; + assertEquals (key, (Double) expected, (Double) (dw.get() / lw.get())); + } else if (vals[1] instanceof HiveDecimalWritable) { + HiveDecimalWritable hdw = (HiveDecimalWritable) vals[1]; + assertEquals (key, (HiveDecimal) expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get()))); + } } } @@ -1935,6 +2142,7 @@ void validateVariance(String key, double expected, long cnt, double sum, double {"var_samp", VarianceSampValidator.class}, {"std", StdValidator.class}, {"stddev", StdValidator.class}, + {"stddev_pop", StdValidator.class}, {"stddev_samp", StdSampValidator.class}, }; @@ -2015,6 +2223,38 @@ public void testAggregateStringIterable ( validator.validate("_total", expected, result); } + public void testAggregateDecimalIterable ( + String aggregateName, + Iterable data, + Object expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("A", 0); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); + + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", + TypeInfoFactory.getDecimalTypeInfo(30, 4)); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + + for (VectorizedRowBatch unit: data) { + vgo.processOp(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(1, outBatchList.size()); + + Object result = outBatchList.get(0); + + Validator validator = getValidator(aggregateName); + validator.validate("_total", expected, result); + } + + public void testAggregateDoubleIterable ( String aggregateName, Iterable data, diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java index c8eaea1..ba7b0f9 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java @@ -23,8 +23,10 @@ import java.util.Iterator; import java.util.List; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; @@ -138,6 +140,18 @@ public void assign( dcv.vector[row] = Double.valueOf(value.toString()); } }; + } else if (types[i].toLowerCase().startsWith("decimal")) { + batch.cols[i] = new DecimalColumnVector(batchSize, 38, 0); + columnAssign[i] = new ColumnVectorAssign() { + @Override + public void assign( + ColumnVector columnVector, + int row, + Object value) { + DecimalColumnVector dcv = (DecimalColumnVector) columnVector; + dcv.vector[row] = (Decimal128)value; + } + }; } else { throw new HiveException("Unimplemented type " + types[i]); } diff --git serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java new file mode 100644 index 0000000..0f9f122 --- /dev/null +++ serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java @@ -0,0 +1,147 @@ +/** + * + */ +package org.apache.hadoop.hive.serde2.io; + +import junit.framework.Assert; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hive.common.util.Decimal128FastBuffer; +import org.junit.Before; +import org.junit.Test; + +/** + * + * + */ +public class TestHiveDecimalWritable { + + private Decimal128FastBuffer scratch; + + @Before + public void setUp() throws Exception { + scratch = new Decimal128FastBuffer(); + } + + private void doTestFastStreamForHiveDecimal(String valueString) { + BigDecimal value = new BigDecimal(valueString); + Decimal128 dec = new Decimal128(); + dec.update(value); + + HiveDecimalWritable witness = new HiveDecimalWritable(); + witness.set(HiveDecimal.create(value)); + + int bufferUsed = dec.fastSerializeForHiveDecimal(scratch); + HiveDecimalWritable hdw = new HiveDecimalWritable(); + hdw.set(scratch.getBytes(bufferUsed), dec.getScale()); + + HiveDecimal hd = hdw.getHiveDecimal(); + + BigDecimal readValue = hd.bigDecimalValue(); + + Assert.assertEquals(value, readValue); + } + + @Test + public void testFastStreamForHiveDecimal() { + + doTestFastStreamForHiveDecimal("0"); + doTestFastStreamForHiveDecimal("-0"); + doTestFastStreamForHiveDecimal("1"); + doTestFastStreamForHiveDecimal("-1"); + doTestFastStreamForHiveDecimal("2"); + doTestFastStreamForHiveDecimal("-2"); + doTestFastStreamForHiveDecimal("127"); + doTestFastStreamForHiveDecimal("-127"); + doTestFastStreamForHiveDecimal("128"); + doTestFastStreamForHiveDecimal("-128"); + doTestFastStreamForHiveDecimal("255"); + doTestFastStreamForHiveDecimal("-255"); + doTestFastStreamForHiveDecimal("256"); + doTestFastStreamForHiveDecimal("-256"); + doTestFastStreamForHiveDecimal("65535"); + doTestFastStreamForHiveDecimal("-65535"); + doTestFastStreamForHiveDecimal("65536"); + doTestFastStreamForHiveDecimal("-65536"); + + doTestFastStreamForHiveDecimal("10"); + doTestFastStreamForHiveDecimal("1000"); + doTestFastStreamForHiveDecimal("1000000"); + doTestFastStreamForHiveDecimal("1000000000"); + doTestFastStreamForHiveDecimal("1000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000000000"); + + doTestFastStreamForHiveDecimal("-10"); + doTestFastStreamForHiveDecimal("-1000"); + doTestFastStreamForHiveDecimal("-1000000"); + doTestFastStreamForHiveDecimal("-1000000000"); + doTestFastStreamForHiveDecimal("-1000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000000000"); + + + doTestFastStreamForHiveDecimal("0.01"); + doTestFastStreamForHiveDecimal("-0.01"); + doTestFastStreamForHiveDecimal("0.02"); + doTestFastStreamForHiveDecimal("-0.02"); + doTestFastStreamForHiveDecimal("0.0127"); + doTestFastStreamForHiveDecimal("-0.0127"); + doTestFastStreamForHiveDecimal("0.0128"); + doTestFastStreamForHiveDecimal("-0.0128"); + doTestFastStreamForHiveDecimal("0.0255"); + doTestFastStreamForHiveDecimal("-0.0255"); + doTestFastStreamForHiveDecimal("0.0256"); + doTestFastStreamForHiveDecimal("-0.0256"); + doTestFastStreamForHiveDecimal("0.065535"); + doTestFastStreamForHiveDecimal("-0.065535"); + doTestFastStreamForHiveDecimal("0.065536"); + doTestFastStreamForHiveDecimal("-0.065536"); + + doTestFastStreamForHiveDecimal("0.101"); + doTestFastStreamForHiveDecimal("0.10001"); + doTestFastStreamForHiveDecimal("0.10000001"); + doTestFastStreamForHiveDecimal("0.10000000001"); + doTestFastStreamForHiveDecimal("0.10000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000000000001"); + + doTestFastStreamForHiveDecimal("-0.101"); + doTestFastStreamForHiveDecimal("-0.10001"); + doTestFastStreamForHiveDecimal("-0.10000001"); + doTestFastStreamForHiveDecimal("-0.10000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000000000001"); + + doTestFastStreamForHiveDecimal(Integer.toString(Integer.MAX_VALUE)); + doTestFastStreamForHiveDecimal(Integer.toString(Integer.MIN_VALUE)); + doTestFastStreamForHiveDecimal(Long.toString(Long.MAX_VALUE)); + doTestFastStreamForHiveDecimal(Long.toString(Long.MIN_VALUE)); + doTestFastStreamForHiveDecimal(Decimal128.MAX_VALUE.toFormalString()); + doTestFastStreamForHiveDecimal(Decimal128.MIN_VALUE.toFormalString()); + + + } + +}