diff --git ql/src/gen/vectorization/ExpressionTemplates/ColumnDivideColumn.txt ql/src/gen/vectorization/ExpressionTemplates/ColumnDivideColumn.txt index c4a76ae..954b90e 100644 --- ql/src/gen/vectorization/ExpressionTemplates/ColumnDivideColumn.txt +++ ql/src/gen/vectorization/ExpressionTemplates/ColumnDivideColumn.txt @@ -86,19 +86,40 @@ public class extends VectorExpression { */ boolean hasDivBy0 = false; if (inputColVector1.isRepeating && inputColVector2.isRepeating) { - denom = vector2[0]; - outputVector[0] = vector1[0] denom; + final denom = vector2[0]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[0] = vector1[0] denom; + } +#ELSE + outputVector[0] = vector1[0] denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } else if (inputColVector1.isRepeating) { final vector1Value = vector1[0]; if (batch.selectedInUse) { for(int j = 0; j != n; j++) { int i = sel[j]; - denom = vector2[i]; - outputVector[i] = vector1Value denom; + final denom = vector2[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = vector1Value denom; + } +#ELSE + outputVector[i] = vector1Value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } else { +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + for(int i = 0; i != n; i++) { + final denom = vector2[i]; + hasDivBy0 = hasDivBy0 || (denom == 0); + if (denom != 0) { + outputVector[i] = vector1Value denom; + } + } +#ELSE for(int i = 0; i != n; i++) { outputVector[i] = vector1Value vector2[i]; } @@ -106,6 +127,7 @@ public class extends VectorExpression { for(int i = 0; i != n; i++) { hasDivBy0 = hasDivBy0 || (vector2[i] == 0); } +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } else if (inputColVector2.isRepeating) { final vector2Value = vector2[0]; @@ -128,11 +150,26 @@ public class extends VectorExpression { if (batch.selectedInUse) { for(int j = 0; j != n; j++) { int i = sel[j]; - denom = vector2[i]; - outputVector[i] = vector1[i] denom; + final denom = vector2[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = vector1[i] denom; + } +#ELSE + outputVector[i] = vector1[i] denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } else { +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + for(int i = 0; i != n; i++) { + final denom = vector2[i]; + hasDivBy0 = hasDivBy0 || (denom == 0); + if (denom != 0) { + outputVector[i] = vector1[i] denom; + } + } +#ELSE for(int i = 0; i != n; i++) { outputVector[i] = vector1[i] vector2[i]; } @@ -140,13 +177,14 @@ public class extends VectorExpression { for(int i = 0; i != n; i++) { hasDivBy0 = hasDivBy0 || (vector2[i] == 0); } +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } #IF CHECKED - //when operating in checked mode make sure we handle overflows similar to non-vectorized expression - OverflowUtils.accountForOverflow(getOutputTypeInfo(), outputColVector, - batch.selectedInUse, sel, n); + //when operating in checked mode make sure we handle overflows similar to non-vectorized expression + OverflowUtils.accountForOverflow(getOutputTypeInfo(), outputColVector, + batch.selectedInUse, sel, n); #ELSE #ENDIF CHECKED /* For the case when the output can have null values, follow diff --git ql/src/gen/vectorization/ExpressionTemplates/ScalarDivideColumn.txt ql/src/gen/vectorization/ExpressionTemplates/ScalarDivideColumn.txt index 95e4ce1..3cb7aaa 100644 --- ql/src/gen/vectorization/ExpressionTemplates/ScalarDivideColumn.txt +++ ql/src/gen/vectorization/ExpressionTemplates/ScalarDivideColumn.txt @@ -20,6 +20,7 @@ package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; import java.util.Arrays; +import org.apache.hadoop.hive.ql.exec.vector.expressions.OverflowUtils; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; @@ -94,9 +95,15 @@ public class extends VectorExpression { if (inputColVector.isRepeating) { if (inputColVector.noNulls || !inputIsNull[0]) { outputIsNull[0] = false; - denom = vector[0]; - outputVector[0] = value denom; + final denom = vector[0]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[0] = value denom; + } +#ELSE + outputVector[0] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } else { outputIsNull[0] = true; outputColVector.noNulls = false; @@ -112,15 +119,27 @@ public class extends VectorExpression { final int i = sel[j]; outputIsNull[i] = false; denom = vector[i]; - outputVector[i] = value denom; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = value denom; + } +#ELSE + outputVector[i] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } else { for(int j = 0; j != n; j++) { final int i = sel[j]; - denom = vector[i]; - outputVector[i] = value denom; + final denom = vector[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = value denom; + } +#ELSE + outputVector[i] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } } else { @@ -132,9 +151,15 @@ public class extends VectorExpression { outputColVector.noNulls = true; } for(int i = 0; i != n; i++) { - denom = vector[i]; - outputVector[i] = value denom; + final denom = vector[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = value denom; + } +#ELSE + outputVector[i] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } } else /* there are NULLs in the inputColVector */ { @@ -146,20 +171,38 @@ public class extends VectorExpression { for(int j = 0; j != n; j++) { int i = sel[j]; outputIsNull[i] = inputIsNull[i]; - denom = vector[i]; - outputVector[i] = value denom; + final denom = vector[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = value denom; + } +#ELSE + outputVector[i] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } else { System.arraycopy(inputIsNull, 0, outputIsNull, 0, n); for(int i = 0; i != n; i++) { - denom = vector[i]; - outputVector[i] = value denom; + final denom = vector[i]; hasDivBy0 = hasDivBy0 || (denom == 0); +#IF MANUAL_DIVIDE_BY_ZERO_CHECK + if (denom != 0) { + outputVector[i] = value denom; + } +#ELSE + outputVector[i] = value denom; +#ENDIF MANUAL_DIVIDE_BY_ZERO_CHECK } } } +#IF CHECKED + //when operating in checked mode make sure we handle overflows similar to non-vectorized expression + OverflowUtils.accountForOverflow(getOutputTypeInfo(), outputColVector, + batch.selectedInUse, sel, n); +#ELSE +#ENDIF CHECKED if (!hasDivBy0) { NullUtil.setNullOutputEntriesColScalar(outputColVector, batch.selectedInUse, sel, n); } else { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumn.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumn.java deleted file mode 100644 index 60faebb..0000000 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumn.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.hive.ql.exec.vector.expressions; - -import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.NullUtil; -import org.apache.hadoop.hive.ql.exec.vector.*; -import org.apache.hadoop.hive.ql.metadata.HiveException; - -/** - * This operation is handled as a special case because Hive - * long%long division returns needs special handling to avoid - * for divide by zero exception - */ -public class LongColModuloLongColumn extends VectorExpression { - - private static final long serialVersionUID = 1L; - - private final int colNum1; - private final int colNum2; - - public LongColModuloLongColumn(int colNum1, int colNum2, int outputColumnNum) { - super(outputColumnNum); - this.colNum1 = colNum1; - this.colNum2 = colNum2; - } - - public LongColModuloLongColumn() { - super(); - - // Dummy final assignments. - colNum1 = -1; - colNum2 = -1; - } - - @Override - public void evaluate(VectorizedRowBatch batch) throws HiveException { - - if (childExpressions != null) { - super.evaluateChildren(batch); - } - - LongColumnVector inputColVector1 = (LongColumnVector) batch.cols[colNum1]; - LongColumnVector inputColVector2 = (LongColumnVector) batch.cols[colNum2]; - LongColumnVector outputColVector = (LongColumnVector) batch.cols[outputColumnNum]; - int[] sel = batch.selected; - int n = batch.size; - long[] vector1 = inputColVector1.vector; - long[] vector2 = inputColVector2.vector; - long[] outputVector = outputColVector.vector; - - // return immediately if batch is empty - if (n == 0) { - return; - } - - /* - * Propagate null values for a two-input operator and set isRepeating and noNulls appropriately. - */ - NullUtil.propagateNullsColCol( - inputColVector1, inputColVector2, outputColVector, sel, n, batch.selectedInUse); - - /* Disregard nulls for processing. In other words, - * the arithmetic operation is performed even if one or - * more inputs are null. This is to improve speed by avoiding - * conditional checks in the inner loop. - */ - boolean hasDivBy0 = false; - if (inputColVector1.isRepeating && inputColVector2.isRepeating) { - long denom = vector2[0]; - hasDivBy0 = hasDivBy0 || (denom == 0); - if (denom != 0) { - outputVector[0] = vector1[0] % denom; - } - } else if (inputColVector1.isRepeating) { - final long vector1Value = vector1[0]; - if (batch.selectedInUse) { - for(int j = 0; j != n; j++) { - int i = sel[j]; - long denom = vector2[i]; - hasDivBy0 = hasDivBy0 || (denom == 0); - if (denom != 0) { - outputVector[i] = vector1Value % denom; - } - } - } else { - for(int i = 0; i != n; i++) { - hasDivBy0 = hasDivBy0 || (vector2[i] == 0); - if (vector2[i] != 0) { - outputVector[i] = vector1Value % vector2[i]; - } - } - } - } else if (inputColVector2.isRepeating) { - final long vector2Value = vector2[0]; - if (vector2Value == 0) { - // Denominator is zero, convert the batch to nulls - outputColVector.noNulls = false; - outputColVector.isRepeating = true; - outputColVector.isNull[0] = true; - } else if (batch.selectedInUse) { - for(int j = 0; j != n; j++) { - int i = sel[j]; - outputVector[i] = vector1[i] % vector2Value; - } - } else { - for(int i = 0; i != n; i++) { - outputVector[i] = vector1[i] % vector2Value; - } - } - } else { - if (batch.selectedInUse) { - for(int j = 0; j != n; j++) { - int i = sel[j]; - long denom = vector2[i]; - hasDivBy0 = hasDivBy0 || (denom == 0); - if (denom != 0) { - outputVector[i] = vector1[i] % denom; - } - } - } else { - for(int i = 0; i != n; i++) { - hasDivBy0 = hasDivBy0 || (vector2[i] == 0); - if (vector2[i] != 0) { - outputVector[i] = vector1[i] % vector2[i]; - } - } - } - } - - /* For the case when the output can have null values, follow - * the convention that the data values must be 1 for long and - * NaN for double. This is to prevent possible later zero-divide errors - * in complex arithmetic expressions like col2 % (col1 - 1) - * in the case when some col1 entries are null. - */ - if (!hasDivBy0) { - NullUtil.setNullDataEntriesLong(outputColVector, batch.selectedInUse, sel, n); - } else { - NullUtil.setNullAndDivBy0DataEntriesLong( - outputColVector, batch.selectedInUse, sel, n, inputColVector2); - } - } - - @Override - public String vectorExpressionParameters() { - return getColumnParamString(0, colNum1) + ", " + getColumnParamString(1, colNum2); - } - - @Override - public VectorExpressionDescriptor.Descriptor getDescriptor() { - return (new VectorExpressionDescriptor.Builder()) - .setMode( - VectorExpressionDescriptor.Mode.PROJECTION) - .setNumArguments(2) - .setArgumentTypes( - VectorExpressionDescriptor.ArgumentType.getType("long"), - VectorExpressionDescriptor.ArgumentType.getType("long")) - .setInputExpressionTypes( - VectorExpressionDescriptor.InputExpressionType.COLUMN, - VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); - } -} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumnChecked.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumnChecked.java deleted file mode 100644 index 24a860a..0000000 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/LongColModuloLongColumnChecked.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hadoop.hive.ql.exec.vector.expressions; - -import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; -import org.apache.hadoop.hive.ql.metadata.HiveException; - -/** - * This vector expression implements a Checked variant of LongColModuloLongColumn - * If the outputTypeInfo is not long it casts the result column vector values to - * the set outputType so as to have similar result when compared to non-vectorized UDF - * execution. - */ -public class LongColModuloLongColumnChecked extends LongColModuloLongColumn { - public LongColModuloLongColumnChecked(int colNum1, int colNum2, int outputColumnNum) { - super(colNum1, colNum2, outputColumnNum); - } - - public LongColModuloLongColumnChecked() { - super(); - } - - @Override - public void evaluate(VectorizedRowBatch batch) throws HiveException { - super.evaluate(batch); - //checked for overflow based on the outputTypeInfo - OverflowUtils - .accountForOverflowLong(outputTypeInfo, (LongColumnVector) batch.cols[outputColumnNum], batch.selectedInUse, - batch.selected, batch.size); - } - - @Override - public boolean supportsCheckedExecution() { - return true; - } -} diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java index 044fb06..bef32b4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMod.java @@ -21,8 +21,6 @@ import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; -import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColModuloLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColModuloLongColumnChecked; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.*; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index 791ac82..f51b8bb 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -54,7 +54,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprVarCharScalarStringGroupColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNotNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNull; -import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColModuloLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColumnInList; import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColEqualLongScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColGreaterLongScalar; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java index 0e4dcfd..ae91b73 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/VectorRandomRowSource.java @@ -243,6 +243,10 @@ public StructObjectInspector rowStructObjectInspector() { return rowStructObjectInspector; } + public List objectInspectorList() { + return objectInspectorList; + } + public StructObjectInspector partialRowStructObjectInspector(int partialFieldCount) { ArrayList partialObjectInspectorList = new ArrayList(partialFieldCount); @@ -445,11 +449,11 @@ private String getDecoratedTypeName(String typeName, return getDecoratedTypeName(r, typeName, supportedTypes, allowedTypeNameSet, depth, maxDepth); } - private ObjectInspector getObjectInspector(TypeInfo typeInfo) { + public static ObjectInspector getObjectInspector(TypeInfo typeInfo) { return getObjectInspector(typeInfo, DataTypePhysicalVariation.NONE); } - private ObjectInspector getObjectInspector(TypeInfo typeInfo, + public static ObjectInspector getObjectInspector(TypeInfo typeInfo, DataTypePhysicalVariation dataTypePhysicalVariation) { final ObjectInspector objectInspector; diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmetic.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmetic.java new file mode 100644 index 0000000..5db6a99 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmetic.java @@ -0,0 +1,611 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; + +import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; +import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory; +import org.apache.hadoop.hive.ql.exec.FunctionInfo; +import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource; +import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx; +import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource.GenerationSpec; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.session.SessionState; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFDateAdd; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFDateDiff; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFDateSub; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPDivide; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMinus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMod; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMultiply; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; + +import junit.framework.Assert; + +import org.junit.Ignore; +import org.junit.Test; + +public class TestVectorArithmetic { + + public TestVectorArithmetic() { + // Arithmetic operations rely on getting conf from SessionState, need to initialize here. + SessionState ss = new SessionState(new HiveConf()); + ss.getConf().setVar(HiveConf.ConfVars.HIVE_COMPAT, "latest"); + SessionState.setCurrentSessionState(ss); + } + + @Test + public void testIntegers() throws Exception { + Random random = new Random(7743); + + doIntegerTests(random); + } + + @Test + public void testIntegerFloating() throws Exception { + Random random = new Random(7743); + + doIntegerFloatingTests(random); + } + + @Test + public void testFloating() throws Exception { + Random random = new Random(7743); + + doFloatingTests(random); + } + + @Test + public void testDecimals() throws Exception { + Random random = new Random(7743); + + doDecimalTests(random); + } + + public enum ArithmeticTestMode { + ROW_MODE, + ADAPTOR, + VECTOR_EXPRESSION; + + static final int count = values().length; + } + + public enum ColumnScalarMode { + COLUMN_COLUMN, + COLUMN_SCALAR, + SCALAR_COLUMN; + + static final int count = values().length; + } + + private static TypeInfo[] integerTypeInfos = new TypeInfo[] { + TypeInfoFactory.byteTypeInfo, + TypeInfoFactory.shortTypeInfo, + TypeInfoFactory.intTypeInfo, + TypeInfoFactory.longTypeInfo + }; + + // We have test failures with FLOAT. Ignoring this issue for now. + private static TypeInfo[] floatingTypeInfos = new TypeInfo[] { + // TypeInfoFactory.floatTypeInfo, + TypeInfoFactory.doubleTypeInfo + }; + + private void doIntegerTests(Random random) + throws Exception { + for (TypeInfo typeInfo : integerTypeInfos) { + for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) { + doTestsWithDiffColumnScalar( + random, typeInfo, typeInfo, columnScalarMode); + } + } + } + + private void doIntegerFloatingTests(Random random) + throws Exception { + for (TypeInfo typeInfo1 : integerTypeInfos) { + for (TypeInfo typeInfo2 : floatingTypeInfos) { + for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) { + doTestsWithDiffColumnScalar( + random, typeInfo1, typeInfo2, columnScalarMode); + } + } + } + for (TypeInfo typeInfo1 : floatingTypeInfos) { + for (TypeInfo typeInfo2 : integerTypeInfos) { + for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) { + doTestsWithDiffColumnScalar( + random, typeInfo1, typeInfo2, columnScalarMode); + } + } + } + } + + private void doFloatingTests(Random random) + throws Exception { + for (TypeInfo typeInfo1 : floatingTypeInfos) { + for (TypeInfo typeInfo2 : floatingTypeInfos) { + for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) { + doTestsWithDiffColumnScalar( + random, typeInfo1, typeInfo2, columnScalarMode); + } + } + } + } + + private static TypeInfo[] decimalTypeInfos = new TypeInfo[] { + new DecimalTypeInfo(38, 18), + new DecimalTypeInfo(25, 2), + new DecimalTypeInfo(19, 4), + new DecimalTypeInfo(18, 10), + new DecimalTypeInfo(17, 3), + new DecimalTypeInfo(12, 2), + new DecimalTypeInfo(7, 1) + }; + + private void doDecimalTests(Random random) + throws Exception { + for (TypeInfo typeInfo : decimalTypeInfos) { + for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) { + doTestsWithDiffColumnScalar( + random, typeInfo, typeInfo, columnScalarMode); + } + } + } + + private TypeInfo getOutputTypeInfo(GenericUDF genericUdfClone, + List objectInspectorList) + throws HiveException { + + ObjectInspector[] array = + objectInspectorList.toArray(new ObjectInspector[objectInspectorList.size()]); + ObjectInspector outputObjectInspector = genericUdfClone.initialize(array); + return TypeInfoUtils.getTypeInfoFromObjectInspector(outputObjectInspector); + } + + public enum Arithmetic { + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + MODULUS; + } + + private TypeInfo getDecimalScalarTypeInfo(Object scalarObject) { + HiveDecimal dec = (HiveDecimal) scalarObject; + int precision = dec.precision(); + int scale = dec.scale(); + return new DecimalTypeInfo(precision, scale); + } + + private void doTestsWithDiffColumnScalar(Random random, TypeInfo typeInfo1, TypeInfo typeInfo2, + ColumnScalarMode columnScalarMode) + throws Exception { + for (Arithmetic arithmetic : Arithmetic.values()) { + doTestsWithDiffColumnScalar(random, typeInfo1, typeInfo2, columnScalarMode, arithmetic); + } + } + + private void doTestsWithDiffColumnScalar(Random random, TypeInfo typeInfo1, TypeInfo typeInfo2, + ColumnScalarMode columnScalarMode, Arithmetic arithmetic) + throws Exception { + + String typeName1 = typeInfo1.getTypeName(); + PrimitiveCategory primitiveCategory1 = + ((PrimitiveTypeInfo) typeInfo1).getPrimitiveCategory(); + + String typeName2 = typeInfo2.getTypeName(); + PrimitiveCategory primitiveCategory2 = + ((PrimitiveTypeInfo) typeInfo2).getPrimitiveCategory(); + + if (columnScalarMode == ColumnScalarMode.COLUMN_SCALAR && + arithmetic == Arithmetic.DIVIDE) { + System.out.println("here"); + } + + List generationSpecList = new ArrayList(); + List explicitDataTypePhysicalVariationList = + new ArrayList(); + + List columns = new ArrayList(); + int columnNum = 0; + + ExprNodeDesc col1Expr; + Object scalar1Object = null; + if (columnScalarMode == ColumnScalarMode.COLUMN_COLUMN || + columnScalarMode == ColumnScalarMode.COLUMN_SCALAR) { + generationSpecList.add( + GenerationSpec.createSameType(typeInfo1)); + explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + String columnName = "col" + (columnNum++); + col1Expr = new ExprNodeColumnDesc(typeInfo1, columnName, "table", false); + columns.add(columnName); + } else { + scalar1Object = + VectorRandomRowSource.randomPrimitiveObject( + random, (PrimitiveTypeInfo) typeInfo1); + + // Adjust the decimal type to the scalar's type... + if (typeInfo1 instanceof DecimalTypeInfo) { + typeInfo1 = getDecimalScalarTypeInfo(scalar1Object); + } + + col1Expr = new ExprNodeConstantDesc(typeInfo1, scalar1Object); + } + ExprNodeDesc col2Expr; + Object scalar2Object = null; + if (columnScalarMode == ColumnScalarMode.COLUMN_COLUMN || + columnScalarMode == ColumnScalarMode.SCALAR_COLUMN) { + generationSpecList.add( + GenerationSpec.createSameType(typeInfo2)); + + explicitDataTypePhysicalVariationList.add(DataTypePhysicalVariation.NONE); + + String columnName = "col" + (columnNum++); + col2Expr = new ExprNodeColumnDesc(typeInfo2, columnName, "table", false); + columns.add(columnName); + } else { + scalar2Object = + VectorRandomRowSource.randomPrimitiveObject( + random, (PrimitiveTypeInfo) typeInfo2); + + // Adjust the decimal type to the scalar's type... + if (typeInfo2 instanceof DecimalTypeInfo) { + typeInfo2 = getDecimalScalarTypeInfo(scalar2Object); + } + + col2Expr = new ExprNodeConstantDesc(typeInfo2, scalar2Object); + } + + List objectInspectorList = new ArrayList(); + objectInspectorList.add(VectorRandomRowSource.getObjectInspector(typeInfo1)); + objectInspectorList.add(VectorRandomRowSource.getObjectInspector(typeInfo2)); + + List children = new ArrayList(); + children.add(col1Expr); + children.add(col2Expr); + + //---------------------------------------------------------------------------------------------- + + String[] columnNames = columns.toArray(new String[0]); + + VectorRandomRowSource rowSource = new VectorRandomRowSource(); + + rowSource.initGenerationSpecSchema( + random, generationSpecList, /* maxComplexDepth */ 0, /* allowNull */ true, + explicitDataTypePhysicalVariationList); + + Object[][] randomRows = rowSource.randomRows(100000); + + VectorRandomBatchSource batchSource = + VectorRandomBatchSource.createInterestingBatches( + random, + rowSource, + randomRows, + null); + + GenericUDF genericUdf; + switch (arithmetic) { + case ADD: + genericUdf = new GenericUDFOPPlus(); + break; + case SUBTRACT: + genericUdf = new GenericUDFOPMinus(); + break; + case MULTIPLY: + genericUdf = new GenericUDFOPMultiply(); + break; + case DIVIDE: + genericUdf = new GenericUDFOPDivide(); + break; + case MODULUS: + genericUdf = new GenericUDFOPMod(); + break; + default: + throw new RuntimeException("Unexpected arithmetic " + arithmetic); + } + + ObjectInspector[] objectInspectors = + objectInspectorList.toArray(new ObjectInspector[objectInspectorList.size()]); + ObjectInspector outputObjectInspector = null; + try { + outputObjectInspector = genericUdf.initialize(objectInspectors); + } catch (Exception e) { + Assert.fail(e.toString()); + } + + TypeInfo outputTypeInfo = TypeInfoUtils.getTypeInfoFromObjectInspector(outputObjectInspector); + + ExprNodeGenericFuncDesc exprDesc = + new ExprNodeGenericFuncDesc(outputTypeInfo, genericUdf, children); + + final int rowCount = randomRows.length; + Object[][] resultObjectsArray = new Object[ArithmeticTestMode.count][]; + for (int i = 0; i < ArithmeticTestMode.count; i++) { + + Object[] resultObjects = new Object[rowCount]; + resultObjectsArray[i] = resultObjects; + + ArithmeticTestMode arithmeticTestMode = ArithmeticTestMode.values()[i]; + switch (arithmeticTestMode) { + case ROW_MODE: + doRowArithmeticTest( + typeInfo1, + typeInfo2, + columns, + children, + exprDesc, + arithmetic, + randomRows, + columnScalarMode, + rowSource.rowStructObjectInspector(), + outputTypeInfo, + resultObjects); + break; + case ADAPTOR: + case VECTOR_EXPRESSION: + doVectorArithmeticTest( + typeInfo1, + typeInfo2, + columns, + columnNames, + rowSource.typeInfos(), + rowSource.dataTypePhysicalVariations(), + children, + exprDesc, + arithmetic, + arithmeticTestMode, + columnScalarMode, + batchSource, + exprDesc.getWritableObjectInspector(), + outputTypeInfo, + resultObjects); + break; + default: + throw new RuntimeException("Unexpected IF statement test mode " + arithmeticTestMode); + } + } + + for (int i = 0; i < rowCount; i++) { + // Row-mode is the expected value. + Object expectedResult = resultObjectsArray[0][i]; + + for (int v = 1; v < ArithmeticTestMode.count; v++) { + Object vectorResult = resultObjectsArray[v][i]; + if (expectedResult == null || vectorResult == null) { + if (expectedResult != null || vectorResult != null) { + Assert.fail( + "Row " + i + + " typeName " + typeName1 + + " outputTypeName " + outputTypeInfo.getTypeName() + + " " + arithmetic + + " " + ArithmeticTestMode.values()[v] + + " " + columnScalarMode + + " result is NULL " + (vectorResult == null) + + " does not match row-mode expected result is NULL " + (expectedResult == null) + + (columnScalarMode == ColumnScalarMode.SCALAR_COLUMN ? + " scalar1 " + scalar1Object.toString() : "") + + " row values " + Arrays.toString(randomRows[i]) + + (columnScalarMode == ColumnScalarMode.COLUMN_SCALAR ? + " scalar2 " + scalar2Object.toString() : "")); + } + } else { + + if (!expectedResult.equals(vectorResult)) { + Assert.fail( + "Row " + i + + " typeName " + typeName1 + + " outputTypeName " + outputTypeInfo.getTypeName() + + " " + arithmetic + + " " + ArithmeticTestMode.values()[v] + + " " + columnScalarMode + + " result " + vectorResult.toString() + + " (" + vectorResult.getClass().getSimpleName() + ")" + + " does not match row-mode expected result " + expectedResult.toString() + + " (" + expectedResult.getClass().getSimpleName() + ")" + + (columnScalarMode == ColumnScalarMode.SCALAR_COLUMN ? + " scalar1 " + scalar1Object.toString() : "") + + " row values " + Arrays.toString(randomRows[i]) + + (columnScalarMode == ColumnScalarMode.COLUMN_SCALAR ? + " scalar2 " + scalar2Object.toString() : "")); + } + } + } + } + } + + private void doRowArithmeticTest(TypeInfo typeInfo1, + TypeInfo typeInfo2, + List columns, List children, + ExprNodeGenericFuncDesc exprDesc, + Arithmetic arithmetic, + Object[][] randomRows, ColumnScalarMode columnScalarMode, + ObjectInspector rowInspector, + TypeInfo outputTypeInfo, Object[] resultObjects) throws Exception { + + System.out.println( + "*DEBUG* typeInfo " + typeInfo1.toString() + + " typeInfo2 " + typeInfo2 + + " arithmeticTestMode ROW_MODE" + + " columnScalarMode " + columnScalarMode + + " exprDesc " + exprDesc.toString()); + if (columnScalarMode == ColumnScalarMode.COLUMN_SCALAR && + arithmetic == Arithmetic.DIVIDE) { + System.out.println("here"); + } + + HiveConf hiveConf = new HiveConf(); + ExprNodeEvaluator evaluator = + ExprNodeEvaluatorFactory.get(exprDesc, hiveConf); + evaluator.initialize(rowInspector); + + ObjectInspector objectInspector = + TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( + outputTypeInfo); + + final int rowCount = randomRows.length; + for (int i = 0; i < rowCount; i++) { + Object[] row = randomRows[i]; + Object result = evaluator.evaluate(row); + Object copyResult = null; + try { + copyResult = + ObjectInspectorUtils.copyToStandardObject( + result, objectInspector, ObjectInspectorCopyOption.WRITABLE); + } catch (Exception e) { + System.out.println("here"); + } + resultObjects[i] = copyResult; + } + } + + private void extractResultObjects(VectorizedRowBatch batch, int rowIndex, + VectorExtractRow resultVectorExtractRow, Object[] scrqtchRow, + ObjectInspector objectInspector, Object[] resultObjects) { + + boolean selectedInUse = batch.selectedInUse; + int[] selected = batch.selected; + for (int logicalIndex = 0; logicalIndex < batch.size; logicalIndex++) { + final int batchIndex = (selectedInUse ? selected[logicalIndex] : logicalIndex); + resultVectorExtractRow.extractRow(batch, batchIndex, scrqtchRow); + + Object copyResult = + ObjectInspectorUtils.copyToStandardObject( + scrqtchRow[0], objectInspector, ObjectInspectorCopyOption.WRITABLE); + resultObjects[rowIndex++] = copyResult; + } + } + + private void doVectorArithmeticTest(TypeInfo typeInfo1, + TypeInfo typeInfo2, + List columns, + String[] columnNames, + TypeInfo[] typeInfos, DataTypePhysicalVariation[] dataTypePhysicalVariations, + List children, + ExprNodeGenericFuncDesc exprDesc, + Arithmetic arithmetic, + ArithmeticTestMode arithmeticTestMode, ColumnScalarMode columnScalarMode, + VectorRandomBatchSource batchSource, + ObjectInspector objectInspector, + TypeInfo outputTypeInfo, Object[] resultObjects) + throws Exception { + + HiveConf hiveConf = new HiveConf(); + if (arithmeticTestMode == ArithmeticTestMode.ADAPTOR) { + hiveConf.setBoolVar(HiveConf.ConfVars.HIVE_TEST_VECTOR_ADAPTOR_OVERRIDE, true); + } + + VectorizationContext vectorizationContext = + new VectorizationContext( + "name", + columns, + Arrays.asList(typeInfos), + Arrays.asList(dataTypePhysicalVariations), + hiveConf); + VectorExpression vectorExpression = vectorizationContext.getVectorExpression(exprDesc); + vectorExpression.transientInit(); + + String[] outputScratchTypeNames= vectorizationContext.getScratchColumnTypeNames(); + + VectorizedRowBatchCtx batchContext = + new VectorizedRowBatchCtx( + columnNames, + typeInfos, + dataTypePhysicalVariations, + /* dataColumnNums */ null, + /* partitionColumnCount */ 0, + /* virtualColumnCount */ 0, + /* neededVirtualColumns */ null, + outputScratchTypeNames, + null); + + VectorizedRowBatch batch = batchContext.createVectorizedRowBatch(); + + VectorExtractRow resultVectorExtractRow = new VectorExtractRow(); + resultVectorExtractRow.init( + new TypeInfo[] { outputTypeInfo }, new int[] { vectorExpression.getOutputColumnNum() }); + Object[] scrqtchRow = new Object[1]; + + System.out.println( + "*DEBUG* typeInfo1 " + typeInfo1.toString() + + " typeInfo2 " + typeInfo2.toString() + + " arithmeticTestMode " + arithmeticTestMode + + " columnScalarMode " + columnScalarMode + + " vectorExpression " + vectorExpression.toString()); + if (arithmeticTestMode == ArithmeticTestMode.VECTOR_EXPRESSION && + columnScalarMode == ColumnScalarMode.COLUMN_SCALAR && + arithmetic == Arithmetic.DIVIDE) { + System.out.println("here"); + } + batchSource.resetBatchIteration(); + int rowIndex = 0; + while (true) { + if (!batchSource.fillNextBatch(batch)) { + break; + } + vectorExpression.evaluate(batch); + extractResultObjects(batch, rowIndex, resultVectorExtractRow, scrqtchRow, + objectInspector, resultObjects); + rowIndex += batch.size; + } + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmeticExpressions.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmeticExpressions.java index f5491af..a716224 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmeticExpressions.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorArithmeticExpressions.java @@ -54,6 +54,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DecimalScalarSubtractDecimalColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DecimalScalarMultiplyDecimalColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongScalarChecked; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; import org.apache.hadoop.hive.ql.exec.vector.util.VectorizedRowGroupGenUtil; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; diff --git vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java index e81a1ac..6d7ed3e 100644 --- vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java +++ vector-code-gen/src/org/apache/hadoop/hive/tools/GenVectorCode.java @@ -269,14 +269,16 @@ {"ColumnDivideScalar", "Modulo", "double", "long", "%", "CHECKED"}, {"ColumnDivideScalar", "Modulo", "double", "double", "%"}, {"ColumnDivideScalar", "Modulo", "double", "double", "%", "CHECKED"}, - {"ScalarDivideColumn", "Modulo", "long", "long", "%"}, - {"ScalarDivideColumn", "Modulo", "long", "long", "%", "CHECKED"}, + {"ScalarDivideColumn", "Modulo", "long", "long", "%", "MANUAL_DIVIDE_BY_ZERO_CHECK"}, + {"ScalarDivideColumn", "Modulo", "long", "long", "%", "MANUAL_DIVIDE_BY_ZERO_CHECK,CHECKED"}, {"ScalarDivideColumn", "Modulo", "long", "double", "%"}, {"ScalarDivideColumn", "Modulo", "long", "double", "%", "CHECKED"}, {"ScalarDivideColumn", "Modulo", "double", "long", "%"}, {"ScalarDivideColumn", "Modulo", "double", "long", "%", "CHECKED"}, {"ScalarDivideColumn", "Modulo", "double", "double", "%"}, {"ScalarDivideColumn", "Modulo", "double", "double", "%", "CHECKED"}, + {"ColumnDivideColumn", "Modulo", "long", "long", "%", "MANUAL_DIVIDE_BY_ZERO_CHECK"}, + {"ColumnDivideColumn", "Modulo", "long", "long", "%", "MANUAL_DIVIDE_BY_ZERO_CHECK,CHECKED"}, {"ColumnDivideColumn", "Modulo", "long", "double", "%"}, {"ColumnDivideColumn", "Modulo", "long", "double", "%", "CHECKED"}, {"ColumnDivideColumn", "Modulo", "double", "long", "%"}, @@ -2124,7 +2126,7 @@ private void generateColumnUnaryMinus(String[] tdesc) throws Exception { String inputColumnVectorType = this.getColumnVectorType(operandType); String outputColumnVectorType = inputColumnVectorType; String returnType = operandType; - boolean checked = (tdesc.length == 3 && "CHECKED".equals(tdesc[2])); + boolean checked = (tdesc.length == 3 && tdesc[2].contains("CHECKED")); String className = getCamelCaseType(operandType) + "ColUnaryMinus" + (checked ? "Checked" : ""); File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); @@ -2342,7 +2344,7 @@ private void generateColumnArithmeticColumn(String [] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - boolean checked = tdesc.length == 6 && "CHECKED".equals(tdesc[5]); + boolean checked = tdesc.length == 6 && tdesc[5].contains("CHECKED"); String className = getCamelCaseType(operandType1) + "Col" + operatorName + getCamelCaseType(operandType2) + "Column" + (checked ? "Checked" : ""); @@ -2735,6 +2737,7 @@ private void generateColumnArithmeticOperatorColumn(String[] tdesc, String retur templateString = templateString.replaceAll("", operandType2); templateString = templateString.replaceAll("", returnType); templateString = templateString.replaceAll("", getCamelCaseType(returnType)); + templateString = evaluateIfDefined(templateString, ifDefined); writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, @@ -2943,7 +2946,7 @@ private void generateColumnArithmeticScalar(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - boolean checked = tdesc.length == 6 && "CHECKED".equals(tdesc[5]); + boolean checked = tdesc.length == 6 && tdesc[5].contains("CHECKED"); String className = getCamelCaseType(operandType1) + "Col" + operatorName + getCamelCaseType(operandType2) + "Scalar" + (checked ? "Checked" : ""); @@ -3039,7 +3042,7 @@ private void generateScalarArithmeticColumn(String[] tdesc) throws Exception { String operatorName = tdesc[1]; String operandType1 = tdesc[2]; String operandType2 = tdesc[3]; - boolean checked = (tdesc.length == 6 && "CHECKED".equals(tdesc[5])); + boolean checked = (tdesc.length == 6 && tdesc[5].contains("CHECKED")); String className = getCamelCaseType(operandType1) + "Scalar" + operatorName + getCamelCaseType(operandType2) + "Column" + (checked ? "Checked" : ""); @@ -3531,17 +3534,75 @@ private boolean containsDefinedStrings(Set defineSet, String commaDefine return result; } - private int doIfDefinedStatement(String[] lines, int index, Set definedSet, + private boolean matchesDefinedStrings(Set defineSet, Set newIfDefinedSet, + IfDefinedMode ifDefinedMode) { + switch (ifDefinedMode) { + case SINGLE: + case AND_ALL: + for (String candidateString : newIfDefinedSet) { + if (!defineSet.contains(candidateString)) { + return false; + } + } + return true; + case OR_ANY: + for (String candidateString : newIfDefinedSet) { + if (defineSet.contains(candidateString)) { + return true; + } + } + return false; + default: + throw new RuntimeException("Unexpected if defined mode " + ifDefinedMode); + } + } + + public enum IfDefinedMode { + SINGLE, + AND_ALL, + OR_ANY; + } + + private IfDefinedMode parseIfDefinedMode(String newIfDefinedString, Set newIfDefinedSet) { + final String[] newIfDefinedStrings; + final IfDefinedMode ifDefinedMode; + int index = newIfDefinedString.indexOf("&&"); + if (index != -1) { + newIfDefinedStrings = newIfDefinedString.split("&&"); + ifDefinedMode = IfDefinedMode.AND_ALL; + } else { + index = newIfDefinedString.indexOf("||"); + if (index == -1) { + + // One element. + newIfDefinedSet.add(newIfDefinedString); + return IfDefinedMode.SINGLE; + } else { + newIfDefinedStrings = newIfDefinedString.split("\\|\\|"); + ifDefinedMode = IfDefinedMode.OR_ANY; + } + } + for (String newDefinedString : newIfDefinedStrings) { + newIfDefinedSet.add(newDefinedString); + } + return ifDefinedMode; + } + + private int doIfDefinedStatement(String[] lines, int index, Set desiredIfDefinedSet, boolean outerInclude, StringBuilder sb) { String ifLine = lines[index]; final int ifLineNumber = index + 1; - String commaDefinedString = ifLine.substring("#IF ".length()); - boolean includeBody = containsDefinedStrings(definedSet, commaDefinedString); + + String ifDefinedString = ifLine.substring("#IF ".length()); + Set ifDefinedSet = new HashSet(); + IfDefinedMode ifDefinedMode = parseIfDefinedMode(ifDefinedString, ifDefinedSet); + boolean includeBody = matchesDefinedStrings(desiredIfDefinedSet, ifDefinedSet, ifDefinedMode); + index++; final int end = lines.length; while (true) { if (index >= end) { - throw new RuntimeException("Unmatched #IF at line " + index + " for " + commaDefinedString); + throw new RuntimeException("Unmatched #IF at line " + index + " for " + ifDefinedString); } String line = lines[index]; if (line.length() == 0 || line.charAt(0) != '#') { @@ -3556,7 +3617,9 @@ private int doIfDefinedStatement(String[] lines, int index, Set definedS // A pound # statement (IF/ELSE/ENDIF). if (line.startsWith("#IF ")) { // Recurse. - index = doIfDefinedStatement(lines, index, definedSet, outerInclude && includeBody, sb); + index = + doIfDefinedStatement( + lines, index, desiredIfDefinedSet, outerInclude && includeBody, sb); } else if (line.equals("#ELSE")) { // Flip inclusion. includeBody = !includeBody; @@ -3565,10 +3628,10 @@ private int doIfDefinedStatement(String[] lines, int index, Set definedS throw new RuntimeException("Missing defined strings with #ENDIF on line " + (index + 1)); } else if (line.startsWith("#ENDIF ")) { String endCommaDefinedString = line.substring("#ENDIF ".length()); - if (!commaDefinedString.equals(endCommaDefinedString)) { + if (!ifDefinedString.equals(endCommaDefinedString)) { throw new RuntimeException( "#ENDIF defined names \"" + endCommaDefinedString + "\" (line " + ifLineNumber + - " do not match \"" + commaDefinedString + "\" (line " + (index + 1) + ")"); + " do not match \"" + ifDefinedString + "\" (line " + (index + 1) + ")"); } return ++index; } else {