diff --git a/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java index f3f6bd6..4f386b9 100644 --- a/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java +++ b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java @@ -233,6 +233,13 @@ {"FilterStringScalarCompareColumn", "Greater", ">"}, {"FilterStringScalarCompareColumn", "GreaterEqual", ">="}, + {"FilterDecimalColumnCompareScalar", "Equal", "=="}, + {"FilterDecimalColumnCompareScalar", "NotEqual", "!="}, + {"FilterDecimalColumnCompareScalar", "Less", "<"}, + {"FilterDecimalColumnCompareScalar", "LessEqual", "<="}, + {"FilterDecimalColumnCompareScalar", "Greater", ">"}, + {"FilterDecimalColumnCompareScalar", "GreaterEqual", ">="}, + {"StringScalarCompareColumn", "Equal", "=="}, {"StringScalarCompareColumn", "NotEqual", "!="}, {"StringScalarCompareColumn", "Less", "<"}, @@ -591,6 +598,8 @@ private void generate() throws Exception { generateIfExprScalarColumn(tdesc); } else if (tdesc[0].equals("IfExprScalarScalar")) { generateIfExprScalarScalar(tdesc); + } else if (tdesc[0].equals("FilterDecimalColumnCompareScalar")) { + generateFilterDecimalColumnCompareScalar(tdesc); } else { continue; } @@ -1123,6 +1132,27 @@ private void generateScalarArithmeticColumn(String[] tdesc) throws IOException { generateScalarBinaryOperatorColumn(tdesc, returnType, className); } + private void generateFilterDecimalColumnCompareScalar(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "FilterDecimalCol" + operatorName + "DecimalScalar"; + generateDecimalColumnCompareScalar(tdesc, className); + } + + private void generateDecimalColumnCompareScalar(String[] tdesc, String className) + throws IOException { + String operatorSymbol = tdesc[2]; + + // Read the template into a string; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", operatorSymbol); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + static void writeFile(long templateTime, String outputDir, String classesDir, String className, String str) throws IOException { File outputFile = new File(outputDir, className + ".java"); diff --git a/ql/src/gen/vectorization/ExpressionTemplates/FilterDecimalColumnCompareScalar.txt b/ql/src/gen/vectorization/ExpressionTemplates/FilterDecimalColumnCompareScalar.txt new file mode 100644 index 0000000..34d0438 --- /dev/null +++ b/ql/src/gen/vectorization/ExpressionTemplates/FilterDecimalColumnCompareScalar.txt @@ -0,0 +1,174 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; +import org.apache.hadoop.hive.common.type.Decimal128; + +/** + * This is a generated class to evaluate a comparison on a vector of decimal + * values. + */ +public class extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int colNum; + private Decimal128 value; + + public (int colNum, Decimal128 value) { + this.colNum = colNum; + this.value = value; + } + + public () { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + if (childExpressions != null) { + super.evaluateChildren(batch); + } + DecimalColumnVector inputColVector = (DecimalColumnVector) batch.cols[colNum]; + int[] sel = batch.selected; + boolean[] nullPos = inputColVector.isNull; + int n = batch.size; + Decimal128[] vector = inputColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + if (inputColVector.noNulls) { + if (inputColVector.isRepeating) { + + // All must be selected otherwise size would be zero. Repeating property will not change. + if (!(vector[0].compareTo(value) 0)) { + + // Entire batch is filtered out. + batch.size = 0; + } + } else if (batch.selectedInUse) { + int newSize = 0; + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (vector[i].compareTo(value) 0) { + sel[newSize++] = i; + } + } + batch.size = newSize; + } else { + int newSize = 0; + for(int i = 0; i != n; i++) { + if (vector[i].compareTo(value) 0) { + sel[newSize++] = i; + } + } + if (newSize < n) { + batch.size = newSize; + batch.selectedInUse = true; + } + } + } else { + if (inputColVector.isRepeating) { + + // All must be selected otherwise size would be zero. Repeating property will not change. + if (!nullPos[0]) { + if (!(vector[0].compareTo(value) 0)) { + + // Entire batch is filtered out. + batch.size = 0; + } + } else { + batch.size = 0; + } + } else if (batch.selectedInUse) { + int newSize = 0; + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (!nullPos[i]) { + if (vector[i].compareTo(value) 0) { + sel[newSize++] = i; + } + } + } + + // Change the selected vector + batch.size = newSize; + } else { + int newSize = 0; + for(int i = 0; i != n; i++) { + if (!nullPos[i]) { + if (vector[i].compareTo(value) 0) { + sel[newSize++] = i; + } + } + } + if (newSize < n) { + batch.size = newSize; + batch.selectedInUse = true; + } + } + } + } + + @Override + public int getOutputColumn() { + return -1; + } + + @Override + public String getOutputType() { + return "boolean"; + } + + public int getColNum() { + return colNum; + } + + public void setColNum(int colNum) { + this.colNum = colNum; + } + + public Decimal128 getValue() { + return value; + } + + public void setValue(Decimal128 value) { + this.value = value; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.FILTER) + .setNumArguments(2) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("decimal"), + VectorExpressionDescriptor.ArgumentType.getType("decimal")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java index 8560117..4a81a1e 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorFilterExpressions.java @@ -24,7 +24,9 @@ import java.sql.Timestamp; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.TimestampUtils; @@ -41,6 +43,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongScalarLessLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColumnBetween; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColumnNotBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDecimalColEqualDecimalScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongScalar; import org.apache.hadoop.hive.ql.exec.vector.util.VectorizedRowGroupGenUtil; import org.junit.Assert; @@ -784,4 +787,62 @@ public void testFilterStringIn() { expr.evaluate(vrb); assertEquals(0, vrb.size); } + + /** + * This tests the template for Decimal-Scalar comparison filters, + * called FilterDecimalColumnCompareScalar.txt. + * Only equal is tested because the logic is the same for <, >, <=, >=, == and !=. + */ + @Test + public void testFilterDecimalColEqualDecimalScalar() { + VectorizedRowBatch b = getVectorizedRowBatch1DecimalCol(); + Decimal128 scalar = new Decimal128(); + scalar.update("-3.30", (short) 2); + VectorExpression expr = new FilterDecimalColEqualDecimalScalar(0, scalar); + expr.evaluate(b); + + // check that right row(s) are selected + assertTrue(b.selectedInUse); + assertEquals(1, b.selected[0]); + assertEquals(1, b.size); + + // try again with a null value + b = getVectorizedRowBatch1DecimalCol(); + b.cols[0].noNulls = false; + b.cols[0].isNull[1] = true; + expr.evaluate(b); + + // verify that no rows were selected + assertEquals(0, b.size); + + // try the repeating case + b = getVectorizedRowBatch1DecimalCol(); + b.cols[0].isRepeating = true; + expr.evaluate(b); + + // verify that no rows were selected + assertEquals(0, b.size); + + // try the repeating null case + b = getVectorizedRowBatch1DecimalCol(); + b.cols[0].isRepeating = true; + b.cols[0].noNulls = false; + b.cols[0].isNull[0] = true; + expr.evaluate(b); + + // verify that no rows were selected + assertEquals(0, b.size); + } + + private VectorizedRowBatch getVectorizedRowBatch1DecimalCol() { + VectorizedRowBatch b = new VectorizedRowBatch(1); + DecimalColumnVector v0; + b.cols[0] = v0 = new DecimalColumnVector(18, 2); + v0.vector[0].update("1.20", (short) 2); + v0.vector[1].update("-3.30", (short) 2); + v0.vector[2].update("0", (short) 2); + + b.size = 3; + return b; + } }