diff --git ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java index 1b76fc9..7744cb3 100644 --- ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java +++ ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java @@ -460,6 +460,11 @@ {"VectorUDAFMinMax", "VectorUDAFMaxDouble", "double", ">", "max", "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: double)"}, + {"VectorUDAFMinMaxDecimal", "VectorUDAFMaxDecimal", "<", "max", + "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: decimal)"}, + {"VectorUDAFMinMaxDecimal", "VectorUDAFMinDecimal", ">", "min", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: decimal)"}, + {"VectorUDAFMinMaxString", "VectorUDAFMinString", "<", "min", "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: string)"}, {"VectorUDAFMinMaxString", "VectorUDAFMaxString", ">", "max", @@ -479,24 +484,36 @@ {"VectorUDAFVar", "VectorUDAFVarPopDouble", "double", "myagg.variance / myagg.count", "variance, var_pop", "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFVarPopDecimal", "myagg.variance / myagg.count", + "variance, var_pop", + "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFVarSampLong", "long", "myagg.variance / (myagg.count-1.0)", "var_samp", "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFVarSampDouble", "double", "myagg.variance / (myagg.count-1.0)", "var_samp", "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFVarSampDecimal", "myagg.variance / (myagg.count-1.0)", + "var_samp", + "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFStdPopLong", "long", "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFStdPopDouble", "double", "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFStdPopDecimal", + "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", + "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, decimal)"}, {"VectorUDAFVar", "VectorUDAFStdSampLong", "long", "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, long)"}, {"VectorUDAFVar", "VectorUDAFStdSampDouble", "double", "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, double)"}, + {"VectorUDAFVarDecimal", "VectorUDAFStdSampDecimal", + "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", + "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, decimal)"}, }; @@ -619,12 +636,16 @@ private void generate() throws Exception { generateVectorUDAFMinMax(tdesc); } else if (tdesc[0].equals("VectorUDAFMinMaxString")) { generateVectorUDAFMinMaxString(tdesc); + } else if (tdesc[0].equals("VectorUDAFMinMaxDecimal")) { + generateVectorUDAFMinMaxDecimal(tdesc); } else if (tdesc[0].equals("VectorUDAFSum")) { generateVectorUDAFSum(tdesc); } else if (tdesc[0].equals("VectorUDAFAvg")) { generateVectorUDAFAvg(tdesc); } else if (tdesc[0].equals("VectorUDAFVar")) { generateVectorUDAFVar(tdesc); + } else if (tdesc[0].equals("VectorUDAFVarDecimal")) { + generateVectorUDAFVarDecimal(tdesc); } else if (tdesc[0].equals("FilterStringColumnCompareScalar")) { generateFilterStringColumnCompareScalar(tdesc); } else if (tdesc[0].equals("FilterStringColumnBetween")) { @@ -675,7 +696,7 @@ private void generateFilterStringColumnBetween(String[] tdesc) throws IOExceptio className, templateString); } - private void generateFilterColumnBetween(String[] tdesc) throws IOException { + private void generateFilterColumnBetween(String[] tdesc) throws Exception { String operandType = tdesc[1]; String optionalNot = tdesc[2]; @@ -695,7 +716,7 @@ private void generateFilterColumnBetween(String[] tdesc) throws IOException { className, templateString); } - private void generateColumnCompareColumn(String[] tdesc) throws IOException { + private void generateColumnCompareColumn(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -748,6 +769,23 @@ private void generateVectorUDAFMinMaxString(String[] tdesc) throws Exception { className, templateString); } + private void generateVectorUDAFMinMaxDecimal(String[] tdesc) throws Exception { + String className = tdesc[1]; + String operatorSymbol = tdesc[2]; + String descName = tdesc[3]; + String descValue = tdesc[4]; + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", descName); + templateString = templateString.replaceAll("", descValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + private void generateVectorUDAFSum(String[] tdesc) throws Exception { //template, , , , String className = tdesc[1]; @@ -768,7 +806,7 @@ private void generateVectorUDAFSum(String[] tdesc) throws Exception { className, templateString); } - private void generateVectorUDAFAvg(String[] tdesc) throws IOException { + private void generateVectorUDAFAvg(String[] tdesc) throws Exception { String className = tdesc[1]; String valueType = tdesc[2]; String columnType = getColumnVectorType(valueType); @@ -783,7 +821,7 @@ private void generateVectorUDAFAvg(String[] tdesc) throws IOException { className, templateString); } - private void generateVectorUDAFVar(String[] tdesc) throws IOException { + private void generateVectorUDAFVar(String[] tdesc) throws Exception { String className = tdesc[1]; String valueType = tdesc[2]; String varianceFormula = tdesc[3]; @@ -804,6 +842,24 @@ private void generateVectorUDAFVar(String[] tdesc) throws IOException { className, templateString); } + private void generateVectorUDAFVarDecimal(String[] tdesc) throws Exception { + String className = tdesc[1]; + String varianceFormula = tdesc[2]; + String descriptionName = tdesc[3]; + String descriptionValue = tdesc[4]; + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", varianceFormula); + templateString = templateString.replaceAll("", descriptionName); + templateString = templateString.replaceAll("", descriptionValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateFilterStringScalarCompareColumn(String[] tdesc) throws IOException { String operatorName = tdesc[1]; String className = "FilterStringScalar" + operatorName + "StringColumn"; @@ -857,7 +913,7 @@ private void generateStringColumnCompareScalar(String[] tdesc, String className) className, templateString); } - private void generateFilterColumnCompareColumn(String[] tdesc) throws IOException { + private void generateFilterColumnCompareColumn(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -868,7 +924,7 @@ private void generateFilterColumnCompareColumn(String[] tdesc) throws IOExceptio generateColumnBinaryOperatorColumn(tdesc, null, className); } - private void generateColumnUnaryMinus(String[] tdesc) throws IOException { + private void generateColumnUnaryMinus(String[] tdesc) throws Exception { String operandType = tdesc[1]; String inputColumnVectorType = this.getColumnVectorType(operandType); String outputColumnVectorType = inputColumnVectorType; @@ -886,7 +942,7 @@ private void generateColumnUnaryMinus(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprColumnColumn(String[] tdesc) throws IOException { + private void generateIfExprColumnColumn(String[] tdesc) throws Exception { String operandType = tdesc[1]; String inputColumnVectorType = this.getColumnVectorType(operandType); String outputColumnVectorType = inputColumnVectorType; @@ -904,7 +960,7 @@ private void generateIfExprColumnColumn(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprColumnScalar(String[] tdesc) throws IOException { + private void generateIfExprColumnScalar(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg2ColumnVectorType = this.getColumnVectorType(operandType2); @@ -926,7 +982,7 @@ private void generateIfExprColumnScalar(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprScalarColumn(String[] tdesc) throws IOException { + private void generateIfExprScalarColumn(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg3ColumnVectorType = this.getColumnVectorType(operandType3); @@ -948,7 +1004,7 @@ private void generateIfExprScalarColumn(String[] tdesc) throws IOException { className, templateString); } - private void generateIfExprScalarScalar(String[] tdesc) throws IOException { + private void generateIfExprScalarScalar(String[] tdesc) throws Exception { String operandType2 = tdesc[1]; String operandType3 = tdesc[2]; String arg3ColumnVectorType = this.getColumnVectorType(operandType3); @@ -970,7 +1026,7 @@ private void generateIfExprScalarScalar(String[] tdesc) throws IOException { } // template, , , , , , - private void generateColumnUnaryFunc(String[] tdesc) throws IOException { + private void generateColumnUnaryFunc(String[] tdesc) throws Exception { String classNamePrefix = tdesc[1]; String operandType = tdesc[3]; String inputColumnVectorType = this.getColumnVectorType(operandType); @@ -998,7 +1054,7 @@ private void generateColumnUnaryFunc(String[] tdesc) throws IOException { className, templateString); } - private void generateColumnArithmeticColumn(String [] tdesc) throws IOException { + private void generateColumnArithmeticColumn(String [] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1008,7 +1064,7 @@ private void generateColumnArithmeticColumn(String [] tdesc) throws IOException generateColumnBinaryOperatorColumn(tdesc, returnType, className); } - private void generateFilterColumnCompareScalar(String[] tdesc) throws IOException { + private void generateFilterColumnCompareScalar(String[] tdesc) throws Exception { //The variables are all same as ColumnCompareScalar except that //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; @@ -1019,7 +1075,7 @@ private void generateFilterColumnCompareScalar(String[] tdesc) throws IOExceptio generateColumnBinaryOperatorScalar(tdesc, null, className); } - private void generateFilterScalarCompareColumn(String[] tdesc) throws IOException { + private void generateFilterScalarCompareColumn(String[] tdesc) throws Exception { //this template doesn't need a return type. Pass anything as return type. String operatorName = tdesc[1]; String operandType1 = tdesc[2]; @@ -1029,7 +1085,7 @@ private void generateFilterScalarCompareColumn(String[] tdesc) throws IOExceptio generateScalarBinaryOperatorColumn(tdesc, null, className); } - private void generateColumnCompareScalar(String[] tdesc) throws IOException { + private void generateColumnCompareScalar(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1039,7 +1095,7 @@ private void generateColumnCompareScalar(String[] tdesc) throws IOException { generateColumnBinaryOperatorScalar(tdesc, returnType, className); } - private void generateScalarCompareColumn(String[] tdesc) throws IOException { + private void generateScalarCompareColumn(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1050,10 +1106,11 @@ private void generateScalarCompareColumn(String[] tdesc) throws IOException { } private void generateColumnBinaryOperatorColumn(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType1 = this.getColumnVectorType(operandType1); String inputColumnVectorType2 = this.getColumnVectorType(operandType2); String operatorSymbol = tdesc[4]; @@ -1089,10 +1146,11 @@ private void generateColumnBinaryOperatorColumn(String[] tdesc, String returnTyp } private void generateColumnBinaryOperatorScalar(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType = this.getColumnVectorType(operandType1); String operatorSymbol = tdesc[4]; @@ -1127,10 +1185,11 @@ private void generateColumnBinaryOperatorScalar(String[] tdesc, String returnTyp } private void generateScalarBinaryOperatorColumn(String[] tdesc, String returnType, - String className) throws IOException { + String className) throws Exception { String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - String outputColumnVectorType = this.getColumnVectorType(returnType); + String outputColumnVectorType = this.getColumnVectorType( + returnType == null ? "long" : returnType); String inputColumnVectorType = this.getColumnVectorType(operandType2); String operatorSymbol = tdesc[4]; @@ -1166,7 +1225,7 @@ private void generateScalarBinaryOperatorColumn(String[] tdesc, String returnTyp } //Binary arithmetic operator - private void generateColumnArithmeticScalar(String[] tdesc) throws IOException { + private void generateColumnArithmeticScalar(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1260,7 +1319,7 @@ private void generateColumnDivideColumnDecimal(String[] tdesc) throws IOExceptio className, templateString); } - private void generateScalarArithmeticColumn(String[] tdesc) throws IOException { + private void generateScalarArithmeticColumn(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; @@ -1369,11 +1428,15 @@ private String getArithmeticReturnType(String operandType1, } } - private String getColumnVectorType(String primitiveType) { + private String getColumnVectorType(String primitiveType) throws Exception { if(primitiveType!=null && primitiveType.equals("double")) { return "DoubleColumnVector"; + } else if (primitiveType.equals("long")) { + return "LongColumnVector"; + } else if (primitiveType.equals("decimal")) { + return "DecimalColumnVector"; } - return "LongColumnVector"; + throw new Exception("Unimplemented primitive column vector type: " + primitiveType); } private String getOutputWritableType(String primitiveType) throws Exception { @@ -1381,6 +1444,8 @@ private String getOutputWritableType(String primitiveType) throws Exception { return "LongWritable"; } else if (primitiveType.equals("double")) { return "DoubleWritable"; + } else if (primitiveType.equals("decimal")) { + return "HiveDecimalWritable"; } throw new Exception("Unimplemented primitive output writable: " + primitiveType); } @@ -1390,6 +1455,8 @@ private String getOutputObjectInspector(String primitiveType) throws Exception { return "PrimitiveObjectInspectorFactory.writableLongObjectInspector"; } else if (primitiveType.equals("double")) { return "PrimitiveObjectInspectorFactory.writableDoubleObjectInspector"; + } else if (primitiveType.equals("decimal")) { + return "PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector"; } throw new Exception("Unimplemented primitive output inspector: " + primitiveType); } diff --git common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java index c9ffc59..39e45ad 100644 --- common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java +++ common/src/java/org/apache/hadoop/hive/common/type/Decimal128.java @@ -17,8 +17,11 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.nio.IntBuffer; +import org.apache.hive.common.util.Decimal128FastBuffer; + /** * This code was originally written for Microsoft PolyBase. *

@@ -65,13 +68,17 @@ /** Minimum value for #scale. */ public static final short MIN_SCALE = 0; + public static final Decimal128 ONE = new Decimal128().update(1); + /** Maximum value that can be represented in this class. */ public static final Decimal128 MAX_VALUE = new Decimal128( - UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, false); + UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, false).subtractDestructive( + Decimal128.ONE, (short) 0); /** Minimum value that can be represented in this class. */ public static final Decimal128 MIN_VALUE = new Decimal128( - UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, true); + UnsignedInt128.TEN_TO_THIRTYEIGHT, (short) 0, true).addDestructive( + Decimal128.ONE, (short) 0); /** For Serializable. */ private static final long serialVersionUID = 1L; @@ -234,9 +241,10 @@ public Decimal128(char[] str, int offset, int length, short scale) { } /** Reset the value of this object to zero. */ - public void zeroClear() { + public Decimal128 zeroClear() { this.unscaledValue.zeroClear(); this.signum = 0; + return this; } /** @return whether this value represents zero. */ @@ -252,10 +260,11 @@ public boolean isZero() { * @param o * object to copy from */ - public void update(Decimal128 o) { + public Decimal128 update(Decimal128 o) { this.unscaledValue.update(o.unscaledValue); this.scale = o.scale; this.signum = o.signum; + return this; } /** @@ -265,8 +274,8 @@ public void update(Decimal128 o) { * @param val * {@code long} value to be set to {@code Decimal128}. */ - public void update(long val) { - update(val, (short) 0); + public Decimal128 update(long val) { + return update(val, (short) 0); } /** @@ -278,7 +287,7 @@ public void update(long val) { * @param scale * scale of the {@code Decimal128}. */ - public void update(long val, short scale) { + public Decimal128 update(long val, short scale) { this.scale = 0; if (val < 0L) { this.unscaledValue.update(-val); @@ -293,6 +302,7 @@ public void update(long val, short scale) { if (scale != 0) { changeScaleDestructive(scale); } + return this; } /** @@ -311,7 +321,7 @@ public void update(long val, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(double val, short scale) { + public Decimal128 update(double val, short scale) { if (Double.isInfinite(val) || Double.isNaN(val)) { throw new NumberFormatException("Infinite or NaN"); } @@ -331,7 +341,7 @@ public void update(double val, short scale) { // zero check if (significand == 0) { zeroClear(); - return; + return this; } this.signum = sign; @@ -389,6 +399,7 @@ public void update(double val, short scale) { (short) (twoScaleDown - scale)); } } + return this; } /** @@ -400,12 +411,13 @@ public void update(double val, short scale) { * @param precision * 0 to 38. Decimal digits. */ - public void update(IntBuffer buf, int precision) { + public Decimal128 update(IntBuffer buf, int precision) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update(buf, precision); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -415,12 +427,13 @@ public void update(IntBuffer buf, int precision) { * @param buf * ByteBuffer to read values from */ - public void update128(IntBuffer buf) { + public Decimal128 update128(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update128(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -430,12 +443,13 @@ public void update128(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update96(IntBuffer buf) { + public Decimal128 update96(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update96(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -445,12 +459,13 @@ public void update96(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update64(IntBuffer buf) { + public Decimal128 update64(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update64(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -460,12 +475,13 @@ public void update64(IntBuffer buf) { * @param buf * ByteBuffer to read values from */ - public void update32(IntBuffer buf) { + public Decimal128 update32(IntBuffer buf) { int scaleAndSignum = buf.get(); this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update32(buf); assert ((signum == 0) == unscaledValue.isZero()); + return this; } /** @@ -479,11 +495,12 @@ public void update32(IntBuffer buf) { * @param precision * 0 to 38. Decimal digits. */ - public void update(int[] array, int offset, int precision) { + public Decimal128 update(int[] array, int offset, int precision) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update(array, offset + 1, precision); + return this; } /** @@ -495,11 +512,12 @@ public void update(int[] array, int offset, int precision) { * @param offset * offset of the int array */ - public void update128(int[] array, int offset) { + public Decimal128 update128(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update128(array, offset + 1); + return this; } /** @@ -511,11 +529,12 @@ public void update128(int[] array, int offset) { * @param offset * offset of the int array */ - public void update96(int[] array, int offset) { + public Decimal128 update96(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update96(array, offset + 1); + return this; } /** @@ -527,11 +546,12 @@ public void update96(int[] array, int offset) { * @param offset * offset of the int array */ - public void update64(int[] array, int offset) { + public Decimal128 update64(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update64(array, offset + 1); + return this; } /** @@ -543,21 +563,31 @@ public void update64(int[] array, int offset) { * @param offset * offset of the int array */ - public void update32(int[] array, int offset) { + public Decimal128 update32(int[] array, int offset) { int scaleAndSignum = array[offset]; this.scale = (short) (scaleAndSignum >> 16); this.signum = (byte) (scaleAndSignum & 0xFF); this.unscaledValue.update32(array, offset + 1); + return this; } /** + * Updates the value of this object with the given {@link BigDecimal}. + * @param bigDecimal + * {@link java.math.BigDecimal} + */ + public Decimal128 update(BigDecimal bigDecimal) { + return update(bigDecimal.unscaledValue(), (short) bigDecimal.scale()); + } + + /** * Updates the value of this object with the given {@link BigInteger} and scale. * * @param bigInt * {@link java.math.BigInteger} * @param scale */ - public void update(BigInteger bigInt, short scale) { + public Decimal128 update(BigInteger bigInt, short scale) { this.scale = scale; this.signum = (byte) bigInt.compareTo(BigInteger.ZERO); if (signum == 0) { @@ -567,6 +597,7 @@ public void update(BigInteger bigInt, short scale) { } else { unscaledValue.update(bigInt); } + return this; } /** @@ -577,8 +608,8 @@ public void update(BigInteger bigInt, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(String str, short scale) { - update(str.toCharArray(), 0, str.length(), scale); + public Decimal128 update(String str, short scale) { + return update(str.toCharArray(), 0, str.length(), scale); } /** @@ -594,7 +625,7 @@ public void update(String str, short scale) { * @param scale * scale of the {@code Decimal128}. */ - public void update(char[] str, int offset, int length, short scale) { + public Decimal128 update(char[] str, int offset, int length, short scale) { final int end = offset + length; assert (end <= str.length); int cursor = offset; @@ -616,7 +647,7 @@ public void update(char[] str, int offset, int length, short scale) { this.scale = scale; zeroClear(); if (cursor == end) { - return; + return this; } // "1234567" => unscaledValue=1234567, negative=false, @@ -695,6 +726,16 @@ public void update(char[] str, int offset, int length, short scale) { this.unscaledValue.scaleDownTenDestructive((short) -scaleAdjust); } this.signum = (byte) (this.unscaledValue.isZero() ? 0 : (negative ? -1 : 1)); + return this; + } + + /** + * Serializes the value in a format compatible with the BigDecimal's own representation + * @param bytes + * @param offset + */ + public int fastSerializeForHiveDecimal( Decimal128FastBuffer scratch) { + return this.unscaledValue.fastSerializeForHiveDecimal(scratch, this.signum); } /** @@ -914,15 +955,15 @@ public static void add(Decimal128 left, Decimal128 right, Decimal128 result, * @param scale * scale of the result. must be 0 or positive. */ - public void addDestructive(Decimal128 right, short scale) { + public Decimal128 addDestructive(Decimal128 right, short scale) { this.changeScaleDestructive(scale); if (right.signum == 0) { - return; + return this; } if (this.signum == 0) { this.update(right); this.changeScaleDestructive(scale); - return; + return this; } short rightScaleTen = (short) (scale - right.scale); @@ -946,6 +987,7 @@ public void addDestructive(Decimal128 right, short scale) { } this.unscaledValue.throwIfExceedsTenToThirtyEight(); + return this; } /** @@ -986,16 +1028,16 @@ public static void subtract(Decimal128 left, Decimal128 right, * @param scale * scale of the result. must be 0 or positive. */ - public void subtractDestructive(Decimal128 right, short scale) { + public Decimal128 subtractDestructive(Decimal128 right, short scale) { this.changeScaleDestructive(scale); if (right.signum == 0) { - return; + return this; } if (this.signum == 0) { this.update(right); this.changeScaleDestructive(scale); this.negateDestructive(); - return; + return this; } short rightScaleTen = (short) (scale - right.scale); @@ -1020,6 +1062,7 @@ public void subtractDestructive(Decimal128 right, short scale) { } this.unscaledValue.throwIfExceedsTenToThirtyEight(); + return this; } /** @@ -1790,4 +1833,56 @@ public void zeroFractionPart() { */ this.getUnscaledValue().scaleUpTenDestructive(placesToRemove); } + + /** + * Multiplies this with this, updating this + * + * @return self + */ + public Decimal128 squareDestructive() { + this.multiplyDestructive(this, this.getScale()); + return this; + } + + /** + * For UDAF variance we use the algorithm described by Chan, Golub, and LeVeque in + * "Algorithms for computing the sample variance: analysis and recommendations" + * The American Statistician, 37 (1983) pp. 242--247. + * + * variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2),2) + * + * where: - variance is sum[x-avg^2] (this is actually n times the variance) + * and is updated at every step. - n is the count of elements in chunk1 - m is + * the count of elements in chunk2 - t1 = sum of elements in chunk1, t2 = + * sum of elements in chunk2. + * + * This is a helper function doing the intermediate computation: + * t = myagg.count*value - myagg.sum; + * myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + * + * @return self + */ + public Decimal128 updateVarianceDestructive( + Decimal128 scratch, Decimal128 value, Decimal128 sum, long count) { + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + scratch.squareDestructive(); + scratch.unscaledValue.divideDestructive(count * (count-1)); + this.addDestructive(scratch, getScale()); + return this; + } + + /** + * Fats update from BigInteger two's complement representation + * @param internalStorage BigInteger two's complement representation of the unscaled value + * @param scale + */ + public Decimal128 fastUpdateFromInternalStorage(byte[] internalStorage, short scale) { + this.scale = scale; + this.signum = this.unscaledValue.fastUpdateFromInternalStorage(internalStorage); + + return this; + } } + diff --git common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java index fb3c346..74168bd 100644 --- common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java +++ common/src/java/org/apache/hadoop/hive/common/type/UnsignedInt128.java @@ -16,9 +16,12 @@ package org.apache.hadoop.hive.common.type; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.nio.IntBuffer; import java.util.Arrays; +import org.apache.hive.common.util.Decimal128FastBuffer; + /** * This code was originally written for Microsoft PolyBase. * @@ -1373,6 +1376,30 @@ public int divideDestructive(int right) { } /** + * Divides this value with the given value. This version is destructive, + * meaning it modifies this object. + * + * @param right + * the value to divide + * @return remainder + */ + public long divideDestructive(long right) { + assert (right >= 0); + + long quotient; + long remainder = 0; + + for (int i = INT_COUNT - 1; i >= 0; --i) { + remainder = ((this.v[i] & SqlMathUtil.LONG_MASK) + (remainder << 32)); + quotient = remainder / right; + remainder %= right; + this.v[i] = (int) quotient; + } + updateCount(); + return remainder; + } + + /** * Right-shift for the given number of bits. This version is destructive, * meaning it modifies this object. * @@ -2405,4 +2432,206 @@ private void updateCount() { this.count = (byte) 0; } } + + /*(non-Javadoc) + * Serializes one int part into the given @{link #ByteBuffer} + * considering two's complement for negatives. + */ + private static void fastSerializeIntPartForHiveDecimal(ByteBuffer buf, + int pos, int value, byte signum, boolean isFirstNonZero) { + if (signum == -1 && value != 0) { + value = (isFirstNonZero ? -value : ~value); + } + buf.putInt(pos, value); + } + + /* (non-Javadoc) + * Serializes this value into the format used by @{link #java.math.BigInteger} + * This is used for fast assignment of a Decimal128 to a HiveDecimalWritable internal storage. + * See OpenJDK BigInteger.toByteArray for a reference implementation. + * @param scratch + * @param signum + * @return + */ + public int fastSerializeForHiveDecimal(Decimal128FastBuffer scratch, byte signum) { + int bufferUsed = this.count; + ByteBuffer buf = scratch.getByteBuffer(bufferUsed); + buf.put(0, (byte) (signum == 1 ? 0 : signum)); + int pos = 1; + int firstNonZero = 0; + while(firstNonZero < this.count && v[firstNonZero] == 0) { + ++firstNonZero; + } + switch(this.count) { + case 4: + fastSerializeIntPartForHiveDecimal(buf, pos, v[3], signum, firstNonZero == 3); + pos+=4; + // intentional fall through + case 3: + fastSerializeIntPartForHiveDecimal(buf, pos, v[2], signum, firstNonZero == 2); + pos+=4; + // intentional fall through + case 2: + fastSerializeIntPartForHiveDecimal(buf, pos, v[1], signum, firstNonZero == 1); + pos+=4; + // intentional fall through + case 1: + fastSerializeIntPartForHiveDecimal(buf, pos, v[0], signum, true); + } + return bufferUsed; + } + + /** + * Updates this value from a serialized unscaled {@link java.math.BigInteger} representation. + * This is used for fast update of a Decimal128 from a HiveDecimalWritable internal storage. + * @param internalStorage + * @return + */ + public byte fastUpdateFromInternalStorage(byte[] internalStorage) { + byte signum = 0; + int skip = 0; + this.count = 0; + // Skip over any leading 0s or 0xFFs + byte firstByte = internalStorage[0]; + if (firstByte == 0 || firstByte == -1) { + while((skip < internalStorage.length) && + (internalStorage[skip] == firstByte)) { + ++skip; + } + } + if (skip == internalStorage.length) { + // The entire storage is 0x00s or 0xFFs + // 0x00s means is 0 + // 0xFFs means is -1 + assert (firstByte == 0 || firstByte == -1); + if (firstByte == -1) { + signum = -1; + this.count = 1; + this.v[0] = 1; + } + else { + signum = 0; + } + } + else { + // We skipped over leading 0x00s and 0xFFs + // Important, signum is given by the firstByte, not by byte[keep]! + signum = (firstByte < 0) ? (byte) -1 : (byte) 1; + + // Now we read the big-endian compacted two's complement int parts + // Compacted means they are stripped of leading 0x00s and 0xFFs + // This is why we do the intLength/pos tricks bellow + // 'length' is all the bytes we have to read, after we skip 'skip' + // 'pos' is where to start reading the current int + // 'intLength' is how many bytes we read for the current int + + int length = internalStorage.length - skip; + int pos = skip; + int intLength = 0; + switch(length) { + case 16: ++intLength; //intentional fall through + case 15: ++intLength; + case 14: ++intLength; + case 13: ++intLength; + v[3] = fastUpdateIntFromInternalStorage(internalStorage, signum, pos, intLength); + ++this.count; + pos += intLength; + intLength = 0; + //intentional fall through + case 12: ++intLength; //intentional fall through + case 11: ++intLength; + case 10: ++intLength; + case 9: ++intLength; + v[2] = fastUpdateIntFromInternalStorage(internalStorage, signum, pos, intLength); + ++this.count; + pos += intLength; + intLength = 0; + //intentional fall through + case 8: ++intLength; //intentional fall through + case 7: ++intLength; + case 6: ++intLength; + case 5: ++intLength; + v[1] = fastUpdateIntFromInternalStorage(internalStorage, signum, pos, intLength); + ++this.count; + pos += intLength; + intLength = 0; + //intentional fall through + case 4: ++intLength; //intentional fall through + case 3: ++intLength; + case 2: ++intLength; + case 1: ++intLength; + v[0] = fastUpdateIntFromInternalStorage(internalStorage, signum, pos, intLength); + ++this.count; + break; + default: + // This should not happen + throw new RuntimeException("Impossible HiveDecimal internal storage length!"); + } + if (signum == -1) { + // So far we've read the one's complement + // add 1 to turn it into two's complement + for(int i = 0; i < this.count; ++i) { + if (v[i] != 0) { + v[i] = (int)((v[i] & 0xFFFFFFFFL) + 1); + if (v[i] != 0) { + break; + } + } + } + } + } + return signum; + } + + /** + * reads one int part from the two's complement Big-Endian compacted representation, + * starting from index pos + * @param internalStorage {@link java.math.BigInteger} serialized representation + * @param pos + * @return + */ + private int fastUpdateIntFromInternalStorage(byte[] internalStorage, + byte signum, int pos, int length) { + // due to the way we use the allocation-free cast from HiveDecimalWriter to decimal128, + // we do not have the luxury of a ByteBuffer... + byte b0, b1, b2, b3; + if (signum == -1) { + b1=b2=b3 = (byte)-1; + } + else { + b1=b2=b3=0; + } + switch(length) { + case 4: + b3 = internalStorage[pos]; + ++pos; + //intentional fall through + case 3: + b2 = internalStorage[pos]; + ++pos; + //intentional fall through + case 2: + b1 = internalStorage[pos]; + ++pos; + //intentional fall through + case 1: + b0 = internalStorage[pos]; + break; + default: + // this should never happen + throw new RuntimeException("Impossible HiveDecimal internal storage position!"); + } + + int value = ((int)b0 & 0x000000FF) | + (((int)b1 << 8) & 0x0000FF00) | + (((int)b2 << 16) & 0x00FF0000) | + (((int)b3 << 24) & 0xFF000000); + + if (signum == -1 && value != 0) { + // Make one's complement, masked only for the bytes read + int mask = -1 >>> (8*(4-length)); + value = ~value & mask; + } + return value; + } } diff --git common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java new file mode 100644 index 0000000..aeca82f --- /dev/null +++ common/src/java/org/apache/hive/common/util/Decimal128FastBuffer.java @@ -0,0 +1,51 @@ +/** + * + */ +package org.apache.hive.common.util; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A helper class for fast serialization of decimal128 in the BigDecimal byte[] representation + * + */ +public class Decimal128FastBuffer { + + /** + * Preallocated byte[] for each Decimal128 size (0-4) + */ + private final byte[][] sumBytes; + + /** + * Preallocated ByteBuffer wrappers around each sumBytes + */ + private final ByteBuffer[] sumBuffer; + + public Decimal128FastBuffer() { + sumBytes = new byte[5][]; + sumBuffer = new ByteBuffer[5]; + sumBytes[0] = new byte[1]; + sumBuffer[0] = ByteBuffer.wrap(sumBytes[0]); + sumBytes[1] = new byte[5]; + sumBuffer[1] = ByteBuffer.wrap(sumBytes[1]); + sumBuffer[1].order(ByteOrder.BIG_ENDIAN); + sumBytes[2] = new byte[9]; + sumBuffer[2] = ByteBuffer.wrap(sumBytes[2]); + sumBuffer[2].order(ByteOrder.BIG_ENDIAN); + sumBytes[3] = new byte[13]; + sumBuffer[3] = ByteBuffer.wrap(sumBytes[3]); + sumBuffer[3].order(ByteOrder.BIG_ENDIAN); + sumBytes[4] = new byte[17]; + sumBuffer[4] = ByteBuffer.wrap(sumBytes[4]); + sumBuffer[4].order(ByteOrder.BIG_ENDIAN); + } + + public ByteBuffer getByteBuffer(int index) { + return sumBuffer[index]; + } + + public byte[] getBytes(int index) { + return sumBytes[index]; + } +} diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt index cb94145..547a60a 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvg.txt @@ -54,7 +54,11 @@ public class extends VectorAggregateExpression { transient private double sum; transient private long count; - transient private boolean isNull; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; public void sumValue( value) { if (isNull) { diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt index 2b0364c..dcc1dfb 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMax.txt @@ -49,7 +49,11 @@ public class extends VectorAggregateExpression { private static final long serialVersionUID = 1L; transient private value; - transient private boolean isNull; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; public void checkValue( value) { if (isNull) { diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt new file mode 100644 index 0000000..de9a84c --- /dev/null +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt @@ -0,0 +1,450 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; + +/** +* . Vectorized implementation for MIN/MAX aggregates. +*/ +@Description(name = "", + value = "") +public class extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** + * class for storing the current aggregate value. + */ + static private final class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 value; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; + + public Aggregation() { + value = new Decimal128(); + } + + public void checkValue(Decimal128 value) { + if (isNull) { + isNull = false; + this.value.update(value); + } else if (this.value.compareTo(value) 0) { + this.value.update(value); + } + } + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + } + + private VectorExpression inputExpression; + private transient VectorExpressionWriter resultWriter; + + public (VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public () { + super(); + } + + @Override + public void init(AggregationDesc desc) throws HiveException { + resultWriter = VectorExpressionWriterFactory.genVectorExpressionWritable( + desc.getParameters().get(0)); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregrateIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + Decimal128[] vector = inputVector.vector; + + if (inputVector.noNulls) { + if (inputVector.isRepeating) { + iterateNoNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize); + } else { + if (batch.selectedInUse) { + iterateNoNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batch.selected, batchSize); + } else { + iterateNoNullsWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize); + } + } + } else { + if (inputVector.isRepeating) { + if (batch.selectedInUse) { + iterateHasNullsRepeatingSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector[0], batchSize, inputVector.isNull); + } + } else { + if (batch.selectedInUse) { + iterateHasNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsWithAggregationSelection( + aggregationBufferSets, aggregrateIndex, + vector, batchSize, inputVector.isNull); + } + } + } + } + + private void iterateNoNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + + private void iterateNoNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int[] selection, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[selection[i]]); + } + } + + private void iterateNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize) { + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[i]); + } + } + + private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[selection[i]]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + + } + + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128 value, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(value); + } + } + } + + private void iterateHasNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int j=0; j < batchSize; ++j) { + int i = selection[j]; + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + j); + myagg.checkValue(values[i]); + } + } + } + + private void iterateHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregrateIndex, + Decimal128[] values, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregrateIndex, + i); + myagg.checkValue(values[i]); + } + } + } + + @Override + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Aggregation myagg = (Aggregation)agg; + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls && + (myagg.isNull || (myagg.value.compareTo(vector[0]) 0))) { + myagg.isNull = false; + myagg.value.update(vector[0]); + } + return; + } + + if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNulls(myagg, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + } + } + + private void iterateSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull, + int[] selected) { + + for (int j=0; j< batchSize; ++j) { + int i = selected[j]; + if (!isNull[i]) { + Decimal128 value = vector[i]; + if (myagg.isNull) { + myagg.isNull = false; + myagg.value.update(value); + } + else if (myagg.value.compareTo(value) 0) { + myagg.value.update(value); + } + } + } + } + + private void iterateSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + int[] selected) { + + if (myagg.isNull) { + myagg.value.update(vector[selected[0]]); + myagg.isNull = false; + } + + for (int i=0; i< batchSize; ++i) { + Decimal128 value = vector[selected[i]]; + if (myagg.value.compareTo(value) 0) { + myagg.value.update(value); + } + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i 0) { + myagg.value.update(value); + } + } + } + } + + private void iterateNoSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize) { + if (myagg.isNull) { + myagg.value.update(vector[0]); + myagg.isNull = false; + } + + for (int i=0;i 0) { + myagg.value.update(value); + } + } + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + return new Aggregation(); + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + Aggregation myAgg = (Aggregation) agg; + myAgg.isNull = true; + } + + @Override + public Object evaluateOutput( + AggregationBuffer agg) throws HiveException { + Aggregation myagg = (Aggregation) agg; + if (myagg.isNull) { + return null; + } + else { + return resultWriter.writeValue(myagg.value); + } + } + + @Override + public ObjectInspector getOutputObjectInspector() { + return resultWriter.getObjectInspector(); + } + + @Override + public int getAggregationBufferFixedSize() { + JavaDataModel model = JavaDataModel.get(); + return JavaDataModel.alignUp( + model.object() + + model.primitive2(), + model.memoryAlign()); + } + + public VectorExpression getInputExpression() { + return inputExpression; + } + + public void setInputExpression(VectorExpression inputExpression) { + this.inputExpression = inputExpression; + } +} + diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt index 36f483e..1f8b28c 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxString.txt @@ -35,15 +35,15 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.Text; /** -* . Vectorized implementation for MIN/MAX aggregates. +* . Vectorized implementation for MIN/MAX aggregates. */ -@Description(name = "", +@Description(name = "", value = "") public class extends VectorAggregateExpression { private static final long serialVersionUID = 1L; - /** + /** * class for storing the current aggregate value. */ static private final class Aggregation implements AggregationBuffer { @@ -53,7 +53,11 @@ public class extends VectorAggregateExpression { transient private final static int MIN_BUFFER_SIZE = 16; transient private byte[] bytes = new byte[MIN_BUFFER_SIZE]; transient private int length; - transient private boolean isNull; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; public void checkValue(byte[] bytes, int start, int length) { if (isNull) { @@ -65,7 +69,7 @@ public class extends VectorAggregateExpression { assign(bytes, start, length); } } - + public void assign(byte[] bytes, int start, int length) { // Avoid new allocation if possible if (this.bytes.length < length) { @@ -80,10 +84,10 @@ public class extends VectorAggregateExpression { return model.lengthForByteArrayOfSize(bytes.length); } } - + private VectorExpression inputExpression; transient private Text result; - + public (VectorExpression inputExpression) { this(); this.inputExpression = inputExpression; @@ -93,7 +97,7 @@ public class extends VectorAggregateExpression { super(); result = new Text(); } - + private Aggregation getCurrentAggregationBuffer( VectorAggregationBufferRow[] aggregationBufferSets, int aggregrateIndex, @@ -102,21 +106,21 @@ public class extends VectorAggregateExpression { Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregrateIndex); return myagg; } - + @Override public void aggregateInputSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregrateIndex, + int aggregrateIndex, VectorizedRowBatch batch) throws HiveException { - + int batchSize = batch.size; - + if (batchSize == 0) { return; } - + inputExpression.evaluate(batch); - + BytesColumnVector inputColumn = (BytesColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; @@ -164,12 +168,12 @@ public class extends VectorAggregateExpression { int length = inputColumn.length[0]; for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregrateIndex, i); myagg.checkValue(bytes, start, length); } - } + } private void iterateNoNullsSelectionWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, @@ -177,11 +181,11 @@ public class extends VectorAggregateExpression { BytesColumnVector inputColumn, int[] selection, int batchSize) { - + for (int i=0; i < batchSize; ++i) { int row = selection[i]; Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregrateIndex, i); myagg.checkValue(inputColumn.vector[row], @@ -197,7 +201,7 @@ public class extends VectorAggregateExpression { int batchSize) { for (int i=0; i < batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregrateIndex, i); myagg.checkValue(inputColumn.vector[i], @@ -217,7 +221,7 @@ public class extends VectorAggregateExpression { int row = selection[i]; if (!inputColumn.isNull[row]) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregrateIndex, i); myagg.checkValue(inputColumn.vector[row], @@ -236,7 +240,7 @@ public class extends VectorAggregateExpression { for (int i=0; i < batchSize; ++i) { if (!inputColumn.isNull[i]) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregrateIndex, i); myagg.checkValue(inputColumn.vector[i], @@ -245,24 +249,24 @@ public class extends VectorAggregateExpression { } } } - + @Override - public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) throws HiveException { - + inputExpression.evaluate(batch); - + BytesColumnVector inputColumn = (BytesColumnVector)batch. cols[this.inputExpression.getOutputColumn()]; - + int batchSize = batch.size; - + if (batchSize == 0) { return; } - + Aggregation myagg = (Aggregation)agg; - + if (inputColumn.isRepeating) { if (inputColumn.noNulls) { myagg.checkValue(inputColumn.vector[0], @@ -271,7 +275,7 @@ public class extends VectorAggregateExpression { } return; } - + if (!batch.selectedInUse && inputColumn.noNulls) { iterateNoSelectionNoNulls(myagg, inputColumn, batchSize); } @@ -285,13 +289,13 @@ public class extends VectorAggregateExpression { iterateSelectionHasNulls(myagg, inputColumn, batchSize, batch.selected); } } - + private void iterateSelectionHasNulls( - Aggregation myagg, - BytesColumnVector inputColumn, + Aggregation myagg, + BytesColumnVector inputColumn, int batchSize, int[] selected) { - + for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!inputColumn.isNull[i]) { @@ -303,11 +307,11 @@ public class extends VectorAggregateExpression { } private void iterateSelectionNoNulls( - Aggregation myagg, - BytesColumnVector inputColumn, - int batchSize, + Aggregation myagg, + BytesColumnVector inputColumn, + int batchSize, int[] selected) { - + for (int i=0; i< batchSize; ++i) { myagg.checkValue(inputColumn.vector[i], inputColumn.start[i], @@ -316,10 +320,10 @@ public class extends VectorAggregateExpression { } private void iterateNoSelectionHasNulls( - Aggregation myagg, - BytesColumnVector inputColumn, + Aggregation myagg, + BytesColumnVector inputColumn, int batchSize) { - + for (int i=0; i< batchSize; ++i) { if (!inputColumn.isNull[i]) { myagg.checkValue(inputColumn.vector[i], @@ -330,7 +334,7 @@ public class extends VectorAggregateExpression { } private void iterateNoSelectionNoNulls( - Aggregation myagg, + Aggregation myagg, BytesColumnVector inputColumn, int batchSize) { for (int i=0; i< batchSize; ++i) { @@ -363,7 +367,7 @@ public class extends VectorAggregateExpression { return result; } } - + @Override public ObjectInspector getOutputObjectInspector() { return PrimitiveObjectInspectorFactory.writableStringObjectInspector; @@ -378,7 +382,7 @@ public class extends VectorAggregateExpression { model.primitive1()*2, model.memoryAlign()); } - + @Override public boolean hasVariableSize() { return true; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt index 3573997..cb0be33 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFSum.txt @@ -50,7 +50,11 @@ public class extends VectorAggregateExpression { private static final long serialVersionUID = 1L; transient private sum; - transient private boolean isNull; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; public void sumValue( value) { if (isNull) { diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt index 7c0e58f..49b0edd 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVar.txt @@ -38,16 +38,16 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; /** -* . Vectorized implementation for VARIANCE aggregates. +* . Vectorized implementation for VARIANCE aggregates. */ @Description(name = "", value = "") public class extends VectorAggregateExpression { private static final long serialVersionUID = 1L; - - /** - /* class for storing the current aggregate value. + + /** + /* class for storing the current aggregate value. */ private static final class Aggregation implements AggregationBuffer { @@ -56,8 +56,12 @@ public class extends VectorAggregateExpression { transient private double sum; transient private long count; transient private double variance; - transient private boolean isNull; - + + /** + * Value is explicitly (re)initialized in reset() (despite the init() bellow...) + */ + transient private boolean isNull = true; + public void init() { isNull = false; sum = 0; @@ -70,16 +74,16 @@ public class extends VectorAggregateExpression { throw new UnsupportedOperationException(); } } - + private VectorExpression inputExpression; transient private LongWritable resultCount; transient private DoubleWritable resultSum; transient private DoubleWritable resultVariance; transient private Object[] partialResult; - + transient private ObjectInspector soi; - - + + public (VectorExpression inputExpression) { this(); this.inputExpression = inputExpression; @@ -120,32 +124,32 @@ public class extends VectorAggregateExpression { return myagg; } - + @Override public void aggregateInputSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, + int aggregateIndex, VectorizedRowBatch batch) throws HiveException { - + inputExpression.evaluate(batch); - + inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; - + int batchSize = batch.size; - + if (batchSize == 0) { return; } - + [] vector = inputVector.vector; - + if (inputVector.isRepeating) { if (inputVector.noNulls || !inputVector.isNull[0]) { iterateRepeatingNoNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector[0], batchSize); } - } + } else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNullsWithAggregationSelection( aggregationBufferSets, aggregateIndex, vector, batchSize); @@ -160,46 +164,46 @@ public class extends VectorAggregateExpression { } else { iterateSelectionHasNullsWithAggregationSelection( - aggregationBufferSets, aggregateIndex, vector, batchSize, + aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull, batch.selected); } - + } - + private void iterateRepeatingNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - double value, + int aggregateIndex, + double value, int batchSize) { for (int i=0; i 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } } } - + private void iterateSelectionHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - [] vector, + int aggregateIndex, + [] vector, int batchSize, - boolean[] isNull, + boolean[] isNull, int[] selected) { - + for (int j=0; j< batchSize; ++j) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregateIndex, j); int i = selected[j]; @@ -220,14 +224,14 @@ public class extends VectorAggregateExpression { private void iterateSelectionNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - [] vector, - int batchSize, + int aggregateIndex, + [] vector, + int batchSize, int[] selected) { for (int i=0; i< batchSize; ++i) { Aggregation myagg = getCurrentAggregationBuffer( - aggregationBufferSets, + aggregationBufferSets, aggregateIndex, i); double value = vector[selected[i]]; @@ -245,20 +249,20 @@ public class extends VectorAggregateExpression { private void iterateNoSelectionHasNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - [] vector, + int aggregateIndex, + [] vector, int batchSize, boolean[] isNull) { - + for(int i=0;i extends VectorAggregateExpression { private void iterateNoSelectionNoNullsWithAggregationSelection( VectorAggregationBufferRow[] aggregationBufferSets, - int aggregateIndex, - [] vector, + int aggregateIndex, + [] vector, int batchSize) { for (int i=0; i extends VectorAggregateExpression { } @Override - public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) throws HiveException { - + inputExpression.evaluate(batch); - + inputVector = ()batch. cols[this.inputExpression.getOutputColumn()]; - + int batchSize = batch.size; - + if (batchSize == 0) { return; } - + Aggregation myagg = (Aggregation)agg; [] vector = inputVector.vector; - + if (inputVector.isRepeating) { if (inputVector.noNulls) { iterateRepeatingNoNulls(myagg, vector[0], batchSize); } - } + } else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNulls(myagg, vector, batchSize); } @@ -333,40 +337,40 @@ public class extends VectorAggregateExpression { } private void iterateRepeatingNoNulls( - Aggregation myagg, - double value, + Aggregation myagg, + double value, int batchSize) { - + if (myagg.isNull) { myagg.init (); } - + // TODO: conjure a formula w/o iterating // - + myagg.sum += value; - myagg.count += 1; + myagg.count += 1; if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } - + // We pulled out i=0 so we can remove the count > 1 check in the loop for (int i=1; i[] vector, + Aggregation myagg, + [] vector, int batchSize, - boolean[] isNull, + boolean[] isNull, int[] selected) { - + for (int j=0; j< batchSize; ++j) { int i = selected[j]; if (!isNull[i]) { @@ -385,11 +389,11 @@ public class extends VectorAggregateExpression { } private void iterateSelectionNoNulls( - Aggregation myagg, - [] vector, - int batchSize, + Aggregation myagg, + [] vector, + int batchSize, int[] selected) { - + if (myagg.isNull) { myagg.init (); } @@ -401,7 +405,7 @@ public class extends VectorAggregateExpression { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } - + // i=0 was pulled out to remove the count > 1 check in the loop // for (int i=1; i< batchSize; ++i) { @@ -414,16 +418,16 @@ public class extends VectorAggregateExpression { } private void iterateNoSelectionHasNulls( - Aggregation myagg, - [] vector, + Aggregation myagg, + [] vector, int batchSize, boolean[] isNull) { - + for(int i=0;i extends VectorAggregateExpression { } private void iterateNoSelectionNoNulls( - Aggregation myagg, - [] vector, + Aggregation myagg, + [] vector, int batchSize) { - + if (myagg.isNull) { myagg.init (); } @@ -447,12 +451,12 @@ public class extends VectorAggregateExpression { double value = vector[0]; myagg.sum += value; myagg.count += 1; - + if(myagg.count > 1) { double t = myagg.count*value - myagg.sum; myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); } - + // i=0 was pulled out to remove count > 1 check for (int i=1; i extends VectorAggregateExpression { public void setInputExpression(VectorExpression inputExpression) { this.inputExpression = inputExpression; - } + } } diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt new file mode 100644 index 0000000..e626161 --- /dev/null +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFVarDecimal.txt @@ -0,0 +1,472 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** +* . Vectorized implementation for VARIANCE aggregates. +*/ +@Description(name = "", + value = "") +public class extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** + /* class for storing the current aggregate value. + */ + private static final class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 sum; + transient private long count; + transient private double variance; + + /** + * Value is explicitly (re)initialized in reset() + */ + transient private boolean isNull = true; + + public Aggregation() { + sum = new Decimal128(); + } + + public void init() { + isNull = false; + sum.zeroClear(); + count = 0; + variance = 0f; + } + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + + public void updateValueWithCheckAndInit(Decimal128 scratch, Decimal128 value) { + if (this.isNull) { + this.init(); + } + this.sum.addDestructive(value, value.getScale()); + this.count += 1; + if(this.count > 1) { + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + double t = scratch.doubleValue(); + this.variance += (t*t) / ((double)this.count*(this.count-1)); + } + } + + public void updateValueNoCheck(Decimal128 scratch, Decimal128 value) { + this.sum.addDestructive(value, value.getScale()); + this.count += 1; + scratch.update(count); + scratch.multiplyDestructive(value, value.getScale()); + scratch.subtractDestructive(sum, sum.getScale()); + double t = scratch.doubleValue(); + this.variance += (t*t) / ((double)this.count*(this.count-1)); + } + + } + + private VectorExpression inputExpression; + transient private LongWritable resultCount; + transient private DoubleWritable resultSum; + transient private DoubleWritable resultVariance; + transient private Object[] partialResult; + + transient private ObjectInspector soi; + + transient private final Decimal128 scratchDecimal; + + + public (VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public () { + super(); + partialResult = new Object[3]; + resultCount = new LongWritable(); + resultSum = new DoubleWritable(); + resultVariance = new DoubleWritable(); + partialResult[0] = resultCount; + partialResult[1] = resultSum; + partialResult[2] = resultVariance; + initPartialResultInspector(); + scratchDecimal = new Decimal128(); + } + + private void initPartialResultInspector() { + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + List fname = new ArrayList(); + fname.add("count"); + fname.add("sum"); + fname.add("variance"); + + soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); + return myagg; + } + + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + VectorizedRowBatch batch) throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = (DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls || !inputVector.isNull[0]) { + iterateRepeatingNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector[0], batchSize); + } + } + else if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, + inputVector.isNull, batch.selected); + } + + } + + private void iterateRepeatingNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i 1 check in the loop + for (int i=1; i 1 check in the loop + // + for (int i=1; i< batchSize; ++i) { + value = vector[selected[i]]; + myagg.updateValueNoCheck(scratchDecimal, value); + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i 1 check + for (int i=1; i { + protected void assignDecimal(HiveDecimal value, int index) { + outCol.vector[index].update(value.unscaledValue(), (byte) value.scale()); + } + + protected void assignDecimal(Decimal128 value, int index) { + outCol.vector[index].update(value); + } + protected void assignDecimal(HiveDecimalWritable hdw, int index) { + byte[] internalStorage = hdw.getInternalStorage(); + int scale = hdw.getScale(); + + outCol.vector[index].fastUpdateFromInternalStorage(internalStorage, (short)scale); + } + } + public static VectorColumnAssign[] buildAssigners(VectorizedRowBatch outputBatch) throws HiveException { @@ -175,6 +195,14 @@ protected void copyValue(BytesColumnVector src, int srcIndex, int destIndex) { } }.init(outputBatch, (BytesColumnVector) cv); } + else if (cv instanceof DecimalColumnVector) { + vca[i] = new VectorDecimalColumnAssign() { + @Override + protected void copyValue(DecimalColumnVector src, int srcIndex, int destIndex) { + assignDecimal(src.vector[srcIndex], destIndex); + } + }; + } else { throw new HiveException("Unimplemented vector column type: " + cv.getClass().getName()); } @@ -336,6 +364,27 @@ public void assignObjectValue(Object val, int destIndex) throws HiveException { poi.getPrimitiveCategory()); } } + else if (destCol instanceof DecimalColumnVector) { + switch(poi.getPrimitiveCategory()) { + case DECIMAL: + outVCA = new VectorDecimalColumnAssign() { + @Override + public void assignObjectValue(Object val, int destIndex) throws HiveException { + if (val == null) { + assignNull(destIndex); + } + else { + HiveDecimalWritable hdw = (HiveDecimalWritable) val; + assignDecimal(hdw, destIndex); + } + } + }.init(outputBatch, (DecimalColumnVector) destCol); + break; + default: + throw new HiveException("Incompatible Decimal vector column and primitive category " + + poi.getPrimitiveCategory()); + } + } else { throw new HiveException("Unknown vector column type " + destCol.getClass().getName()); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java index d9855c1..4de9f9f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java @@ -49,7 +49,7 @@ public int getValue() { public static ArgumentType getType(String inType) { String type = VectorizationContext.getNormalizedTypeName(inType); - if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()) { + if (VectorizationContext.decimalTypePattern.matcher(type).matches()) { type = "decimal"; } return valueOf(type.toUpperCase()); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java index f083d86..a2a7266 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapper.java @@ -20,6 +20,7 @@ import java.util.Arrays; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.KeyWrapper; import org.apache.hadoop.hive.ql.exec.vector.expressions.StringExpr; import org.apache.hadoop.hive.ql.metadata.HiveException; @@ -42,16 +43,23 @@ private int[] byteStarts; private int[] byteLengths; + private Decimal128[] decimalValues; + private boolean[] isNull; private int hashcode; - public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount, int byteValuesCount) { + public VectorHashKeyWrapper(int longValuesCount, int doubleValuesCount, + int byteValuesCount, int decimalValuesCount) { longValues = new long[longValuesCount]; doubleValues = new double[doubleValuesCount]; + decimalValues = new Decimal128[decimalValuesCount]; + for(int i = 0; i < decimalValuesCount; ++i) { + decimalValues[i] = new Decimal128(); + } byteValues = new byte[byteValuesCount][]; byteStarts = new int[byteValuesCount]; byteLengths = new int[byteValuesCount]; - isNull = new boolean[longValuesCount + doubleValuesCount + byteValuesCount]; + isNull = new boolean[longValuesCount + doubleValuesCount + byteValuesCount + decimalValuesCount]; } private VectorHashKeyWrapper() { @@ -66,6 +74,7 @@ public void getNewKey(Object row, ObjectInspector rowInspector) throws HiveExcep public void setHashKey() { hashcode = Arrays.hashCode(longValues) ^ Arrays.hashCode(doubleValues) ^ + Arrays.hashCode(decimalValues) ^ Arrays.hashCode(isNull); // This code, with branches and all, is not executed if there are no string keys @@ -104,6 +113,7 @@ public boolean equals(Object that) { return hashcode == keyThat.hashcode && Arrays.equals(longValues, keyThat.longValues) && Arrays.equals(doubleValues, keyThat.doubleValues) && + Arrays.equals(decimalValues, keyThat.decimalValues) && Arrays.equals(isNull, keyThat.isNull) && byteValues.length == keyThat.byteValues.length && (0 == byteValues.length || bytesEquals(keyThat)); @@ -137,6 +147,12 @@ protected Object clone() { clone.doubleValues = doubleValues.clone(); clone.isNull = isNull.clone(); + // Decimal128 requires deep clone + clone.decimalValues = new Decimal128[decimalValues.length]; + for(int i = 0; i < decimalValues.length; ++i) { + clone.decimalValues[i] = new Decimal128().update(decimalValues[i]); + } + clone.byteValues = new byte[byteValues.length][]; clone.byteStarts = new int[byteValues.length]; clone.byteLengths = byteLengths.clone(); @@ -201,13 +217,22 @@ public void assignNullString(int index) { isNull[longValues.length + doubleValues.length + index] = true; } + public void assignDecimal(int index, Decimal128 value) { + decimalValues[index].update(value); + } + + public void assignNullDecimal(int index) { + isNull[longValues.length + doubleValues.length + byteValues.length + index] = true; + } + @Override public String toString() { - return String.format("%d[%s] %d[%s] %d[%s]", + return String.format("%d[%s] %d[%s] %d[%s] %d[%s]", longValues.length, Arrays.toString(longValues), doubleValues.length, Arrays.toString(doubleValues), - byteValues.length, Arrays.toString(byteValues)); + byteValues.length, Arrays.toString(byteValues), + decimalValues.length, Arrays.toString(decimalValues)); } public boolean getIsLongNull(int i) { @@ -222,7 +247,7 @@ public boolean getIsBytesNull(int i) { return isNull[longValues.length + doubleValues.length + i]; } - + public long getLongValue(int i) { return longValues[i]; } @@ -252,6 +277,12 @@ public int getVariableSize() { return variableSize; } + public boolean getIsDecimalNull(int i) { + return isNull[longValues.length + doubleValues.length + byteValues.length + i]; + } + public Decimal128 getDecimal(int i) { + return decimalValues[i]; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java index e978110..581046e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java @@ -40,9 +40,40 @@ private int longIndex; private int doubleIndex; private int stringIndex; + private int decimalIndex; + + private static final int INDEX_UNUSED = -1; + + private void resetIndices() { + this.longIndex = this.doubleIndex = this.stringIndex = this.decimalIndex = INDEX_UNUSED; + } + public void setLong(int index) { + resetIndices(); + this.longIndex= index; + } + + public void setDouble(int index) { + resetIndices(); + this.doubleIndex = index; + } + + public void setString(int index) { + resetIndices(); + this.stringIndex = index; + } + + public void setDecimal(int index) { + resetIndices(); + this.decimalIndex = index; + } } /** + * Number of object references in 'this' (for size computation) + */ + private static final int MODEL_REFERENCES_COUNT = 7; + + /** * The key expressions that require evaluation and output the primitive values for each key. */ private VectorExpression[] keyExpressions; @@ -63,6 +94,11 @@ private int[] stringIndices; /** + * indices of decimal primitive keys. + */ + private int[] decimalIndices; + + /** * Pre-allocated batch size vector of keys wrappers. * N.B. these keys are **mutable** and should never be used in a HashMap. * Always clone the key wrapper to obtain an immutable keywrapper suitable @@ -175,6 +211,28 @@ public void evaluateBatch(VectorizedRowBatch batch) throws HiveException { columnVector.noNulls, columnVector.isRepeating, batch.selectedInUse)); } } + for(int i=0;i= 0) { + return kw.getIsDecimalNull(klh.decimalIndex)? null : + keyOutputWriter.writeValue( + kw.getDecimal(klh.decimalIndex)); + } + else { throw new HiveException(String.format( "Internal inconsistent KeyLookupHelper at index [%d]:%d %d %d", i, klh.longIndex, klh.doubleIndex, klh.stringIndex)); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java index 036f080..2466a3b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorMapJoinOperator.java @@ -63,8 +63,8 @@ private int tagLen; private VectorExpression[] keyExpressions; - private VectorHashKeyWrapperBatch keyWrapperBatch; - private VectorExpressionWriter[] keyOutputWriters; + private transient VectorHashKeyWrapperBatch keyWrapperBatch; + private transient VectorExpressionWriter[] keyOutputWriters; private VectorExpression[] bigTableFilterExpressions; private VectorExpression[] bigTableValueExpressions; @@ -111,7 +111,6 @@ public VectorMapJoinOperator (VectorizationContext vContext, OperatorDesc conf) List keyDesc = desc.getKeys().get(posBigTable); keyExpressions = vContext.getVectorExpressions(keyDesc); - keyOutputWriters = VectorExpressionWriterFactory.getExpressionWriters(keyDesc); // We're only going to evaluate the big table vectorized expressions, Map> exprs = desc.getExprs(); @@ -135,6 +134,8 @@ public VectorMapJoinOperator (VectorizationContext vContext, OperatorDesc conf) public void initializeOp(Configuration hconf) throws HiveException { super.initializeOp(hconf); + List keyDesc = conf.getKeys().get(posBigTable); + keyOutputWriters = VectorExpressionWriterFactory.getExpressionWriters(keyDesc); vrbCtx = new VectorizedRowBatchCtx(); vrbCtx.init(hconf, this.fileKey, (StructObjectInspector) this.outputObjInspector); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index f69bfc0..153aca3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -40,28 +40,37 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.ArgumentType; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.InputExpressionType; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Mode; import org.apache.hadoop.hive.ql.exec.vector.expressions.*; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFAvgDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCount; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFCountStar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorUDAFSumDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFAvgLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMaxString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFMinString; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdPopLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFStdSampLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFSumLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarPopLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDecimal; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.gen.VectorUDAFVarSampLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastLongToBooleanViaLongToLong; @@ -114,7 +123,8 @@ private final Map columnMap; private final int firstOutputColumnIndex; - public static final Pattern decimalTypePattern = Pattern.compile("decimal.*"); + public static final Pattern decimalTypePattern = Pattern.compile("decimal.*", + Pattern.CASE_INSENSITIVE); //Map column number to type private final OutputColumnManager ocm; @@ -869,7 +879,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode ExprNodeDesc colExpr = childExpr.get(0); TypeInfo colTypeInfo = colExpr.getTypeInfo(); String colType = colExpr.getTypeString(); - + // prepare arguments for createVectorExpression List childrenForInList = foldConstantsForUnaryExprs(childExpr.subList(1, childExpr.size())); @@ -1111,7 +1121,7 @@ private VectorExpression getBetweenFilterExpression(List childExpr String colType = colExpr.getTypeString(); // prepare arguments for createVectorExpression - List childrenAfterNot = foldConstantsForUnaryExprs(childExpr.subList(1, 4)); + List childrenAfterNot = foldConstantsForUnaryExprs(childExpr.subList(1, 4));; // determine class Class cl = null; @@ -1241,6 +1251,10 @@ public static boolean isIntFamily(String resultType) { || resultType.equalsIgnoreCase("long"); } + public static boolean isDecimalFamily(String colType) { + return decimalTypePattern.matcher(colType).matches(); + } + private Object getScalarValue(ExprNodeConstantDesc constDesc) throws HiveException { if (constDesc.getTypeString().equalsIgnoreCase("String")) { @@ -1353,14 +1367,13 @@ private long evaluateCastToTimestamp(ExprNodeDesc expr) throws HiveException { } } - static String getNormalizedTypeName(String colType) { + static String getNormalizedTypeName(String colType){ String normalizedType = null; if (colType.equalsIgnoreCase("Double") || colType.equalsIgnoreCase("Float")) { normalizedType = "Double"; } else if (colType.equalsIgnoreCase("String")) { normalizedType = "String"; - } else if (decimalTypePattern.matcher(colType.toLowerCase()).matches()) { - + } else if (decimalTypePattern.matcher(colType).matches()) { //Return the decimal type as is, it includes scale and precision. normalizedType = colType; } else { @@ -1373,31 +1386,43 @@ static String getNormalizedTypeName(String colType) { {"min", "Long", VectorUDAFMinLong.class}, {"min", "Double", VectorUDAFMinDouble.class}, {"min", "String", VectorUDAFMinString.class}, + {"min", "Decimal",VectorUDAFMinDecimal.class}, {"max", "Long", VectorUDAFMaxLong.class}, {"max", "Double", VectorUDAFMaxDouble.class}, {"max", "String", VectorUDAFMaxString.class}, + {"max", "Decimal",VectorUDAFMaxDecimal.class}, {"count", null, VectorUDAFCountStar.class}, {"count", "Long", VectorUDAFCount.class}, {"count", "Double", VectorUDAFCount.class}, {"count", "String", VectorUDAFCount.class}, + {"count", "Decimal",VectorUDAFCount.class}, {"sum", "Long", VectorUDAFSumLong.class}, {"sum", "Double", VectorUDAFSumDouble.class}, + {"sum", "Decimal",VectorUDAFSumDecimal.class}, {"avg", "Long", VectorUDAFAvgLong.class}, {"avg", "Double", VectorUDAFAvgDouble.class}, + {"avg", "Decimal",VectorUDAFAvgDecimal.class}, {"variance", "Long", VectorUDAFVarPopLong.class}, {"var_pop", "Long", VectorUDAFVarPopLong.class}, {"variance", "Double", VectorUDAFVarPopDouble.class}, {"var_pop", "Double", VectorUDAFVarPopDouble.class}, + {"variance", "Decimal",VectorUDAFVarPopDecimal.class}, + {"var_pop", "Decimal",VectorUDAFVarPopDecimal.class}, {"var_samp", "Long", VectorUDAFVarSampLong.class}, {"var_samp" , "Double", VectorUDAFVarSampDouble.class}, + {"var_samp" , "Decimal",VectorUDAFVarSampDecimal.class}, {"std", "Long", VectorUDAFStdPopLong.class}, {"stddev", "Long", VectorUDAFStdPopLong.class}, {"stddev_pop","Long", VectorUDAFStdPopLong.class}, {"std", "Double", VectorUDAFStdPopDouble.class}, {"stddev", "Double", VectorUDAFStdPopDouble.class}, {"stddev_pop","Double", VectorUDAFStdPopDouble.class}, + {"std", "Decimal",VectorUDAFStdPopDecimal.class}, + {"stddev", "Decimal",VectorUDAFStdPopDecimal.class}, + {"stddev_pop","Decimal",VectorUDAFStdPopDecimal.class}, {"stddev_samp","Long", VectorUDAFStdSampLong.class}, {"stddev_samp","Double",VectorUDAFStdSampDouble.class}, + {"stddev_samp","Decimal",VectorUDAFStdSampDecimal.class}, }; public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) @@ -1417,6 +1442,9 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) if (paramDescList.size() > 0) { ExprNodeDesc inputExpr = paramDescList.get(0); inputType = getNormalizedTypeName(inputExpr.getTypeString()); + if (decimalTypePattern.matcher(inputType).matches()) { + inputType = "Decimal"; + } } for (Object[] aggDef : aggregatesDefinition) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java index d409d44..a7efa41 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java @@ -394,7 +394,7 @@ private ColumnVector allocateColumnVector(String type, int defaultSize) { return new DoubleColumnVector(defaultSize); } else if (type.equalsIgnoreCase("string")) { return new BytesColumnVector(defaultSize); - } else if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()){ + } else if (VectorizationContext.decimalTypePattern.matcher(type).matches()){ int [] precisionScale = getScalePrecisionFromDecimalType(type); return new DecimalColumnVector(defaultSize, precisionScale[0], precisionScale[1]); } else { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java new file mode 100644 index 0000000..6f593f9 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFAvgDecimal.java @@ -0,0 +1,516 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage; +import org.apache.hadoop.hive.ql.util.JavaDataModel; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hive.common.util.Decimal128FastBuffer; + +/** + * Generated from template VectorUDAFAvg.txt. + */ +@Description(name = "avg", + value = "_FUNC_(AVG) - Returns the average value of expr (vectorized, type: decimal)") +public class VectorUDAFAvgDecimal extends VectorAggregateExpression { + + private static final long serialVersionUID = 1L; + + /** class for storing the current aggregate value. */ + static class Aggregation implements AggregationBuffer { + + private static final long serialVersionUID = 1L; + + transient private final Decimal128 sum = new Decimal128(); + transient private long count; + transient private boolean isNull; + + public void sumValueWithCheck(Decimal128 value, short scale) { + if (isNull) { + sum.update(value); + sum.changeScaleDestructive(scale); + count = 1; + isNull = false; + } else { + sum.addDestructive(value, scale); + count++; + } + } + + public void sumValueNoCheck(Decimal128 value, short scale) { + sum.addDestructive(value, scale); + count++; + } + + + @Override + public int getVariableSize() { + throw new UnsupportedOperationException(); + } + } + + private VectorExpression inputExpression; + transient private Object[] partialResult; + transient private LongWritable resultCount; + transient private HiveDecimalWritable resultSum; + transient private StructObjectInspector soi; + + transient private final Decimal128FastBuffer scratch; + + /** + * The scale of the SUM in the partial output + */ + private short sumScale; + + /** + * The precision of the SUM in the partial output + */ + private short sumPrecision; + + /** + * the scale of the input expression + */ + private short inputScale; + + /** + * the precision of the input expression + */ + private short inputPrecision; + + /** + * A value used as scratch to avoid allocating at runtime. + * Needed by computations like vector[0] * batchSize + */ + transient private Decimal128 scratchDecimal = new Decimal128(); + + public VectorUDAFAvgDecimal(VectorExpression inputExpression) { + this(); + this.inputExpression = inputExpression; + } + + public VectorUDAFAvgDecimal() { + super(); + partialResult = new Object[2]; + resultCount = new LongWritable(); + resultSum = new HiveDecimalWritable(); + partialResult[0] = resultCount; + partialResult[1] = resultSum; + scratch = new Decimal128FastBuffer(); + + } + + private void initPartialResultInspector() { + // the output type of the vectorized partial aggregate must match the + // expected type for the row-mode aggregation + // For decimal, the type is "same number of integer digits and 4 more decimal digits" + + DecimalTypeInfo dtiSum = GenericUDAFAverage.deriveSumTypeInfo(inputScale, inputPrecision); + this.sumScale = (short) dtiSum.scale(); + this.sumPrecision = (short) dtiSum.precision(); + + List foi = new ArrayList(); + foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + foi.add(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(dtiSum)); + List fname = new ArrayList(); + fname.add("count"); + fname.add("sum"); + soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); + } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(bufferIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = ( DecimalColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + Decimal128[] vector = inputVector.vector; + + if (inputVector.noNulls) { + if (inputVector.isRepeating) { + iterateNoNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize); + } else { + if (batch.selectedInUse) { + iterateNoNullsSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batch.selected, batchSize); + } else { + iterateNoNullsWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize); + } + } + } else { + if (inputVector.isRepeating) { + if (batch.selectedInUse) { + iterateHasNullsRepeatingSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector[0], batchSize, inputVector.isNull); + } + } else { + if (batch.selectedInUse) { + iterateHasNullsSelectionWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsWithAggregationSelection( + aggregationBufferSets, bufferIndex, + vector, batchSize, inputVector.isNull); + } + } + } + } + + private void iterateNoNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + + private void iterateNoNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int[] selection, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[selection[i]], this.sumScale); + } + } + + private void iterateNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize) { + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + + private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[selection[i]]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + + } + + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128 value, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(value, this.sumScale); + } + } + } + + private void iterateHasNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int j=0; j < batchSize; ++j) { + int i = selection[j]; + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + j); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + } + + private void iterateHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int bufferIndex, + Decimal128[] values, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + bufferIndex, + i); + myagg.sumValueWithCheck(values[i], this.sumScale); + } + } + } + + + @Override + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) + throws HiveException { + + inputExpression.evaluate(batch); + + DecimalColumnVector inputVector = + (DecimalColumnVector)batch.cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + Aggregation myagg = (Aggregation)agg; + + Decimal128[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls) { + if (myagg.isNull) { + myagg.isNull = false; + myagg.sum.zeroClear(); + myagg.count = 0; + } + scratchDecimal.update(batchSize); + scratchDecimal.multiplyDestructive(vector[0], vector[0].getScale()); + myagg.sum.update(scratchDecimal); + myagg.count += batchSize; + } + return; + } + + if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNulls(myagg, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); + } + } + + private void iterateSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull, + int[] selected) { + + for (int j=0; j< batchSize; ++j) { + int i = selected[j]; + if (!isNull[i]) { + Decimal128 value = vector[i]; + myagg.sumValueWithCheck(value, this.sumScale); + } + } + } + + private void iterateSelectionNoNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + int[] selected) { + + if (myagg.isNull) { + myagg.isNull = false; + myagg.sum.zeroClear(); + myagg.count = 0; + } + + for (int i=0; i< batchSize; ++i) { + Decimal128 value = vector[selected[i]]; + myagg.sumValueNoCheck(value, this.sumScale); + } + } + + private void iterateNoSelectionHasNulls( + Aggregation myagg, + Decimal128[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i)aggregation); } } + + /** + * The result type has the same number of integer digits and 4 more decimal digits + * This is exposed as static so that the vectorized AVG operator use the same logic + * @param scale + * @param precision + * @return + */ + public static DecimalTypeInfo deriveSumTypeInfo(int scale, int precision) { + int intPart = precision - scale; + scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart); + return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale); + } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java index a2b45f8..81a0bb4 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java @@ -26,6 +26,8 @@ import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; import java.lang.reflect.Constructor; +import java.math.BigDecimal; +import java.math.BigInteger; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +38,8 @@ import java.util.Map; import java.util.Set; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureOutputOperator; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromLongIterables; @@ -48,6 +52,7 @@ import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -158,9 +163,9 @@ private static GroupByDesc buildKeyGroupByDesc( return desc; } - + long outputRowCount = 0; - + @Test public void testMemoryPressureFlush() throws HiveException { @@ -169,22 +174,22 @@ public void testMemoryPressureFlush() throws HiveException { mapColumnNames.put("Value", 1); VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); - GroupByDesc desc = buildKeyGroupByDesc (ctx, "max", - "Value", TypeInfoFactory.longTypeInfo, + GroupByDesc desc = buildKeyGroupByDesc (ctx, "max", + "Value", TypeInfoFactory.longTypeInfo, "Key", TypeInfoFactory.longTypeInfo); - + // Set the memory treshold so that we get 100Kb before we need to flush. MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean(); long maxMemory = memoryMXBean.getHeapMemoryUsage().getMax(); - + float treshold = 100.0f*1024.0f/maxMemory; desc.setMemoryThreshold(treshold); VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); - + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); vgo.initialize(null, null); - + this.outputRowCount = 0; out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() { @Override @@ -192,7 +197,7 @@ public void inspectRow(Object row, int tag) throws HiveException { ++outputRowCount; } }); - + Iterable it = new Iterable() { @Override public Iterator iterator() { @@ -215,7 +220,7 @@ public void remove() { }; } }; - + FakeVectorRowBatchFromObjectIterables data = new FakeVectorRowBatchFromObjectIterables( 100, new String[] {"long", "long"}, @@ -223,7 +228,7 @@ public void remove() { it); // The 'it' data source will produce data w/o ever ending - // We want to see that memory pressure kicks in and some + // We want to see that memory pressure kicks in and some // entries in the VGBY are flushed. long countRowsProduced = 0; for (VectorizedRowBatch unit: data) { @@ -237,7 +242,7 @@ public void remove() { // It should not go beyond 100k/16 (key+data) assertTrue(countRowsProduced < 100*1024/16); } - + assertTrue(0 < outputRowCount); } @@ -596,6 +601,178 @@ public void testCountStar() throws HiveException { } @Test + public void testCountDecimal() throws HiveException { + testAggregateDecimal( + "count", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + 3L); + } + + @Test + public void testMaxDecimal() throws HiveException { + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(3)); + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(3), + new Decimal128(2), + new Decimal128(1)}), + new Decimal128(3)); + testAggregateDecimal( + "max", + 2, + Arrays.asList(new Object[]{ + new Decimal128(2), + new Decimal128(3), + new Decimal128(1)}), + new Decimal128(3)); + } + + @Test + public void testMinDecimal() throws HiveException { + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(1)); + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(3), + new Decimal128(2), + new Decimal128(1)}), + new Decimal128(1)); + + testAggregateDecimal( + "min", + 2, + Arrays.asList(new Object[]{ + new Decimal128(2), + new Decimal128(1), + new Decimal128(3)}), + new Decimal128(1)); + } + + @Test + public void testSumDecimal() throws HiveException { + testAggregateDecimal( + "sum", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + new Decimal128(1+2+3)); + } + + @Test + public void testAvgDecimal() throws HiveException { + testAggregateDecimal( + "avg", + 2, + Arrays.asList(new Object[]{ + new Decimal128(1), + new Decimal128(2), + new Decimal128(3)}), + HiveDecimal.create((1+2+3)/3)); + } + + @Test + public void testAvgDecimalNegative() throws HiveException { + testAggregateDecimal( + "avg", + 2, + Arrays.asList(new Object[]{ + new Decimal128(-1), + new Decimal128(-2), + new Decimal128(-3)}), + HiveDecimal.create((-1-2-3)/3)); + } + + @Test + public void testVarianceDecimal () throws HiveException { + testAggregateDecimal( + "variance", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) 30); + } + + @Test + public void testVarSampDecimal () throws HiveException { + testAggregateDecimal( + "var_samp", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) 40); + } + + @Test + public void testStdPopDecimal () throws HiveException { + testAggregateDecimal( + "stddev_pop", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) Math.sqrt(30)); + } + + @Test + public void testStdSampDecimal () throws HiveException { + testAggregateDecimal( + "stddev_samp", + 2, + Arrays.asList(new Object[]{ + new Decimal128(13), + new Decimal128(5), + new Decimal128(7), + new Decimal128(19)}), + (double) Math.sqrt(40)); + } + + @Test + public void testDecimalKeyTypeAggregate() throws HiveException { + testKeyTypeAggregate( + "sum", + new FakeVectorRowBatchFromObjectIterables( + 2, + new String[] {"decimal(38,0)", "bigint"}, + Arrays.asList(new Object[]{ + new Decimal128(1),null, + new Decimal128(1), null}), + Arrays.asList(new Object[]{13L,null,7L, 19L})), + buildHashMap(HiveDecimal.create(1), 20L, null, 19L)); + } + + + @Test public void testCountString() throws HiveException { testAggregateString( "count", @@ -1655,6 +1832,9 @@ public void inspectRow(Object row, int tag) throws HiveException { } else if (key instanceof BooleanWritable) { BooleanWritable bwKey = (BooleanWritable)key; keyValue = bwKey.get(); + } else if (key instanceof HiveDecimalWritable) { + HiveDecimalWritable hdwKey = (HiveDecimalWritable)key; + keyValue = hdwKey.getHiveDecimal(); } else { Assert.fail(String.format("Not implemented key output type %s: %s", key.getClass().getName(), key)); @@ -1755,6 +1935,19 @@ public void testAggregateLongKeyAggregate ( testAggregateLongKeyIterable (aggregateName, fdr, expected); } + public void testAggregateDecimal ( + String aggregateName, + int batchSize, + Iterable values, + Object expected) throws HiveException { + + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromObjectIterables fdr = new FakeVectorRowBatchFromObjectIterables( + batchSize, new String[] {"Decimal"}, values); + testAggregateDecimalIterable (aggregateName, fdr, expected); + } + + public void testAggregateString ( String aggregateName, int batchSize, @@ -1832,6 +2025,15 @@ public void validate(String key, Object expected, Object result) { assertEquals (key, (Double) expected, (Double) arr[0]); } else if (arr[0] instanceof Long) { assertEquals (key, (Long) expected, (Long) arr[0]); + } else if (arr[0] instanceof HiveDecimalWritable) { + HiveDecimalWritable hdw = (HiveDecimalWritable) arr[0]; + HiveDecimal hd = hdw.getHiveDecimal(); + Decimal128 d128 = (Decimal128)expected; + assertEquals (key, d128.toBigDecimal(), hd.bigDecimalValue()); + } else if (arr[0] instanceof HiveDecimal) { + HiveDecimal hd = (HiveDecimal) arr[0]; + Decimal128 d128 = (Decimal128)expected; + assertEquals (key, d128.toBigDecimal(), hd.bigDecimalValue()); } else { Assert.fail("Unsupported result type: " + arr[0].getClass().getName()); } @@ -1853,11 +2055,16 @@ public void validate(String key, Object expected, Object result) { assertEquals (2, vals.length); assertEquals (true, vals[0] instanceof LongWritable); - assertEquals (true, vals[1] instanceof DoubleWritable); LongWritable lw = (LongWritable) vals[0]; - DoubleWritable dw = (DoubleWritable) vals[1]; assertFalse (lw.get() == 0L); - assertEquals (key, (Double) expected, (Double) (dw.get() / lw.get())); + + if (vals[1] instanceof DoubleWritable) { + DoubleWritable dw = (DoubleWritable) vals[1]; + assertEquals (key, (Double) expected, (Double) (dw.get() / lw.get())); + } else if (vals[1] instanceof HiveDecimalWritable) { + HiveDecimalWritable hdw = (HiveDecimalWritable) vals[1]; + assertEquals (key, (HiveDecimal) expected, hdw.getHiveDecimal().divide(HiveDecimal.create(lw.get()))); + } } } @@ -1935,6 +2142,7 @@ void validateVariance(String key, double expected, long cnt, double sum, double {"var_samp", VarianceSampValidator.class}, {"std", StdValidator.class}, {"stddev", StdValidator.class}, + {"stddev_pop", StdValidator.class}, {"stddev_samp", StdSampValidator.class}, }; @@ -2015,6 +2223,38 @@ public void testAggregateStringIterable ( validator.validate("_total", expected, result); } + public void testAggregateDecimalIterable ( + String aggregateName, + Iterable data, + Object expected) throws HiveException { + Map mapColumnNames = new HashMap(); + mapColumnNames.put("A", 0); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 1); + + GroupByDesc desc = buildGroupByDescType(ctx, aggregateName, "A", + TypeInfoFactory.getDecimalTypeInfo(30, 4)); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + + for (VectorizedRowBatch unit: data) { + vgo.processOp(unit, 0); + } + vgo.close(false); + + List outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(1, outBatchList.size()); + + Object result = outBatchList.get(0); + + Validator validator = getValidator(aggregateName); + validator.validate("_total", expected, result); + } + + public void testAggregateDoubleIterable ( String aggregateName, Iterable data, diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java index c8eaea1..ba7b0f9 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeVectorRowBatchFromObjectIterables.java @@ -23,8 +23,10 @@ import java.util.Iterator; import java.util.List; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; @@ -138,6 +140,18 @@ public void assign( dcv.vector[row] = Double.valueOf(value.toString()); } }; + } else if (types[i].toLowerCase().startsWith("decimal")) { + batch.cols[i] = new DecimalColumnVector(batchSize, 38, 0); + columnAssign[i] = new ColumnVectorAssign() { + @Override + public void assign( + ColumnVector columnVector, + int row, + Object value) { + DecimalColumnVector dcv = (DecimalColumnVector) columnVector; + dcv.vector[row] = (Decimal128)value; + } + }; } else { throw new HiveException("Unimplemented type " + types[i]); } diff --git ql/src/test/queries/clientpositive/vector_decimal_aggregate.q ql/src/test/queries/clientpositive/vector_decimal_aggregate.q new file mode 100644 index 0000000..eb9146e --- /dev/null +++ ql/src/test/queries/clientpositive/vector_decimal_aggregate.q @@ -0,0 +1,20 @@ +CREATE TABLE decimal_vgby STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc; + +SET hive.vectorized.execution.enabled=true; + +EXPLAIN SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; +SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1; \ No newline at end of file diff --git ql/src/test/queries/clientpositive/vector_decimal_mapjoin.q ql/src/test/queries/clientpositive/vector_decimal_mapjoin.q new file mode 100644 index 0000000..d8b3d1a --- /dev/null +++ ql/src/test/queries/clientpositive/vector_decimal_mapjoin.q @@ -0,0 +1,19 @@ +CREATE TABLE decimal_mapjoin STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc; + +SET hive.auto.convert.join=true; +SET hive.auto.convert.join.nonconditionaltask=true; +SET hive.auto.convert.join.nonconditionaltask.size=1000000000; +SET hive.vectorized.execution.enabled=true; + +EXPLAIN SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981; +SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981; \ No newline at end of file diff --git ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out new file mode 100644 index 0000000..8b73971 --- /dev/null +++ ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out @@ -0,0 +1,109 @@ +PREHOOK: query: CREATE TABLE decimal_vgby STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc +PREHOOK: type: CREATETABLE_AS_SELECT +PREHOOK: Input: default@alltypesorc +POSTHOOK: query: CREATE TABLE decimal_vgby STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc +POSTHOOK: type: CREATETABLE_AS_SELECT +POSTHOOK: Input: default@alltypesorc +POSTHOOK: Output: default@decimal_vgby +PREHOOK: query: EXPLAIN SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Map Operator Tree: + TableScan + alias: decimal_vgby + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: cint (type: int), cdecimal1 (type: decimal(20,10)), cdecimal2 (type: decimal(23,14)) + outputColumnNames: cint, cdecimal1, cdecimal2 + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + Group By Operator + aggregations: count(cdecimal1), max(cdecimal1), min(cdecimal1), sum(cdecimal1), avg(cdecimal1), stddev_pop(cdecimal1), stddev_samp(cdecimal1), count(cdecimal2), max(cdecimal2), min(cdecimal2), sum(cdecimal2), avg(cdecimal2), stddev_pop(cdecimal2), stddev_samp(cdecimal2), count() + keys: cint (type: int) + mode: hash + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14, _col15 + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: _col0 (type: int) + sort order: + + Map-reduce partition columns: _col0 (type: int) + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + value expressions: _col1 (type: bigint), _col2 (type: decimal(20,10)), _col3 (type: decimal(20,10)), _col4 (type: decimal(30,10)), _col5 (type: struct), _col6 (type: struct), _col7 (type: struct), _col8 (type: bigint), _col9 (type: decimal(23,14)), _col10 (type: decimal(23,14)), _col11 (type: decimal(33,14)), _col12 (type: struct), _col13 (type: struct), _col14 (type: struct), _col15 (type: bigint) + Execution mode: vectorized + Reduce Operator Tree: + Group By Operator + aggregations: count(VALUE._col0), max(VALUE._col1), min(VALUE._col2), sum(VALUE._col3), avg(VALUE._col4), stddev_pop(VALUE._col5), stddev_samp(VALUE._col6), count(VALUE._col7), max(VALUE._col8), min(VALUE._col9), sum(VALUE._col10), avg(VALUE._col11), stddev_pop(VALUE._col12), stddev_samp(VALUE._col13), count(VALUE._col14) + keys: KEY._col0 (type: int) + mode: mergepartial + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14, _col15 + Statistics: Num rows: 6144 Data size: 1082530 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: (_col15 > 1) (type: boolean) + Statistics: Num rows: 2048 Data size: 360843 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: _col0 (type: int), _col1 (type: bigint), _col2 (type: decimal(20,10)), _col3 (type: decimal(20,10)), _col4 (type: decimal(30,10)), _col5 (type: decimal(24,14)), _col6 (type: double), _col7 (type: double), _col8 (type: bigint), _col9 (type: decimal(23,14)), _col10 (type: decimal(23,14)), _col11 (type: decimal(33,14)), _col12 (type: decimal(27,18)), _col13 (type: double), _col14 (type: double) + outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14 + Statistics: Num rows: 2048 Data size: 360843 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 2048 Data size: 360843 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + +PREHOOK: query: SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1 +PREHOOK: type: QUERY +PREHOOK: Input: default@decimal_vgby +#### A masked pattern was here #### +POSTHOOK: query: SELECT cint, + COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), + COUNT(cdecimal2), MAX(cdecimal2), MIN(cdecimal2), SUM(cdecimal2), AVG(cdecimal2), STDDEV_POP(cdecimal2), STDDEV_SAMP(cdecimal2) + FROM decimal_vgby + GROUP BY cint + HAVING COUNT(*) > 1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@decimal_vgby +#### A masked pattern was here #### +NULL 3072 9318.4351351351 -4298.1513513514 5018444.1081079808 1633.60810810806667 5695.4830821353335 5696.410307714474 3072 11160.715384615385 -5147.907692307693 6010604.3076923073536 1956.576923076922966667 6821.495748565141 6822.606289190906 +-3728 6 5831542.269248378 -3367.6517567568 5817556.0411483778 16510.89638306946651 2174330.2092403853 2381859.406131774 6 6984454.211097692 -4033.445769230769 6967702.8672438458471 1161283.811207307641183333 2604201.2704476737 2852759.5602156054 +-563 2 -515.621072973 -3367.6517567568 -3883.2728297298 -1941.6364148649 1426.0153418919 2016.6902366556312 2 -617.5607769230769 -4033.445769230769 -4651.0065461538459 -2325.50327307692295 1707.9424961538462 2415.395441814127 +762 2 5831542.269248378 1531.2194054054 5833073.4886537834 2916536.7443268917 2915005.524921486 4122440.347736469 2 6984454.211097692 1833.9456923076925 6986288.1567899996925 3493144.07839499984625 3491310.132702692 4937458.140118757 +6981 3 5831542.269248378 -515.621072973 5830511.027102432 1943503.67570081066667 2749258.4550124914 3367140.192906513 3 6984454.211097692 -617.5607769230769 6983219.0895438458462 2327739.696514615282066667 3292794.4113115156 4032833.0678006653 +253665376 1024 9767.0054054054 -9779.5486486487 -347484.0818378374 -339.33992366976309 5708.9563478862 5711.745967572779 1024 11697.969230769231 -11712.99230769231 -416182.64030769233089 -406.428359675480791885 6837.632716002934 6840.973851172274 +528534767 1024 5831542.269248378 -9777.1594594595 11646372.8607481068 11373.41099682432305 257528.9298820665 257654.76860439766 1024 6984454.211097692 -11710.130769230771 13948892.79980307629003 13621.965624807691689482 308443.1074570801 308593.82484083984 +626923679 1024 9723.4027027027 -9778.9513513514 10541.0525297287 10.29399661106318 5742.09145323734 5744.897264034267 1024 11645.746153846154 -11712.276923076923 12625.04759999997746 12.329148046874977988 6877.318722794877 6880.679250101604 diff --git ql/src/test/results/clientpositive/vector_decimal_mapjoin.q.out ql/src/test/results/clientpositive/vector_decimal_mapjoin.q.out new file mode 100644 index 0000000..5fc2235 --- /dev/null +++ ql/src/test/results/clientpositive/vector_decimal_mapjoin.q.out @@ -0,0 +1,206 @@ +PREHOOK: query: CREATE TABLE decimal_mapjoin STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc +PREHOOK: type: CREATETABLE_AS_SELECT +PREHOOK: Input: default@alltypesorc +POSTHOOK: query: CREATE TABLE decimal_mapjoin STORED AS ORC AS + SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, + CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2, + cint + FROM alltypesorc +POSTHOOK: type: CREATETABLE_AS_SELECT +POSTHOOK: Input: default@alltypesorc +POSTHOOK: Output: default@decimal_mapjoin +PREHOOK: query: EXPLAIN SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-4 is a root stage + Stage-3 depends on stages: Stage-4 + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-4 + Map Reduce Local Work + Alias -> Map Local Tables: + l + Fetch Operator + limit: -1 + Alias -> Map Local Operator Tree: + l + TableScan + alias: l + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: (cint = 6981) (type: boolean) + Statistics: Num rows: 6144 Data size: 1082530 Basic stats: COMPLETE Column stats: NONE + HashTable Sink Operator + condition expressions: + 0 {cdecimal1} {cint} + 1 {cdecimal2} {cint} + keys: + 0 cint (type: int) + 1 cint (type: int) + + Stage: Stage-3 + Map Reduce + Map Operator Tree: + TableScan + alias: r + Statistics: Num rows: 12288 Data size: 2165060 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: (cint = 6981) (type: boolean) + Statistics: Num rows: 6144 Data size: 1082530 Basic stats: COMPLETE Column stats: NONE + Map Join Operator + condition map: + Inner Join 0 to 1 + condition expressions: + 0 {cdecimal1} {cint} + 1 {cdecimal2} {cint} + keys: + 0 cint (type: int) + 1 cint (type: int) + outputColumnNames: _col1, _col3, _col8, _col9 + Statistics: Num rows: 6758 Data size: 1190783 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: _col3 (type: int), _col9 (type: int), _col1 (type: decimal(20,10)), _col8 (type: decimal(23,14)) + outputColumnNames: _col0, _col1, _col2, _col3 + Statistics: Num rows: 6758 Data size: 1190783 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 6758 Data size: 1190783 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + Local Work: + Map Reduce Local Work + Execution mode: vectorized + + Stage: Stage-0 + Fetch Operator + limit: -1 + +PREHOOK: query: SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981 +PREHOOK: type: QUERY +PREHOOK: Input: default@decimal_mapjoin +#### A masked pattern was here #### +POSTHOOK: query: SELECT l.cint, r.cint, l.cdecimal1, r.cdecimal2 + FROM decimal_mapjoin l + JOIN decimal_mapjoin r ON l.cint = r.cint + WHERE l.cint = 6981 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@decimal_mapjoin +#### A masked pattern was here #### +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 5831542.269248378 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 NULL 6984454.211097692 +6981 6981 -515.621072973 6984454.211097692 +6981 6981 -515.621072973 6984454.211097692 +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 5831542.269248378 NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 NULL NULL +6981 6981 -515.621072973 NULL +6981 6981 -515.621072973 NULL +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 5831542.269248378 -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 -515.621072973 -617.5607769230769 +6981 6981 -515.621072973 -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 5831542.269248378 -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 NULL -617.5607769230769 +6981 6981 -515.621072973 -617.5607769230769 +6981 6981 -515.621072973 -617.5607769230769 diff --git serde/src/java/org/apache/hadoop/hive/serde2/io/HiveDecimalWritable.java serde/src/java/org/apache/hadoop/hive/serde2/io/HiveDecimalWritable.java index 008fda3..67cb1e8 100644 --- serde/src/java/org/apache/hadoop/hive/serde2/io/HiveDecimalWritable.java +++ serde/src/java/org/apache/hadoop/hive/serde2/io/HiveDecimalWritable.java @@ -148,4 +148,21 @@ public boolean equals(Object other) { public int hashCode() { return getHiveDecimal().hashCode(); } + + /* (non-Javadoc) + * In order to update a Decimal128 fast (w/o allocation) we need to expose access to the + * internal storage bytes and scale. + * @return + */ + public byte[] getInternalStorage() { + return internalStorage; + } + + /* (non-Javadoc) + * In order to update a Decimal128 fast (w/o allocation) we need to expose access to the + * internal storage bytes and scale. + */ + public int getScale() { + return scale; + } } diff --git serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java new file mode 100644 index 0000000..ff11be2 --- /dev/null +++ serde/src/test/org/apache/hadoop/hive/serde2/io/TestHiveDecimalWritable.java @@ -0,0 +1,220 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.serde2.io; + +import junit.framework.Assert; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; + +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hive.common.util.Decimal128FastBuffer; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for tsting the fast allocation-free conversion + * between HiveDecimalWritable and Decimal128 + */ +public class TestHiveDecimalWritable { + + private Decimal128FastBuffer scratch; + + @Before + public void setUp() throws Exception { + scratch = new Decimal128FastBuffer(); + } + + private void doTestFastStreamForHiveDecimal(String valueString) { + BigDecimal value = new BigDecimal(valueString); + Decimal128 dec = new Decimal128(); + dec.update(value); + + HiveDecimalWritable witness = new HiveDecimalWritable(); + witness.set(HiveDecimal.create(value)); + + int bufferUsed = dec.fastSerializeForHiveDecimal(scratch); + HiveDecimalWritable hdw = new HiveDecimalWritable(); + hdw.set(scratch.getBytes(bufferUsed), dec.getScale()); + + HiveDecimal hd = hdw.getHiveDecimal(); + + BigDecimal readValue = hd.bigDecimalValue(); + + Assert.assertEquals(value, readValue); + + // Now test fastUpdate from the same serialized HiveDecimal + Decimal128 decRead = new Decimal128().fastUpdateFromInternalStorage( + witness.getInternalStorage(), (short) witness.getScale()); + + Assert.assertEquals(dec, decRead); + + // Test fastUpdate from it's own (not fully compacted) serialized output + Decimal128 decReadSelf = new Decimal128().fastUpdateFromInternalStorage( + hdw.getInternalStorage(), (short) hdw.getScale()); + Assert.assertEquals(dec, decReadSelf); + } + + @Test + public void testFastStreamForHiveDecimal() { + + doTestFastStreamForHiveDecimal("0"); + doTestFastStreamForHiveDecimal("-0"); + doTestFastStreamForHiveDecimal("1"); + doTestFastStreamForHiveDecimal("-1"); + doTestFastStreamForHiveDecimal("2"); + doTestFastStreamForHiveDecimal("-2"); + doTestFastStreamForHiveDecimal("127"); + doTestFastStreamForHiveDecimal("-127"); + doTestFastStreamForHiveDecimal("128"); + doTestFastStreamForHiveDecimal("-128"); + doTestFastStreamForHiveDecimal("255"); + doTestFastStreamForHiveDecimal("-255"); + doTestFastStreamForHiveDecimal("256"); + doTestFastStreamForHiveDecimal("-256"); + doTestFastStreamForHiveDecimal("65535"); + doTestFastStreamForHiveDecimal("-65535"); + doTestFastStreamForHiveDecimal("65536"); + doTestFastStreamForHiveDecimal("-65536"); + + doTestFastStreamForHiveDecimal("10"); + doTestFastStreamForHiveDecimal("1000"); + doTestFastStreamForHiveDecimal("1000000"); + doTestFastStreamForHiveDecimal("1000000000"); + doTestFastStreamForHiveDecimal("1000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000000"); + doTestFastStreamForHiveDecimal("1000000000000000000000000000000"); + + doTestFastStreamForHiveDecimal("-10"); + doTestFastStreamForHiveDecimal("-1000"); + doTestFastStreamForHiveDecimal("-1000000"); + doTestFastStreamForHiveDecimal("-1000000000"); + doTestFastStreamForHiveDecimal("-1000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000000"); + doTestFastStreamForHiveDecimal("-1000000000000000000000000000000"); + + + doTestFastStreamForHiveDecimal("0.01"); + doTestFastStreamForHiveDecimal("-0.01"); + doTestFastStreamForHiveDecimal("0.02"); + doTestFastStreamForHiveDecimal("-0.02"); + doTestFastStreamForHiveDecimal("0.0127"); + doTestFastStreamForHiveDecimal("-0.0127"); + doTestFastStreamForHiveDecimal("0.0128"); + doTestFastStreamForHiveDecimal("-0.0128"); + doTestFastStreamForHiveDecimal("0.0255"); + doTestFastStreamForHiveDecimal("-0.0255"); + doTestFastStreamForHiveDecimal("0.0256"); + doTestFastStreamForHiveDecimal("-0.0256"); + doTestFastStreamForHiveDecimal("0.065535"); + doTestFastStreamForHiveDecimal("-0.065535"); + doTestFastStreamForHiveDecimal("0.065536"); + doTestFastStreamForHiveDecimal("-0.065536"); + + doTestFastStreamForHiveDecimal("0.101"); + doTestFastStreamForHiveDecimal("0.10001"); + doTestFastStreamForHiveDecimal("0.10000001"); + doTestFastStreamForHiveDecimal("0.10000000001"); + doTestFastStreamForHiveDecimal("0.10000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000000001"); + doTestFastStreamForHiveDecimal("0.10000000000000000000000000000001"); + + doTestFastStreamForHiveDecimal("-0.101"); + doTestFastStreamForHiveDecimal("-0.10001"); + doTestFastStreamForHiveDecimal("-0.10000001"); + doTestFastStreamForHiveDecimal("-0.10000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000000001"); + doTestFastStreamForHiveDecimal("-0.10000000000000000000000000000001"); + + doTestFastStreamForHiveDecimal(Integer.toString(Integer.MAX_VALUE)); + doTestFastStreamForHiveDecimal(Integer.toString(Integer.MIN_VALUE)); + doTestFastStreamForHiveDecimal(Long.toString(Long.MAX_VALUE)); + doTestFastStreamForHiveDecimal(Long.toString(Long.MIN_VALUE)); + doTestFastStreamForHiveDecimal(Decimal128.MAX_VALUE.toFormalString()); + doTestFastStreamForHiveDecimal(Decimal128.MIN_VALUE.toFormalString()); + + // Test known serialization tricky values + int[] values = new int[] { + 0x80, + 0x8000, + 0x800000, + 0x80000000, + 0x81, + 0x8001, + 0x800001, + 0x80000001, + 0x7f, + 0x7fff, + 0x7fffff, + 0x7fffffff, + 0xff, + 0xffff, + 0xffffff, + 0xffffffff}; + + + for(int value: values) { + for (int i = 0; i < 4; ++i) { + int[] pos = new int[] {1, 0, 0, 0, 0}; + int[] neg = new int[] {0xff, 0, 0, 0, 0}; + + pos[i+1] = neg[i+1] = value; + + doTestDecimalWithBoundsCheck(new Decimal128().update32(pos, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update32(neg, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update64(pos, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update64(neg, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update96(pos, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update96(neg, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update128(pos, 0)); + doTestDecimalWithBoundsCheck(new Decimal128().update128(neg, 0)); + } + } + } + + void doTestDecimalWithBoundsCheck(Decimal128 value) { + if ((value.compareTo(Decimal128.MAX_VALUE)) > 0 || + (value.compareTo(Decimal128.MIN_VALUE)) < 0) { + // Ignore this one, out of bounds and HiveDecimal will NPE + return; + } + doTestFastStreamForHiveDecimal(value.toFormalString()); + } + +} +