diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 44b9eb2..ef73f89 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -3550,6 +3550,10 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "1. chosen : use VectorUDFAdaptor for a small set of UDFs that were chosen for good performance\n" + "2. all : use VectorUDFAdaptor for all UDFs" ), + HIVE_TEST_VECTOR_ADAPTOR_OVERRIDE("hive.test.vectorized.adaptor.override", false, + "internal use only, used to force always using the VectorUDFAdaptor.\n" + + "The default is false, of course", + true), HIVE_VECTORIZATION_PTF_ENABLED("hive.vectorized.execution.ptf.enabled", true, "This flag should be set to true to enable vectorized mode of the PTF of query execution.\n" + "The default value is true."), diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt index fa72171..f512639 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimal.txt @@ -522,7 +522,7 @@ public class extends VectorAggregateExpression { fields[AVERAGE_COUNT_FIELD_INDEX].isNull[batchIndex] = false; ((LongColumnVector) fields[AVERAGE_COUNT_FIELD_INDEX]).vector[batchIndex] = myagg.count; fields[AVERAGE_SUM_FIELD_INDEX].isNull[batchIndex] = false; - ((DecimalColumnVector) fields[AVERAGE_SUM_FIELD_INDEX]).vector[batchIndex].set(myagg.sum); + ((DecimalColumnVector) fields[AVERAGE_SUM_FIELD_INDEX]).set(batchIndex, myagg.sum); // NULL out useless source field. ColumnVector sourceColVector = (ColumnVector) fields[AVERAGE_SOURCE_FIELD_INDEX]; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt index e273d07..5fe9256 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFAvgDecimalMerge.txt @@ -532,7 +532,7 @@ public class extends VectorAggregateExpression { fields[AVERAGE_COUNT_FIELD_INDEX].isNull[batchIndex] = false; ((LongColumnVector) fields[AVERAGE_COUNT_FIELD_INDEX]).vector[batchIndex] = myagg.mergeCount; fields[AVERAGE_SUM_FIELD_INDEX].isNull[batchIndex] = false; - ((DecimalColumnVector) fields[AVERAGE_SUM_FIELD_INDEX]).vector[batchIndex].set(myagg.mergeSum); + ((DecimalColumnVector) fields[AVERAGE_SUM_FIELD_INDEX]).set(batchIndex, myagg.mergeSum); // NULL out useless source field. ColumnVector sourceColVector = (ColumnVector) fields[AVERAGE_SOURCE_FIELD_INDEX]; diff --git ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt index 9fe85d3..9c8ebcc 100644 --- ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt +++ ql/src/gen/vectorization/UDAFTemplates/VectorUDAFMinMaxDecimal.txt @@ -471,6 +471,6 @@ public class extends VectorAggregateExpression { return; } outputColVector.isNull[batchIndex] = false; - outputColVector.vector[batchIndex].set(myagg.value); + outputColVector.set(batchIndex, myagg.value); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java index 0e6f8c5..689d3c3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorHashKeyWrapperBatch.java @@ -1037,7 +1037,7 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int keyInd kw.getByteLength(columnTypeSpecificIndex)); break; case DECIMAL: - ((DecimalColumnVector) colVector).vector[batchIndex].set( + ((DecimalColumnVector) colVector).set(batchIndex, kw.getDecimal(columnTypeSpecificIndex)); break; case TIMESTAMP: diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 491a6b1..18b4857 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -133,6 +133,7 @@ public static HiveVectorAdaptorUsageMode getHiveConfValue(HiveConf hiveConf) { } private HiveVectorAdaptorUsageMode hiveVectorAdaptorUsageMode; + private boolean testVectorAdaptorOverride; public enum HiveVectorIfStmtMode { ADAPTOR, @@ -158,6 +159,8 @@ public static HiveVectorIfStmtMode getHiveConfValue(HiveConf hiveConf) { private void setHiveConfVars(HiveConf hiveConf) { hiveVectorAdaptorUsageMode = HiveVectorAdaptorUsageMode.getHiveConfValue(hiveConf); + testVectorAdaptorOverride = + HiveConf.getBoolVar(hiveConf, ConfVars.HIVE_TEST_VECTOR_ADAPTOR_OVERRIDE); hiveVectorIfStmtMode = HiveVectorIfStmtMode.getHiveConfValue(hiveConf); this.reuseScratchColumns = HiveConf.getBoolVar(hiveConf, ConfVars.HIVE_VECTORIZATION_TESTING_REUSE_SCRATCH_COLUMNS); @@ -171,8 +174,11 @@ private void setHiveConfVars(HiveConf hiveConf) { private void copyHiveConfVars(VectorizationContext vContextEnvironment) { hiveVectorAdaptorUsageMode = vContextEnvironment.hiveVectorAdaptorUsageMode; + testVectorAdaptorOverride = vContextEnvironment.testVectorAdaptorOverride; hiveVectorIfStmtMode = vContextEnvironment.hiveVectorIfStmtMode; this.reuseScratchColumns = vContextEnvironment.reuseScratchColumns; + useCheckedVectorExpressions = vContextEnvironment.useCheckedVectorExpressions; + adaptorSuppressEvaluateExceptions = vContextEnvironment.adaptorSuppressEvaluateExceptions; this.ocm.setReuseColumns(reuseScratchColumns); } @@ -801,8 +807,12 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, VectorExpress // Note: this is a no-op for custom UDFs List childExpressions = getChildExpressionsWithImplicitCast(expr.getGenericUDF(), exprDesc.getChildren(), exprDesc.getTypeInfo()); - ve = getGenericUdfVectorExpression(expr.getGenericUDF(), - childExpressions, mode, exprDesc.getTypeInfo()); + + // Are we forcing the usage of VectorUDFAdaptor for test purposes? + if (!testVectorAdaptorOverride) { + ve = getGenericUdfVectorExpression(expr.getGenericUDF(), + childExpressions, mode, exprDesc.getTypeInfo()); + } if (ve == null) { // Ok, no vectorized class available. No problem -- try to use the VectorUDFAdaptor // when configured. @@ -1104,7 +1114,7 @@ private int getPrecisionForType(PrimitiveTypeInfo typeInfo) { return HiveDecimalUtils.getPrecisionForType(typeInfo); } - private GenericUDF getGenericUDFForCast(TypeInfo castType) throws HiveException { + public static GenericUDF getGenericUDFForCast(TypeInfo castType) throws HiveException { UDF udfClass = null; GenericUDF genericUdf = null; switch (((PrimitiveTypeInfo) castType).getPrimitiveCategory()) { @@ -1165,8 +1175,10 @@ private GenericUDF getGenericUDFForCast(TypeInfo castType) throws HiveException if (udfClass == null) { throw new HiveException("Could not add implicit cast for type "+castType.getTypeName()); } - genericUdf = new GenericUDFBridge(); - ((GenericUDFBridge) genericUdf).setUdfClassName(udfClass.getClass().getName()); + GenericUDFBridge genericUDFBridge = new GenericUDFBridge(); + genericUDFBridge.setUdfClassName(udfClass.getClass().getName()); + genericUDFBridge.setUdfName(udfClass.getClass().getSimpleName()); + genericUdf = genericUDFBridge; } if (genericUdf instanceof SettableUDF) { ((SettableUDF) genericUdf).setTypeInfo(castType); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java index d92ec32..d51d44a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedBatchUtil.java @@ -880,6 +880,10 @@ public static VectorizedRowBatch makeLike(VectorizedRowBatch batch) throws HiveE return newBatch; } + public static Writable getPrimitiveWritable(TypeInfo typeInfo) { + return getPrimitiveWritable(((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()); + } + public static Writable getPrimitiveWritable(PrimitiveCategory primitiveCategory) { switch (primitiveCategory) { case VOID: diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDateToBoolean.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDateToBoolean.java new file mode 100644 index 0000000..117e814 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDateToBoolean.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/* + * Comment from BooleanWritable evaluate(DateWritable d) + * // date value to boolean doesn't make any sense. + * So, we always set the output to NULL. + */ +public class CastDateToBoolean extends NullVectorExpression { + private static final long serialVersionUID = 1L; + + private final int colNum; + + public CastDateToBoolean(int colNum, int outputColumnNum) { + super(outputColumnNum); + this.colNum = colNum; + } + + public CastDateToBoolean() { + super(); + + // Dummy final assignments. + colNum = -1; + } + + @Override + public String vectorExpressionParameters() { + return getColumnParamString(0, colNum); + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(1) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("date")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToDecimal.java index 5e0d570..bcf55cd 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToDecimal.java @@ -57,7 +57,7 @@ public CastDecimalToDecimal() { */ protected void convert(DecimalColumnVector outputColVector, DecimalColumnVector inputColVector, int i) { // The set routine enforces precision and scale. - outputColVector.vector[i].set(inputColVector.vector[i]); + outputColVector.set(i, inputColVector.vector[i]); } /** diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDoubleToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDoubleToDecimal.java index 4619724..de7b6de 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDoubleToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDoubleToDecimal.java @@ -41,7 +41,7 @@ public CastDoubleToDecimal(int inputColumn, int outputColumnNum) { protected void func(DecimalColumnVector outV, DoubleColumnVector inV, int i) { HiveDecimalWritable decWritable = outV.vector[i]; decWritable.setFromDouble(inV.vector[i]); - if (!decWritable.isSet()) { + if (!decWritable.mutateEnforcePrecisionScale(outV.precision, outV.scale)) { outV.isNull[i] = true; outV.noNulls = false; } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastLongToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastLongToDecimal.java index f8edbd9..fa88e3f 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastLongToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastLongToDecimal.java @@ -41,6 +41,6 @@ public CastLongToDecimal(int inputColumn, int outputColumnNum) { @Override protected void func(DecimalColumnVector outV, LongColumnVector inV, int i) { - outV.vector[i].set(HiveDecimal.create(inV.vector[i])); + outV.set(i, HiveDecimal.create(inV.vector[i])); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastStringToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastStringToDecimal.java index d8d7dae..7dc322e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastStringToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastStringToDecimal.java @@ -63,7 +63,7 @@ protected void func(DecimalColumnVector outputColVector, BytesColumnVector input * making a new string. */ s = new String(inputColVector.vector[i], inputColVector.start[i], inputColVector.length[i], "UTF-8"); - outputColVector.vector[i].set(HiveDecimal.create(s)); + outputColVector.set(i, HiveDecimal.create(s)); } catch (Exception e) { // for any exception in conversion to decimal, produce NULL diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastTimestampToLong.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastTimestampToLong.java index 42e005e..3f5f25d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastTimestampToLong.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastTimestampToLong.java @@ -24,12 +24,16 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.MathExpr; import org.apache.hadoop.hive.ql.exec.vector.*; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; public class CastTimestampToLong extends VectorExpression { private static final long serialVersionUID = 1L; private int colNum; + private transient PrimitiveCategory integerPrimitiveCategory; + public CastTimestampToLong(int colNum, int outputColumnNum) { super(outputColumnNum); this.colNum = colNum; @@ -40,6 +44,41 @@ public CastTimestampToLong() { } @Override + public void transientInit() throws HiveException { + integerPrimitiveCategory = ((PrimitiveTypeInfo) outputTypeInfo).getPrimitiveCategory(); + } + + private void setIntegerFromTimestamp(TimestampColumnVector inputColVector, + LongColumnVector outputColVector, int batchIndex) { + + final long longValue = inputColVector.getTimestampAsLong(batchIndex); + + boolean isInRange; + switch (integerPrimitiveCategory) { + case BYTE: + isInRange = ((byte) longValue) == longValue; + break; + case SHORT: + isInRange = ((short) longValue) == longValue; + break; + case INT: + isInRange = ((int) longValue) == longValue; + break; + case LONG: + isInRange = true; + break; + default: + throw new RuntimeException("Unexpected integer primitive category " + integerPrimitiveCategory); + } + if (isInRange) { + outputColVector.vector[batchIndex] = longValue; + } else { + outputColVector.isNull[batchIndex] = true; + outputColVector.noNulls = false; + } + } + + @Override public void evaluate(VectorizedRowBatch batch) throws HiveException { if (childExpressions != null) { @@ -52,7 +91,6 @@ public void evaluate(VectorizedRowBatch batch) throws HiveException { boolean[] inputIsNull = inputColVector.isNull; boolean[] outputIsNull = outputColVector.isNull; int n = batch.size; - long[] outputVector = outputColVector.vector; // return immediately if batch is empty if (n == 0) { @@ -65,7 +103,7 @@ public void evaluate(VectorizedRowBatch batch) throws HiveException { if (inputColVector.isRepeating) { if (inputColVector.noNulls || !inputIsNull[0]) { outputIsNull[0] = false; - outputVector[0] = inputColVector.getTimestampAsLong(0); + setIntegerFromTimestamp(inputColVector, outputColVector, 0); } else { outputIsNull[0] = true; outputColVector.noNulls = false; @@ -84,12 +122,12 @@ public void evaluate(VectorizedRowBatch batch) throws HiveException { final int i = sel[j]; // Set isNull before call in case it changes it mind. outputIsNull[i] = false; - outputVector[i] = inputColVector.getTimestampAsLong(i); + setIntegerFromTimestamp(inputColVector, outputColVector, i); } } else { for(int j = 0; j != n; j++) { final int i = sel[j]; - outputVector[i] = inputColVector.getTimestampAsLong(i); + setIntegerFromTimestamp(inputColVector, outputColVector, i); } } } else { @@ -101,7 +139,7 @@ public void evaluate(VectorizedRowBatch batch) throws HiveException { outputColVector.noNulls = true; } for(int i = 0; i != n; i++) { - outputVector[i] = inputColVector.getTimestampAsLong(i); + setIntegerFromTimestamp(inputColVector, outputColVector, i); } } } else /* there are NULLs in the inputColVector */ { @@ -114,20 +152,20 @@ public void evaluate(VectorizedRowBatch batch) throws HiveException { for(int j = 0; j != n; j++) { int i = sel[j]; if (!inputIsNull[i]) { - inputIsNull[i] = false; - outputVector[i] = inputColVector.getTimestampAsLong(i); + outputIsNull[i] = false; + setIntegerFromTimestamp(inputColVector, outputColVector, i); } else { - inputIsNull[i] = true; + outputIsNull[i] = true; outputColVector.noNulls = false; } } } else { for(int i = 0; i != n; i++) { if (!inputIsNull[i]) { - inputIsNull[i] = false; - outputVector[i] = inputColVector.getTimestampAsLong(i); + outputIsNull[i] = false; + setIntegerFromTimestamp(inputColVector, outputColVector, i); } else { - inputIsNull[i] = true; + outputIsNull[i] = true; outputColVector.noNulls = false; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/NullVectorExpression.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/NullVectorExpression.java new file mode 100644 index 0000000..b7bfe1e --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/NullVectorExpression.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.*; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor.Descriptor; +import org.apache.hadoop.hive.ql.metadata.HiveException; + +public class NullVectorExpression extends VectorExpression { + private static final long serialVersionUID = 1L; + + public NullVectorExpression(int outputColumnNum) { + super(outputColumnNum); + } + + public NullVectorExpression() { + super(); + } + + + @Override + public String vectorExpressionParameters() { + return null; + } + + @Override + public void evaluate(VectorizedRowBatch batch) throws HiveException { + ColumnVector colVector = batch.cols[outputColumnNum]; + colVector.isNull[0] = true; + colVector.noNulls = false; + colVector.isRepeating = true; + } + + @Override + public Descriptor getDescriptor() { + // Not applicable. + return null; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java index 95703b0..315b72b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal.java @@ -460,6 +460,6 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int column } outputColVector.isNull[batchIndex] = false; - outputColVector.vector[batchIndex].set(myagg.sum); + outputColVector.set(batchIndex, myagg.sum); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java index d091f3f..117611e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/aggregates/VectorUDAFSumDecimal64ToDecimal.java @@ -516,6 +516,6 @@ public void assignRowColumn(VectorizedRowBatch batch, int batchIndex, int column } outputColVector.isNull[batchIndex] = false; - outputColVector.vector[batchIndex].set(myagg.regularDecimalSum); + outputColVector.set(batchIndex, myagg.regularDecimalSum); } } \ No newline at end of file diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFEvaluatorDecimalFirstValue.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFEvaluatorDecimalFirstValue.java index ce118bc..dc037ae 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFEvaluatorDecimalFirstValue.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFEvaluatorDecimalFirstValue.java @@ -102,7 +102,7 @@ public void evaluateGroupBatch(VectorizedRowBatch batch, boolean isLastGroupBatc outputColVector.isNull[0] = true; } else { outputColVector.isNull[0] = false; - outputColVector.vector[0].set(firstValue); + outputColVector.set(0, firstValue); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFGroupBatches.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFGroupBatches.java index 573910e..a39da0d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFGroupBatches.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ptf/VectorPTFGroupBatches.java @@ -206,7 +206,7 @@ private void fillGroupResults(VectorizedRowBatch batch) { ((DoubleColumnVector) outputColVector).vector[0] = evaluator.getDoubleGroupResult(); break; case DECIMAL: - ((DecimalColumnVector) outputColVector).vector[0].set(evaluator.getDecimalGroupResult()); + ((DecimalColumnVector) outputColVector).set(0, evaluator.getDecimalGroupResult()); break; default: throw new RuntimeException("Unexpected column vector type " + evaluator.getResultColumnVectorType()); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToBoolean.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToBoolean.java index d7d8bcc..3ac7a06 100755 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToBoolean.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToBoolean.java @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.CastStringToLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastDoubleToBooleanViaDoubleToLong; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastLongToBooleanViaLongToLong; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastDateToBooleanViaLongToLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDateToBoolean; import org.apache.hadoop.hive.ql.exec.vector.expressions.CastTimestampToBoolean; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; @@ -48,7 +48,7 @@ * */ @VectorizedExpressions({CastLongToBooleanViaLongToLong.class, - CastDateToBooleanViaLongToLong.class, CastTimestampToBoolean.class, CastStringToBoolean.class, + CastDateToBoolean.class, CastTimestampToBoolean.class, CastStringToBoolean.class, CastDoubleToBooleanViaDoubleToLong.class, CastDecimalToBoolean.class, CastStringToLong.class}) public class UDFToBoolean extends UDF { private final BooleanWritable booleanWritable = new BooleanWritable(); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToByte.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToByte.java index 8c6629e..1128b32 100755 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToByte.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToByte.java @@ -187,7 +187,12 @@ public ByteWritable evaluate(TimestampWritable i) { if (i == null) { return null; } else { - byteWritable.set((byte)i.getSeconds()); + final long longValue = i.getSeconds(); + final byte byteValue = (byte) longValue; + if (byteValue != longValue) { + return null; + } + byteWritable.set(byteValue); return byteWritable; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToInteger.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToInteger.java index 9540449..748a688 100755 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToInteger.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToInteger.java @@ -197,7 +197,12 @@ public IntWritable evaluate(TimestampWritable i) { if (i == null) { return null; } else { - intWritable.set((int) i.getSeconds()); + final long longValue = i.getSeconds(); + final int intValue = (int) longValue; + if (intValue != longValue) { + return null; + } + intWritable.set(intValue); return intWritable; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToShort.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToShort.java index 94bbe82..e003ff3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToShort.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToShort.java @@ -189,7 +189,12 @@ public ShortWritable evaluate(TimestampWritable i) { if (i == null) { return null; } else { - shortWritable.set((short) i.getSeconds()); + final long longValue = i.getSeconds(); + final short shortValue = (short) longValue; + if (shortValue != longValue) { + return null; + } + shortWritable.set(shortValue); return shortWritable; } } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index fa5c775..7877532 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -222,21 +222,23 @@ public void initExplicitSchema(Random r, List explicitTypeNameList, int "map" }; - private String getRandomTypeName(SupportedTypes supportedTypes, Set allowedTypeNameSet) { + private static String getRandomTypeName(Random random, SupportedTypes supportedTypes, + Set allowedTypeNameSet) { + String typeName = null; do { - if (r.nextInt(10 ) != 0) { - typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + if (random.nextInt(10 ) != 0) { + typeName = possibleHivePrimitiveTypeNames[random.nextInt(possibleHivePrimitiveTypeNames.length)]; } else { switch (supportedTypes) { case PRIMITIVES: - typeName = possibleHivePrimitiveTypeNames[r.nextInt(possibleHivePrimitiveTypeNames.length)]; + typeName = possibleHivePrimitiveTypeNames[random.nextInt(possibleHivePrimitiveTypeNames.length)]; break; case ALL_EXCEPT_MAP: - typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length - 1)]; + typeName = possibleHiveComplexTypeNames[random.nextInt(possibleHiveComplexTypeNames.length - 1)]; break; case ALL: - typeName = possibleHiveComplexTypeNames[r.nextInt(possibleHiveComplexTypeNames.length)]; + typeName = possibleHiveComplexTypeNames[random.nextInt(possibleHiveComplexTypeNames.length)]; break; } } @@ -244,17 +246,22 @@ private String getRandomTypeName(SupportedTypes supportedTypes, Set allo return typeName; } - private String getDecoratedTypeName(String typeName, SupportedTypes supportedTypes, - Set allowedTypeNameSet, int depth, int maxDepth) { + public static String getDecoratedTypeName(Random random, String typeName) { + return getDecoratedTypeName(random, typeName, null, null, 0, 1); + } + + private static String getDecoratedTypeName(Random random, String typeName, + SupportedTypes supportedTypes, Set allowedTypeNameSet, int depth, int maxDepth) { + depth++; if (depth < maxDepth) { supportedTypes = SupportedTypes.PRIMITIVES; } if (typeName.equals("char")) { - final int maxLength = 1 + r.nextInt(100); + final int maxLength = 1 + random.nextInt(100); typeName = String.format("char(%d)", maxLength); } else if (typeName.equals("varchar")) { - final int maxLength = 1 + r.nextInt(100); + final int maxLength = 1 + random.nextInt(100); typeName = String.format("varchar(%d)", maxLength); } else if (typeName.equals("decimal")) { typeName = @@ -263,26 +270,34 @@ private String getDecoratedTypeName(String typeName, SupportedTypes supportedTyp HiveDecimal.SYSTEM_DEFAULT_PRECISION, HiveDecimal.SYSTEM_DEFAULT_SCALE); } else if (typeName.equals("array")) { - String elementTypeName = getRandomTypeName(supportedTypes, allowedTypeNameSet); + String elementTypeName = getRandomTypeName(random, supportedTypes, allowedTypeNameSet); elementTypeName = - getDecoratedTypeName(elementTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + getDecoratedTypeName(random, elementTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); typeName = String.format("array<%s>", elementTypeName); } else if (typeName.equals("map")) { - String keyTypeName = getRandomTypeName(SupportedTypes.PRIMITIVES, allowedTypeNameSet); + String keyTypeName = + getRandomTypeName( + random, SupportedTypes.PRIMITIVES, allowedTypeNameSet); keyTypeName = - getDecoratedTypeName(keyTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); - String valueTypeName = getRandomTypeName(supportedTypes, allowedTypeNameSet); + getDecoratedTypeName( + random, keyTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + String valueTypeName = + getRandomTypeName( + random, supportedTypes, allowedTypeNameSet); valueTypeName = - getDecoratedTypeName(valueTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + getDecoratedTypeName( + random, valueTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); typeName = String.format("map<%s,%s>", keyTypeName, valueTypeName); } else if (typeName.equals("struct")) { - final int fieldCount = 1 + r.nextInt(10); + final int fieldCount = 1 + random.nextInt(10); final StringBuilder sb = new StringBuilder(); for (int i = 0; i < fieldCount; i++) { - String fieldTypeName = getRandomTypeName(supportedTypes, allowedTypeNameSet); + String fieldTypeName = + getRandomTypeName( + random, supportedTypes, allowedTypeNameSet); fieldTypeName = getDecoratedTypeName( - fieldTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + random, fieldTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); if (i > 0) { sb.append(","); } @@ -294,13 +309,15 @@ private String getDecoratedTypeName(String typeName, SupportedTypes supportedTyp typeName = String.format("struct<%s>", sb.toString()); } else if (typeName.equals("struct") || typeName.equals("uniontype")) { - final int fieldCount = 1 + r.nextInt(10); + final int fieldCount = 1 + random.nextInt(10); final StringBuilder sb = new StringBuilder(); for (int i = 0; i < fieldCount; i++) { - String fieldTypeName = getRandomTypeName(supportedTypes, allowedTypeNameSet); + String fieldTypeName = + getRandomTypeName( + random, supportedTypes, allowedTypeNameSet); fieldTypeName = getDecoratedTypeName( - fieldTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + random, fieldTypeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); if (i > 0) { sb.append(","); } @@ -311,6 +328,11 @@ private String getDecoratedTypeName(String typeName, SupportedTypes supportedTyp return typeName; } + private String getDecoratedTypeName(String typeName, + SupportedTypes supportedTypes, Set allowedTypeNameSet, int depth, int maxDepth) { + return getDecoratedTypeName(r, typeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); + } + private ObjectInspector getObjectInspector(TypeInfo typeInfo) { return getObjectInspector(typeInfo, DataTypePhysicalVariation.NONE); } @@ -454,7 +476,7 @@ private void chooseSchema(SupportedTypes supportedTypes, Set allowedType typeName = explicitTypeNameList.get(c); dataTypePhysicalVariation = explicitDataTypePhysicalVariationList.get(c); } else if (onlyOne || allowedTypeNameSet != null) { - typeName = getRandomTypeName(supportedTypes, allowedTypeNameSet); + typeName = getRandomTypeName(r, supportedTypes, allowedTypeNameSet); } else { int typeNum; if (allTypes) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorCastStatement.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorCastStatement.java new file mode 100644 index 0000000..3354185 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorCastStatement.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IdentityExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.metadata.VirtualColumn; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.io.LongWritable; + +import junit.framework.Assert; + +import org.junit.Ignore; +import org.junit.Test; + +public class TestVectorCastStatement { + + @Test + public void testBoolean() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "boolean"); + } + + @Test + public void testTinyInt() throws Exception { + Random random = new Random(5371); + + doIfTests(random, "tinyint"); + } + + @Test + public void testSmallInt() throws Exception { + Random random = new Random(2772); + + doIfTests(random, "smallint"); + } + + @Test + public void testInt() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "int"); + } + + @Test + public void testBigInt() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "bigint"); + } + + @Test + public void testString() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "string"); + } + + @Test + public void testTimestamp() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "timestamp"); + } + + @Test + public void testDate() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "date"); + } + + @Ignore("HIVE-19108") + @Test + public void testFloat() throws Exception { + Random random = new Random(7322); + + doIfTests(random, "float"); + } + + @Test + public void testDouble() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "double"); + } + + @Test + public void testChar() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "char(10)"); + } + + @Test + public void testVarchar() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "varchar(15)"); + } + + @Test + public void testBinary() throws Exception { + Random random = new Random(12882); + + doIfTests(random, "binary"); + } + + @Test + public void testDecimal() throws Exception { + Random random = new Random(9300); + + doIfTests(random, "decimal(38,18)"); + doIfTests(random, "decimal(38,0)"); + doIfTests(random, "decimal(20,8)"); + doIfTests(random, "decimal(10,4)"); + } + + public enum CastStmtTestMode { + ROW_MODE, + ADAPTOR, + VECTOR_EXPRESSION; + + static final int count = values().length; + } + + private void doIfTests(Random random, String typeName) + throws Exception { + doIfTests(random, typeName, DataTypePhysicalVariation.NONE); + } + + private void doIfTests(Random random, String typeName, + DataTypePhysicalVariation dataTypePhysicalVariation) + throws Exception { + + TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeName); + PrimitiveCategory primitiveCategory = ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + + for (PrimitiveCategory targetPrimitiveCategory : PrimitiveCategory.values()) { + + if (targetPrimitiveCategory == PrimitiveCategory.VOID || + targetPrimitiveCategory == PrimitiveCategory.INTERVAL_YEAR_MONTH || + targetPrimitiveCategory == PrimitiveCategory.INTERVAL_DAY_TIME || + targetPrimitiveCategory == PrimitiveCategory.TIMESTAMPLOCALTZ || + targetPrimitiveCategory == PrimitiveCategory.UNKNOWN) { + continue; + } + + // BINARY conversions supported by GenericUDFDecimal, GenericUDFTimestamp. + if (primitiveCategory == PrimitiveCategory.BINARY) { + if (targetPrimitiveCategory == PrimitiveCategory.DECIMAL || + targetPrimitiveCategory == PrimitiveCategory.TIMESTAMP) { + continue; + } + } + + // DATE conversions supported by GenericUDFDecimal. + if (primitiveCategory == PrimitiveCategory.DATE) { + if (targetPrimitiveCategory == PrimitiveCategory.DECIMAL) { + continue; + } + } + + if (primitiveCategory == targetPrimitiveCategory) { + if (primitiveCategory != PrimitiveCategory.CHAR && + primitiveCategory != PrimitiveCategory.VARCHAR && + primitiveCategory != PrimitiveCategory.DECIMAL) { + continue; + } + } + + doIfTestOneCast(random, typeName, dataTypePhysicalVariation, targetPrimitiveCategory); + } + } + + private void doIfTestOneCast(Random random, String typeName, + DataTypePhysicalVariation dataTypePhysicalVariation, + PrimitiveCategory targetPrimitiveCategory) + throws Exception { + + TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeName); + + boolean isDecimal64 = (dataTypePhysicalVariation == DataTypePhysicalVariation.DECIMAL_64); + final int decimal64Scale = + (isDecimal64 ? ((DecimalTypeInfo) typeInfo).getScale() : 0); + + List explicitTypeNameList = new ArrayList(); + List explicitDataTypePhysicalVariationList = new ArrayList(); + explicitTypeNameList.add(typeName); + explicitDataTypePhysicalVariationList.add(dataTypePhysicalVariation); + + VectorRandomRowSource rowSource = new VectorRandomRowSource(); + + rowSource.initExplicitSchema( + random, explicitTypeNameList, /* maxComplexDepth */ 0, /* allowNull */ true, + explicitDataTypePhysicalVariationList); + + List columns = new ArrayList(); + columns.add("col0"); + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(typeInfo, "col0", "table", false); + + List children = new ArrayList(); + children.add(col1Expr); + + //---------------------------------------------------------------------------------------------- + + String targetTypeName; + if (targetPrimitiveCategory == PrimitiveCategory.BYTE) { + targetTypeName = "tinyint"; + } else if (targetPrimitiveCategory == PrimitiveCategory.SHORT) { + targetTypeName = "smallint"; + } else if (targetPrimitiveCategory == PrimitiveCategory.LONG) { + targetTypeName = "bigint"; + } else { + targetTypeName = targetPrimitiveCategory.name().toLowerCase(); + } + targetTypeName = VectorRandomRowSource.getDecoratedTypeName(random, targetTypeName); + TypeInfo targetTypeInfo = TypeInfoUtils.getTypeInfoFromTypeString(targetTypeName); + + //---------------------------------------------------------------------------------------------- + + String[] columnNames = columns.toArray(new String[0]); + + Object[][] randomRows = rowSource.randomRows(100000); + + VectorRandomBatchSource batchSource = + VectorRandomBatchSource.createInterestingBatches( + random, + rowSource, + randomRows, + null); + + final int rowCount = randomRows.length; + Object[][] resultObjectsArray = new Object[CastStmtTestMode.count][]; + for (int i = 0; i < CastStmtTestMode.count; i++) { + + Object[] resultObjects = new Object[rowCount]; + resultObjectsArray[i] = resultObjects; + + CastStmtTestMode ifStmtTestMode = CastStmtTestMode.values()[i]; + switch (ifStmtTestMode) { + case ROW_MODE: + if (!doRowCastTest( + typeInfo, + targetTypeInfo, + columns, + children, + randomRows, + rowSource.rowStructObjectInspector(), + resultObjects)) { + return; + } + break; + case ADAPTOR: + case VECTOR_EXPRESSION: + if (!doVectorCastTest( + typeInfo, + targetTypeInfo, + columns, + columnNames, + rowSource.typeInfos(), + rowSource.dataTypePhysicalVariations(), + children, + ifStmtTestMode, + batchSource, + resultObjects)) { + return; + } + break; + default: + throw new RuntimeException("Unexpected IF statement test mode " + ifStmtTestMode); + } + } + + for (int i = 0; i < rowCount; i++) { + // Row-mode is the expected value. + Object expectedResult = resultObjectsArray[0][i]; + + for (int v = 1; v < CastStmtTestMode.count; v++) { + Object vectorResult = resultObjectsArray[v][i]; + if (expectedResult == null || vectorResult == null) { + if (expectedResult != null || vectorResult != null) { + Assert.fail( + "Row " + i + + " sourceTypeName " + typeName + + " targetTypeName " + targetTypeName + + " " + CastStmtTestMode.values()[v] + + " result is NULL " + (vectorResult == null ? "YES" : "NO") + + " does not match row-mode expected result is NULL " + + (expectedResult == null ? "YES" : "NO")); + } + } else { + + if (isDecimal64 && expectedResult instanceof LongWritable) { + + HiveDecimalWritable expectedHiveDecimalWritable = new HiveDecimalWritable(0); + expectedHiveDecimalWritable.deserialize64( + ((LongWritable) expectedResult).get(), decimal64Scale); + expectedResult = expectedHiveDecimalWritable; + } + + if (!expectedResult.equals(vectorResult)) { + Assert.fail( + "Row " + i + + " sourceTypeName " + typeName + + " targetTypeName " + targetTypeName + + " " + CastStmtTestMode.values()[v] + + " result " + vectorResult.toString() + + " (" + vectorResult.getClass().getSimpleName() + ")" + + " does not match row-mode expected result " + expectedResult.toString() + + " (" + expectedResult.getClass().getSimpleName() + ")"); + } + } + } + } + } + + private boolean doRowCastTest(TypeInfo typeInfo, TypeInfo targetTypeInfo, + List columns, List children, + Object[][] randomRows, ObjectInspector rowInspector, Object[] resultObjects) + throws Exception { + + GenericUDF udf; + try { + udf = VectorizationContext.getGenericUDFForCast(targetTypeInfo); + } catch (HiveException e) { + return false; + } + + ExprNodeGenericFuncDesc exprDesc = + new ExprNodeGenericFuncDesc(targetTypeInfo, udf, children); + HiveConf hiveConf = new HiveConf(); + ExprNodeEvaluator evaluator = + ExprNodeEvaluatorFactory.get(exprDesc, hiveConf); + try { + evaluator.initialize(rowInspector); + } catch (HiveException e) { + return false; + } + + ObjectInspector objectInspector = TypeInfoUtils + .getStandardWritableObjectInspectorFromTypeInfo(targetTypeInfo); + + final int rowCount = randomRows.length; + for (int i = 0; i < rowCount; i++) { + Object[] row = randomRows[i]; + Object result = evaluator.evaluate(row); + Object copyResult = + ObjectInspectorUtils.copyToStandardObject( + result, objectInspector, ObjectInspectorCopyOption.WRITABLE); + resultObjects[i] = copyResult; + } + + return true; + } + + private void extractResultObjects(VectorizedRowBatch batch, int rowIndex, + VectorExtractRow resultVectorExtractRow, Object[] scrqtchRow, Object[] resultObjects) { + // UNDONE: selectedInUse + for (int i = 0; i < batch.size; i++) { + resultVectorExtractRow.extractRow(batch, i, scrqtchRow); + + // UNDONE: Need to copy the object. + resultObjects[rowIndex++] = scrqtchRow[0]; + } + } + + private boolean doVectorCastTest(TypeInfo typeInfo, TypeInfo targetTypeInfo, + List columns, String[] columnNames, + TypeInfo[] typeInfos, DataTypePhysicalVariation[] dataTypePhysicalVariations, + List children, + CastStmtTestMode castStmtTestMode, + VectorRandomBatchSource batchSource, + Object[] resultObjects) + throws Exception { + + GenericUDF udf; + try { + udf = VectorizationContext.getGenericUDFForCast(targetTypeInfo); + } catch (HiveException e) { + return false; + } + + ExprNodeGenericFuncDesc exprDesc = + new ExprNodeGenericFuncDesc(targetTypeInfo, udf, children); + + HiveConf hiveConf = new HiveConf(); + if (castStmtTestMode == CastStmtTestMode.ADAPTOR) { + hiveConf.setBoolVar(HiveConf.ConfVars.HIVE_TEST_VECTOR_ADAPTOR_OVERRIDE, true); + } + + VectorizationContext vectorizationContext = + new VectorizationContext( + "name", + columns, + Arrays.asList(typeInfos), + Arrays.asList(dataTypePhysicalVariations), + hiveConf); + VectorExpression vectorExpression = vectorizationContext.getVectorExpression(exprDesc); + vectorExpression.transientInit(); + + /* + System.out.println( + "*DEBUG* typeInfo " + typeInfo.toString() + + " targetTypeInfo " + targetTypeInfo + + " castStmtTestMode " + castStmtTestMode + + " vectorExpression " + vectorExpression.getClass().getSimpleName()); + */ + + VectorRandomRowSource rowSource = batchSource.getRowSource(); + VectorizedRowBatchCtx batchContext = + new VectorizedRowBatchCtx( + columnNames, + rowSource.typeInfos(), + rowSource.dataTypePhysicalVariations(), + /* dataColumnNums */ null, + /* partitionColumnCount */ 0, + /* virtualColumnCount */ 0, + /* neededVirtualColumns */ null, + vectorizationContext.getScratchColumnTypeNames(), + vectorizationContext.getScratchDataTypePhysicalVariations()); + + VectorizedRowBatch batch = batchContext.createVectorizedRowBatch(); + + VectorExtractRow resultVectorExtractRow = new VectorExtractRow(); + + resultVectorExtractRow.init( + new TypeInfo[] { targetTypeInfo }, new int[] { vectorExpression.getOutputColumnNum() }); + Object[] scrqtchRow = new Object[1]; + + batchSource.resetBatchIteration(); + int rowIndex = 0; + while (true) { + if (!batchSource.fillNextBatch(batch)) { + break; + } + vectorExpression.evaluate(batch); + extractResultObjects(batch, rowIndex, resultVectorExtractRow, scrqtchRow, resultObjects); + rowIndex += batch.size; + } + + return true; + } +} diff --git ql/src/test/results/clientpositive/llap/vector_decimal_aggregate.q.out ql/src/test/results/clientpositive/llap/vector_decimal_aggregate.q.out index 902d137..6cd1e8d 100644 --- ql/src/test/results/clientpositive/llap/vector_decimal_aggregate.q.out +++ ql/src/test/results/clientpositive/llap/vector_decimal_aggregate.q.out @@ -806,7 +806,7 @@ POSTHOOK: Input: default@decimal_vgby_small 626923679 1024 9723.40270 -9778.95135 10541.05247 10.29399655273437500000000000000 5742.091453325365 5744.897264122335 1024 11646 -11712 12641 12.3447 6877.306686989158 6880.6672084147185 6981 2 -515.62107 -515.62107 -1031.24214 -515.62107000000000000000000000000 0.0 0.0 3 6984454 -618 6983218 2327739.3333 3292794.518850853 4032833.1995089175 762 1 1531.21941 1531.21941 1531.21941 1531.21941000000000000000000000000 0.0 NULL 2 6984454 1834 6986288 3493144.0000 3491310.0 4937457.95244881 -NULL 3072 9318.43514 -4298.15135 5018444.11392 1633.60811000000000000000000000000 5695.4830839098695 5696.410309489299 3072 11161 -5148 6010880 1956.6667 6821.647911041892 6822.758476439734 +NULL 3072 9318.43514 -4298.15135 5018444.11392 NULL 5695.4830839098695 5696.410309489299 3072 11161 -5148 6010880 1956.6667 6821.647911041892 6822.758476439734 PREHOOK: query: SELECT SUM(HASH(*)) FROM (SELECT cint, COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), @@ -825,4 +825,4 @@ FROM (SELECT cint, POSTHOOK: type: QUERY POSTHOOK: Input: default@decimal_vgby_small #### A masked pattern was here #### -96966670826 +96673467876 diff --git ql/src/test/results/clientpositive/spark/vector_decimal_aggregate.q.out ql/src/test/results/clientpositive/spark/vector_decimal_aggregate.q.out index d37a27e..c46b607 100644 --- ql/src/test/results/clientpositive/spark/vector_decimal_aggregate.q.out +++ ql/src/test/results/clientpositive/spark/vector_decimal_aggregate.q.out @@ -796,7 +796,7 @@ POSTHOOK: Input: default@decimal_vgby_small 626923679 1024 9723.40270 -9778.95135 10541.05247 10.29399655273437500000000000000 5742.091453325365 5744.897264122335 1024 11646 -11712 12641 12.3447 6877.306686989158 6880.6672084147185 6981 2 -515.62107 -515.62107 -1031.24214 -515.62107000000000000000000000000 0.0 0.0 3 6984454 -618 6983218 2327739.3333 3292794.518850853 4032833.1995089175 762 1 1531.21941 1531.21941 1531.21941 1531.21941000000000000000000000000 0.0 NULL 2 6984454 1834 6986288 3493144.0000 3491310.0 4937457.95244881 -NULL 3072 9318.43514 -4298.15135 5018444.11392 1633.60811000000000000000000000000 5695.4830839098695 5696.410309489299 3072 11161 -5148 6010880 1956.6667 6821.647911041892 6822.758476439734 +NULL 3072 9318.43514 -4298.15135 5018444.11392 NULL 5695.4830839098695 5696.410309489299 3072 11161 -5148 6010880 1956.6667 6821.647911041892 6822.758476439734 PREHOOK: query: SELECT SUM(HASH(*)) FROM (SELECT cint, COUNT(cdecimal1), MAX(cdecimal1), MIN(cdecimal1), SUM(cdecimal1), AVG(cdecimal1), STDDEV_POP(cdecimal1), STDDEV_SAMP(cdecimal1), @@ -815,4 +815,4 @@ FROM (SELECT cint, POSTHOOK: type: QUERY POSTHOOK: Input: default@decimal_vgby_small #### A masked pattern was here #### -96966670826 +96673467876 diff --git ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out index 16c80f0..04c534e 100644 --- ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out +++ ql/src/test/results/clientpositive/vector_decimal_aggregate.q.out @@ -379,7 +379,7 @@ STAGE PLANS: Map Operator Tree: TableScan alias: decimal_vgby_small - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE TableScan Vectorization: native: true vectorizationSchemaColumns: [0:cdouble:double, 1:cdecimal1:decimal(11,5)/DECIMAL_64, 2:cdecimal2:decimal(16,0)/DECIMAL_64, 3:cint:int, 4:ROW__ID:struct] @@ -390,7 +390,7 @@ STAGE PLANS: className: VectorSelectOperator native: true projectedOutputColumnNums: [1, 2, 3] - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE Group By Operator aggregations: count(cdecimal1), max(cdecimal1), min(cdecimal1), sum(cdecimal1), count(cdecimal2), max(cdecimal2), min(cdecimal2), sum(cdecimal2), count() Group By Vectorization: @@ -404,7 +404,7 @@ STAGE PLANS: keys: cint (type: int) mode: hash outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9 - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: int) sort order: + @@ -414,7 +414,7 @@ STAGE PLANS: native: false nativeConditionsMet: hive.vectorized.execution.reducesink.new.enabled IS true, No PTF TopN IS true, No DISTINCT columns IS true, BinarySortableSerDe for keys IS true, LazyBinarySerDe for values IS true nativeConditionsNotMet: hive.execution.engine mr IN [tez, spark] IS false - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE value expressions: _col1 (type: bigint), _col2 (type: decimal(11,5)), _col3 (type: decimal(11,5)), _col4 (type: decimal(21,5)), _col5 (type: bigint), _col6 (type: decimal(16,0)), _col7 (type: decimal(16,0)), _col8 (type: decimal(26,0)), _col9 (type: bigint) Execution mode: vectorized Map Vectorization: @@ -442,17 +442,17 @@ STAGE PLANS: keys: KEY._col0 (type: int) mode: mergepartial outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9 - Statistics: Num rows: 6144 Data size: 173221 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 6144 Data size: 173216 Basic stats: COMPLETE Column stats: NONE Filter Operator predicate: (_col9 > 1L) (type: boolean) - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: _col0 (type: int), _col1 (type: bigint), _col2 (type: decimal(11,5)), _col3 (type: decimal(11,5)), _col4 (type: decimal(21,5)), _col5 (type: bigint), _col6 (type: decimal(16,0)), _col7 (type: decimal(16,0)), _col8 (type: decimal(26,0)) outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8 - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE File Output Operator compressed: false - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE table: input format: org.apache.hadoop.mapred.SequenceFileInputFormat output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat @@ -539,7 +539,7 @@ STAGE PLANS: Map Operator Tree: TableScan alias: decimal_vgby_small - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE TableScan Vectorization: native: true vectorizationSchemaColumns: [0:cdouble:double, 1:cdecimal1:decimal(11,5)/DECIMAL_64, 2:cdecimal2:decimal(16,0)/DECIMAL_64, 3:cint:int, 4:ROW__ID:struct] @@ -551,7 +551,7 @@ STAGE PLANS: native: true projectedOutputColumnNums: [3, 1, 2, 6, 9, 7, 12] selectExpressions: CastDecimalToDouble(col 5:decimal(11,5))(children: ConvertDecimal64ToDecimal(col 1:decimal(11,5)/DECIMAL_64) -> 5:decimal(11,5)) -> 6:double, DoubleColMultiplyDoubleColumn(col 7:double, col 8:double)(children: CastDecimalToDouble(col 5:decimal(11,5))(children: ConvertDecimal64ToDecimal(col 1:decimal(11,5)/DECIMAL_64) -> 5:decimal(11,5)) -> 7:double, CastDecimalToDouble(col 5:decimal(11,5))(children: ConvertDecimal64ToDecimal(col 1:decimal(11,5)/DECIMAL_64) -> 5:decimal(11,5)) -> 8:double) -> 9:double, CastDecimalToDouble(col 10:decimal(16,0))(children: ConvertDecimal64ToDecimal(col 2:decimal(16,0)/DECIMAL_64) -> 10:decimal(16,0)) -> 7:double, DoubleColMultiplyDoubleColumn(col 8:double, col 11:double)(children: CastDecimalToDouble(col 10:decimal(16,0))(children: ConvertDecimal64ToDecimal(col 2:decimal(16,0)/DECIMAL_64) -> 10:decimal(16,0)) -> 8:double, CastDecimalToDouble(col 10:decimal(16,0))(children: ConvertDecimal64ToDecimal(col 2:decimal(16,0)/DECIMAL_64) -> 10:decimal(16,0)) -> 11:double) -> 12:double - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE Group By Operator aggregations: count(_col1), max(_col1), min(_col1), sum(_col1), sum(_col4), sum(_col3), count(_col2), max(_col2), min(_col2), sum(_col2), sum(_col6), sum(_col5), count() Group By Vectorization: @@ -565,7 +565,7 @@ STAGE PLANS: keys: _col0 (type: int) mode: hash outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13 - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: int) sort order: + @@ -575,7 +575,7 @@ STAGE PLANS: native: false nativeConditionsMet: hive.vectorized.execution.reducesink.new.enabled IS true, No PTF TopN IS true, No DISTINCT columns IS true, BinarySortableSerDe for keys IS true, LazyBinarySerDe for values IS true nativeConditionsNotMet: hive.execution.engine mr IN [tez, spark] IS false - Statistics: Num rows: 12289 Data size: 346472 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 12289 Data size: 346462 Basic stats: COMPLETE Column stats: NONE value expressions: _col1 (type: bigint), _col2 (type: decimal(11,5)), _col3 (type: decimal(11,5)), _col4 (type: decimal(21,5)), _col5 (type: double), _col6 (type: double), _col7 (type: bigint), _col8 (type: decimal(16,0)), _col9 (type: decimal(16,0)), _col10 (type: decimal(26,0)), _col11 (type: double), _col12 (type: double), _col13 (type: bigint) Execution mode: vectorized Map Vectorization: @@ -603,17 +603,17 @@ STAGE PLANS: keys: KEY._col0 (type: int) mode: mergepartial outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13 - Statistics: Num rows: 6144 Data size: 173221 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 6144 Data size: 173216 Basic stats: COMPLETE Column stats: NONE Filter Operator predicate: (_col13 > 1L) (type: boolean) - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: _col0 (type: int), _col1 (type: bigint), _col2 (type: decimal(11,5)), _col3 (type: decimal(11,5)), _col4 (type: decimal(21,5)), (CAST( _col4 AS decimal(15,9)) / _col1) (type: decimal(35,29)), power(((_col5 - ((_col6 * _col6) / _col1)) / _col1), 0.5) (type: double), power(((_col5 - ((_col6 * _col6) / _col1)) / CASE WHEN ((_col1 = 1L)) THEN (null) ELSE ((_col1 - 1)) END), 0.5) (type: double), _col7 (type: bigint), _col8 (type: decimal(16,0)), _col9 (type: decimal(16,0)), _col10 (type: decimal(26,0)), CAST( (CAST( _col10 AS decimal(20,4)) / _col7) AS decimal(20,4)) (type: decimal(20,4)), power(((_col11 - ((_col12 * _col12) / _col7)) / _col7), 0.5) (type: double), power(((_col11 - ((_col12 * _col12) / _col7)) / CASE WHEN ((_col7 = 1L)) THEN (null) ELSE ((_col7 - 1)) END), 0.5) (type: double) outputColumnNames: _col0, _col1, _col2, _col3, _col4, _col5, _col6, _col7, _col8, _col9, _col10, _col11, _col12, _col13, _col14 - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE File Output Operator compressed: false - Statistics: Num rows: 2048 Data size: 57740 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 2048 Data size: 57738 Basic stats: COMPLETE Column stats: NONE table: input format: org.apache.hadoop.mapred.SequenceFileInputFormat output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat diff --git vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java index b5220a0..45fa739 100644 --- vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java +++ vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java @@ -1035,8 +1035,6 @@ "", "", ""}, {"ColumnUnaryFunc", "CastLongToBooleanVia", "long", "long", "MathExpr.toBool", "", "", "", ""}, - {"ColumnUnaryFunc", "CastDateToBooleanVia", "long", "long", "MathExpr.toBool", "", - "", "", "date"}, // Boolean to long is done with an IdentityExpression // Boolean to double is done with standard Long to Double cast