diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 887033c..9ccff90 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -35,9 +35,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.ConstantVectorExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; -import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterNotExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.IdentityExpression; -import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsFalse; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNotNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsNull; import org.apache.hadoop.hive.ql.exec.vector.expressions.SelectColumnIsTrue; @@ -159,7 +157,7 @@ int allocateOutputColumn(String columnType) { private int allocateOutputColumnInternal(String columnType) { for (int i = 0; i < outputColCount; i++) { if (usedOutputColumns.contains(i) || - !(outputColumnsTypes)[i].equals(columnType)) { + !(outputColumnsTypes)[i].equalsIgnoreCase(columnType)) { continue; } //Use i @@ -324,13 +322,15 @@ private VectorExpression getUnaryMinusExpression(List childExprLis colType = v1.getOutputType(); } else if (childExpr instanceof ExprNodeColumnDesc) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) childExpr; - inputCol = columnMap.get(colDesc.getColumn()); + inputCol = getInputColumnIndex(colDesc.getColumn()); colType = colDesc.getTypeString(); } else { throw new HiveException("Expression not supported: "+childExpr); } - int outputCol = ocm.allocateOutputColumn(colType); - String className = getNormalizedTypeName(colType) + "colUnaryMinus"; + String outputColumnType = getNormalizedTypeName(colType); + int outputCol = ocm.allocateOutputColumn(outputColumnType); + String className = "org.apache.hadoop.hive.ql.exec.vector.expressions.gen." + + outputColumnType + "ColUnaryMinus"; VectorExpression expr; try { expr = (VectorExpression) Class.forName(className). @@ -685,16 +685,17 @@ private VectorExpression getVectorExpression(GenericUDFOPOr udf, private VectorExpression getVectorExpression(GenericUDFOPNot udf, List childExpr) throws HiveException { - ExprNodeDesc expr = childExpr.get(0); - if (expr instanceof ExprNodeColumnDesc) { - ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) expr; - int inputCol = getInputColumnIndex(colDesc.getColumn()); - VectorExpression ve = new SelectColumnIsFalse(inputCol); - return ve; - } else { - VectorExpression ve = getVectorExpression(expr); - return new FilterNotExpr(ve); - } + throw new HiveException("Not is not supported"); +// ExprNodeDesc expr = childExpr.get(0); +// if (expr instanceof ExprNodeColumnDesc) { +// ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc) expr; +// int inputCol = getInputColumnIndex(colDesc.getColumn()); +// VectorExpression ve = new SelectColumnIsFalse(inputCol); +// return ve; +// } else { +// VectorExpression ve = getVectorExpression(expr); +// return new FilterNotExpr(ve); +// } } private VectorExpression getVectorExpression(GenericUDFOPAnd udf, @@ -916,7 +917,8 @@ private VectorExpression getVectorBinaryComparisonFilterExpression(String return expr; } - private String getNormalizedTypeName(String colType) { + private String getNormalizedTypeName(String colType) throws HiveException { + validateInputType(colType); String normalizedType = null; if (colType.equalsIgnoreCase("Double") || colType.equalsIgnoreCase("Float")) { normalizedType = "Double"; @@ -929,7 +931,7 @@ private String getNormalizedTypeName(String colType) { } private String getFilterColumnColumnExpressionClassName(String colType1, - String colType2, String opName) { + String colType2, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -944,7 +946,7 @@ private String getFilterColumnColumnExpressionClassName(String colType1, } private String getFilterColumnScalarExpressionClassName(String colType, String - scalarType, String opName) { + scalarType, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -959,7 +961,7 @@ private String getFilterColumnScalarExpressionClassName(String colType, String } private String getFilterScalarColumnExpressionClassName(String colType, String - scalarType, String opName) { + scalarType, String opName) throws HiveException { StringBuilder b = new StringBuilder(); b.append("org.apache.hadoop.hive.ql.exec.vector.expressions.gen."); if (opType.equals(OperatorType.FILTER)) { @@ -974,7 +976,7 @@ private String getFilterScalarColumnExpressionClassName(String colType, String } private String getBinaryColumnScalarExpressionClassName(String colType, - String scalarType, String method) { + String scalarType, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType = getNormalizedTypeName(colType); String normScalarType = getNormalizedTypeName(scalarType); @@ -993,7 +995,7 @@ private String getBinaryColumnScalarExpressionClassName(String colType, } private String getBinaryScalarColumnExpressionClassName(String colType, - String scalarType, String method) { + String scalarType, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType = getNormalizedTypeName(colType); String normScalarType = getNormalizedTypeName(scalarType); @@ -1012,7 +1014,7 @@ private String getBinaryScalarColumnExpressionClassName(String colType, } private String getBinaryColumnColumnExpressionClassName(String colType1, - String colType2, String method) { + String colType2, String method) throws HiveException { StringBuilder b = new StringBuilder(); String normColType1 = getNormalizedTypeName(colType1); String normColType2 = getNormalizedTypeName(colType2); @@ -1030,7 +1032,10 @@ private String getBinaryColumnColumnExpressionClassName(String colType1, return b.toString(); } - private String getOutputColType(String inputType1, String inputType2, String method) { + private String getOutputColType(String inputType1, String inputType2, String method) + throws HiveException { + validateInputType(inputType1); + validateInputType(inputType2); if (method.equalsIgnoreCase("divide") || inputType1.equalsIgnoreCase("double") || inputType2.equalsIgnoreCase("double") || inputType1.equalsIgnoreCase("float") || inputType2.equalsIgnoreCase("float")) { @@ -1044,7 +1049,25 @@ private String getOutputColType(String inputType1, String inputType2, String met } } - private String getOutputColType(String inputType, String method) { + private void validateInputType(String inputType) throws HiveException { + if (! (inputType.equalsIgnoreCase("float") || + inputType.equalsIgnoreCase("double") || + inputType.equalsIgnoreCase("string") || + inputType.equalsIgnoreCase("tinyint") || + inputType.equalsIgnoreCase("smallint") || + inputType.equalsIgnoreCase("short") || + inputType.equalsIgnoreCase("byte") || + inputType.equalsIgnoreCase("int") || + inputType.equalsIgnoreCase("long") || + inputType.equalsIgnoreCase("bigint") || + inputType.equalsIgnoreCase("boolean") || + inputType.equalsIgnoreCase("timestamp") ) ) { + throw new HiveException("Unsupported input type: "+inputType); + } + } + + private String getOutputColType(String inputType, String method) throws HiveException { + validateInputType(inputType); if (inputType.equalsIgnoreCase("float") || inputType.equalsIgnoreCase("double")) { return "double"; } else if (inputType.equalsIgnoreCase("string")) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index ed20ecc..f7bfce1 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -11,6 +11,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprAndExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.FilterExprOrExpr; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpression; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColUnaryMinus; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterDoubleColLessDoubleScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterLongColGreaterLongScalar; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FilterStringColGreaterStringScalar; @@ -18,6 +19,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColUnaryMinus; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractLongColumn; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; @@ -28,6 +30,7 @@ import org.apache.hadoop.hive.ql.udf.UDFOPMinus; import org.apache.hadoop.hive.ql.udf.UDFOPMod; import org.apache.hadoop.hive.ql.udf.UDFOPMultiply; +import org.apache.hadoop.hive.ql.udf.UDFOPNegative; import org.apache.hadoop.hive.ql.udf.UDFOPPlus; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; @@ -287,4 +290,42 @@ public void testFilterWithNegativeScalar() throws HiveException { assertTrue(ve instanceof FilterLongColGreaterLongScalar); } + + @Test + public void testUnaryMinusColumnLong() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Integer.class, "col1", "table", false); + ExprNodeGenericFuncDesc negExprDesc = new ExprNodeGenericFuncDesc(); + GenericUDF gudf = new GenericUDFBridge("-", true, UDFOPNegative.class); + negExprDesc.setGenericUDF(gudf); + List children = new ArrayList(1); + children.add(col1Expr); + negExprDesc.setChildExprs(children); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + vc.setOperatorType(OperatorType.SELECT); + + VectorExpression ve = vc.getVectorExpression(negExprDesc); + + assertTrue( ve instanceof LongColUnaryMinus); + } + + @Test + public void testUnaryMinusColumnDouble() throws HiveException { + ExprNodeColumnDesc col1Expr = new ExprNodeColumnDesc(Float.class, "col1", "table", false); + ExprNodeGenericFuncDesc negExprDesc = new ExprNodeGenericFuncDesc(); + GenericUDF gudf = new GenericUDFBridge("-", true, UDFOPNegative.class); + negExprDesc.setGenericUDF(gudf); + List children = new ArrayList(1); + children.add(col1Expr); + negExprDesc.setChildExprs(children); + Map columnMap = new HashMap(); + columnMap.put("col1", 1); + VectorizationContext vc = new VectorizationContext(columnMap, 1); + vc.setOperatorType(OperatorType.SELECT); + + VectorExpression ve = vc.getVectorExpression(negExprDesc); + + assertTrue( ve instanceof DoubleColUnaryMinus); + } }