diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 887033c..2cc4ea8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -35,9 +35,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.ConstantVectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; -import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterNotExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.IdentityExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsFalse; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNotNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue; @@ -159,7 +157,7 @@ int allocateOutputColumn(String columnType) { private int allocateOutputColumnInternal(String columnType) { for (int i = 0; i < outputColCount; i++) { if (usedOutputColumns.contains(i) || - !(outputColumnsTypes)[i].equals(columnType)) { + !(outputColumnsTypes)[i].equalsIgnoreCase(columnType)) { continue; } //Use i @@ -324,13 +322,14 @@ private VectorExpression getUnaryMinusExpression(List childExprLis colType = v1.getOutputType(); } else if (childExpr instanceof ExprNodeColumnDesc) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) childExpr; - inputCol = columnMap.get(colDesc.getColumn()); + inputCol = getInputColumnIndex(colDesc.getColumn()); colType = colDesc.getTypeString(); } else { throw new HiveException("Expression not supported: "+childExpr); } - int outputCol = ocm.allocateOutputColumn(colType); - String className = getNormalizedTypeName(colType) + "colUnaryMinus"; + String outputColumnType = getNormalizedTypeName(colType); + int outputCol = ocm.allocateOutputColumn(outputColumnType); + String className = outputColumnType + "colUnaryMinus"; VectorExpression expr; try { expr = (VectorExpression) Class.forName(className). @@ -685,16 +684,17 @@ private VectorExpression getVectorExpression(GenericUDFOPOr udf, private VectorExpression getVectorExpression(GenericUDFOPNot udf, List childExpr) throws HiveException { - ExprNodeDesc expr = childExpr.get(0); - if (expr instanceof ExprNodeColumnDesc) { - ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) expr; - int inputCol = getInputColumnIndex(colDesc.getColumn()); - VectorExpression ve = new SelectColumnIsFalse(inputCol); - return ve; - } else { - VectorExpression ve = getVectorExpression(expr); - return new FilterNotExpr(ve); - } + throw new HiveException("Not is not supported"); +// ExprNodeDesc expr = childExpr.get(0); +// if (expr instanceof ExprNodeColumnDesc) { +// ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) expr; +// int inputCol = getInputColumnIndex(colDesc.getColumn()); +// VectorExpression ve = new SelectColumnIsFalse(inputCol); +// return ve; +// } else { +// VectorExpression ve = getVectorExpression(expr); +// return new FilterNotExpr(ve); +// } } private VectorExpression getVectorExpression(GenericUDFOPAnd udf, @@ -916,7 +916,8 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String return expr; } - private String getNormalizedTypeName(String colType) { + private String getNormalizedTypeName(String colType) throws HiveException { + validateInputType(colType); String normalizedType = null; if (colType.equalsIgnoreCase("Double") || colType.equalsIgnoreCase("Float")) { normalizedType = "Double"; @@ -929,7 +930,7 @@ private String getNormalizedTypeName(String colType) { } private String getFilterColumnColumnExpressionClassName(String colType1, - String colType2, String opName) { + String colType2, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -944,7 +945,7 @@ private String getFilterColumnColumnExpressionClassName(String colType1, } private String getFilterColumnScalarExpressionClassName(String colType, String - scalarType, String opName) { + scalarType, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -959,7 +960,7 @@ private String getFilterColumnScalarExpressionClassName(String colType, String } private String getFilterScalarColumnExpressionClassName(String colType, String - scalarType, String opName) { + scalarType, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -974,7 +975,7 @@ private String getFilterScalarColumnExpressionClassName(String colType, String } private String getBinaryColumnScalarExpressionClassName(String colType, - String scalarType, String method) { + String scalarType, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType = getNormalizedTypeName(colType); String normScalarType = getNormalizedTypeName(scalarType); @@ -993,7 +994,7 @@ private String getBinaryColumnScalarExpressionClassName(String colType, } private String getBinaryScalarColumnExpressionClassName(String colType, - String scalarType, String method) { + String scalarType, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType = getNormalizedTypeName(colType); String normScalarType = getNormalizedTypeName(scalarType); @@ -1012,7 +1013,7 @@ private String getBinaryScalarColumnExpressionClassName(String colType, } private String getBinaryColumnColumnExpressionClassName(String colType1, - String colType2, String method) { + String colType2, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType1 = getNormalizedTypeName(colType1); String normColType2 = getNormalizedTypeName(colType2); @@ -1030,7 +1031,10 @@ private String getBinaryColumnColumnExpressionClassName(String colType1, return b.toString(); } - private String getOutputColType(String inputType1, String inputType2, String method) { + private String getOutputColType(String inputType1, String inputType2, String method) + throws HiveException { + validateInputType(inputType1); + validateInputType(inputType2); if (method.equalsIgnoreCase("divide") || inputType1.equalsIgnoreCase("double") || inputType2.equalsIgnoreCase("double") || inputType1.equalsIgnoreCase("float") || inputType2.equalsIgnoreCase("float")) { @@ -1044,7 +1048,25 @@ private String getOutputColType(String inputType1, String inputType2, String met } } - private String getOutputColType(String inputType, String method) { + private void validateInputType(String inputType) throws HiveException { + if (! (inputType.equalsIgnoreCase("float") || + inputType.equalsIgnoreCase("double") || + inputType.equalsIgnoreCase("string") || + inputType.equalsIgnoreCase("tinyint") || + inputType.equalsIgnoreCase("smallint") || + inputType.equalsIgnoreCase("short") || + inputType.equalsIgnoreCase("byte") || + inputType.equalsIgnoreCase("int") || + inputType.equalsIgnoreCase("long") || + inputType.equalsIgnoreCase("bigint") || + inputType.equalsIgnoreCase("boolean") || + inputType.equalsIgnoreCase("timestamp") ) ) { + throw new HiveException("Unsupported input type: "+inputType); + } + } + + private String getOutputColType(String inputType, String method) throws HiveException { + validateInputType(inputType); if (inputType.equalsIgnoreCase("float") || inputType.equalsIgnoreCase("double")) { return "double"; } else if (inputType.equalsIgnoreCase("string")) {