diff --git a/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java index 1d3c5c4..a286024 100644 --- a/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java +++ b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java @@ -386,9 +386,25 @@ // See org.apache.hadoop.hive.ql.exec.vector.expressions for remaining cast VectorExpression // classes - {"ColumnUnaryMinus", "long"}, - {"ColumnUnaryMinus", "double"}, - + {"ColumnUnaryMinus", "long"}, + {"ColumnUnaryMinus", "double"}, + + // IF conditional expression + // fileHeader, resultType, arg2Type, arg3Type + {"IfExprColumnColumn", "long"}, + {"IfExprColumnColumn", "double"}, + {"IfExprColumnScalar", "long", "long"}, + {"IfExprColumnScalar", "double", "long"}, + {"IfExprColumnScalar", "long", "double"}, + {"IfExprColumnScalar", "double", "double"}, + {"IfExprScalarColumn", "long", "long"}, + {"IfExprScalarColumn", "double", "long"}, + {"IfExprScalarColumn", "long", "double"}, + {"IfExprScalarColumn", "double", "double"}, + {"IfExprScalarScalar", "long", "long"}, + {"IfExprScalarScalar", "double", "long"}, + {"IfExprScalarScalar", "long", "double"}, + {"IfExprScalarScalar", "double", "double"}, // template, , , , , {"VectorUDAFMinMax", "VectorUDAFMinLong", "long", "<", "min", @@ -567,6 +583,14 @@ private void generate() throws Exception { generateFilterStringColumnCompareColumn(tdesc); } else if (tdesc[0].equals("StringColumnCompareColumn")) { generateStringColumnCompareColumn(tdesc); + } else if (tdesc[0].equals("IfExprColumnColumn")) { + generateIfExprColumnColumn(tdesc); + } else if (tdesc[0].equals("IfExprColumnScalar")) { + generateIfExprColumnScalar(tdesc); + } else if (tdesc[0].equals("IfExprScalarColumn")) { + generateIfExprScalarColumn(tdesc); + } else if (tdesc[0].equals("IfExprScalarScalar")) { + generateIfExprScalarScalar(tdesc); } else { continue; } @@ -800,6 +824,89 @@ private void generateColumnUnaryMinus(String[] tdesc) throws IOException { className, templateString); } + private void generateIfExprColumnColumn(String[] tdesc) throws IOException { + String operandType = tdesc[1]; + String inputColumnVectorType = this.getColumnVectorType(operandType); + String outputColumnVectorType = inputColumnVectorType; + String returnType = operandType; + String className = "IfExpr" + getCamelCaseType(operandType) + "Column" + + getCamelCaseType(operandType) + "Column"; + String outputFile = joinPath(this.expressionOutputDirectory, className + ".java"); + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", operandType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateIfExprColumnScalar(String[] tdesc) throws IOException { + String operandType2 = tdesc[1]; + String operandType3 = tdesc[2]; + String arg2ColumnVectorType = this.getColumnVectorType(operandType2); + String returnType = getArithmeticReturnType(operandType2, operandType3); + String outputColumnVectorType = getColumnVectorType(returnType); + String className = "IfExpr" + getCamelCaseType(operandType2) + "Column" + + getCamelCaseType(operandType3) + "Scalar"; + String outputFile = joinPath(this.expressionOutputDirectory, className + ".java"); + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", arg2ColumnVectorType); + templateString = templateString.replaceAll("", returnType); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", operandType3); + templateString = templateString.replaceAll("", outputColumnVectorType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateIfExprScalarColumn(String[] tdesc) throws IOException { + String operandType2 = tdesc[1]; + String operandType3 = tdesc[2]; + String arg3ColumnVectorType = this.getColumnVectorType(operandType3); + String returnType = getArithmeticReturnType(operandType2, operandType3); + String outputColumnVectorType = getColumnVectorType(returnType); + String className = "IfExpr" + getCamelCaseType(operandType2) + "Scalar" + + getCamelCaseType(operandType3) + "Column"; + String outputFile = joinPath(this.expressionOutputDirectory, className + ".java"); + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", arg3ColumnVectorType); + templateString = templateString.replaceAll("", returnType); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", operandType3); + templateString = templateString.replaceAll("", outputColumnVectorType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateIfExprScalarScalar(String[] tdesc) throws IOException { + String operandType2 = tdesc[1]; + String operandType3 = tdesc[2]; + String arg3ColumnVectorType = this.getColumnVectorType(operandType3); + String returnType = getArithmeticReturnType(operandType2, operandType3); + String outputColumnVectorType = getColumnVectorType(returnType); + String className = "IfExpr" + getCamelCaseType(operandType2) + "Scalar" + + getCamelCaseType(operandType3) + "Scalar"; + String outputFile = joinPath(this.expressionOutputDirectory, className + ".java"); + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", returnType); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", operandType3); + templateString = templateString.replaceAll("", outputColumnVectorType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + // template, , , , , , private void generateColumnUnaryFunc(String[] tdesc) throws IOException { String classNamePrefix = tdesc[1]; diff --git a/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java.orig b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java.orig new file mode 100644 index 0000000..1d3c5c4 --- /dev/null +++ b/ant/src/org/apache/hadoop/hive/ant/GenVectorCode.java.orig @@ -0,0 +1,1110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ant; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; + +import org.apache.tools.ant.BuildException; +import org.apache.tools.ant.Task; + +/** + * This class generates java classes from the templates. + */ +public class GenVectorCode extends Task { + + private static String [][] templateExpansions = + { + {"ColumnArithmeticScalar", "Add", "long", "long", "+"}, + {"ColumnArithmeticScalar", "Subtract", "long", "long", "-"}, + {"ColumnArithmeticScalar", "Multiply", "long", "long", "*"}, + {"ColumnArithmeticScalar", "Modulo", "long", "long", "%"}, + + {"ColumnArithmeticScalar", "Add", "long", "double", "+"}, + {"ColumnArithmeticScalar", "Subtract", "long", "double", "-"}, + {"ColumnArithmeticScalar", "Multiply", "long", "double", "*"}, + {"ColumnArithmeticScalar", "Modulo", "long", "double", "%"}, + + {"ColumnArithmeticScalar", "Add", "double", "long", "+"}, + {"ColumnArithmeticScalar", "Subtract", "double", "long", "-"}, + {"ColumnArithmeticScalar", "Multiply", "double", "long", "*"}, + {"ColumnArithmeticScalar", "Modulo", "double", "long", "%"}, + + {"ColumnArithmeticScalar", "Add", "double", "double", "+"}, + {"ColumnArithmeticScalar", "Subtract", "double", "double", "-"}, + {"ColumnArithmeticScalar", "Multiply", "double", "double", "*"}, + {"ColumnArithmeticScalar", "Modulo", "double", "double", "%"}, + + {"ScalarArithmeticColumn", "Add", "long", "long", "+"}, + {"ScalarArithmeticColumn", "Subtract", "long", "long", "-"}, + {"ScalarArithmeticColumn", "Multiply", "long", "long", "*"}, + {"ScalarArithmeticColumn", "Modulo", "long", "long", "%"}, + + {"ScalarArithmeticColumn", "Add", "long", "double", "+"}, + {"ScalarArithmeticColumn", "Subtract", "long", "double", "-"}, + {"ScalarArithmeticColumn", "Multiply", "long", "double", "*"}, + {"ScalarArithmeticColumn", "Modulo", "long", "double", "%"}, + + {"ScalarArithmeticColumn", "Add", "double", "long", "+"}, + {"ScalarArithmeticColumn", "Subtract", "double", "long", "-"}, + {"ScalarArithmeticColumn", "Multiply", "double", "long", "*"}, + {"ScalarArithmeticColumn", "Modulo", "double", "long", "%"}, + + {"ScalarArithmeticColumn", "Add", "double", "double", "+"}, + {"ScalarArithmeticColumn", "Subtract", "double", "double", "-"}, + {"ScalarArithmeticColumn", "Multiply", "double", "double", "*"}, + {"ScalarArithmeticColumn", "Modulo", "double", "double", "%"}, + + {"ColumnArithmeticColumn", "Add", "long", "long", "+"}, + {"ColumnArithmeticColumn", "Subtract", "long", "long", "-"}, + {"ColumnArithmeticColumn", "Multiply", "long", "long", "*"}, + {"ColumnArithmeticColumn", "Modulo", "long", "long", "%"}, + + {"ColumnArithmeticColumn", "Add", "long", "double", "+"}, + {"ColumnArithmeticColumn", "Subtract", "long", "double", "-"}, + {"ColumnArithmeticColumn", "Multiply", "long", "double", "*"}, + {"ColumnArithmeticColumn", "Modulo", "long", "double", "%"}, + + {"ColumnArithmeticColumn", "Add", "double", "long", "+"}, + {"ColumnArithmeticColumn", "Subtract", "double", "long", "-"}, + {"ColumnArithmeticColumn", "Multiply", "double", "long", "*"}, + {"ColumnArithmeticColumn", "Modulo", "double", "long", "%"}, + + {"ColumnArithmeticColumn", "Add", "double", "double", "+"}, + {"ColumnArithmeticColumn", "Subtract", "double", "double", "-"}, + {"ColumnArithmeticColumn", "Multiply", "double", "double", "*"}, + {"ColumnArithmeticColumn", "Modulo", "double", "double", "%"}, + + {"ColumnDivideScalar", "Divide", "long", "double", "/"}, + {"ColumnDivideScalar", "Divide", "double", "long", "/"}, + {"ColumnDivideScalar", "Divide", "double", "double", "/"}, + {"ScalarDivideColumn", "Divide", "long", "double", "/"}, + {"ScalarDivideColumn", "Divide", "double", "long", "/"}, + {"ScalarDivideColumn", "Divide", "double", "double", "/"}, + {"ColumnDivideColumn", "Divide", "long", "double", "/"}, + {"ColumnDivideColumn", "Divide", "double", "long", "/"}, + {"ColumnDivideColumn", "Divide", "double", "double", "/"}, + + {"ColumnCompareScalar", "Equal", "long", "double", "=="}, + {"ColumnCompareScalar", "Equal", "double", "double", "=="}, + {"ColumnCompareScalar", "NotEqual", "long", "double", "!="}, + {"ColumnCompareScalar", "NotEqual", "double", "double", "!="}, + {"ColumnCompareScalar", "Less", "long", "double", "<"}, + {"ColumnCompareScalar", "Less", "double", "double", "<"}, + {"ColumnCompareScalar", "LessEqual", "long", "double", "<="}, + {"ColumnCompareScalar", "LessEqual", "double", "double", "<="}, + {"ColumnCompareScalar", "Greater", "long", "double", ">"}, + {"ColumnCompareScalar", "Greater", "double", "double", ">"}, + {"ColumnCompareScalar", "GreaterEqual", "long", "double", ">="}, + {"ColumnCompareScalar", "GreaterEqual", "double", "double", ">="}, + + {"ColumnCompareScalar", "Equal", "long", "long", "=="}, + {"ColumnCompareScalar", "Equal", "double", "long", "=="}, + {"ColumnCompareScalar", "NotEqual", "long", "long", "!="}, + {"ColumnCompareScalar", "NotEqual", "double", "long", "!="}, + {"ColumnCompareScalar", "Less", "long", "long", "<"}, + {"ColumnCompareScalar", "Less", "double", "long", "<"}, + {"ColumnCompareScalar", "LessEqual", "long", "long", "<="}, + {"ColumnCompareScalar", "LessEqual", "double", "long", "<="}, + {"ColumnCompareScalar", "Greater", "long", "long", ">"}, + {"ColumnCompareScalar", "Greater", "double", "long", ">"}, + {"ColumnCompareScalar", "GreaterEqual", "long", "long", ">="}, + {"ColumnCompareScalar", "GreaterEqual", "double", "long", ">="}, + + {"ScalarCompareColumn", "Equal", "long", "double", "=="}, + {"ScalarCompareColumn", "Equal", "double", "double", "=="}, + {"ScalarCompareColumn", "NotEqual", "long", "double", "!="}, + {"ScalarCompareColumn", "NotEqual", "double", "double", "!="}, + {"ScalarCompareColumn", "Less", "long", "double", "<"}, + {"ScalarCompareColumn", "Less", "double", "double", "<"}, + {"ScalarCompareColumn", "LessEqual", "long", "double", "<="}, + {"ScalarCompareColumn", "LessEqual", "double", "double", "<="}, + {"ScalarCompareColumn", "Greater", "long", "double", ">"}, + {"ScalarCompareColumn", "Greater", "double", "double", ">"}, + {"ScalarCompareColumn", "GreaterEqual", "long", "double", ">="}, + {"ScalarCompareColumn", "GreaterEqual", "double", "double", ">="}, + + {"ScalarCompareColumn", "Equal", "long", "long", "=="}, + {"ScalarCompareColumn", "Equal", "double", "long", "=="}, + {"ScalarCompareColumn", "NotEqual", "long", "long", "!="}, + {"ScalarCompareColumn", "NotEqual", "double", "long", "!="}, + {"ScalarCompareColumn", "Less", "long", "long", "<"}, + {"ScalarCompareColumn", "Less", "double", "long", "<"}, + {"ScalarCompareColumn", "LessEqual", "long", "long", "<="}, + {"ScalarCompareColumn", "LessEqual", "double", "long", "<="}, + {"ScalarCompareColumn", "Greater", "long", "long", ">"}, + {"ScalarCompareColumn", "Greater", "double", "long", ">"}, + {"ScalarCompareColumn", "GreaterEqual", "long", "long", ">="}, + {"ScalarCompareColumn", "GreaterEqual", "double", "long", ">="}, + + {"FilterColumnCompareScalar", "Equal", "long", "double", "=="}, + {"FilterColumnCompareScalar", "Equal", "double", "double", "=="}, + {"FilterColumnCompareScalar", "NotEqual", "long", "double", "!="}, + {"FilterColumnCompareScalar", "NotEqual", "double", "double", "!="}, + {"FilterColumnCompareScalar", "Less", "long", "double", "<"}, + {"FilterColumnCompareScalar", "Less", "double", "double", "<"}, + {"FilterColumnCompareScalar", "LessEqual", "long", "double", "<="}, + {"FilterColumnCompareScalar", "LessEqual", "double", "double", "<="}, + {"FilterColumnCompareScalar", "Greater", "long", "double", ">"}, + {"FilterColumnCompareScalar", "Greater", "double", "double", ">"}, + {"FilterColumnCompareScalar", "GreaterEqual", "long", "double", ">="}, + {"FilterColumnCompareScalar", "GreaterEqual", "double", "double", ">="}, + + {"FilterColumnCompareScalar", "Equal", "long", "long", "=="}, + {"FilterColumnCompareScalar", "Equal", "double", "long", "=="}, + {"FilterColumnCompareScalar", "NotEqual", "long", "long", "!="}, + {"FilterColumnCompareScalar", "NotEqual", "double", "long", "!="}, + {"FilterColumnCompareScalar", "Less", "long", "long", "<"}, + {"FilterColumnCompareScalar", "Less", "double", "long", "<"}, + {"FilterColumnCompareScalar", "LessEqual", "long", "long", "<="}, + {"FilterColumnCompareScalar", "LessEqual", "double", "long", "<="}, + {"FilterColumnCompareScalar", "Greater", "long", "long", ">"}, + {"FilterColumnCompareScalar", "Greater", "double", "long", ">"}, + {"FilterColumnCompareScalar", "GreaterEqual", "long", "long", ">="}, + {"FilterColumnCompareScalar", "GreaterEqual", "double", "long", ">="}, + + {"FilterScalarCompareColumn", "Equal", "long", "double", "=="}, + {"FilterScalarCompareColumn", "Equal", "double", "double", "=="}, + {"FilterScalarCompareColumn", "NotEqual", "long", "double", "!="}, + {"FilterScalarCompareColumn", "NotEqual", "double", "double", "!="}, + {"FilterScalarCompareColumn", "Less", "long", "double", "<"}, + {"FilterScalarCompareColumn", "Less", "double", "double", "<"}, + {"FilterScalarCompareColumn", "LessEqual", "long", "double", "<="}, + {"FilterScalarCompareColumn", "LessEqual", "double", "double", "<="}, + {"FilterScalarCompareColumn", "Greater", "long", "double", ">"}, + {"FilterScalarCompareColumn", "Greater", "double", "double", ">"}, + {"FilterScalarCompareColumn", "GreaterEqual", "long", "double", ">="}, + {"FilterScalarCompareColumn", "GreaterEqual", "double", "double", ">="}, + + {"FilterScalarCompareColumn", "Equal", "long", "long", "=="}, + {"FilterScalarCompareColumn", "Equal", "double", "long", "=="}, + {"FilterScalarCompareColumn", "NotEqual", "long", "long", "!="}, + {"FilterScalarCompareColumn", "NotEqual", "double", "long", "!="}, + {"FilterScalarCompareColumn", "Less", "long", "long", "<"}, + {"FilterScalarCompareColumn", "Less", "double", "long", "<"}, + {"FilterScalarCompareColumn", "LessEqual", "long", "long", "<="}, + {"FilterScalarCompareColumn", "LessEqual", "double", "long", "<="}, + {"FilterScalarCompareColumn", "Greater", "long", "long", ">"}, + {"FilterScalarCompareColumn", "Greater", "double", "long", ">"}, + {"FilterScalarCompareColumn", "GreaterEqual", "long", "long", ">="}, + {"FilterScalarCompareColumn", "GreaterEqual", "double", "long", ">="}, + + {"FilterStringColumnCompareScalar", "Equal", "=="}, + {"FilterStringColumnCompareScalar", "NotEqual", "!="}, + {"FilterStringColumnCompareScalar", "Less", "<"}, + {"FilterStringColumnCompareScalar", "LessEqual", "<="}, + {"FilterStringColumnCompareScalar", "Greater", ">"}, + {"FilterStringColumnCompareScalar", "GreaterEqual", ">="}, + + {"FilterStringColumnBetween", ""}, + {"FilterStringColumnBetween", "!"}, + + {"StringColumnCompareScalar", "Equal", "=="}, + {"StringColumnCompareScalar", "NotEqual", "!="}, + {"StringColumnCompareScalar", "Less", "<"}, + {"StringColumnCompareScalar", "LessEqual", "<="}, + {"StringColumnCompareScalar", "Greater", ">"}, + {"StringColumnCompareScalar", "GreaterEqual", ">="}, + + {"FilterStringScalarCompareColumn", "Equal", "=="}, + {"FilterStringScalarCompareColumn", "NotEqual", "!="}, + {"FilterStringScalarCompareColumn", "Less", "<"}, + {"FilterStringScalarCompareColumn", "LessEqual", "<="}, + {"FilterStringScalarCompareColumn", "Greater", ">"}, + {"FilterStringScalarCompareColumn", "GreaterEqual", ">="}, + + {"StringScalarCompareColumn", "Equal", "=="}, + {"StringScalarCompareColumn", "NotEqual", "!="}, + {"StringScalarCompareColumn", "Less", "<"}, + {"StringScalarCompareColumn", "LessEqual", "<="}, + {"StringScalarCompareColumn", "Greater", ">"}, + {"StringScalarCompareColumn", "GreaterEqual", ">="}, + + {"FilterStringColumnCompareColumn", "Equal", "=="}, + {"FilterStringColumnCompareColumn", "NotEqual", "!="}, + {"FilterStringColumnCompareColumn", "Less", "<"}, + {"FilterStringColumnCompareColumn", "LessEqual", "<="}, + {"FilterStringColumnCompareColumn", "Greater", ">"}, + {"FilterStringColumnCompareColumn", "GreaterEqual", ">="}, + + {"StringColumnCompareColumn", "Equal", "=="}, + {"StringColumnCompareColumn", "NotEqual", "!="}, + {"StringColumnCompareColumn", "Less", "<"}, + {"StringColumnCompareColumn", "LessEqual", "<="}, + {"StringColumnCompareColumn", "Greater", ">"}, + {"StringColumnCompareColumn", "GreaterEqual", ">="}, + + {"FilterColumnCompareColumn", "Equal", "long", "double", "=="}, + {"FilterColumnCompareColumn", "Equal", "double", "double", "=="}, + {"FilterColumnCompareColumn", "NotEqual", "long", "double", "!="}, + {"FilterColumnCompareColumn", "NotEqual", "double", "double", "!="}, + {"FilterColumnCompareColumn", "Less", "long", "double", "<"}, + {"FilterColumnCompareColumn", "Less", "double", "double", "<"}, + {"FilterColumnCompareColumn", "LessEqual", "long", "double", "<="}, + {"FilterColumnCompareColumn", "LessEqual", "double", "double", "<="}, + {"FilterColumnCompareColumn", "Greater", "long", "double", ">"}, + {"FilterColumnCompareColumn", "Greater", "double", "double", ">"}, + {"FilterColumnCompareColumn", "GreaterEqual", "long", "double", ">="}, + {"FilterColumnCompareColumn", "GreaterEqual", "double", "double", ">="}, + + {"FilterColumnCompareColumn", "Equal", "long", "long", "=="}, + {"FilterColumnCompareColumn", "Equal", "double", "long", "=="}, + {"FilterColumnCompareColumn", "NotEqual", "long", "long", "!="}, + {"FilterColumnCompareColumn", "NotEqual", "double", "long", "!="}, + {"FilterColumnCompareColumn", "Less", "long", "long", "<"}, + {"FilterColumnCompareColumn", "Less", "double", "long", "<"}, + {"FilterColumnCompareColumn", "LessEqual", "long", "long", "<="}, + {"FilterColumnCompareColumn", "LessEqual", "double", "long", "<="}, + {"FilterColumnCompareColumn", "Greater", "long", "long", ">"}, + {"FilterColumnCompareColumn", "Greater", "double", "long", ">"}, + {"FilterColumnCompareColumn", "GreaterEqual", "long", "long", ">="}, + {"FilterColumnCompareColumn", "GreaterEqual", "double", "long", ">="}, + + {"FilterColumnBetween", "long", ""}, + {"FilterColumnBetween", "double", ""}, + {"FilterColumnBetween", "long", "!"}, + {"FilterColumnBetween", "double", "!"}, + + {"ColumnCompareColumn", "Equal", "long", "double", "=="}, + {"ColumnCompareColumn", "Equal", "double", "double", "=="}, + {"ColumnCompareColumn", "NotEqual", "long", "double", "!="}, + {"ColumnCompareColumn", "NotEqual", "double", "double", "!="}, + {"ColumnCompareColumn", "Less", "long", "double", "<"}, + {"ColumnCompareColumn", "Less", "double", "double", "<"}, + {"ColumnCompareColumn", "LessEqual", "long", "double", "<="}, + {"ColumnCompareColumn", "LessEqual", "double", "double", "<="}, + {"ColumnCompareColumn", "Greater", "long", "double", ">"}, + {"ColumnCompareColumn", "Greater", "double", "double", ">"}, + {"ColumnCompareColumn", "GreaterEqual", "long", "double", ">="}, + {"ColumnCompareColumn", "GreaterEqual", "double", "double", ">="}, + + {"ColumnCompareColumn", "Equal", "long", "long", "=="}, + {"ColumnCompareColumn", "Equal", "double", "long", "=="}, + {"ColumnCompareColumn", "NotEqual", "long", "long", "!="}, + {"ColumnCompareColumn", "NotEqual", "double", "long", "!="}, + {"ColumnCompareColumn", "Less", "long", "long", "<"}, + {"ColumnCompareColumn", "Less", "double", "long", "<"}, + {"ColumnCompareColumn", "LessEqual", "long", "long", "<="}, + {"ColumnCompareColumn", "LessEqual", "double", "long", "<="}, + {"ColumnCompareColumn", "Greater", "long", "long", ">"}, + {"ColumnCompareColumn", "Greater", "double", "long", ">"}, + {"ColumnCompareColumn", "GreaterEqual", "long", "long", ">="}, + {"ColumnCompareColumn", "GreaterEqual", "double", "long", ">="}, + + // template, , , , , , + // , + {"ColumnUnaryFunc", "FuncRound", "double", "double", "MathExpr.round", "", "", ""}, + // round(longCol) returns a long and is a no-op. So it will not be implemented here. + // round(Col, N) is a special case and will be implemented separately from this template + {"ColumnUnaryFunc", "FuncFloor", "long", "double", "Math.floor", "", "(long)", ""}, + // Floor on an integer argument is a noop, but it is less code to handle it this way. + {"ColumnUnaryFunc", "FuncFloor", "long", "long", "Math.floor", "", "(long)", ""}, + {"ColumnUnaryFunc", "FuncCeil", "long", "double", "Math.ceil", "", "(long)", ""}, + // Ceil on an integer argument is a noop, but it is less code to handle it this way. + {"ColumnUnaryFunc", "FuncCeil", "long", "long", "Math.ceil", "", "(long)", ""}, + {"ColumnUnaryFunc", "FuncExp", "double", "double", "Math.exp", "", "", ""}, + {"ColumnUnaryFunc", "FuncExp", "double", "long", "Math.exp", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncLn", "double", "double", "Math.log", "", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncLn", "double", "long", "Math.log", "(double)", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncLog10", "double", "double", "Math.log10", "", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncLog10", "double", "long", "Math.log10", "(double)", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + // The MathExpr class contains helper functions for cases when existing library + // routines can't be used directly. + {"ColumnUnaryFunc", "FuncLog2", "double", "double", "MathExpr.log2", "", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncLog2", "double", "long", "MathExpr.log2", "(double)", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + // Log(base, Col) is a special case and will be implemented separately from this template + // Pow(col, P) and Power(col, P) are special cases implemented separately from this template + {"ColumnUnaryFunc", "FuncSqrt", "double", "double", "Math.sqrt", "", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncSqrt", "double", "long", "Math.sqrt", "(double)", "", + "MathExpr.NaNToNull(outputColVector, sel, batch.selectedInUse, n);"}, + {"ColumnUnaryFunc", "FuncAbs", "double", "double", "Math.abs", "", "", ""}, + {"ColumnUnaryFunc", "FuncAbs", "long", "long", "MathExpr.abs", "", "", ""}, + {"ColumnUnaryFunc", "FuncSin", "double", "double", "Math.sin", "", "", ""}, + {"ColumnUnaryFunc", "FuncSin", "double", "long", "Math.sin", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncASin", "double", "double", "Math.asin", "", "", ""}, + {"ColumnUnaryFunc", "FuncASin", "double", "long", "Math.asin", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncCos", "double", "double", "Math.cos", "", "", ""}, + {"ColumnUnaryFunc", "FuncCos", "double", "long", "Math.cos", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncACos", "double", "double", "Math.acos", "", "", ""}, + {"ColumnUnaryFunc", "FuncACos", "double", "long", "Math.acos", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncTan", "double", "double", "Math.tan", "", "", ""}, + {"ColumnUnaryFunc", "FuncTan", "double", "long", "Math.tan", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncATan", "double", "double", "Math.atan", "", "", ""}, + {"ColumnUnaryFunc", "FuncATan", "double", "long", "Math.atan", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncDegrees", "double", "double", "Math.toDegrees", "", "", ""}, + {"ColumnUnaryFunc", "FuncDegrees", "double", "long", "Math.toDegrees", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncRadians", "double", "double", "Math.toRadians", "", "", ""}, + {"ColumnUnaryFunc", "FuncRadians", "double", "long", "Math.toRadians", "(double)", "", ""}, + {"ColumnUnaryFunc", "FuncSign", "double", "double", "MathExpr.sign", "", "", ""}, + {"ColumnUnaryFunc", "FuncSign", "double", "long", "MathExpr.sign", "(double)", "", ""}, + + // Casts + {"ColumnUnaryFunc", "Cast", "long", "double", "", "", "(long)", ""}, + {"ColumnUnaryFunc", "Cast", "double", "long", "", "", "(double)", ""}, + {"ColumnUnaryFunc", "CastTimestampToLongVia", "long", "long", "MathExpr.fromTimestamp", "", + "", ""}, + {"ColumnUnaryFunc", "CastTimestampToDoubleVia", "double", "long", + "MathExpr.fromTimestampToDouble", "", "", ""}, + {"ColumnUnaryFunc", "CastDoubleToBooleanVia", "long", "double", "MathExpr.toBool", "", + "", ""}, + {"ColumnUnaryFunc", "CastLongToBooleanVia", "long", "long", "MathExpr.toBool", "", + "", ""}, + {"ColumnUnaryFunc", "CastLongToTimestampVia", "long", "long", "MathExpr.longToTimestamp", "", + "", ""}, + {"ColumnUnaryFunc", "CastDoubleToTimestampVia", "long", "double", + "MathExpr.doubleToTimestamp", "", "", ""}, + + // Boolean to long is done with an IdentityExpression + // Boolean to double is done with standard Long to Double cast + // See org.apache.hadoop.hive.ql.exec.vector.expressions for remaining cast VectorExpression + // classes + + {"ColumnUnaryMinus", "long"}, + {"ColumnUnaryMinus", "double"}, + + + // template, , , , , + {"VectorUDAFMinMax", "VectorUDAFMinLong", "long", "<", "min", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: long)"}, + {"VectorUDAFMinMax", "VectorUDAFMinDouble", "double", "<", "min", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: double)"}, + {"VectorUDAFMinMax", "VectorUDAFMaxLong", "long", ">", "max", + "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: long)"}, + {"VectorUDAFMinMax", "VectorUDAFMaxDouble", "double", ">", "max", + "_FUNC_(expr) - Returns the maximum value of expr (vectorized, type: double)"}, + + {"VectorUDAFMinMaxString", "VectorUDAFMinString", "<", "min", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: string)"}, + {"VectorUDAFMinMaxString", "VectorUDAFMaxString", ">", "max", + "_FUNC_(expr) - Returns the minimum value of expr (vectorized, type: string)"}, + + //template, , + {"VectorUDAFSum", "VectorUDAFSumLong", "long"}, + {"VectorUDAFSum", "VectorUDAFSumDouble", "double"}, + {"VectorUDAFAvg", "VectorUDAFAvgLong", "long"}, + {"VectorUDAFAvg", "VectorUDAFAvgDouble", "double"}, + + // template, , , , , + // + {"VectorUDAFVar", "VectorUDAFVarPopLong", "long", "myagg.variance / myagg.count", + "variance, var_pop", + "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, long)"}, + {"VectorUDAFVar", "VectorUDAFVarPopDouble", "double", "myagg.variance / myagg.count", + "variance, var_pop", + "_FUNC_(x) - Returns the variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVar", "VectorUDAFVarSampLong", "long", "myagg.variance / (myagg.count-1.0)", + "var_samp", + "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, long)"}, + {"VectorUDAFVar", "VectorUDAFVarSampDouble", "double", "myagg.variance / (myagg.count-1.0)", + "var_samp", + "_FUNC_(x) - Returns the sample variance of a set of numbers (vectorized, double)"}, + {"VectorUDAFVar", "VectorUDAFStdPopLong", "long", + "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", + "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, long)"}, + {"VectorUDAFVar", "VectorUDAFStdPopDouble", "double", + "Math.sqrt(myagg.variance / (myagg.count))", "std,stddev,stddev_pop", + "_FUNC_(x) - Returns the standard deviation of a set of numbers (vectorized, double)"}, + {"VectorUDAFVar", "VectorUDAFStdSampLong", "long", + "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", + "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, long)"}, + {"VectorUDAFVar", "VectorUDAFStdSampDouble", "double", + "Math.sqrt(myagg.variance / (myagg.count-1.0))", "stddev_samp", + "_FUNC_(x) - Returns the sample standard deviation of a set of numbers (vectorized, double)"}, + + }; + + + private String templateBaseDir; + private String buildDir; + + private String expressionOutputDirectory; + private String expressionClassesDirectory; + private String expressionTemplateDirectory; + private String udafOutputDirectory; + private String udafClassesDirectory; + private String udafTemplateDirectory; + private GenVectorTestCode testCodeGen; + + static String joinPath(String...parts) { + String path = parts[0]; + for (int i=1; i < parts.length; ++i) { + path += File.separatorChar + parts[i]; + } + return path; + } + + public void init(String templateBaseDir, String buildDir) { + File generationDirectory = new File(templateBaseDir); + + String buildPath = joinPath(buildDir, "generated-sources", "java"); + String compiledPath = joinPath(buildDir, "classes"); + + String expression = joinPath("org", "apache", "hadoop", + "hive", "ql", "exec", "vector", "expressions", "gen"); + File exprOutput = new File(joinPath(buildPath, expression)); + File exprClasses = new File(joinPath(compiledPath, expression)); + expressionOutputDirectory = exprOutput.getAbsolutePath(); + expressionClassesDirectory = exprClasses.getAbsolutePath(); + + expressionTemplateDirectory = + joinPath(generationDirectory.getAbsolutePath(), "ExpressionTemplates"); + + String udaf = joinPath("org", "apache", "hadoop", + "hive", "ql", "exec", "vector", "expressions", "aggregates", "gen"); + File udafOutput = new File(joinPath(buildPath, udaf)); + File udafClasses = new File(joinPath(compiledPath, udaf)); + udafOutputDirectory = udafOutput.getAbsolutePath(); + udafClassesDirectory = udafClasses.getAbsolutePath(); + + udafTemplateDirectory = + joinPath(generationDirectory.getAbsolutePath(), "UDAFTemplates"); + + File testCodeOutput = + new File( + joinPath(buildDir, "generated-test-sources", "java", "org", + "apache", "hadoop", "hive", "ql", "exec", "vector", + "expressions", "gen")); + testCodeGen = new GenVectorTestCode(testCodeOutput.getAbsolutePath(), + joinPath(generationDirectory.getAbsolutePath(), "TestTemplates")); + } + + /** + * @param args + * @throws Exception + */ + public static void main(String[] args) throws Exception { + GenVectorCode gen = new GenVectorCode(); + gen.init(System.getProperty("user.dir"), + joinPath(System.getProperty("user.dir"), "..", "..", "..", "..", "build")); + gen.generate(); + } + + @Override + public void execute() throws BuildException { + init(templateBaseDir, buildDir); + try { + this.generate(); + } catch (Exception e) { + throw new BuildException(e); + } + } + + private void generate() throws Exception { + System.out.println("Generating vector expression code"); + for (String [] tdesc : templateExpansions) { + if (tdesc[0].equals("ColumnArithmeticScalar") || tdesc[0].equals("ColumnDivideScalar")) { + generateColumnArithmeticScalar(tdesc); + } else if (tdesc[0].equals("ColumnCompareScalar")) { + generateColumnCompareScalar(tdesc); + } else if (tdesc[0].equals("ScalarCompareColumn")) { + generateScalarCompareColumn(tdesc); + } else if (tdesc[0].equals("FilterColumnCompareScalar")) { + generateFilterColumnCompareScalar(tdesc); + } else if (tdesc[0].equals("FilterScalarCompareColumn")) { + generateFilterScalarCompareColumn(tdesc); + } else if (tdesc[0].equals("FilterColumnBetween")) { + generateFilterColumnBetween(tdesc); + } else if (tdesc[0].equals("ScalarArithmeticColumn") || tdesc[0].equals("ScalarDivideColumn")) { + generateScalarArithmeticColumn(tdesc); + } else if (tdesc[0].equals("FilterColumnCompareColumn")) { + generateFilterColumnCompareColumn(tdesc); + } else if (tdesc[0].equals("ColumnCompareColumn")) { + generateColumnCompareColumn(tdesc); + } else if (tdesc[0].equals("ColumnArithmeticColumn") || tdesc[0].equals("ColumnDivideColumn")) { + generateColumnArithmeticColumn(tdesc); + } else if (tdesc[0].equals("ColumnUnaryMinus")) { + generateColumnUnaryMinus(tdesc); + } else if (tdesc[0].equals("ColumnUnaryFunc")) { + generateColumnUnaryFunc(tdesc); + } else if (tdesc[0].equals("VectorUDAFMinMax")) { + generateVectorUDAFMinMax(tdesc); + } else if (tdesc[0].equals("VectorUDAFMinMaxString")) { + generateVectorUDAFMinMaxString(tdesc); + } else if (tdesc[0].equals("VectorUDAFSum")) { + generateVectorUDAFSum(tdesc); + } else if (tdesc[0].equals("VectorUDAFAvg")) { + generateVectorUDAFAvg(tdesc); + } else if (tdesc[0].equals("VectorUDAFVar")) { + generateVectorUDAFVar(tdesc); + } else if (tdesc[0].equals("FilterStringColumnCompareScalar")) { + generateFilterStringColumnCompareScalar(tdesc); + } else if (tdesc[0].equals("FilterStringColumnBetween")) { + generateFilterStringColumnBetween(tdesc); + } else if (tdesc[0].equals("StringColumnCompareScalar")) { + generateStringColumnCompareScalar(tdesc); + } else if (tdesc[0].equals("FilterStringScalarCompareColumn")) { + generateFilterStringScalarCompareColumn(tdesc); + } else if (tdesc[0].equals("StringScalarCompareColumn")) { + generateStringScalarCompareColumn(tdesc); + } else if (tdesc[0].equals("FilterStringColumnCompareColumn")) { + generateFilterStringColumnCompareColumn(tdesc); + } else if (tdesc[0].equals("StringColumnCompareColumn")) { + generateStringColumnCompareColumn(tdesc); + } else { + continue; + } + } + System.out.println("Generating vector expression test code"); + testCodeGen.generateTestSuites(); + } + + private void generateFilterStringColumnBetween(String[] tdesc) throws IOException { + String optionalNot = tdesc[1]; + String className = "FilterStringColumn" + (optionalNot.equals("!") ? "Not" : "") + + "Between"; + // Read the template into a string, expand it, and write it. + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", optionalNot); + + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateFilterColumnBetween(String[] tdesc) throws IOException { + String operandType = tdesc[1]; + String optionalNot = tdesc[2]; + + String className = "Filter" + getCamelCaseType(operandType) + "Column" + + (optionalNot.equals("!") ? "Not" : "") + "Between"; + String inputColumnVectorType = getColumnVectorType(operandType); + + // Read the template into a string, expand it, and write it. + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", operandType); + templateString = templateString.replaceAll("", optionalNot); + + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateColumnCompareColumn(String[] tdesc) throws IOException { + //The variables are all same as ColumnCompareScalar except that + //this template doesn't need a return type. Pass anything as return type. + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Column"; + generateColumnBinaryOperatorColumn(tdesc, "long", className); + } + + private void generateVectorUDAFMinMax(String[] tdesc) throws Exception { + String className = tdesc[1]; + String valueType = tdesc[2]; + String operatorSymbol = tdesc[3]; + String descName = tdesc[4]; + String descValue = tdesc[5]; + String columnType = getColumnVectorType(valueType); + String writableType = getOutputWritableType(valueType); + String inspectorType = getOutputObjectInspector(valueType); + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", valueType); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", columnType); + templateString = templateString.replaceAll("", descName); + templateString = templateString.replaceAll("", descValue); + templateString = templateString.replaceAll("", writableType); + templateString = templateString.replaceAll("", inspectorType); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateVectorUDAFMinMaxString(String[] tdesc) throws Exception { + String className = tdesc[1]; + String operatorSymbol = tdesc[2]; + String descName = tdesc[3]; + String descValue = tdesc[4]; + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", descName); + templateString = templateString.replaceAll("", descValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateVectorUDAFSum(String[] tdesc) throws Exception { + //template, , , , + String className = tdesc[1]; + String valueType = tdesc[2]; + String columnType = getColumnVectorType(valueType); + String writableType = getOutputWritableType(valueType); + String inspectorType = getOutputObjectInspector(valueType); + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", valueType); + templateString = templateString.replaceAll("", columnType); + templateString = templateString.replaceAll("", writableType); + templateString = templateString.replaceAll("", inspectorType); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateVectorUDAFAvg(String[] tdesc) throws IOException { + String className = tdesc[1]; + String valueType = tdesc[2]; + String columnType = getColumnVectorType(valueType); + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", valueType); + templateString = templateString.replaceAll("", columnType); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateVectorUDAFVar(String[] tdesc) throws IOException { + String className = tdesc[1]; + String valueType = tdesc[2]; + String varianceFormula = tdesc[3]; + String descriptionName = tdesc[4]; + String descriptionValue = tdesc[5]; + String columnType = getColumnVectorType(valueType); + + File templateFile = new File(joinPath(this.udafTemplateDirectory, tdesc[0] + ".txt")); + + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", valueType); + templateString = templateString.replaceAll("", columnType); + templateString = templateString.replaceAll("", varianceFormula); + templateString = templateString.replaceAll("", descriptionName); + templateString = templateString.replaceAll("", descriptionValue); + writeFile(templateFile.lastModified(), udafOutputDirectory, udafClassesDirectory, + className, templateString); + } + + private void generateFilterStringScalarCompareColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "FilterStringScalar" + operatorName + "StringColumn"; + + // Template expansion logic is the same for both column-scalar and scalar-column cases. + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateStringScalarCompareColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "StringScalar" + operatorName + "StringColumn"; + + // Template expansion logic is the same for both column-scalar and scalar-column cases. + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateFilterStringColumnCompareScalar(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "FilterStringCol" + operatorName + "StringScalar"; + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateStringColumnCompareScalar(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "StringCol" + operatorName + "StringScalar"; + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateFilterStringColumnCompareColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "FilterStringCol" + operatorName + "StringColumn"; + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateStringColumnCompareColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String className = "StringCol" + operatorName + "StringColumn"; + generateStringColumnCompareScalar(tdesc, className); + } + + private void generateStringColumnCompareScalar(String[] tdesc, String className) + throws IOException { + String operatorSymbol = tdesc[2]; + // Read the template into a string; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", operatorSymbol); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateFilterColumnCompareColumn(String[] tdesc) throws IOException { + //The variables are all same as ColumnCompareScalar except that + //this template doesn't need a return type. Pass anything as return type. + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = "Filter" + getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Column"; + generateColumnBinaryOperatorColumn(tdesc, null, className); + } + + private void generateColumnUnaryMinus(String[] tdesc) throws IOException { + String operandType = tdesc[1]; + String inputColumnVectorType = this.getColumnVectorType(operandType); + String outputColumnVectorType = inputColumnVectorType; + String returnType = operandType; + String className = getCamelCaseType(operandType) + "ColUnaryMinus"; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", outputColumnVectorType); + templateString = templateString.replaceAll("", operandType); + templateString = templateString.replaceAll("", returnType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + // template, , , , , , + private void generateColumnUnaryFunc(String[] tdesc) throws IOException { + String classNamePrefix = tdesc[1]; + String operandType = tdesc[3]; + String inputColumnVectorType = this.getColumnVectorType(operandType); + String returnType = tdesc[2]; + String outputColumnVectorType = this.getColumnVectorType(returnType); + String className = classNamePrefix + getCamelCaseType(operandType) + "To" + + getCamelCaseType(returnType); + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + String funcName = tdesc[4]; + String operandCast = tdesc[5]; + String resultCast = tdesc[6]; + String cleanup = tdesc[7]; + // Expand, and write result + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", outputColumnVectorType); + templateString = templateString.replaceAll("", operandType); + templateString = templateString.replaceAll("", returnType); + templateString = templateString.replaceAll("", funcName); + templateString = templateString.replaceAll("", operandCast); + templateString = templateString.replaceAll("", resultCast); + templateString = templateString.replaceAll("", cleanup); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + } + + private void generateColumnArithmeticColumn(String [] tdesc) throws IOException { + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Column"; + String returnType = getArithmeticReturnType(operandType1, operandType2); + generateColumnBinaryOperatorColumn(tdesc, returnType, className); + } + + private void generateFilterColumnCompareScalar(String[] tdesc) throws IOException { + //The variables are all same as ColumnCompareScalar except that + //this template doesn't need a return type. Pass anything as return type. + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = "Filter" + getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Scalar"; + generateColumnBinaryOperatorScalar(tdesc, null, className); + } + + private void generateFilterScalarCompareColumn(String[] tdesc) throws IOException { + //this template doesn't need a return type. Pass anything as return type. + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = "Filter" + getCamelCaseType(operandType1) + + "Scalar" + operatorName + getCamelCaseType(operandType2) + "Column"; + generateScalarBinaryOperatorColumn(tdesc, null, className); + } + + private void generateColumnCompareScalar(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String returnType = "long"; + String className = getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Scalar"; + generateColumnBinaryOperatorScalar(tdesc, returnType, className); + } + + private void generateScalarCompareColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String returnType = "long"; + String className = getCamelCaseType(operandType1) + + "Scalar" + operatorName + getCamelCaseType(operandType2) + "Column"; + generateScalarBinaryOperatorColumn(tdesc, returnType, className); + } + + private void generateColumnBinaryOperatorColumn(String[] tdesc, String returnType, + String className) throws IOException { + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String outputColumnVectorType = this.getColumnVectorType(returnType); + String inputColumnVectorType1 = this.getColumnVectorType(operandType1); + String inputColumnVectorType2 = this.getColumnVectorType(operandType2); + String operatorSymbol = tdesc[4]; + + //Read the template into a string; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType1); + templateString = templateString.replaceAll("", inputColumnVectorType2); + templateString = templateString.replaceAll("", outputColumnVectorType); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", operandType1); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", returnType); + templateString = templateString.replaceAll("", getCamelCaseType(returnType)); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + + if(returnType==null){ + testCodeGen.addColumnColumnFilterTestCases( + className, + inputColumnVectorType1, + inputColumnVectorType2, + operatorSymbol); + }else{ + testCodeGen.addColumnColumnOperationTestCases( + className, + inputColumnVectorType1, + inputColumnVectorType2, + outputColumnVectorType); + } + } + + private void generateColumnBinaryOperatorScalar(String[] tdesc, String returnType, + String className) throws IOException { + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String outputColumnVectorType = this.getColumnVectorType(returnType); + String inputColumnVectorType = this.getColumnVectorType(operandType1); + String operatorSymbol = tdesc[4]; + + //Read the template into a string; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", outputColumnVectorType); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", operandType1); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", returnType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + + if(returnType==null) { + testCodeGen.addColumnScalarFilterTestCases( + true, + className, + inputColumnVectorType, + operandType2, + operatorSymbol); + } else { + testCodeGen.addColumnScalarOperationTestCases( + true, + className, + inputColumnVectorType, + outputColumnVectorType, + operandType2); + } + } + + private void generateScalarBinaryOperatorColumn(String[] tdesc, String returnType, + String className) throws IOException { + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String outputColumnVectorType = this.getColumnVectorType(returnType); + String inputColumnVectorType = this.getColumnVectorType(operandType2); + String operatorSymbol = tdesc[4]; + + //Read the template into a string; + File templateFile = new File(joinPath(this.expressionTemplateDirectory, tdesc[0] + ".txt")); + String templateString = readFile(templateFile); + templateString = templateString.replaceAll("", className); + templateString = templateString.replaceAll("", inputColumnVectorType); + templateString = templateString.replaceAll("", outputColumnVectorType); + templateString = templateString.replaceAll("", operatorSymbol); + templateString = templateString.replaceAll("", operandType1); + templateString = templateString.replaceAll("", operandType2); + templateString = templateString.replaceAll("", returnType); + writeFile(templateFile.lastModified(), expressionOutputDirectory, expressionClassesDirectory, + className, templateString); + + if(returnType==null) { + testCodeGen.addColumnScalarFilterTestCases( + false, + className, + inputColumnVectorType, + operandType1, + operatorSymbol); + } else { + testCodeGen.addColumnScalarOperationTestCases( + false, + className, + inputColumnVectorType, + outputColumnVectorType, + operandType1); + } + } + + //Binary arithmetic operator + private void generateColumnArithmeticScalar(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = getCamelCaseType(operandType1) + + "Col" + operatorName + getCamelCaseType(operandType2) + "Scalar"; + String returnType = getArithmeticReturnType(operandType1, operandType2); + generateColumnBinaryOperatorScalar(tdesc, returnType, className); + } + + private void generateScalarArithmeticColumn(String[] tdesc) throws IOException { + String operatorName = tdesc[1]; + String operandType1 = tdesc[2]; + String operandType2 = tdesc[3]; + String className = getCamelCaseType(operandType1) + + "Scalar" + operatorName + getCamelCaseType(operandType2) + "Column"; + String returnType = getArithmeticReturnType(operandType1, operandType2); + generateScalarBinaryOperatorColumn(tdesc, returnType, className); + } + + static void writeFile(long templateTime, String outputDir, String classesDir, + String className, String str) throws IOException { + File outputFile = new File(outputDir, className + ".java"); + File outputClass = new File(classesDir, className + ".class"); + if (outputFile.lastModified() > templateTime && outputFile.length() == str.length() && + outputClass.lastModified() > templateTime) { + // best effort + return; + } + writeFile(outputFile, str); + } + + static void writeFile(File outputFile, String str) throws IOException { + BufferedWriter w = new BufferedWriter(new FileWriter(outputFile)); + w.write(str); + w.close(); + } + + static String readFile(String templateFile) throws IOException { + return readFile(new File(templateFile)); + } + + static String readFile(File templateFile) throws IOException { + BufferedReader r = new BufferedReader(new FileReader(templateFile)); + String line = r.readLine(); + StringBuilder b = new StringBuilder(); + while (line != null) { + b.append(line); + b.append("\n"); + line = r.readLine(); + } + r.close(); + return b.toString(); + } + + static String getCamelCaseType(String type) { + if (type == null) { + return null; + } + if (type.equals("long")) { + return "Long"; + } else if (type.equals("double")) { + return "Double"; + } else { + return type; + } + } + + private String getArithmeticReturnType(String operandType1, + String operandType2) { + if (operandType1.equals("double") || + operandType2.equals("double")) { + return "double"; + } else { + return "long"; + } + } + + private String getColumnVectorType(String primitiveType) { + if(primitiveType!=null && primitiveType.equals("double")) { + return "DoubleColumnVector"; + } + return "LongColumnVector"; + } + + private String getOutputWritableType(String primitiveType) throws Exception { + if (primitiveType.equals("long")) { + return "LongWritable"; + } else if (primitiveType.equals("double")) { + return "DoubleWritable"; + } + throw new Exception("Unimplemented primitive output writable: " + primitiveType); + } + + private String getOutputObjectInspector(String primitiveType) throws Exception { + if (primitiveType.equals("long")) { + return "PrimitiveObjectInspectorFactory.writableLongObjectInspector"; + } else if (primitiveType.equals("double")) { + return "PrimitiveObjectInspectorFactory.writableDoubleObjectInspector"; + } + throw new Exception("Unimplemented primitive output inspector: " + primitiveType); + } + + public void setTemplateBaseDir(String templateBaseDir) { + this.templateBaseDir = templateBaseDir; + } + + public void setBuildDir(String buildDir) { + this.buildDir = buildDir; + } +} + diff --git a/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnColumn.txt b/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnColumn.txt new file mode 100644 index 0000000..d43e044 --- /dev/null +++ b/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnColumn.txt @@ -0,0 +1,187 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second and third are long columns or long expression results. + */ +public class extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg2Column, arg3Column; + private int outputColumn; + + public (int arg1Column, int arg2Column, int arg3Column, int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Column = arg2Column; + this.arg3Column = arg3Column; + this.outputColumn = outputColumn; + } + + public () { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + arg2ColVector = () batch.cols[arg2Column]; + arg3ColVector = () batch.cols[arg3Column]; + outputColVector = () batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg2ColVector.noNulls && arg3ColVector.noNulls; + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + [] vector2 = arg2ColVector.vector; + [] vector3 = arg3ColVector.vector; + [] outputVector = outputColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + /* All the code paths below propagate nulls even if neither arg2 nor arg3 + * have nulls. This is to reduce the number of code paths and shorten the + * code, at the expense of maybe doing unnecessary work if neither input + * has nulls. This could be improved in the future by expanding the number + * of code paths. + */ + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + arg2ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } else { + arg3ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } + return; + } + + // extend any repeating values and noNulls indicator in the inputs + arg2ColVector.flatten(batch.selectedInUse, sel, n); + arg3ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (vector1[i] == 1 ? vector2[i] : vector3[i]); + outputIsNull[i] = (vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (vector1[i] == 1 ? vector2[i] : vector3[i]); + outputIsNull[i] = (vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + vector2[i] : vector3[i]); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + vector2[i] : vector3[i]); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } + } + + // restore repeating and no nulls indicators + arg2ColVector.unFlatten(); + arg3ColVector.unFlatten(); + + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return ""; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public int getArg2Column() { + return arg2Column; + } + + public void setArg2Column(int colNum) { + this.arg2Column = colNum; + } + + public int getArg3Column() { + return arg3Column; + } + + public void setArg3Column(int colNum) { + this.arg3Column = colNum; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType(""), + VectorExpressionDescriptor.ArgumentType.getType("")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + } +} diff --git a/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnScalar.txt b/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnScalar.txt new file mode 100644 index 0000000..5515c5e --- /dev/null +++ b/ql/src/gen/vectorization/ExpressionTemplates/IfExprColumnScalar.txt @@ -0,0 +1,177 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.expressions.NullUtil; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second is a column or non-constant expression result. + * The third is a constant value. + */ +public class extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg2Column; + private arg3Scalar; + private int outputColumn; + + public (int arg1Column, int arg2Column, arg3Scalar, + int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Column = arg2Column; + this.arg3Scalar = arg3Scalar; + this.outputColumn = outputColumn; + } + + public () { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + arg2ColVector = () batch.cols[arg2Column]; + outputColVector = () batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg2ColVector.noNulls; // nulls can only come from arg2 + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + [] vector2 = arg2ColVector.vector; + [] outputVector = outputColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + arg2ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } else { + outputColVector.fill(arg3Scalar); + } + return; + } + + // Extend any repeating values and noNulls indicator in the inputs to + // reduce the number of code paths needed below. + arg2ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (vector1[i] == 1 ? vector2[i] : arg3Scalar); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (vector1[i] == 1 ? vector2[i] : arg3Scalar); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + vector2[i] : arg3Scalar); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : false); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + vector2[i] : arg3Scalar); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : false); + } + } + } + + // restore repeating and no nulls indicators + arg2ColVector.unFlatten(); + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return ""; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public int getArg2Column() { + return arg2Column; + } + + public void setArg2Column(int colNum) { + this.arg2Column = colNum; + } + + public getArg3Scalar() { + return arg3Scalar; + } + + public void setArg3Scalar( value) { + this.arg3Scalar = value; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType(""), + VectorExpressionDescriptor.ArgumentType.getType("")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + } +} diff --git a/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarColumn.txt b/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarColumn.txt new file mode 100644 index 0000000..4dae9a2 --- /dev/null +++ b/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarColumn.txt @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.expressions.NullUtil; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second is a column or non-constant expression result. + * The third is a constant value. + */ +public class extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg3Column; + private arg2Scalar; + private int outputColumn; + + public (int arg1Column, arg2Scalar, int arg3Column, + int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Scalar = arg2Scalar; + this.arg3Column = arg3Column; + this.outputColumn = outputColumn; + } + + public () { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + arg3ColVector = () batch.cols[arg3Column]; + outputColVector = () batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg3ColVector.noNulls; // nulls can only come from arg3 column vector + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + [] vector3 = arg3ColVector.vector; + [] outputVector = outputColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + outputColVector.fill(arg2Scalar); + } else { + arg3ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } + return; + } + + // Extend any repeating values and noNulls indicator in the inputs to + // reduce the number of code paths needed below. + // This could be optimized in the future by having separate paths + // for when arg3ColVector is repeating or has no nulls. + arg3ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (vector1[i] == 1 ? arg2Scalar : vector3[i]); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (vector1[i] == 1 ? arg2Scalar : vector3[i]); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2Scalar : vector3[i]); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + false : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2Scalar : vector3[i]); + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + false : arg3ColVector.isNull[i]); + } + } + } + + // restore repeating and no nulls indicators + arg3ColVector.unFlatten(); + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return ""; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public int getArg3Column() { + return arg3Column; + } + + public void setArg3Column(int colNum) { + this.arg3Column = colNum; + } + + public getArg2Scalar() { + return arg2Scalar; + } + + public void setArg2Scalar( value) { + this.arg2Scalar = value; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType(""), + VectorExpressionDescriptor.ArgumentType.getType("")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR, + VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + } +} diff --git a/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarScalar.txt b/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarScalar.txt new file mode 100644 index 0000000..692a1d1 --- /dev/null +++ b/ql/src/gen/vectorization/ExpressionTemplates/IfExprScalarScalar.txt @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions.gen; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; +import java.util.Arrays; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input expressions. + * The first is always a boolean (LongColumnVector). + * The second is a constant value. + * The third is a constant value. + */ +public class extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column; + private arg2Scalar; + private arg3Scalar; + private int outputColumn; + + public (int arg1Column, arg2Scalar, arg3Scalar, + int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Scalar = arg2Scalar; + this.arg3Scalar = arg3Scalar; + this.outputColumn = outputColumn; + } + + public () { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + outputColVector = () batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = false; // output is a scalar which we know is non null + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + [] outputVector = outputColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + outputColVector.fill(arg2Scalar); + } else { + outputColVector.fill(arg3Scalar); + } + } else if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (vector1[i] == 1 ? arg2Scalar : arg3Scalar); + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (vector1[i] == 1 ? arg2Scalar : arg3Scalar); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2Scalar : arg3Scalar); + outputIsNull[i] = false; + } + } else { + for(int i = 0; i != n; i++) { + outputVector[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2Scalar : arg3Scalar); + } + Arrays.fill(outputIsNull, 0, n, false); + } + } + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return ""; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public getArg2Scalar() { + return arg2Scalar; + } + + public void setArg2Scalar( value) { + this.arg2Scalar = value; + } + + public getArg3Scalar() { + return arg3Scalar; + } + + public void setArg3Scalar( value) { + this.arg3Scalar = value; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType(""), + VectorExpressionDescriptor.ArgumentType.getType("")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/BytesColumnVector.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/BytesColumnVector.java index e1d4543..a10feb7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/BytesColumnVector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/BytesColumnVector.java @@ -18,6 +18,8 @@ package org.apache.hadoop.hive.ql.exec.vector; +import java.util.Arrays; + import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; @@ -219,4 +221,93 @@ public Writable getWritableObject(int index) { } return result; } + + /** Copy the current object contents into the output. Only copy selected entries, + * as indicated by selectedInUse and the sel array. + */ + public void copySelected( + boolean selectedInUse, int[] sel, int size, BytesColumnVector output) { + + // Output has nulls if and only if input has nulls. + output.noNulls = noNulls; + output.isRepeating = false; + + // Handle repeating case + if (isRepeating) { + output.setVal(0, vector[0], start[0], length[0]); + output.isNull[0] = isNull[0]; + output.isRepeating = true; + return; + } + + // Handle normal case + + // Copy data values over + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.setVal(i, vector[i], start[i], length[i]); + } + } + else { + for (int i = 0; i < size; i++) { + output.setVal(i, vector[i], start[i], length[i]); + } + } + + // Copy nulls over if needed + if (!noNulls) { + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.isNull[i] = isNull[i]; + } + } + else { + System.arraycopy(isNull, 0, output.isNull, 0, size); + } + } + } + + /** Simplify vector by brute-force flattening noNulls and isRepeating + * This can be used to reduce combinatorial explosion of code paths in VectorExpressions + * with many arguments, at the expense of loss of some performance. + */ + public void flatten(boolean selectedInUse, int[] sel, int size) { + flattenPush(); + if (isRepeating) { + isRepeating = false; + + // setRef is used below and this is safe, because the reference + // is to data owned by this column vector. If this column vector + // gets re-used, the whole thing is re-used together so there + // is no danger of a dangling reference. + + // Only copy data values if entry is not null. The string value + // at position 0 is undefined if the position 0 value is null. + if (noNulls || (!noNulls && !isNull[0])) { + + // loops start at position 1 because position 0 is already set + if (selectedInUse) { + for (int j = 1; j < size; j++) { + int i = sel[j]; + this.setRef(i, vector[0], start[0], length[0]); + } + } else { + for (int i = 1; i < size; i++) { + this.setRef(i, vector[0], start[0], length[0]); + } + } + } + flattenRepeatingNulls(selectedInUse, sel, size); + } + flattenNoNulls(selectedInUse, sel, size); + } + + // Fill the all the vector entries with provided value + public void fill(byte[] value) { + noNulls = true; + isRepeating = true; + setRef(0, value, 0, value.length); + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ColumnVector.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ColumnVector.java index 48b87ea..9d8b2de 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ColumnVector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/ColumnVector.java @@ -49,6 +49,11 @@ * If so, vector[0] holds the repeating value. */ public boolean isRepeating; + + // Variables to hold state from before flattening so it can be easily restored. + private boolean preFlattenIsRepeating; + private boolean preFlattenNoNulls; + public abstract Writable getWritableObject(int index); /** @@ -76,5 +81,66 @@ public void reset() { noNulls = true; isRepeating = false; } + + abstract public void flatten(boolean selectedInUse, int[] sel, int size); + + // Simplify vector by brute-force flattening noNulls if isRepeating + // This can be used to reduce combinatorial explosion of code paths in VectorExpressions + // with many arguments. + public void flattenRepeatingNulls(boolean selectedInUse, int[] sel, int size) { + + boolean nullFillValue; + + if (noNulls) { + nullFillValue = false; + } else { + nullFillValue = isNull[0]; + } + + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + isNull[i] = nullFillValue; + } + } else { + Arrays.fill(isNull, 0, size, nullFillValue); + } + + // all nulls are now explicit + noNulls = false; + } + + public void flattenNoNulls(boolean selectedInUse, int[] sel, int size) { + if (noNulls) { + noNulls = false; + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + isNull[i] = false; + } + } else { + Arrays.fill(isNull, 0, size, false); + } + } + } + + /** + * Restore the state of isRepeating and noNulls to what it was + * before flattening. This must only be called just after flattening + * and then evaluating a VectorExpression on the column vector. + * It is an optimization that allows other operations on the same + * column to continue to benefit from the isRepeating and noNulls + * indicators. + */ + public void unFlatten() { + isRepeating = preFlattenIsRepeating; + noNulls = preFlattenNoNulls; + } + + // Record repeating and no nulls state to be restored later. + protected void flattenPush() { + preFlattenIsRepeating = isRepeating; + preFlattenNoNulls = noNulls; + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/DoubleColumnVector.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/DoubleColumnVector.java index d3bb28e..cb23129 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/DoubleColumnVector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/DoubleColumnVector.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.ql.exec.vector; +import java.util.Arrays; + import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Writable; @@ -67,4 +69,76 @@ public Writable getWritableObject(int index) { return writableObj; } } + + // Copy the current object contents into the output. Only copy selected entries, + // as indicated by selectedInUse and the sel array. + public void copySelected( + boolean selectedInUse, int[] sel, int size, DoubleColumnVector output) { + + // Output has nulls if and only if input has nulls. + output.noNulls = noNulls; + output.isRepeating = false; + + // Handle repeating case + if (isRepeating) { + output.vector[0] = vector[0]; + output.isNull[0] = isNull[0]; + output.isRepeating = true; + return; + } + + // Handle normal case + + // Copy data values over + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.vector[i] = vector[i]; + } + } + else { + System.arraycopy(vector, 0, output.vector, 0, size); + } + + // Copy nulls over if needed + if (!noNulls) { + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.isNull[i] = isNull[i]; + } + } + else { + System.arraycopy(isNull, 0, output.isNull, 0, size); + } + } + } + + // Fill the column vector with the provided value + public void fill(double value) { + noNulls = true; + isRepeating = true; + vector[0] = value; + } + + // Simplify vector by brute-force flattening noNulls and isRepeating + // This can be used to reduce combinatorial explosion of code paths in VectorExpressions + // with many arguments. + public void flatten(boolean selectedInUse, int[] sel, int size) { + flattenPush(); + if (isRepeating) { + isRepeating = false; + double repeatVal = vector[0]; + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + vector[i] = repeatVal; + } + } else { + Arrays.fill(vector, 0, size, repeatVal); + } + flattenRepeatingNulls(selectedInUse, sel, size); + } + flattenNoNulls(selectedInUse, sel, size); + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/LongColumnVector.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/LongColumnVector.java index f65e8fa..aa05b19 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/LongColumnVector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/LongColumnVector.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.ql.exec.vector; +import java.util.Arrays; + import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Writable; @@ -67,4 +69,120 @@ public Writable getWritableObject(int index) { return writableObj; } } + + // Copy the current object contents into the output. Only copy selected entries, + // as indicated by selectedInUse and the sel array. + public void copySelected( + boolean selectedInUse, int[] sel, int size, LongColumnVector output) { + + // Output has nulls if and only if input has nulls. + output.noNulls = noNulls; + output.isRepeating = false; + + // Handle repeating case + if (isRepeating) { + output.vector[0] = vector[0]; + output.isNull[0] = isNull[0]; + output.isRepeating = true; + return; + } + + // Handle normal case + + // Copy data values over + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.vector[i] = vector[i]; + } + } + else { + System.arraycopy(vector, 0, output.vector, 0, size); + } + + // Copy nulls over if needed + if (!noNulls) { + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.isNull[i] = isNull[i]; + } + } + else { + System.arraycopy(isNull, 0, output.isNull, 0, size); + } + } + } + + // Copy the current object contents into the output. Only copy selected entries, + // as indicated by selectedInUse and the sel array. + public void copySelected( + boolean selectedInUse, int[] sel, int size, DoubleColumnVector output) { + + // Output has nulls if and only if input has nulls. + output.noNulls = noNulls; + output.isRepeating = false; + + // Handle repeating case + if (isRepeating) { + output.vector[0] = vector[0]; // automatic conversion to double is done here + output.isNull[0] = isNull[0]; + output.isRepeating = true; + return; + } + + // Handle normal case + + // Copy data values over + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.vector[i] = vector[i]; + } + } + else { + System.arraycopy(vector, 0, output.vector, 0, size); + } + + // Copy nulls over if needed + if (!noNulls) { + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + output.isNull[i] = isNull[i]; + } + } + else { + System.arraycopy(isNull, 0, output.isNull, 0, size); + } + } + } + + // Fill the column vector with the provided value + public void fill(long value) { + noNulls = true; + isRepeating = true; + vector[0] = value; + } + + // Simplify vector by brute-force flattening noNulls and isRepeating + // This can be used to reduce combinatorial explosion of code paths in VectorExpressions + // with many arguments. + public void flatten(boolean selectedInUse, int[] sel, int size) { + flattenPush(); + if (isRepeating) { + isRepeating = false; + long repeatVal = vector[0]; + if (selectedInUse) { + for (int j = 0; j < size; j++) { + int i = sel[j]; + vector[i] = repeatVal; + } + } else { + Arrays.fill(vector, 0, size, repeatVal); + } + flattenRepeatingNulls(selectedInUse, sel, size); + } + flattenNoNulls(selectedInUse, sel, size); + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringColumn.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringColumn.java new file mode 100644 index 0000000..c321ad0 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringColumn.java @@ -0,0 +1,205 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second and third are string columns or string expression results. + */ +public class IfExprStringColumnStringColumn extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg2Column, arg3Column; + private int outputColumn; + + public IfExprStringColumnStringColumn(int arg1Column, int arg2Column, int arg3Column, int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Column = arg2Column; + this.arg3Column = arg3Column; + this.outputColumn = outputColumn; + } + + public IfExprStringColumnStringColumn() { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + BytesColumnVector arg2ColVector = (BytesColumnVector) batch.cols[arg2Column]; + BytesColumnVector arg3ColVector = (BytesColumnVector) batch.cols[arg3Column]; + BytesColumnVector outputColVector = (BytesColumnVector) batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg2ColVector.noNulls && arg3ColVector.noNulls; + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + outputColVector.initBuffer(); + + /* All the code paths below propagate nulls even if neither arg2 nor arg3 + * have nulls. This is to reduce the number of code paths and shorten the + * code, at the expense of maybe doing unnecessary work if neither input + * has nulls. This could be improved in the future by expanding the number + * of code paths. + */ + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + arg2ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } else { + arg3ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } + return; + } + + // extend any repeating values and noNulls indicator in the inputs + arg2ColVector.flatten(batch.selectedInUse, sel, n); + arg3ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + if (vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : arg3ColVector.isNull[i]); + } + } + } + arg2ColVector.unFlatten(); + arg3ColVector.unFlatten(); + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return "String"; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public int getArg2Column() { + return arg2Column; + } + + public void setArg2Column(int colNum) { + this.arg2Column = colNum; + } + + public int getArg3Column() { + return arg3Column; + } + + public void setArg3Column(int colNum) { + this.arg3Column = colNum; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType("string"), + VectorExpressionDescriptor.ArgumentType.getType("string")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringScalar.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringScalar.java new file mode 100644 index 0000000..627319a --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringColumnStringScalar.java @@ -0,0 +1,200 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input expressions. + * The first is always a boolean (LongColumnVector). + * The second is a string column expression. + * The third is a string scalar. + */ +public class IfExprStringColumnStringScalar extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg2Column; + private byte[] arg3Scalar; + private int outputColumn; + + public IfExprStringColumnStringScalar(int arg1Column, int arg2Column, byte[] arg3Scalar, int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Column = arg2Column; + this.arg3Scalar = arg3Scalar; + this.outputColumn = outputColumn; + } + + public IfExprStringColumnStringScalar() { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + BytesColumnVector arg2ColVector = (BytesColumnVector) batch.cols[arg2Column]; + BytesColumnVector outputColVector = (BytesColumnVector) batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg2ColVector.noNulls; + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + outputColVector.initBuffer(); + + /* All the code paths below propagate nulls even if arg2 has no nulls. + * This is to reduce the number of code paths and shorten the + * code, at the expense of maybe doing unnecessary work if neither input + * has nulls. This could be improved in the future by expanding the number + * of code paths. + */ + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + arg2ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } else { + outputColVector.fill(arg3Scalar); + } + return; + } + + // extend any repeating values and noNulls indicator in the inputs + arg2ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg3Scalar.length); + } + outputIsNull[i] = (vector1[i] == 1 ? arg2ColVector.isNull[i] : false); + } + } else { + for(int i = 0; i != n; i++) { + if (vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg3Scalar.length); + } + outputIsNull[i] = (vector1[i] == 1 ? arg2ColVector.isNull[i] : false); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg3Scalar.length); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : false); + } + } else { + for(int i = 0; i != n; i++) { + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setVal( + i, arg2ColVector.vector[i], arg2ColVector.start[i], arg2ColVector.length[i]); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg3Scalar.length); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + arg2ColVector.isNull[i] : false); + } + } + } + + // restore state of repeating and non nulls indicators + arg2ColVector.unFlatten(); + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return "String"; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public int getArg2Column() { + return arg2Column; + } + + public void setArg2Column(int colNum) { + this.arg2Column = colNum; + } + + public byte[] getArg3Scalar() { + return arg3Scalar; + } + + public void setArg3Scalar(byte[] value) { + this.arg3Scalar = value; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType("string"), + VectorExpressionDescriptor.ArgumentType.getType("string")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringColumn.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringColumn.java new file mode 100644 index 0000000..37636e2 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringColumn.java @@ -0,0 +1,200 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second is a string scalar. + * The third is a string column or non-constant expression result. + */ +public class IfExprStringScalarStringColumn extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column, arg3Column; + private byte[] arg2Scalar; + private int outputColumn; + + public IfExprStringScalarStringColumn(int arg1Column, byte[] arg2Scalar, int arg3Column, int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Scalar = arg2Scalar; + this.arg3Column = arg3Column; + this.outputColumn = outputColumn; + } + + public IfExprStringScalarStringColumn() { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + BytesColumnVector arg3ColVector = (BytesColumnVector) batch.cols[arg3Column]; + BytesColumnVector outputColVector = (BytesColumnVector) batch.cols[outputColumn]; + int[] sel = batch.selected; + boolean[] outputIsNull = outputColVector.isNull; + outputColVector.noNulls = arg3ColVector.noNulls; + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + outputColVector.initBuffer(); + + /* All the code paths below propagate nulls even arg3 has no + * nulls. This is to reduce the number of code paths and shorten the + * code, at the expense of maybe doing unnecessary work if neither input + * has nulls. This could be improved in the future by expanding the number + * of code paths. + */ + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + outputColVector.fill(arg2Scalar); + } else { + arg3ColVector.copySelected(batch.selectedInUse, sel, n, outputColVector); + } + return; + } + + // extend any repeating values and noNulls indicator in the input + arg3ColVector.flatten(batch.selectedInUse, sel, n); + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (vector1[i] == 1 ? false : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + if (vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (vector1[i] == 1 ? false : arg3ColVector.isNull[i]); + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + false : arg3ColVector.isNull[i]); + } + } else { + for(int i = 0; i != n; i++) { + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setVal( + i, arg3ColVector.vector[i], arg3ColVector.start[i], arg3ColVector.length[i]); + } + outputIsNull[i] = (!arg1ColVector.isNull[i] && vector1[i] == 1 ? + false : arg3ColVector.isNull[i]); + } + } + } + + // restore state of repeating and non nulls indicators + arg3ColVector.unFlatten(); + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return "String"; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public byte[] getArg2Scalar() { + return arg2Scalar; + } + + public void setArg2Scalar(byte[] value) { + this.arg2Scalar = value; + } + + public int getArg3Column() { + return arg3Column; + } + + public void setArg3Column(int colNum) { + this.arg3Column = colNum; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType("string"), + VectorExpressionDescriptor.ArgumentType.getType("string")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR, + VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringScalar.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringScalar.java new file mode 100644 index 0000000..f6fcfea --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/IfExprStringScalarStringScalar.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; + +/** + * Compute IF(expr1, expr2, expr3) for 3 input column expressions. + * The first is always a boolean (LongColumnVector). + * The second is a string scalar. + * The third is a string scalar. + */ +public class IfExprStringScalarStringScalar extends VectorExpression { + + private static final long serialVersionUID = 1L; + + private int arg1Column; + private byte[] arg2Scalar; + private byte[] arg3Scalar; + private int outputColumn; + + public IfExprStringScalarStringScalar( + int arg1Column, byte[] arg2Scalar, byte[] arg3Scalar, int outputColumn) { + this.arg1Column = arg1Column; + this.arg2Scalar = arg2Scalar; + this.arg3Scalar = arg3Scalar; + this.outputColumn = outputColumn; + } + + public IfExprStringScalarStringScalar() { + } + + @Override + public void evaluate(VectorizedRowBatch batch) { + + if (childExpressions != null) { + super.evaluateChildren(batch); + } + + LongColumnVector arg1ColVector = (LongColumnVector) batch.cols[arg1Column]; + BytesColumnVector outputColVector = (BytesColumnVector) batch.cols[outputColumn]; + int[] sel = batch.selected; + outputColVector.noNulls = true; // output must be a scalar and neither one is null + outputColVector.isRepeating = false; // may override later + int n = batch.size; + long[] vector1 = arg1ColVector.vector; + + // return immediately if batch is empty + if (n == 0) { + return; + } + + outputColVector.initBuffer(); + + if (arg1ColVector.isRepeating) { + if (vector1[0] == 1) { + outputColVector.fill(arg2Scalar); + } else { + outputColVector.fill(arg3Scalar); + } + return; + } + + if (arg1ColVector.noNulls) { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg2Scalar.length); + } + } + } else { + for(int i = 0; i != n; i++) { + if (vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg2Scalar.length); + } + } + } + } else /* there are nulls */ { + if (batch.selectedInUse) { + for(int j = 0; j != n; j++) { + int i = sel[j]; + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg2Scalar.length); + } + } + } else { + for(int i = 0; i != n; i++) { + if (!arg1ColVector.isNull[i] && vector1[i] == 1) { + outputColVector.setRef(i, arg2Scalar, 0, arg2Scalar.length); + } else { + outputColVector.setRef(i, arg3Scalar, 0, arg2Scalar.length); + } + } + } + } + } + + @Override + public int getOutputColumn() { + return outputColumn; + } + + @Override + public String getOutputType() { + return "String"; + } + + public int getArg1Column() { + return arg1Column; + } + + public void setArg1Column(int colNum) { + this.arg1Column = colNum; + } + + public byte[] getArg2Scalar() { + return arg2Scalar; + } + + public void setArg2Scalar(byte[] value) { + this.arg2Scalar = value; + } + + public byte[] getArg3Scalar() { + return arg3Scalar; + } + + public void setArg3Scalar(byte[] value) { + this.arg3Scalar = value; + } + + public void setOutputColumn(int outputColumn) { + this.outputColumn = outputColumn; + } + + @Override + public VectorExpressionDescriptor.Descriptor getDescriptor() { + return (new VectorExpressionDescriptor.Builder()) + .setMode( + VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(3) + .setArgumentTypes( + VectorExpressionDescriptor.ArgumentType.getType("long"), + VectorExpressionDescriptor.ArgumentType.getType("string"), + VectorExpressionDescriptor.ArgumentType.getType("string")) + .setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index 5c7617e..7392a9e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -126,6 +126,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCeil; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFConcat; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFFloor; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; @@ -261,6 +262,9 @@ public Vectorizer() { supportedGenericUDFs.add(UDFToString.class); supportedGenericUDFs.add(GenericUDFTimestamp.class); + // For conditional expressions + supportedGenericUDFs.add(GenericUDFIf.class); + supportedAggregationUdfs.add("min"); supportedAggregationUdfs.add("max"); supportedAggregationUdfs.add("count"); @@ -347,17 +351,17 @@ private void vectorizeMRTask(MapRedTask mrTask) throws SemanticException { topNodes.addAll(mapWork.getAliasToWork().values()); HashMap nodeOutput = new HashMap(); ogw.startWalking(topNodes, nodeOutput); - + Map> columnVectorTypes = vnp.getScratchColumnVectorTypes(); mapWork.setScratchColumnVectorTypes(columnVectorTypes); Map> columnMap = vnp.getScratchColumnMap(); mapWork.setScratchColumnMap(columnMap); - + if (LOG.isDebugEnabled()) { LOG.debug(String.format("vectorTypes: %s", columnVectorTypes.toString())); LOG.debug(String.format("columnMap: %s", columnMap.toString())); } - + return; } } @@ -426,9 +430,9 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Object... nodeOutputs) throws SemanticException { Operator op = (Operator) nd; - - VectorizationContext vContext = null; - + + VectorizationContext vContext = null; + if (op instanceof TableScanOperator) { vContext = getVectorizationContext(op, physicalContext); for (String onefile : mWork.getPathToAliases().keySet()) { @@ -458,9 +462,9 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, --i; } } - + assert vContext != null; - + if (op.getType().equals(OperatorType.REDUCESINK) && op.getParentOperators().get(0).getType().equals(OperatorType.GROUPBY)) { // No need to vectorize diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java.orig b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java.orig new file mode 100644 index 0000000..5c7617e --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java.orig @@ -0,0 +1,771 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.optimizer.physical; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.Stack; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.exec.ColumnInfo; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.FilterOperator; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.OperatorFactory; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.SelectOperator; +import org.apache.hadoop.hive.ql.exec.TableScanOperator; +import org.apache.hadoop.hive.ql.exec.Task; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.exec.mr.MapRedTask; +import org.apache.hadoop.hive.ql.exec.vector.VectorExpressionDescriptor; +import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext; +import org.apache.hadoop.hive.ql.exec.vector.VectorizationContextRegion; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.lib.PreOrderWalker; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; +import org.apache.hadoop.hive.ql.lib.TaskGraphWalker; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.metadata.Table; +import org.apache.hadoop.hive.ql.parse.RowResolver; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.AbstractOperatorDesc; +import org.apache.hadoop.hive.ql.plan.AggregationDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.plan.MapJoinDesc; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.PartitionDesc; +import org.apache.hadoop.hive.ql.plan.TableScanDesc; +import org.apache.hadoop.hive.ql.plan.api.OperatorType; +import org.apache.hadoop.hive.ql.udf.UDFAcos; +import org.apache.hadoop.hive.ql.udf.UDFAsin; +import org.apache.hadoop.hive.ql.udf.UDFAtan; +import org.apache.hadoop.hive.ql.udf.UDFBin; +import org.apache.hadoop.hive.ql.udf.UDFConv; +import org.apache.hadoop.hive.ql.udf.UDFCos; +import org.apache.hadoop.hive.ql.udf.UDFDayOfMonth; +import org.apache.hadoop.hive.ql.udf.UDFDegrees; +import org.apache.hadoop.hive.ql.udf.UDFExp; +import org.apache.hadoop.hive.ql.udf.UDFHex; +import org.apache.hadoop.hive.ql.udf.UDFHour; +import org.apache.hadoop.hive.ql.udf.UDFLTrim; +import org.apache.hadoop.hive.ql.udf.UDFLength; +import org.apache.hadoop.hive.ql.udf.UDFLike; +import org.apache.hadoop.hive.ql.udf.UDFLn; +import org.apache.hadoop.hive.ql.udf.UDFLog; +import org.apache.hadoop.hive.ql.udf.UDFLog10; +import org.apache.hadoop.hive.ql.udf.UDFLog2; +import org.apache.hadoop.hive.ql.udf.UDFMinute; +import org.apache.hadoop.hive.ql.udf.UDFMonth; +import org.apache.hadoop.hive.ql.udf.UDFRTrim; +import org.apache.hadoop.hive.ql.udf.UDFRadians; +import org.apache.hadoop.hive.ql.udf.UDFRand; +import org.apache.hadoop.hive.ql.udf.UDFRegExp; +import org.apache.hadoop.hive.ql.udf.UDFSecond; +import org.apache.hadoop.hive.ql.udf.UDFSign; +import org.apache.hadoop.hive.ql.udf.UDFSin; +import org.apache.hadoop.hive.ql.udf.UDFSqrt; +import org.apache.hadoop.hive.ql.udf.UDFSubstr; +import org.apache.hadoop.hive.ql.udf.UDFTan; +import org.apache.hadoop.hive.ql.udf.UDFToBoolean; +import org.apache.hadoop.hive.ql.udf.UDFToByte; +import org.apache.hadoop.hive.ql.udf.UDFToDouble; +import org.apache.hadoop.hive.ql.udf.UDFToFloat; +import org.apache.hadoop.hive.ql.udf.UDFToInteger; +import org.apache.hadoop.hive.ql.udf.UDFToLong; +import org.apache.hadoop.hive.ql.udf.UDFToShort; +import org.apache.hadoop.hive.ql.udf.UDFToString; +import org.apache.hadoop.hive.ql.udf.UDFTrim; +import org.apache.hadoop.hive.ql.udf.UDFWeekOfYear; +import org.apache.hadoop.hive.ql.udf.UDFYear; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCeil; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFConcat; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFFloor; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPDivide; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMinus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMod; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMultiply; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPositive; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPower; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPosMod; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper; + +public class Vectorizer implements PhysicalPlanResolver { + + protected static transient final Log LOG = LogFactory.getLog(Vectorizer.class); + + Set supportedDataTypes = new HashSet(); + List> vectorizableTasks = + new ArrayList>(); + Set> supportedGenericUDFs = new HashSet>(); + + Set supportedAggregationUdfs = new HashSet(); + + private PhysicalContext physicalContext = null;; + + public Vectorizer() { + supportedDataTypes.add("int"); + supportedDataTypes.add("smallint"); + supportedDataTypes.add("tinyint"); + supportedDataTypes.add("bigint"); + supportedDataTypes.add("integer"); + supportedDataTypes.add("long"); + supportedDataTypes.add("short"); + supportedDataTypes.add("timestamp"); + supportedDataTypes.add("boolean"); + supportedDataTypes.add("string"); + supportedDataTypes.add("byte"); + supportedDataTypes.add("float"); + supportedDataTypes.add("double"); + + supportedGenericUDFs.add(GenericUDFOPPlus.class); + supportedGenericUDFs.add(GenericUDFOPMinus.class); + supportedGenericUDFs.add(GenericUDFOPMultiply.class); + supportedGenericUDFs.add(GenericUDFOPDivide.class); + supportedGenericUDFs.add(GenericUDFOPMod.class); + supportedGenericUDFs.add(GenericUDFOPNegative.class); + supportedGenericUDFs.add(GenericUDFOPPositive.class); + + supportedGenericUDFs.add(GenericUDFOPEqualOrLessThan.class); + supportedGenericUDFs.add(GenericUDFOPEqualOrGreaterThan.class); + supportedGenericUDFs.add(GenericUDFOPGreaterThan.class); + supportedGenericUDFs.add(GenericUDFOPLessThan.class); + supportedGenericUDFs.add(GenericUDFOPNot.class); + supportedGenericUDFs.add(GenericUDFOPNotEqual.class); + supportedGenericUDFs.add(GenericUDFOPNotNull.class); + supportedGenericUDFs.add(GenericUDFOPNull.class); + supportedGenericUDFs.add(GenericUDFOPOr.class); + supportedGenericUDFs.add(GenericUDFOPAnd.class); + supportedGenericUDFs.add(GenericUDFOPEqual.class); + supportedGenericUDFs.add(UDFLength.class); + + supportedGenericUDFs.add(UDFYear.class); + supportedGenericUDFs.add(UDFMonth.class); + supportedGenericUDFs.add(UDFDayOfMonth.class); + supportedGenericUDFs.add(UDFHour.class); + supportedGenericUDFs.add(UDFMinute.class); + supportedGenericUDFs.add(UDFSecond.class); + supportedGenericUDFs.add(UDFWeekOfYear.class); + supportedGenericUDFs.add(GenericUDFToUnixTimeStamp.class); + + supportedGenericUDFs.add(UDFLike.class); + supportedGenericUDFs.add(UDFRegExp.class); + supportedGenericUDFs.add(UDFSubstr.class); + supportedGenericUDFs.add(UDFLTrim.class); + supportedGenericUDFs.add(UDFRTrim.class); + supportedGenericUDFs.add(UDFTrim.class); + + supportedGenericUDFs.add(UDFSin.class); + supportedGenericUDFs.add(UDFCos.class); + supportedGenericUDFs.add(UDFTan.class); + supportedGenericUDFs.add(UDFAsin.class); + supportedGenericUDFs.add(UDFAcos.class); + supportedGenericUDFs.add(UDFAtan.class); + supportedGenericUDFs.add(UDFDegrees.class); + supportedGenericUDFs.add(UDFRadians.class); + supportedGenericUDFs.add(GenericUDFFloor.class); + supportedGenericUDFs.add(GenericUDFCeil.class); + supportedGenericUDFs.add(UDFExp.class); + supportedGenericUDFs.add(UDFLn.class); + supportedGenericUDFs.add(UDFLog2.class); + supportedGenericUDFs.add(UDFLog10.class); + supportedGenericUDFs.add(UDFLog.class); + supportedGenericUDFs.add(GenericUDFPower.class); + supportedGenericUDFs.add(GenericUDFRound.class); + supportedGenericUDFs.add(GenericUDFPosMod.class); + supportedGenericUDFs.add(UDFSqrt.class); + supportedGenericUDFs.add(UDFSign.class); + supportedGenericUDFs.add(UDFRand.class); + supportedGenericUDFs.add(UDFBin.class); + supportedGenericUDFs.add(UDFHex.class); + supportedGenericUDFs.add(UDFConv.class); + + supportedGenericUDFs.add(GenericUDFLower.class); + supportedGenericUDFs.add(GenericUDFUpper.class); + supportedGenericUDFs.add(GenericUDFConcat.class); + supportedGenericUDFs.add(GenericUDFAbs.class); + supportedGenericUDFs.add(GenericUDFBetween.class); + supportedGenericUDFs.add(GenericUDFIn.class); + + // For type casts + supportedGenericUDFs.add(UDFToLong.class); + supportedGenericUDFs.add(UDFToInteger.class); + supportedGenericUDFs.add(UDFToShort.class); + supportedGenericUDFs.add(UDFToByte.class); + supportedGenericUDFs.add(UDFToBoolean.class); + supportedGenericUDFs.add(UDFToFloat.class); + supportedGenericUDFs.add(UDFToDouble.class); + supportedGenericUDFs.add(UDFToString.class); + supportedGenericUDFs.add(GenericUDFTimestamp.class); + + supportedAggregationUdfs.add("min"); + supportedAggregationUdfs.add("max"); + supportedAggregationUdfs.add("count"); + supportedAggregationUdfs.add("sum"); + supportedAggregationUdfs.add("avg"); + supportedAggregationUdfs.add("variance"); + supportedAggregationUdfs.add("var_pop"); + supportedAggregationUdfs.add("var_samp"); + supportedAggregationUdfs.add("std"); + supportedAggregationUdfs.add("stddev"); + supportedAggregationUdfs.add("stddev_pop"); + supportedAggregationUdfs.add("stddev_samp"); + } + + class VectorizationDispatcher implements Dispatcher { + + public VectorizationDispatcher(PhysicalContext pctx) { + } + + @Override + public Object dispatch(Node nd, Stack stack, Object... nodeOutputs) + throws SemanticException { + Task currTask = (Task) nd; + if (currTask instanceof MapRedTask) { + boolean ret = validateMRTask((MapRedTask) currTask); + if (ret) { + vectorizeMRTask((MapRedTask) currTask); + } + } + return null; + } + + private boolean validateMRTask(MapRedTask mrTask) throws SemanticException { + MapWork mapWork = mrTask.getWork().getMapWork(); + + // Validate the input format + for (String path : mapWork.getPathToPartitionInfo().keySet()) { + PartitionDesc pd = mapWork.getPathToPartitionInfo().get(path); + List> interfaceList = + Arrays.asList(pd.getInputFileFormatClass().getInterfaces()); + if (!interfaceList.contains(VectorizedInputFormatInterface.class)) { + LOG.info("Input format: " + pd.getInputFileFormatClassName() + + ", doesn't provide vectorized input"); + return false; + } + } + Map opRules = new LinkedHashMap(); + ValidationNodeProcessor vnp = new ValidationNodeProcessor(); + opRules.put(new RuleRegExp("R1", TableScanOperator.getOperatorName() + ".*" + + FileSinkOperator.getOperatorName()), vnp); + opRules.put(new RuleRegExp("R2", TableScanOperator.getOperatorName() + ".*" + + ReduceSinkOperator.getOperatorName()), vnp); + Dispatcher disp = new DefaultRuleDispatcher(vnp, opRules, null); + GraphWalker ogw = new DefaultGraphWalker(disp); + // iterator the mapper operator tree + ArrayList topNodes = new ArrayList(); + topNodes.addAll(mapWork.getAliasToWork().values()); + HashMap nodeOutput = new HashMap(); + ogw.startWalking(topNodes, nodeOutput); + for (Node n : nodeOutput.keySet()) { + if (nodeOutput.get(n) != null) { + if (!((Boolean)nodeOutput.get(n)).booleanValue()) { + return false; + } + } + } + return true; + } + + private void vectorizeMRTask(MapRedTask mrTask) throws SemanticException { + LOG.info("Vectorizing task..."); + MapWork mapWork = mrTask.getWork().getMapWork(); + mapWork.setVectorMode(true); + Map opRules = new LinkedHashMap(); + VectorizationNodeProcessor vnp = new VectorizationNodeProcessor(mrTask); + opRules.put(new RuleRegExp("R1", TableScanOperator.getOperatorName() + ".*" + + ReduceSinkOperator.getOperatorName()), vnp); + opRules.put(new RuleRegExp("R2", TableScanOperator.getOperatorName() + ".*" + + FileSinkOperator.getOperatorName()), vnp); + Dispatcher disp = new DefaultRuleDispatcher(vnp, opRules, null); + GraphWalker ogw = new PreOrderWalker(disp); + // iterator the mapper operator tree + ArrayList topNodes = new ArrayList(); + topNodes.addAll(mapWork.getAliasToWork().values()); + HashMap nodeOutput = new HashMap(); + ogw.startWalking(topNodes, nodeOutput); + + Map> columnVectorTypes = vnp.getScratchColumnVectorTypes(); + mapWork.setScratchColumnVectorTypes(columnVectorTypes); + Map> columnMap = vnp.getScratchColumnMap(); + mapWork.setScratchColumnMap(columnMap); + + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("vectorTypes: %s", columnVectorTypes.toString())); + LOG.debug(String.format("columnMap: %s", columnMap.toString())); + } + + return; + } + } + + class ValidationNodeProcessor implements NodeProcessor { + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + for (Node n : stack) { + Operator op = (Operator) n; + if (op.getType().equals(OperatorType.REDUCESINK) && + op.getParentOperators().get(0).getType().equals(OperatorType.GROUPBY)) { + return new Boolean(true); + } + boolean ret = validateOperator(op); + if (!ret) { + LOG.info("Operator: " + op.getName() + " could not be vectorized."); + return new Boolean(false); + } + } + return new Boolean(true); + } + } + + class VectorizationNodeProcessor implements NodeProcessor { + + private final MapWork mWork; + private final Map vectorizationContexts = + new HashMap(); + + private final Map, VectorizationContext> vContextsByTSOp = + new HashMap, VectorizationContext>(); + + private final Set> opsDone = + new HashSet>(); + + public VectorizationNodeProcessor(MapRedTask mrTask) { + this.mWork = mrTask.getWork().getMapWork(); + } + + public Map> getScratchColumnVectorTypes() { + Map> scratchColumnVectorTypes = + new HashMap>(); + for (String onefile : vectorizationContexts.keySet()) { + VectorizationContext vc = vectorizationContexts.get(onefile); + Map cmap = vc.getOutputColumnTypeMap(); + scratchColumnVectorTypes.put(onefile, cmap); + } + return scratchColumnVectorTypes; + } + + public Map> getScratchColumnMap() { + Map> scratchColumnMap = + new HashMap>(); + for(String oneFile: vectorizationContexts.keySet()) { + VectorizationContext vc = vectorizationContexts.get(oneFile); + Map cmap = vc.getColumnMap(); + scratchColumnMap.put(oneFile, cmap); + } + return scratchColumnMap; + } + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + + Operator op = (Operator) nd; + + VectorizationContext vContext = null; + + if (op instanceof TableScanOperator) { + vContext = getVectorizationContext(op, physicalContext); + for (String onefile : mWork.getPathToAliases().keySet()) { + List aliases = mWork.getPathToAliases().get(onefile); + for (String alias : aliases) { + Operator opRoot = mWork.getAliasToWork().get(alias); + if (op == opRoot) { + // The same vectorization context is copied multiple times into + // the MapWork scratch columnMap + // Each partition gets a copy + // + vContext.setFileKey(onefile); + vectorizationContexts.put(onefile, vContext); + break; + } + } + } + vContextsByTSOp.put(op, vContext); + } else { + assert stack.size() > 1; + // Walk down the stack of operators until we found one willing to give us a context. + // At the bottom will be the TS operator, guaranteed to have a context + int i= stack.size()-2; + while (vContext == null) { + Operator opParent = (Operator) stack.get(i); + vContext = vContextsByTSOp.get(opParent); + --i; + } + } + + assert vContext != null; + + if (op.getType().equals(OperatorType.REDUCESINK) && + op.getParentOperators().get(0).getType().equals(OperatorType.GROUPBY)) { + // No need to vectorize + if (!opsDone.contains(op)) { + opsDone.add(op); + } + } else { + try { + if (!opsDone.contains(op)) { + Operator vectorOp = + vectorizeOperator(op, vContext); + opsDone.add(op); + if (vectorOp != op) { + opsDone.add(vectorOp); + } + if (vectorOp instanceof VectorizationContextRegion) { + VectorizationContextRegion vcRegion = (VectorizationContextRegion) vectorOp; + VectorizationContext vOutContext = vcRegion.getOuputVectorizationContext(); + vContextsByTSOp.put(op, vOutContext); + vectorizationContexts.put(vOutContext.getFileKey(), vOutContext); + } + } + } catch (HiveException e) { + throw new SemanticException(e); + } + } + return null; + } + } + + private static class ValidatorVectorizationContext extends VectorizationContext { + private ValidatorVectorizationContext() { + super(null, -1); + } + + @Override + protected int getInputColumnIndex(String name) { + return 0; + } + + @Override + protected int getInputColumnIndex(ExprNodeColumnDesc colExpr) { + return 0; + } + } + + @Override + public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException { + this.physicalContext = pctx; + boolean vectorPath = HiveConf.getBoolVar(pctx.getConf(), + HiveConf.ConfVars.HIVE_VECTORIZATION_ENABLED); + if (!vectorPath) { + LOG.info("Vectorization is disabled"); + return pctx; + } + // create dispatcher and graph walker + Dispatcher disp = new VectorizationDispatcher(pctx); + TaskGraphWalker ogw = new TaskGraphWalker(disp); + + // get all the tasks nodes from root task + ArrayList topNodes = new ArrayList(); + topNodes.addAll(pctx.getRootTasks()); + + // begin to walk through the task tree. + ogw.startWalking(topNodes, null); + return pctx; + } + + boolean validateOperator(Operator op) { + boolean ret = false; + switch (op.getType()) { + case MAPJOIN: + if (op instanceof MapJoinOperator) { + ret = validateMapJoinOperator((MapJoinOperator) op); + } + break; + case GROUPBY: + ret = validateGroupByOperator((GroupByOperator) op); + break; + case FILTER: + ret = validateFilterOperator((FilterOperator) op); + break; + case SELECT: + ret = validateSelectOperator((SelectOperator) op); + break; + case REDUCESINK: + ret = validateReduceSinkOperator((ReduceSinkOperator) op); + break; + case TABLESCAN: + ret = validateTableScanOperator((TableScanOperator) op); + break; + case FILESINK: + case LIMIT: + ret = true; + break; + default: + ret = false; + break; + } + return ret; + } + + private boolean validateTableScanOperator(TableScanOperator op) { + TableScanDesc desc = op.getConf(); + return !desc.isGatherStats(); + } + + private boolean validateMapJoinOperator(MapJoinOperator op) { + MapJoinDesc desc = op.getConf(); + byte posBigTable = (byte) desc.getPosBigTable(); + List filterExprs = desc.getFilters().get(posBigTable); + List keyExprs = desc.getKeys().get(posBigTable); + List valueExprs = desc.getExprs().get(posBigTable); + return validateExprNodeDesc(filterExprs, VectorExpressionDescriptor.Mode.FILTER) && + validateExprNodeDesc(keyExprs) && + validateExprNodeDesc(valueExprs); + } + + private boolean validateReduceSinkOperator(ReduceSinkOperator op) { + List keyDescs = op.getConf().getKeyCols(); + List partitionDescs = op.getConf().getPartitionCols(); + List valueDesc = op.getConf().getValueCols(); + return validateExprNodeDesc(keyDescs) && validateExprNodeDesc(partitionDescs) && + validateExprNodeDesc(valueDesc); + } + + private boolean validateSelectOperator(SelectOperator op) { + List descList = op.getConf().getColList(); + for (ExprNodeDesc desc : descList) { + boolean ret = validateExprNodeDesc(desc); + if (!ret) { + return false; + } + } + return true; + } + + private boolean validateFilterOperator(FilterOperator op) { + ExprNodeDesc desc = op.getConf().getPredicate(); + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.FILTER); + } + + private boolean validateGroupByOperator(GroupByOperator op) { + boolean ret = validateExprNodeDesc(op.getConf().getKeys()); + if (!ret) { + return false; + } + return validateAggregationDesc(op.getConf().getAggregators()); + } + + private boolean validateExprNodeDesc(List descs) { + return validateExprNodeDesc(descs, VectorExpressionDescriptor.Mode.PROJECTION); + } + + private boolean validateExprNodeDesc(List descs, VectorExpressionDescriptor.Mode mode) { + for (ExprNodeDesc d : descs) { + boolean ret = validateExprNodeDesc(d, mode); + if (!ret) { + return false; + } + } + return true; + } + + private boolean validateAggregationDesc(List descs) { + for (AggregationDesc d : descs) { + boolean ret = validateAggregationDesc(d); + if (!ret) { + return false; + } + } + return true; + } + + private boolean validateExprNodeDescRecursive(ExprNodeDesc desc) { + String typeName = desc.getTypeInfo().getTypeName(); + boolean ret = validateDataType(typeName); + if (!ret) { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot vectorize " + desc.getExprString() + " of type " + typeName); + } + return false; + } + if (desc instanceof ExprNodeGenericFuncDesc) { + ExprNodeGenericFuncDesc d = (ExprNodeGenericFuncDesc) desc; + boolean r = validateGenericUdf(d); + if (!r) { + return false; + } + } + if (desc.getChildren() != null) { + for (ExprNodeDesc d: desc.getChildren()) { + boolean r = validateExprNodeDescRecursive(d); + if (!r) { + return false; + } + } + } + return true; + } + + private boolean validateExprNodeDesc(ExprNodeDesc desc) { + return validateExprNodeDesc(desc, VectorExpressionDescriptor.Mode.PROJECTION); + } + + boolean validateExprNodeDesc(ExprNodeDesc desc, VectorExpressionDescriptor.Mode mode) { + if (!validateExprNodeDescRecursive(desc)) { + return false; + } + try { + VectorizationContext vc = new ValidatorVectorizationContext(); + if (vc.getVectorExpression(desc, mode) == null) { + // TODO: this cannot happen - VectorizationContext throws in such cases. + return false; + } + } catch (HiveException e) { + if (LOG.isDebugEnabled()) { + LOG.debug("Failed to vectorize", e); + } + return false; + } + return true; + } + + private boolean validateGenericUdf(ExprNodeGenericFuncDesc genericUDFExpr) { + if (VectorizationContext.isCustomUDF(genericUDFExpr)) { + return true; + } + GenericUDF genericUDF = genericUDFExpr.getGenericUDF(); + if (genericUDF instanceof GenericUDFBridge) { + Class udf = ((GenericUDFBridge) genericUDF).getUdfClass(); + return supportedGenericUDFs.contains(udf); + } else { + return supportedGenericUDFs.contains(genericUDF.getClass()); + } + } + + private boolean validateAggregationDesc(AggregationDesc aggDesc) { + if (!supportedAggregationUdfs.contains(aggDesc.getGenericUDAFName().toLowerCase())) { + return false; + } + if (aggDesc.getParameters() != null) { + return validateExprNodeDesc(aggDesc.getParameters()); + } + return true; + } + + private boolean validateDataType(String type) { + return supportedDataTypes.contains(type.toLowerCase()); + } + + private VectorizationContext getVectorizationContext(Operator op, + PhysicalContext pctx) { + RowResolver rr = pctx.getParseContext().getOpParseCtx().get(op).getRowResolver(); + + Map cmap = new HashMap(); + int columnCount = 0; + for (ColumnInfo c : rr.getColumnInfos()) { + if (!c.getIsVirtualCol()) { + cmap.put(c.getInternalName(), columnCount++); + } + } + Table tab = pctx.getParseContext().getTopToTable().get(op); + if (tab.getPartitionKeys() != null) { + for (FieldSchema fs : tab.getPartitionKeys()) { + cmap.put(fs.getName(), columnCount++); + } + } + return new VectorizationContext(cmap, columnCount); + } + + Operator vectorizeOperator(Operator op, + VectorizationContext vContext) throws HiveException { + Operator vectorOp = null; + + switch (op.getType()) { + case MAPJOIN: + case GROUPBY: + case FILTER: + case SELECT: + case FILESINK: + case REDUCESINK: + case LIMIT: + vectorOp = OperatorFactory.getVectorOperator(op.getConf(), vContext); + break; + default: + vectorOp = op; + break; + } + + if (vectorOp != op) { + if (op.getParentOperators() != null) { + vectorOp.setParentOperators(op.getParentOperators()); + for (Operator p : op.getParentOperators()) { + p.replaceChild(op, vectorOp); + } + } + if (op.getChildOperators() != null) { + vectorOp.setChildOperators(op.getChildOperators()); + for (Operator c : op.getChildOperators()) { + c.replaceParent(op, vectorOp); + } + } + ((AbstractOperatorDesc) vectorOp.getConf()).setVectorMode(true); + } + return vectorOp; + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFIf.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFIf.java index 0c7e61c..adf55c8 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFIf.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFIf.java @@ -21,11 +21,30 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringScalar; /** * IF(expr1,expr2,expr3)
@@ -33,6 +52,17 @@ * otherwise it returns expr3. IF() returns a numeric or string value, depending * on the context in which it is used. */ +@VectorizedExpressions({ + IfExprLongColumnLongColumn.class, IfExprDoubleColumnDoubleColumn.class, + IfExprLongColumnLongScalar.class, IfExprDoubleColumnDoubleScalar.class, + IfExprLongColumnDoubleScalar.class, IfExprDoubleColumnLongScalar.class, + IfExprLongScalarLongColumn.class, IfExprDoubleScalarDoubleColumn.class, + IfExprLongScalarDoubleColumn.class, IfExprDoubleScalarLongColumn.class, + IfExprLongScalarLongScalar.class, IfExprDoubleScalarDoubleScalar.class, + IfExprLongScalarDoubleScalar.class, IfExprDoubleScalarLongScalar.class, + IfExprStringColumnStringColumn.class, IfExprStringColumnStringScalar.class, + IfExprStringScalarStringColumn.class, IfExprStringScalarStringScalar.class +}) public class GenericUDFIf extends GenericUDF { private transient ObjectInspector[] argumentOIs; private transient GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver; @@ -94,5 +124,4 @@ public String getDisplayString(String[] children) { sb.append(children[2]).append(")"); return sb.toString(); } - } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index 720ca54..60e562b 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -37,6 +37,10 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseDoubleToDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseLongToDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncPowerDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNotNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColumnInList; @@ -56,6 +60,15 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterStringColumnInList; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterLongColumnInList; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterDoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColUnaryMinus; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColLessDoubleScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColumnBetween; @@ -93,6 +106,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; @@ -111,6 +125,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.junit.Test; @@ -1004,4 +1019,177 @@ public void testInFiltersAndExprs() throws HiveException { ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); assertTrue(ve instanceof DoubleColumnInList); } + + /** + * Test that correct VectorExpression classes are chosen for the + * IF (expr1, expr2, expr3) conditional expression for integer, float, + * boolean, timestamp and string input types. expr1 is always an input column expression + * of type long. expr2 and expr3 can be column expressions or constants of other types + * but must have the same type. + */ + @Test + public void testIfConditionalExprs() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Long.class, "col1", "table", false); + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Long.class, "col2", "table", false); + ExprNodeColumnDesc col3Expr = new ExprNodeColumnDesc(Long.class, "col3", "table", false); + + ExprNodeConstantDesc constDesc2 = new ExprNodeConstantDesc(new Integer(1)); + ExprNodeConstantDesc constDesc3 = new ExprNodeConstantDesc(new Integer(2)); + + // long column/column IF + GenericUDFIf udf = new GenericUDFIf(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(); + children1.add(col1Expr); + children1.add(col2Expr); + children1.add(col3Expr); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + columnMap.put("col3", 3); + VectorizationContext vc = new VectorizationContext(columnMap, 3); + VectorExpression ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // long column/scalar IF + children1.set(2, new ExprNodeConstantDesc(1L)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongScalar); + + // long scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(1L)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongScalar); + + // long scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongColumn); + + // test for double type + col2Expr = new ExprNodeColumnDesc(Double.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Double.class, "col3", "table", false); + + // double column/column IF + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleColumnDoubleColumn); + + // double column/scalar IF + children1.set(2, new ExprNodeConstantDesc(1D)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleColumnDoubleScalar); + + // double scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(1D)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarDoubleScalar); + + // double scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarDoubleColumn); + + // double scalar/long column IF + children1.set(2, new ExprNodeColumnDesc(Long.class, "col3", "table", false)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarLongColumn); + + // Additional combinations of (long,double)X(column,scalar) for each of the second + // and third arguments are omitted. We have coverage of all the source templates + // already. + + // test for timestamp type + col2Expr = new ExprNodeColumnDesc(Timestamp.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Timestamp.class, "col3", "table", false); + + // timestamp column/column IF + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // timestamp column/scalar IF where scalar is really a CAST of a constant to timestamp. + ExprNodeGenericFuncDesc f = new ExprNodeGenericFuncDesc(); + f.setGenericUDF(new GenericUDFTimestamp()); + f.setTypeInfo(TypeInfoFactory.timestampTypeInfo); + List children2 = new ArrayList(); + f.setChildren(children2); + children2.add(new ExprNodeConstantDesc("2013-11-05 00:00:00.000")); + children1.set(2, f); + ve = vc.getVectorExpression(exprDesc); + + // We check for two different classes below because initially the result + // is IfExprLongColumnLongColumn but in the future if the system is enhanced + // with constant folding then the result will be IfExprLongColumnLongScalar. + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongColumnLongScalar.class == ve.getClass()); + + // timestamp scalar/scalar + children1.set(1, f); + ve = vc.getVectorExpression(exprDesc); + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongScalarLongScalar.class == ve.getClass()); + + // timestamp scalar/column + children1.set(2, col3Expr); + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongScalarLongColumn.class == ve.getClass()); + + // test for boolean type + col2Expr = new ExprNodeColumnDesc(Boolean.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Boolean.class, "col3", "table", false); + + // column/column + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // column/scalar IF + children1.set(2, new ExprNodeConstantDesc(true)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongScalar); + + // scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(true)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongScalar); + + // scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongColumn); + + // test for string type + constDesc2 = new ExprNodeConstantDesc("Alpha"); + constDesc3 = new ExprNodeConstantDesc("Bravo"); + col2Expr = new ExprNodeColumnDesc(String.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(String.class, "col3", "table", false); + + // column/column + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprStringColumnStringColumn); + + // column/scalar + children1.set(2, constDesc3); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprStringColumnStringScalar); + + // scalar/scalar + children1.set(1, constDesc2); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprStringScalarStringScalar); + + // scalar/column + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprStringScalarStringColumn); + } } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java.orig b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java.orig new file mode 100644 index 0000000..1afadea --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java.orig @@ -0,0 +1,1164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import junit.framework.Assert; + +import org.apache.hadoop.hive.ql.exec.vector.expressions.ColAndCol; +import org.apache.hadoop.hive.ql.exec.vector.expressions.ColOrCol; +import org.apache.hadoop.hive.ql.exec.vector.expressions.DoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncLogWithBaseLongToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncPowerDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNotNull; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IsNull; +import org.apache.hadoop.hive.ql.exec.vector.expressions.LongColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.NotCol; +import org.apache.hadoop.hive.ql.exec.vector.expressions.RoundWithNumDigitsDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsFalse; +import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNotNull; +import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNull; +import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLTrim; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringLower; +import org.apache.hadoop.hive.ql.exec.vector.expressions.StringUpper; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorUDFUnixTimeStampLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorUDFYearLong; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterStringColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterLongColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterDoubleColumnInList; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprLongColumnLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprLongColumnLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprLongScalarLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprLongScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleColumnDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleColumnDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleScalarDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleScalarDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprDoubleScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColUnaryMinus; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColLessDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColumnBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColumnNotBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColEqualLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColGreaterLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColLessDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColumnBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColumnNotBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongScalarGreaterLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColGreaterStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColGreaterStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColumnBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColumnNotBetween; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncLnDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncRoundDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncSinDoubleToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColEqualLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColGreaterLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColUnaryMinus; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractLongColumn; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; +import org.apache.hadoop.hive.ql.udf.UDFLTrim; +import org.apache.hadoop.hive.ql.udf.UDFLog; +import org.apache.hadoop.hive.ql.udf.UDFSin; +import org.apache.hadoop.hive.ql.udf.UDFYear; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMinus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMod; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMultiply; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPower; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.junit.Test; + +public class TestVectorizationContext { + + @Test + public void testVectorExpressionDescriptor() { + VectorUDFUnixTimeStampLong v1 = new VectorUDFUnixTimeStampLong(); + VectorExpressionDescriptor.Builder builder1 = new VectorExpressionDescriptor.Builder(); + VectorExpressionDescriptor.Descriptor d1 = builder1.setMode(VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(1).setArgumentTypes(VectorExpressionDescriptor.ArgumentType.LONG) + .setInputExpressionTypes(VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + Assert.assertEquals(d1, v1.getDescriptor()); + + VectorExpressionDescriptor.Builder builder2 = new VectorExpressionDescriptor.Builder(); + VectorExpressionDescriptor.Descriptor d2 = builder2.setMode(VectorExpressionDescriptor.Mode.FILTER) + .setNumArguments(2).setArgumentTypes(VectorExpressionDescriptor.ArgumentType.LONG, + VectorExpressionDescriptor.ArgumentType.DOUBLE).setInputExpressionTypes( + VectorExpressionDescriptor.InputExpressionType.COLUMN, + VectorExpressionDescriptor.InputExpressionType.SCALAR).build(); + FilterLongColLessDoubleScalar v2 = new FilterLongColLessDoubleScalar(); + Assert.assertEquals(d2, v2.getDescriptor()); + + VectorExpressionDescriptor.Builder builder3 = new VectorExpressionDescriptor.Builder(); + VectorExpressionDescriptor.Descriptor d3 = builder3.setMode(VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(1).setArgumentTypes(VectorExpressionDescriptor.ArgumentType.STRING) + .setInputExpressionTypes(VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + StringLower v3 = new StringLower(); + Assert.assertEquals(d3, v3.getDescriptor()); + + VectorExpressionDescriptor.Builder builder4 = new VectorExpressionDescriptor.Builder(); + VectorExpressionDescriptor.Descriptor d4 = builder4.setMode(VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(1).setArgumentTypes(VectorExpressionDescriptor.ArgumentType.ANY) + .setInputExpressionTypes(VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + StringUpper v4 = new StringUpper(); + Assert.assertEquals(d4, v4.getDescriptor()); + + VectorExpressionDescriptor.Builder builder5 = new VectorExpressionDescriptor.Builder(); + VectorExpressionDescriptor.Descriptor d5 = builder5.setMode(VectorExpressionDescriptor.Mode.PROJECTION) + .setNumArguments(1).setArgumentTypes(VectorExpressionDescriptor.ArgumentType.STRING) + .setInputExpressionTypes(VectorExpressionDescriptor.InputExpressionType.COLUMN).build(); + IsNull v5 = new IsNull(); + Assert.assertEquals(d5, v5.getDescriptor()); + } + + @Test + public void testArithmeticExpressionVectorization() throws HiveException { + /** + * Create original expression tree for following + * (plus (minus (plus col1 col2) col3) (multiply col4 (mod col5 col6)) ) + */ + GenericUDFOPPlus udf1 = new GenericUDFOPPlus(); + GenericUDFOPMinus udf2 = new GenericUDFOPMinus(); + GenericUDFOPMultiply udf3 = new GenericUDFOPMultiply(); + GenericUDFOPPlus udf4 = new GenericUDFOPPlus(); + GenericUDFOPMod udf5 = new GenericUDFOPMod(); + + ExprNodeGenericFuncDesc sumExpr = new ExprNodeGenericFuncDesc(); + sumExpr.setTypeInfo(TypeInfoFactory.intTypeInfo); + sumExpr.setGenericUDF(udf1); + ExprNodeGenericFuncDesc minusExpr = new ExprNodeGenericFuncDesc(); + minusExpr.setTypeInfo(TypeInfoFactory.intTypeInfo); + minusExpr.setGenericUDF(udf2); + ExprNodeGenericFuncDesc multiplyExpr = new ExprNodeGenericFuncDesc(); + multiplyExpr.setTypeInfo(TypeInfoFactory.intTypeInfo); + multiplyExpr.setGenericUDF(udf3); + ExprNodeGenericFuncDesc sum2Expr = new ExprNodeGenericFuncDesc(); + sum2Expr.setTypeInfo(TypeInfoFactory.intTypeInfo); + sum2Expr.setGenericUDF(udf4); + ExprNodeGenericFuncDesc modExpr = new ExprNodeGenericFuncDesc(); + modExpr.setTypeInfo(TypeInfoFactory.intTypeInfo); + modExpr.setGenericUDF(udf5); + + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Long.class, "col1", "table", false); + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Long.class, "col2", "table", false); + ExprNodeColumnDesc col3Expr = new ExprNodeColumnDesc(Long.class, "col3", "table", false); + ExprNodeColumnDesc col4Expr = new ExprNodeColumnDesc(Long.class, "col4", "table", false); + ExprNodeColumnDesc col5Expr = new ExprNodeColumnDesc(Long.class, "col5", "table", false); + ExprNodeColumnDesc col6Expr = new ExprNodeColumnDesc(Long.class, "col6", "table", false); + + List children1 = new ArrayList(2); + List children2 = new ArrayList(2); + List children3 = new ArrayList(2); + List children4 = new ArrayList(2); + List children5 = new ArrayList(2); + + children1.add(minusExpr); + children1.add(multiplyExpr); + sumExpr.setChildren(children1); + + children2.add(sum2Expr); + children2.add(col3Expr); + minusExpr.setChildren(children2); + + children3.add(col1Expr); + children3.add(col2Expr); + sum2Expr.setChildren(children3); + + children4.add(col4Expr); + children4.add(modExpr); + multiplyExpr.setChildren(children4); + + children5.add(col5Expr); + children5.add(col6Expr); + modExpr.setChildren(children5); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + columnMap.put("col3", 3); + columnMap.put("col4", 4); + columnMap.put("col5", 5); + columnMap.put("col6", 6); + + //Generate vectorized expression + VectorizationContext vc = new VectorizationContext(columnMap, 6); + VectorExpression ve = vc.getVectorExpression(sumExpr, VectorExpressionDescriptor.Mode.PROJECTION); + + //Verify vectorized expression + assertTrue(ve instanceof LongColAddLongColumn); + assertEquals(2, ve.getChildExpressions().length); + VectorExpression childExpr1 = ve.getChildExpressions()[0]; + VectorExpression childExpr2 = ve.getChildExpressions()[1]; + System.out.println(ve.toString()); + assertEquals(6, ve.getOutputColumn()); + + assertTrue(childExpr1 instanceof LongColSubtractLongColumn); + assertEquals(1, childExpr1.getChildExpressions().length); + assertTrue(childExpr1.getChildExpressions()[0] instanceof LongColAddLongColumn); + assertEquals(7, childExpr1.getOutputColumn()); + assertEquals(6, childExpr1.getChildExpressions()[0].getOutputColumn()); + + assertTrue(childExpr2 instanceof LongColMultiplyLongColumn); + assertEquals(1, childExpr2.getChildExpressions().length); + assertTrue(childExpr2.getChildExpressions()[0] instanceof LongColModuloLongColumn); + assertEquals(8, childExpr2.getOutputColumn()); + assertEquals(6, childExpr2.getChildExpressions()[0].getOutputColumn()); + } + + @Test + public void testStringFilterExpressions() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(String.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc("Alpha"); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + + assertTrue(ve instanceof FilterStringColGreaterStringScalar); + } + + @Test + public void testFilterStringColCompareStringColumnExpressions() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(String.class, "col1", "table", false); + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(String.class, "col2", "table", false); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(col2Expr); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + + assertTrue(ve instanceof FilterStringColGreaterStringColumn); + } + + @Test + public void testFloatInExpressions() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Float.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPPlus udf = new GenericUDFOPPlus(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + + assertTrue(ve.getOutputType().equalsIgnoreCase("double")); + } + + @Test + public void testVectorizeFilterAndOrExpression() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Float.class, "col2", "table", false); + ExprNodeConstantDesc const2Desc = new ExprNodeConstantDesc(new Float(1.0)); + + GenericUDFOPLessThan udf2 = new GenericUDFOPLessThan(); + ExprNodeGenericFuncDesc lessExprDesc = new ExprNodeGenericFuncDesc(); + lessExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + lessExprDesc.setGenericUDF(udf2); + List children2 = new ArrayList(2); + children2.add(col2Expr); + children2.add(const2Desc); + lessExprDesc.setChildren(children2); + + GenericUDFOPAnd andUdf = new GenericUDFOPAnd(); + ExprNodeGenericFuncDesc andExprDesc = new ExprNodeGenericFuncDesc(); + andExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + andExprDesc.setGenericUDF(andUdf); + List children3 = new ArrayList(2); + children3.add(greaterExprDesc); + children3.add(lessExprDesc); + andExprDesc.setChildren(children3); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + columnMap.put("col2", 1); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(andExprDesc, VectorExpressionDescriptor.Mode.FILTER); + + assertEquals(ve.getClass(), FilterExprAndExpr.class); + assertEquals(ve.getChildExpressions()[0].getClass(), FilterLongColGreaterLongScalar.class); + assertEquals(ve.getChildExpressions()[1].getClass(), FilterDoubleColLessDoubleScalar.class); + + GenericUDFOPOr orUdf = new GenericUDFOPOr(); + ExprNodeGenericFuncDesc orExprDesc = new ExprNodeGenericFuncDesc(); + orExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + orExprDesc.setGenericUDF(orUdf); + List children4 = new ArrayList(2); + children4.add(greaterExprDesc); + children4.add(lessExprDesc); + orExprDesc.setChildren(children4); + VectorExpression veOr = vc.getVectorExpression(orExprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertEquals(veOr.getClass(), FilterExprOrExpr.class); + assertEquals(veOr.getChildExpressions()[0].getClass(), FilterLongColGreaterLongScalar.class); + assertEquals(veOr.getChildExpressions()[1].getClass(), FilterDoubleColLessDoubleScalar.class); + } + + @Test + public void testVectorizeAndOrProjectionExpression() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Boolean.class, "col2", "table", false); + + GenericUDFOPAnd andUdf = new GenericUDFOPAnd(); + ExprNodeGenericFuncDesc andExprDesc = new ExprNodeGenericFuncDesc(); + andExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + andExprDesc.setGenericUDF(andUdf); + List children3 = new ArrayList(2); + children3.add(greaterExprDesc); + children3.add(col2Expr); + andExprDesc.setChildren(children3); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + columnMap.put("col2", 1); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression veAnd = vc.getVectorExpression(andExprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertEquals(veAnd.getClass(), FilterExprAndExpr.class); + assertEquals(veAnd.getChildExpressions()[0].getClass(), FilterLongColGreaterLongScalar.class); + assertEquals(veAnd.getChildExpressions()[1].getClass(), SelectColumnIsTrue.class); + + veAnd = vc.getVectorExpression(andExprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(veAnd.getClass(), ColAndCol.class); + assertEquals(1, veAnd.getChildExpressions().length); + assertEquals(veAnd.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + assertEquals(2, ((ColAndCol) veAnd).getColNum1()); + assertEquals(1, ((ColAndCol) veAnd).getColNum2()); + assertEquals(3, ((ColAndCol) veAnd).getOutputColumn()); + + //OR + GenericUDFOPOr orUdf = new GenericUDFOPOr(); + ExprNodeGenericFuncDesc orExprDesc = new ExprNodeGenericFuncDesc(); + orExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + orExprDesc.setGenericUDF(orUdf); + List children4 = new ArrayList(2); + children4.add(greaterExprDesc); + children4.add(col2Expr); + orExprDesc.setChildren(children4); + + //Allocate new Vectorization context to reset the intermediate columns. + vc = new VectorizationContext(columnMap, 2); + VectorExpression veOr = vc.getVectorExpression(orExprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertEquals(veOr.getClass(), FilterExprOrExpr.class); + assertEquals(veOr.getChildExpressions()[0].getClass(), FilterLongColGreaterLongScalar.class); + assertEquals(veOr.getChildExpressions()[1].getClass(), SelectColumnIsTrue.class); + + veOr = vc.getVectorExpression(orExprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(veOr.getClass(), ColOrCol.class); + assertEquals(1, veAnd.getChildExpressions().length); + assertEquals(veAnd.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + assertEquals(2, ((ColOrCol) veOr).getColNum1()); + assertEquals(1, ((ColOrCol) veOr).getColNum2()); + assertEquals(3, ((ColOrCol) veOr).getOutputColumn()); + } + + @Test + public void testNotExpression() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + ExprNodeGenericFuncDesc notExpr = new ExprNodeGenericFuncDesc(); + notExpr.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + GenericUDFOPNot notUdf = new GenericUDFOPNot(); + notExpr.setGenericUDF(notUdf); + List childOfNot = new ArrayList(); + childOfNot.add(greaterExprDesc); + notExpr.setChildren(childOfNot); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + columnMap.put("col2", 1); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(notExpr, VectorExpressionDescriptor.Mode.FILTER); + + assertEquals(ve.getClass(), SelectColumnIsFalse.class); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + + ve = vc.getVectorExpression(notExpr, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(ve.getClass(), NotCol.class); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + } + + @Test + public void testNullExpressions() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + ExprNodeGenericFuncDesc isNullExpr = new ExprNodeGenericFuncDesc(); + isNullExpr.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + GenericUDFOPNull isNullUdf = new GenericUDFOPNull(); + isNullExpr.setGenericUDF(isNullUdf); + List childOfIsNull = new ArrayList(); + childOfIsNull.add(greaterExprDesc); + isNullExpr.setChildren(childOfIsNull); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + columnMap.put("col2", 1); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(isNullExpr, VectorExpressionDescriptor.Mode.FILTER); + + assertEquals(ve.getClass(), SelectColumnIsNull.class); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + assertEquals(2, ve.getChildExpressions()[0].getOutputColumn()); + assertEquals(2, ((SelectColumnIsNull) ve).getColNum()); + + ve = vc.getVectorExpression(isNullExpr, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(ve.getClass(), IsNull.class); + assertEquals(2, ((IsNull) ve).getColNum()); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + } + + @Test + public void testNotNullExpressions() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc greaterExprDesc = new ExprNodeGenericFuncDesc(); + greaterExprDesc.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + greaterExprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + greaterExprDesc.setChildren(children1); + + ExprNodeGenericFuncDesc isNotNullExpr = new ExprNodeGenericFuncDesc(); + isNotNullExpr.setTypeInfo(TypeInfoFactory.booleanTypeInfo); + GenericUDFOPNotNull notNullUdf = new GenericUDFOPNotNull(); + isNotNullExpr.setGenericUDF(notNullUdf); + List childOfNot = new ArrayList(); + childOfNot.add(greaterExprDesc); + isNotNullExpr.setChildren(childOfNot); + + Map columnMap = new HashMap(); + columnMap.put("col1", 0); + columnMap.put("col2", 1); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(isNotNullExpr, VectorExpressionDescriptor.Mode.FILTER); + + assertEquals(ve.getClass(), SelectColumnIsNotNull.class); + assertEquals(2, ((SelectColumnIsNotNull) ve).getColNum()); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + + ve = vc.getVectorExpression(isNotNullExpr, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(ve.getClass(), IsNotNull.class); + assertEquals(2, ((IsNotNull) ve).getColNum()); + assertEquals(ve.getChildExpressions()[0].getClass(), LongColGreaterLongScalar.class); + } + + @Test + public void testVectorizeScalarColumnExpression() throws HiveException { + ExprNodeGenericFuncDesc scalarMinusConstant = new ExprNodeGenericFuncDesc(); + GenericUDFOPMinus gudf = new GenericUDFOPMinus(); + scalarMinusConstant.setGenericUDF(gudf); + List children = new ArrayList(2); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, 20); + ExprNodeColumnDesc colDesc = new ExprNodeColumnDesc(Long.class, "a", "table", false); + + children.add(constDesc); + children.add(colDesc); + + scalarMinusConstant.setChildren(children); + + Map columnMap = new HashMap(); + columnMap.put("a", 0); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(scalarMinusConstant, VectorExpressionDescriptor.Mode.PROJECTION); + + assertEquals(ve.getClass(), LongScalarSubtractLongColumn.class); + } + + @Test + public void testFilterWithNegativeScalar() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(new Integer(-10)); + + GenericUDFOPGreaterThan udf = new GenericUDFOPGreaterThan(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(2); + children1.add(col1Expr); + children1.add(constDesc); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + + assertTrue(ve instanceof FilterLongColGreaterLongScalar); + } + + @Test + public void testUnaryMinusColumnLong() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeGenericFuncDesc negExprDesc = new ExprNodeGenericFuncDesc(); + GenericUDF gudf = new GenericUDFOPNegative(); + negExprDesc.setGenericUDF(gudf); + List children = new ArrayList(1); + children.add(col1Expr); + negExprDesc.setChildren(children); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + + VectorExpression ve = vc.getVectorExpression(negExprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + + assertTrue( ve instanceof LongColUnaryMinus); + } + + @Test + public void testUnaryMinusColumnDouble() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Float.class, "col1", "table", false); + ExprNodeGenericFuncDesc negExprDesc = new ExprNodeGenericFuncDesc(); + GenericUDF gudf = new GenericUDFOPNegative(); + negExprDesc.setGenericUDF(gudf); + List children = new ArrayList(1); + children.add(col1Expr); + negExprDesc.setChildren(children); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + + VectorExpression ve = vc.getVectorExpression(negExprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + + assertTrue( ve instanceof DoubleColUnaryMinus); + } + + @Test + public void testFilterScalarCompareColumn() throws HiveException { + ExprNodeGenericFuncDesc scalarGreaterColExpr = new ExprNodeGenericFuncDesc(); + GenericUDFOPGreaterThan gudf = new GenericUDFOPGreaterThan(); + scalarGreaterColExpr.setGenericUDF(gudf); + List children = new ArrayList(2); + ExprNodeConstantDesc constDesc = + new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, 20); + ExprNodeColumnDesc colDesc = + new ExprNodeColumnDesc(Long.class, "a", "table", false); + + children.add(constDesc); + children.add(colDesc); + + scalarGreaterColExpr.setChildren(children); + + Map columnMap = new HashMap(); + columnMap.put("a", 0); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(scalarGreaterColExpr, VectorExpressionDescriptor.Mode.FILTER); + assertEquals(FilterLongScalarGreaterLongColumn.class, ve.getClass()); + } + + @Test + public void testFilterBooleanColumnCompareBooleanScalar() throws HiveException { + ExprNodeGenericFuncDesc colEqualScalar = new ExprNodeGenericFuncDesc(); + GenericUDFOPEqual gudf = new GenericUDFOPEqual(); + colEqualScalar.setGenericUDF(gudf); + List children = new ArrayList(2); + ExprNodeConstantDesc constDesc = + new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, 20); + ExprNodeColumnDesc colDesc = + new ExprNodeColumnDesc(Boolean.class, "a", "table", false); + + children.add(colDesc); + children.add(constDesc); + + colEqualScalar.setChildren(children); + + Map columnMap = new HashMap(); + columnMap.put("a", 0); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(colEqualScalar, VectorExpressionDescriptor.Mode.FILTER); + assertEquals(FilterLongColEqualLongScalar.class, ve.getClass()); + } + + @Test + public void testBooleanColumnCompareBooleanScalar() throws HiveException { + ExprNodeGenericFuncDesc colEqualScalar = new ExprNodeGenericFuncDesc(); + GenericUDFOPEqual gudf = new GenericUDFOPEqual(); + colEqualScalar.setGenericUDF(gudf); + List children = new ArrayList(2); + ExprNodeConstantDesc constDesc = + new ExprNodeConstantDesc(TypeInfoFactory.booleanTypeInfo, 20); + ExprNodeColumnDesc colDesc = + new ExprNodeColumnDesc(Boolean.class, "a", "table", false); + + children.add(colDesc); + children.add(constDesc); + + colEqualScalar.setChildren(children); + + Map columnMap = new HashMap(); + columnMap.put("a", 0); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(colEqualScalar, VectorExpressionDescriptor.Mode.PROJECTION); + assertEquals(LongColEqualLongScalar.class, ve.getClass()); + } + + @Test + public void testUnaryStringExpressions() throws HiveException { + ExprNodeGenericFuncDesc stringUnary = new ExprNodeGenericFuncDesc(); + stringUnary.setTypeInfo(TypeInfoFactory.stringTypeInfo); + ExprNodeColumnDesc colDesc = new ExprNodeColumnDesc(String.class, "a", "table", false); + List children = new ArrayList(); + children.add(colDesc); + stringUnary.setChildren(children); + + Map columnMap = new HashMap(); + columnMap.put("b", 0); + columnMap.put("a", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + GenericUDF stringLower = new GenericUDFLower(); + stringUnary.setGenericUDF(stringLower); + + VectorExpression ve = vc.getVectorExpression(stringUnary); + + assertEquals(StringLower.class, ve.getClass()); + assertEquals(1, ((StringLower) ve).getColNum()); + assertEquals(2, ((StringLower) ve).getOutputColumn()); + + vc = new VectorizationContext(columnMap, 2); + + ExprNodeGenericFuncDesc anotherUnary = new ExprNodeGenericFuncDesc(); + anotherUnary.setTypeInfo(TypeInfoFactory.stringTypeInfo); + List children2 = new ArrayList(); + children2.add(stringUnary); + anotherUnary.setChildren(children2); + GenericUDFBridge udfbridge = new GenericUDFBridge("ltrim", false, UDFLTrim.class.getName()); + anotherUnary.setGenericUDF(udfbridge); + + ve = vc.getVectorExpression(anotherUnary); + VectorExpression childVe = ve.getChildExpressions()[0]; + assertEquals(StringLower.class, childVe.getClass()); + assertEquals(1, ((StringLower) childVe).getColNum()); + assertEquals(2, ((StringLower) childVe).getOutputColumn()); + + assertEquals(StringLTrim.class, ve.getClass()); + assertEquals(2, ((StringLTrim) ve).getInputColumn()); + assertEquals(3, ((StringLTrim) ve).getOutputColumn()); + } + + @Test + public void testMathFunctions() throws HiveException { + ExprNodeGenericFuncDesc mathFuncExpr = new ExprNodeGenericFuncDesc(); + mathFuncExpr.setTypeInfo(TypeInfoFactory.doubleTypeInfo); + ExprNodeColumnDesc colDesc1 = new ExprNodeColumnDesc(Integer.class, "a", "table", false); + ExprNodeColumnDesc colDesc2 = new ExprNodeColumnDesc(Double.class, "b", "table", false); + List children1 = new ArrayList(); + List children2 = new ArrayList(); + children1.add(colDesc1); + children2.add(colDesc2); + + Map columnMap = new HashMap(); + columnMap.put("b", 0); + columnMap.put("a", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + // Sin(double) + GenericUDFBridge gudfBridge = new GenericUDFBridge("sin", false, UDFSin.class.getName()); + mathFuncExpr.setGenericUDF(gudfBridge); + mathFuncExpr.setChildren(children2); + VectorExpression ve = vc.getVectorExpression(mathFuncExpr, VectorExpressionDescriptor.Mode.PROJECTION); + Assert.assertEquals(FuncSinDoubleToDouble.class, ve.getClass()); + + // Round without digits + GenericUDFRound udfRound = new GenericUDFRound(); + mathFuncExpr.setGenericUDF(udfRound); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncRoundDoubleToDouble.class, ve.getClass()); + + // Round with digits + mathFuncExpr.setGenericUDF(udfRound); + children2.add(new ExprNodeConstantDesc(4)); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(RoundWithNumDigitsDoubleToDouble.class, ve.getClass()); + Assert.assertEquals(4, ((RoundWithNumDigitsDoubleToDouble) ve).getDecimalPlaces().get()); + + // Log with int base + gudfBridge = new GenericUDFBridge("log", false, UDFLog.class.getName()); + mathFuncExpr.setGenericUDF(gudfBridge); + children2.clear(); + children2.add(new ExprNodeConstantDesc(4.0)); + children2.add(colDesc2); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncLogWithBaseDoubleToDouble.class, ve.getClass()); + Assert.assertTrue(4 == ((FuncLogWithBaseDoubleToDouble) ve).getBase()); + + // Log with default base + children2.clear(); + children2.add(colDesc2); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncLnDoubleToDouble.class, ve.getClass()); + + //Log with double base + children2.clear(); + children2.add(new ExprNodeConstantDesc(4.5)); + children2.add(colDesc2); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncLogWithBaseDoubleToDouble.class, ve.getClass()); + Assert.assertTrue(4.5 == ((FuncLogWithBaseDoubleToDouble) ve).getBase()); + + //Log with int input and double base + children2.clear(); + children2.add(new ExprNodeConstantDesc(4.5)); + children2.add(colDesc1); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncLogWithBaseLongToDouble.class, ve.getClass()); + Assert.assertTrue(4.5 == ((FuncLogWithBaseLongToDouble) ve).getBase()); + + //Power with double power + children2.clear(); + children2.add(colDesc2); + children2.add(new ExprNodeConstantDesc(4.5)); + mathFuncExpr.setGenericUDF(new GenericUDFPower()); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncPowerDoubleToDouble.class, ve.getClass()); + Assert.assertTrue(4.5 == ((FuncPowerDoubleToDouble) ve).getPower()); + + //Round with default decimal places + mathFuncExpr.setGenericUDF(udfRound); + children2.clear(); + children2.add(colDesc2); + mathFuncExpr.setChildren(children2); + ve = vc.getVectorExpression(mathFuncExpr); + Assert.assertEquals(FuncRoundDoubleToDouble.class, ve.getClass()); + } + + @Test + public void testTimeStampUdfs() throws HiveException { + ExprNodeGenericFuncDesc tsFuncExpr = new ExprNodeGenericFuncDesc(); + tsFuncExpr.setTypeInfo(TypeInfoFactory.intTypeInfo); + ExprNodeColumnDesc colDesc1 = new ExprNodeColumnDesc( + TypeInfoFactory.timestampTypeInfo, "a", "table", false); + List children = new ArrayList(); + children.add(colDesc1); + + Map columnMap = new HashMap(); + columnMap.put("b", 0); + columnMap.put("a", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + + //UDFYear + GenericUDFBridge gudfBridge = new GenericUDFBridge("year", false, UDFYear.class.getName()); + tsFuncExpr.setGenericUDF(gudfBridge); + tsFuncExpr.setChildren(children); + VectorExpression ve = vc.getVectorExpression(tsFuncExpr); + Assert.assertEquals(VectorUDFYearLong.class, ve.getClass()); + + //GenericUDFToUnixTimeStamp + GenericUDFToUnixTimeStamp gudf = new GenericUDFToUnixTimeStamp(); + tsFuncExpr.setGenericUDF(gudf); + tsFuncExpr.setTypeInfo(TypeInfoFactory.longTypeInfo); + ve = vc.getVectorExpression(tsFuncExpr); + Assert.assertEquals(VectorUDFUnixTimeStampLong.class, ve.getClass()); + } + + @Test + public void testBetweenFilters() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(String.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc("Alpha"); + ExprNodeConstantDesc constDesc2 = new ExprNodeConstantDesc("Bravo"); + + // string BETWEEN + GenericUDFBetween udf = new GenericUDFBetween(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(); + children1.add(new ExprNodeConstantDesc(new Boolean(false))); // no NOT keyword + children1.add(col1Expr); + children1.add(constDesc); + children1.add(constDesc2); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterStringColumnBetween); + + // string NOT BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(true))); // has NOT keyword + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterStringColumnNotBetween); + + // long BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(false))); + children1.set(1, new ExprNodeColumnDesc(Long.class, "col1", "table", false)); + children1.set(2, new ExprNodeConstantDesc(10)); + children1.set(3, new ExprNodeConstantDesc(20)); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterLongColumnBetween); + + // long NOT BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(true))); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterLongColumnNotBetween); + + // double BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(false))); + children1.set(1, new ExprNodeColumnDesc(Double.class, "col1", "table", false)); + children1.set(2, new ExprNodeConstantDesc(10.0d)); + children1.set(3, new ExprNodeConstantDesc(20.0d)); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterDoubleColumnBetween); + + // double NOT BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(true))); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterDoubleColumnNotBetween); + + // timestamp BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(false))); + children1.set(1, new ExprNodeColumnDesc(Timestamp.class, "col1", "table", false)); + children1.set(2, new ExprNodeConstantDesc("2013-11-05 00:00:00.000")); + children1.set(3, new ExprNodeConstantDesc("2013-11-06 00:00:00.000")); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterLongColumnBetween); + + // timestamp NOT BETWEEN + children1.set(0, new ExprNodeConstantDesc(new Boolean(true))); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterLongColumnNotBetween); + } + + // Test translation of both IN filters and boolean-valued IN expressions (non-filters). + @Test + public void testInFiltersAndExprs() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(String.class, "col1", "table", false); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc("Alpha"); + ExprNodeConstantDesc constDesc2 = new ExprNodeConstantDesc("Bravo"); + + // string IN + GenericUDFIn udf = new GenericUDFIn(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(); + children1.add(col1Expr); + children1.add(constDesc); + children1.add(constDesc2); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterStringColumnInList); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + assertTrue(ve instanceof StringColumnInList); + + // long IN + children1.set(0, new ExprNodeColumnDesc(Long.class, "col1", "table", false)); + children1.set(1, new ExprNodeConstantDesc(10)); + children1.set(2, new ExprNodeConstantDesc(20)); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterLongColumnInList); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + assertTrue(ve instanceof LongColumnInList); + + // double IN + children1.set(0, new ExprNodeColumnDesc(Double.class, "col1", "table", false)); + children1.set(1, new ExprNodeConstantDesc(10d)); + children1.set(2, new ExprNodeConstantDesc(20d)); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.FILTER); + assertTrue(ve instanceof FilterDoubleColumnInList); + ve = vc.getVectorExpression(exprDesc, VectorExpressionDescriptor.Mode.PROJECTION); + assertTrue(ve instanceof DoubleColumnInList); + } + + /** + * Test that correct VectorExpression classes are chosen for the + * IF (expr1, expr2, expr3) conditional expression for integer, float, + * boolean and timestamp input types. expr1 is always an input column expression of type + * long. expr2 and expr3 can be column expressions or constants of other types + * but must have the same type. + */ + @Test + public void testIfConditionalExprs() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Long.class, "col1", "table", false); + ExprNodeColumnDesc col2Expr = new ExprNodeColumnDesc(Long.class, "col2", "table", false); + ExprNodeColumnDesc col3Expr = new ExprNodeColumnDesc(Long.class, "col3", "table", false); + + ExprNodeConstantDesc constDesc2 = new ExprNodeConstantDesc(new Integer(1)); + ExprNodeConstantDesc constDesc3 = new ExprNodeConstantDesc(new Integer(2)); + + // long column/column IF + GenericUDFIf udf = new GenericUDFIf(); + ExprNodeGenericFuncDesc exprDesc = new ExprNodeGenericFuncDesc(); + exprDesc.setGenericUDF(udf); + List children1 = new ArrayList(); + children1.add(col1Expr); + children1.add(col2Expr); + children1.add(col3Expr); + exprDesc.setChildren(children1); + + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + columnMap.put("col2", 2); + columnMap.put("col3", 3); + VectorizationContext vc = new VectorizationContext(columnMap, 3); + VectorExpression ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // long column/scalar IF + children1.set(2, new ExprNodeConstantDesc(1L)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongScalar); + + // long scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(1L)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongScalar); + + // long scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongColumn); + + // test for double type + col2Expr = new ExprNodeColumnDesc(Double.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Double.class, "col3", "table", false); + + // double column/column IF + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleColumnDoubleColumn); + + // double column/scalar IF + children1.set(2, new ExprNodeConstantDesc(1D)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleColumnDoubleScalar); + + // double scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(1D)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarDoubleScalar); + + // double scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarDoubleColumn); + + // double scalar/long column IF + children1.set(2, new ExprNodeColumnDesc(Long.class, "col3", "table", false)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprDoubleScalarLongColumn); + + // Additional combinations of (long,double)X(column,scalar) for each of the second + // and third arguments are omitted. We have coverage of all the source templates + // already. + + // test for timestamp type + col2Expr = new ExprNodeColumnDesc(Timestamp.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Timestamp.class, "col3", "table", false); + + // timestamp column/column IF + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // timestamp column/scalar IF where scalar is really a CAST of a constant to timestamp. + ExprNodeGenericFuncDesc f = new ExprNodeGenericFuncDesc(); + f.setGenericUDF(new GenericUDFTimestamp()); + f.setTypeInfo(TypeInfoFactory.timestampTypeInfo); + List children2 = new ArrayList(); + f.setChildren(children2); + children2.add(new ExprNodeConstantDesc("2013-11-05 00:00:00.000")); + children1.set(2, f); + ve = vc.getVectorExpression(exprDesc); + + // We check for two different classes below because initially the result + // is IfExprLongColumnLongColumn but in the future if the system is enhanced + // with constant folding then the result will be IfExprLongColumnLongScalar. + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongColumnLongScalar.class == ve.getClass()); + + // timestamp scalar/scalar + children1.set(1, f); + ve = vc.getVectorExpression(exprDesc); + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongScalarLongScalar.class == ve.getClass()); + + // timestamp scalar/column + children1.set(2, col3Expr); + assertTrue(IfExprLongColumnLongColumn.class == ve.getClass() + || IfExprLongScalarLongColumn.class == ve.getClass()); + + // test for boolean type + col2Expr = new ExprNodeColumnDesc(Boolean.class, "col2", "table", false); + col3Expr = new ExprNodeColumnDesc(Boolean.class, "col3", "table", false); + + // column/column + children1.set(1, col2Expr); + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongColumn); + + // long column/scalar IF + children1.set(2, new ExprNodeConstantDesc(true)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongColumnLongScalar); + + // long scalar/scalar IF + children1.set(1, new ExprNodeConstantDesc(true)); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongScalar); + + // long scalar/column IF + children1.set(2, col3Expr); + ve = vc.getVectorExpression(exprDesc); + assertTrue(ve instanceof IfExprLongScalarLongColumn); + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizedRowBatch.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizedRowBatch.java index a250c9d..51a73c1 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizedRowBatch.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizedRowBatch.java @@ -218,4 +218,64 @@ public static void setRepeatingDoubleCol(DoubleColumnVector col) { col.isRepeating = true; col.vector[0] = 50.0; } + + @Test + public void testFlatten() { + verifyFlatten(new LongColumnVector()); + verifyFlatten(new DoubleColumnVector()); + verifyFlatten(new BytesColumnVector()); + } + + private void verifyFlatten(ColumnVector v) { + + // verify that flattening and unflattenting no-nulls works + v.noNulls = true; + v.isNull[1] = true; + int[] sel = {0, 2}; + int size = 2; + v.flatten(true, sel, size); + Assert.assertFalse(v.noNulls); + Assert.assertFalse(v.isNull[0] || v.isNull[2]); + v.unFlatten(); + Assert.assertTrue(v.noNulls); + + // verify that flattening and unflattening "isRepeating" works + v.isRepeating = true; + v.noNulls = false; + v.isNull[0] = true; + v.flatten(true, sel, 2); + Assert.assertFalse(v.noNulls); + Assert.assertTrue(v.isNull[0] && v.isNull[2]); + Assert.assertFalse(v.isRepeating); + v.unFlatten(); + Assert.assertFalse(v.noNulls); + Assert.assertTrue(v.isRepeating); + + // verify extension of values in the array + v.noNulls = true; + if (v instanceof LongColumnVector) { + ((LongColumnVector) v).vector[0] = 100; + v.flatten(true, sel, 2); + Assert.assertTrue(((LongColumnVector) v).vector[2] == 100); + } else if (v instanceof DoubleColumnVector) { + ((DoubleColumnVector) v).vector[0] = 200d; + v.flatten(true, sel, 2); + Assert.assertTrue(((DoubleColumnVector) v).vector[2] == 200d); + } else if (v instanceof BytesColumnVector) { + BytesColumnVector bv = (BytesColumnVector) v; + byte[] b = null; + try { + b = "foo".getBytes("UTF-8"); + } catch (Exception e) { + ; // eat it + } + bv.setRef(0, b, 0, b.length); + bv.flatten(true, sel, 2); + Assert.assertEquals(bv.vector[0], bv.vector[2]); + Assert.assertEquals(bv.start[0], bv.start[2]); + Assert.assertEquals(bv.length[0], bv.length[2]); + } + } + + } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorConditionalExpressions.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorConditionalExpressions.java new file mode 100644 index 0000000..3914245 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorConditionalExpressions.java @@ -0,0 +1,517 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import static org.junit.Assert.*; + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongColumnLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprLongScalarLongScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleScalarDoubleColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.IfExprDoubleColumnDoubleScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringColumnStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringScalar; +import org.apache.hadoop.hive.ql.exec.vector.expressions.IfExprStringScalarStringColumn; + +import org.junit.Test; + +/** + * Test vectorized conditional expression handling. + */ +public class TestVectorConditionalExpressions { + + private VectorizedRowBatch getBatch4LongVectors() { + VectorizedRowBatch batch = new VectorizedRowBatch(4); + LongColumnVector v = new LongColumnVector(); + + // set first argument to IF -- boolean flag + v.vector[0] = 0; + v.vector[1] = 0; + v.vector[2] = 1; + v.vector[3] = 1; + batch.cols[0] = v; + + // set second argument to IF + v = new LongColumnVector(); + v.vector[0] = -1; + v.vector[1] = -2; + v.vector[2] = -3; + v.vector[3] = -4; + batch.cols[1] = v; + + // set third argument to IF + v = new LongColumnVector(); + v.vector[0] = 1; + v.vector[1] = 2; + v.vector[2] = 3; + v.vector[3] = 4; + batch.cols[2] = v; + + // set output column + batch.cols[3] = new LongColumnVector(); + + batch.size = 4; + return batch; + } + + private VectorizedRowBatch getBatch1Long3DoubleVectors() { + VectorizedRowBatch batch = new VectorizedRowBatch(4); + LongColumnVector lv = new LongColumnVector(); + + // set first argument to IF -- boolean flag + lv.vector[0] = 0; + lv.vector[1] = 0; + lv.vector[2] = 1; + lv.vector[3] = 1; + batch.cols[0] = lv; + + // set second argument to IF + DoubleColumnVector v = new DoubleColumnVector(); + v.vector[0] = -1; + v.vector[1] = -2; + v.vector[2] = -3; + v.vector[3] = -4; + batch.cols[1] = v; + + // set third argument to IF + v = new DoubleColumnVector(); + v.vector[0] = 1; + v.vector[1] = 2; + v.vector[2] = 3; + v.vector[3] = 4; + batch.cols[2] = v; + + // set output column + batch.cols[3] = new DoubleColumnVector(); + + batch.size = 4; + return batch; + } + + private VectorizedRowBatch getBatch1Long3BytesVectors() { + VectorizedRowBatch batch = new VectorizedRowBatch(4); + LongColumnVector lv = new LongColumnVector(); + + // set first argument to IF -- boolean flag + lv.vector[0] = 0; + lv.vector[1] = 0; + lv.vector[2] = 1; + lv.vector[3] = 1; + batch.cols[0] = lv; + + // set second argument to IF + BytesColumnVector v = new BytesColumnVector(); + v.initBuffer(); + setString(v, 0, "arg2_0"); + setString(v, 1, "arg2_1"); + setString(v, 2, "arg2_2"); + setString(v, 3, "arg2_3"); + + batch.cols[1] = v; + + // set third argument to IF + v = new BytesColumnVector(); + v.initBuffer(); + setString(v, 0, "arg3_0"); + setString(v, 1, "arg3_1"); + setString(v, 2, "arg3_2"); + setString(v, 3, "arg3_3"); + batch.cols[2] = v; + + // set output column + v = new BytesColumnVector(); + v.initBuffer(); + batch.cols[3] = v; + batch.size = 4; + return batch; + } + + private void setString(BytesColumnVector v, int i, String s) { + byte[] b = getUTF8Bytes(s); + v.setVal(i, b, 0, b.length); + } + + private byte[] getUTF8Bytes(String s) { + byte[] b = null; + try { + b = s.getBytes("UTF-8"); + } catch (Exception e) { + ; // eat it + } + return b; + } + + private String getString(BytesColumnVector v, int i) { + String s = null; + try { + s = new String(v.vector[i], v.start[i], v.length[i], "UTF-8"); + } catch (Exception e) { + ; // eat it + } + return s; + } + + @Test + public void testLongColumnColumnIfExpr() { + VectorizedRowBatch batch = getBatch4LongVectors(); + VectorExpression expr = new IfExprLongColumnLongColumn(0, 1, 2, 3); + expr.evaluate(batch); + + // get result vector + LongColumnVector r = (LongColumnVector) batch.cols[3]; + + // verify standard case + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(-3, r.vector[2]); + assertEquals(-4, r.vector[3]); + assertEquals(true, r.noNulls); + assertEquals(false, r.isRepeating); + + // verify when first argument (boolean flags) is repeating + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[0].isRepeating = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(4, r.vector[3]); + + // verify when second argument is repeating + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[1].isRepeating = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(-1, r.vector[2]); + assertEquals(-1, r.vector[3]); + + // verify when third argument is repeating + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[2].isRepeating = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(1, r.vector[1]); + assertEquals(-3, r.vector[2]); + assertEquals(-4, r.vector[3]); + + // test when first argument has nulls + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[0].noNulls = false; + batch.cols[0].isNull[1] = true; + batch.cols[0].isNull[2] = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(3, r.vector[2]); + assertEquals(-4, r.vector[3]); + assertEquals(true, r.noNulls); + assertEquals(false, r.isRepeating); + + // test when second argument has nulls + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[1].noNulls = false; + batch.cols[1].isNull[1] = true; + batch.cols[1].isNull[2] = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(true, r.isNull[2]); + assertEquals(-4, r.vector[3]); + assertEquals(false, r.noNulls); + assertEquals(false, r.isRepeating); + + // test when third argument has nulls + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[2].noNulls = false; + batch.cols[2].isNull[1] = true; + batch.cols[2].isNull[2] = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(true, r.isNull[1]); + assertEquals(-3, r.vector[2]); + assertEquals(-4, r.vector[3]); + assertEquals(false, r.noNulls); + assertEquals(false, r.isRepeating); + + + // test when second argument has nulls and repeats + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[1].noNulls = false; + batch.cols[1].isNull[0] = true; + batch.cols[1].isRepeating = true; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(true, r.isNull[2]); + assertEquals(true, r.isNull[3]); + assertEquals(false, r.noNulls); + assertEquals(false, r.isRepeating); + + // test when third argument has nulls and repeats + batch = getBatch4LongVectors(); + r = (LongColumnVector) batch.cols[3]; + batch.cols[2].noNulls = false; + batch.cols[2].isNull[0] = true; + batch.cols[2].isRepeating = true; + expr.evaluate(batch); + assertEquals(true, r.isNull[0]); + assertEquals(true, r.isNull[1]); + assertEquals(-3, r.vector[2]); + assertEquals(-4, r.vector[3]); + assertEquals(false, r.noNulls); + assertEquals(false, r.isRepeating); + } + + @Test + public void testDoubleColumnColumnIfExpr() { + // Just spot check because we already checked the logic for long. + // The code is from the same template file. + + VectorizedRowBatch batch = getBatch1Long3DoubleVectors(); + VectorExpression expr = new IfExprDoubleColumnDoubleColumn(0, 1, 2, 3); + expr.evaluate(batch); + + // get result vector + DoubleColumnVector r = (DoubleColumnVector) batch.cols[3]; + + // verify standard case + assertEquals(true, 1d == r.vector[0]); + assertEquals(true, 2d == r.vector[1]); + assertEquals(true, -3d == r.vector[2]); + assertEquals(true, -4d == r.vector[3]); + assertEquals(true, r.noNulls); + assertEquals(false, r.isRepeating); + } + + @Test + public void testLongColumnScalarIfExpr() { + VectorizedRowBatch batch = getBatch4LongVectors(); + VectorExpression expr = new IfExprLongColumnLongScalar(0, 1, 100, 3); + LongColumnVector r = (LongColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(100, r.vector[0]); + assertEquals(100, r.vector[1]); + assertEquals(-3, r.vector[2]); + assertEquals(-4, r.vector[3]); + } + + @Test + public void testLongScalarColumnIfExpr() { + VectorizedRowBatch batch = getBatch4LongVectors(); + VectorExpression expr = new IfExprLongScalarLongColumn(0, 100, 2, 3); + LongColumnVector r = (LongColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(1, r.vector[0]); + assertEquals(2, r.vector[1]); + assertEquals(100, r.vector[2]); + assertEquals(100, r.vector[3]); + } + + @Test + public void testLongScalarScalarIfExpr() { + VectorizedRowBatch batch = getBatch4LongVectors(); + VectorExpression expr = new IfExprLongScalarLongScalar(0, 100, 200, 3); + LongColumnVector r = (LongColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(200, r.vector[0]); + assertEquals(200, r.vector[1]); + assertEquals(100, r.vector[2]); + assertEquals(100, r.vector[3]); + } + + @Test + public void testDoubleScalarScalarIfExpr() { + VectorizedRowBatch batch = getBatch1Long3DoubleVectors(); + VectorExpression expr = new IfExprDoubleScalarDoubleScalar(0, 100.0d, 200.0d, 3); + DoubleColumnVector r = (DoubleColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(true, 200d == r.vector[0]); + assertEquals(true, 200d == r.vector[1]); + assertEquals(true, 100d == r.vector[2]); + assertEquals(true, 100d == r.vector[3]); + } + + @Test + public void testDoubleScalarColumnIfExpr() { + VectorizedRowBatch batch = getBatch1Long3DoubleVectors(); + VectorExpression expr = new IfExprDoubleScalarDoubleColumn(0, 100.0d, 2, 3); + DoubleColumnVector r = (DoubleColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(true, 1d == r.vector[0]); + assertEquals(true, 2d == r.vector[1]); + assertEquals(true, 100d == r.vector[2]); + assertEquals(true, 100d == r.vector[3]); + } + + @Test + public void testDoubleColumnScalarIfExpr() { + VectorizedRowBatch batch = getBatch1Long3DoubleVectors(); + VectorExpression expr = new IfExprDoubleColumnDoubleScalar(0, 1, 200d, 3); + DoubleColumnVector r = (DoubleColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertEquals(true, 200d == r.vector[0]); + assertEquals(true, 200d == r.vector[1]); + assertEquals(true, -3d == r.vector[2]); + assertEquals(true, -4d == r.vector[3]); + } + + @Test + public void testIfExprStringColumnStringColumn() { + VectorizedRowBatch batch = getBatch1Long3BytesVectors(); + VectorExpression expr = new IfExprStringColumnStringColumn(0, 1, 2, 3); + BytesColumnVector r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(getString(r, 2).equals("arg2_2")); + assertTrue(getString(r, 3).equals("arg2_3")); + + // test first IF argument repeating + batch = getBatch1Long3BytesVectors(); + batch.cols[0].isRepeating = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(getString(r, 2).equals("arg3_2")); + assertTrue(getString(r, 3).equals("arg3_3")); + + // test second IF argument repeating + batch = getBatch1Long3BytesVectors(); + batch.cols[1].isRepeating = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(getString(r, 2).equals("arg2_0")); + assertTrue(getString(r, 3).equals("arg2_0")); + + // test third IF argument repeating + batch = getBatch1Long3BytesVectors(); + batch.cols[2].isRepeating = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_0")); + assertTrue(getString(r, 2).equals("arg2_2")); + assertTrue(getString(r, 3).equals("arg2_3")); + + // test second IF argument with nulls + batch = getBatch1Long3BytesVectors(); + batch.cols[1].noNulls = false; + batch.cols[1].isNull[2] = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(!r.noNulls && r.isNull[2]); + assertTrue(getString(r, 3).equals("arg2_3")); + assertFalse(r.isNull[0] || r.isNull[1] || r.isNull[3]); + + // test third IF argument with nulls + batch = getBatch1Long3BytesVectors(); + batch.cols[2].noNulls = false; + batch.cols[2].isNull[0] = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(!r.noNulls && r.isNull[0]); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(getString(r, 2).equals("arg2_2")); + assertTrue(getString(r, 3).equals("arg2_3")); + assertFalse(r.isNull[1] || r.isNull[2] || r.isNull[3]); + + // test second IF argument with nulls and repeating + batch = getBatch1Long3BytesVectors(); + batch.cols[1].noNulls = false; + batch.cols[1].isNull[0] = true; + batch.cols[1].isRepeating = true; + r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(!r.noNulls && r.isNull[2]); + assertTrue(!r.noNulls && r.isNull[3]); + assertFalse(r.isNull[0] || r.isNull[1]); + } + + @Test + public void testIfExprStringColumnStringScalar() { + VectorizedRowBatch batch = getBatch1Long3BytesVectors(); + byte[] scalar = getUTF8Bytes("scalar"); + VectorExpression expr = new IfExprStringColumnStringScalar(0, 1, scalar, 3); + BytesColumnVector r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("scalar")); + assertTrue(getString(r, 1).equals("scalar")); + assertTrue(getString(r, 2).equals("arg2_2")); + assertTrue(getString(r, 3).equals("arg2_3")); + } + + @Test + public void testIfExprStringScalarStringColumn() { + VectorizedRowBatch batch = getBatch1Long3BytesVectors(); + byte[] scalar = getUTF8Bytes("scalar"); + VectorExpression expr = new IfExprStringScalarStringColumn(0,scalar, 2, 3); + BytesColumnVector r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("arg3_0")); + assertTrue(getString(r, 1).equals("arg3_1")); + assertTrue(getString(r, 2).equals("scalar")); + assertTrue(getString(r, 3).equals("scalar")); + } + + @Test + public void testIfExprStringScalarStringScalar() { + + // standard case + VectorizedRowBatch batch = getBatch1Long3BytesVectors(); + byte[] scalar1 = getUTF8Bytes("scalar1"); + byte[] scalar2 = getUTF8Bytes("scalar2"); + VectorExpression expr = new IfExprStringScalarStringScalar(0,scalar1, scalar2, 3); + BytesColumnVector r = (BytesColumnVector) batch.cols[3]; + expr.evaluate(batch); + assertTrue(getString(r, 0).equals("scalar2")); + assertTrue(getString(r, 1).equals("scalar2")); + assertTrue(getString(r, 2).equals("scalar1")); + assertTrue(getString(r, 3).equals("scalar1")); + assertFalse(r.isRepeating); + + // repeating case for first (boolean flag) argument to IF + batch = getBatch1Long3BytesVectors(); + batch.cols[0].isRepeating = true; + expr.evaluate(batch); + r = (BytesColumnVector) batch.cols[3]; + assertTrue(r.isRepeating); + assertTrue(getString(r, 0).equals("scalar2")); + } +}