diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index 9179efd..78ec197 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -443,23 +443,23 @@ private VectorExpression getBinaryArithmeticExpression(String method, } catch (Exception ex) { throw new HiveException(ex); } - } else if ( (rightExpr instanceof ExprNodeColumnDesc) && - (leftExpr instanceof ExprNodeConstantDesc) ) { + } else if ( (leftExpr instanceof ExprNodeConstantDesc) && + (rightExpr instanceof ExprNodeColumnDesc) ) { ExprNodeColumnDesc rightColDesc = (ExprNodeColumnDesc) rightExpr; ExprNodeConstantDesc constDesc = (ExprNodeConstantDesc) leftExpr; int inputCol = getInputColumnIndex(rightColDesc.getColumn()); String colType = rightColDesc.getTypeString(); String scalarType = constDesc.getTypeString(); - String className = getBinaryColumnScalarExpressionClassName(colType, + String className = getBinaryScalarColumnExpressionClassName(colType, scalarType, method); String outputColType = getOutputColType(colType, scalarType, method); int outputCol = ocm.allocateOutputColumn(outputColType); try { expr = (VectorExpression) Class.forName(className). - getDeclaredConstructors()[0].newInstance(inputCol, - getScalarValue(constDesc), outputCol); + getDeclaredConstructors()[0].newInstance(getScalarValue(constDesc), + inputCol, outputCol); } catch (Exception ex) { - throw new HiveException(ex); + throw new HiveException("Could not instantiate: "+className, ex); } } else if ( (rightExpr instanceof ExprNodeColumnDesc) && (leftExpr instanceof ExprNodeColumnDesc) ) { diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java index 246423c..ed20ecc 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorizationContext.java @@ -18,6 +18,7 @@ import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColModuloLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongColumn; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractLongColumn; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; @@ -28,11 +29,13 @@ import org.apache.hadoop.hive.ql.udf.UDFOPMod; import org.apache.hadoop.hive.ql.udf.UDFOPMultiply; import org.apache.hadoop.hive.ql.udf.UDFOPPlus; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.junit.Test; public class TestVectorizationContext { @@ -237,6 +240,28 @@ public void testVectorizeAndOrExpression() throws HiveException { assertEquals(veOr.getChildExpressions()[1].getClass(), FilterDoubleColLessDoubleScalar.class); } + @Test + public void testVectorizeScalarColumnExpression() throws HiveException { + ExprNodeGenericFuncDesc scalarMinusConstant = new ExprNodeGenericFuncDesc(); + GenericUDF gudf = new GenericUDFBridge("-", true, UDFOPMinus.class); + scalarMinusConstant.setGenericUDF(gudf); + List children = new ArrayList(2); + ExprNodeConstantDesc constDesc = new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, 20); + ExprNodeColumnDesc colDesc = new ExprNodeColumnDesc(Long.class, "a", "table", false); + + children.add(constDesc); + children.add(colDesc); + + scalarMinusConstant.setChildExprs(children); + + Map columnMap = new HashMap(); + columnMap.put("a", 0); + + VectorizationContext vc = new VectorizationContext(columnMap, 2); + VectorExpression ve = vc.getVectorExpression(scalarMinusConstant); + + assertEquals(ve.getClass(), LongScalarSubtractLongColumn.class); + } @Test public void testFilterWithNegativeScalar() throws HiveException {