diff --git ql/src/gen/vectorization/ExpressionTemplates/ColumnArithmeticColumnDecimal.txt ql/src/gen/vectorization/ExpressionTemplates/ColumnArithmeticColumnDecimal.txt index 699b7c5..2cb8ec0 100644 --- ql/src/gen/vectorization/ExpressionTemplates/ColumnArithmeticColumnDecimal.txt +++ ql/src/gen/vectorization/ExpressionTemplates/ColumnArithmeticColumnDecimal.txt @@ -37,6 +37,7 @@ public class extends VectorExpression { private int colNum1; private int colNum2; private int outputColumn; + private String outputType = "decimal"; public (int colNum1, int colNum2, int outputColumn) { this.colNum1 = colNum1; @@ -146,7 +147,12 @@ public class extends VectorExpression { @Override public String getOutputType() { - return "decimal"; + return outputType; + } + + @Override + public void setOutputType(String type) { + outputType = type; } public int getColNum1() { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java index 1c70387..0efeee6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorExpressionDescriptor.java @@ -48,7 +48,11 @@ public int getValue() { } public static ArgumentType getType(String inType) { - return valueOf(VectorizationContext.getNormalizedTypeName(inType).toUpperCase()); + String type = VectorizationContext.getNormalizedTypeName(inType); + if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()) { + type = "decimal"; + } + return valueOf(type.toUpperCase()); } } @@ -186,9 +190,36 @@ private Descriptor(Mode mode, int argCount, ArgumentType[] argTypes, InputExpres this.exprTypes = exprTypes.clone(); this.argCount = argCount; } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("Argument Count = "); + b.append(argCount); + b.append(", mode = "); + b.append(mode); + b.append(", Argument Types = {"); + for (int i = 0; i < argCount; i++) { + if (i == 0) { + b.append(","); + } + b.append(argTypes[i]); + } + b.append("}"); + + b.append(", Input Expression Types = {"); + for (int i = 0; i < argCount; i++) { + if (i == 0) { + b.append(","); + } + b.append(exprTypes[i]); + } + b.append("}"); + return b.toString(); + } } public Class getVectorExpressionClass(Class udf, Descriptor descriptor) throws HiveException { + System.out.println("udf = "+udf.getSimpleName()+", Descriptor = "+descriptor.toString()); VectorizedExpressions annotation = udf.getAnnotation(VectorizedExpressions.class); if (annotation == null || annotation.value() == null) { return null; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java index f5ab731..105710e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizationContext.java @@ -27,9 +27,12 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.common.type.Decimal128; +import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory; import org.apache.hadoop.hive.ql.exec.FunctionInfo; @@ -90,6 +93,10 @@ import org.apache.hadoop.hive.ql.udf.generic.*; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; /** * Context class for vectorization execution. @@ -109,6 +116,8 @@ private final int firstOutputColumnIndex; private final Mode operatorMode = Mode.PROJECTION; + public static final Pattern decimalTypePattern = Pattern.compile("decimal.*"); + //Map column number to type private final OutputColumnManager ocm; @@ -168,6 +177,10 @@ int allocateOutputColumn(String columnType) { private int allocateOutputColumnInternal(String columnType) { for (int i = 0; i < outputColCount; i++) { + + // Re-use an existing, available column of the same required type. + // For simplicity decimal columns are not being re-used because for decimal columns + // precision and scale should be matched as well. This should be fixed in future. if (usedOutputColumns.contains(i) || !(outputColumnsTypes)[i].equalsIgnoreCase(columnType)) { continue; @@ -202,6 +215,10 @@ void freeOutputColumn(int index) { usedOutputColumns.remove(index-initialOutputCol); } } + + public void updateColumnType(int outputColumn, String typeName) { + outputColumnsTypes[outputColumn] = typeName; + } } private VectorExpression getColumnVectorExpression(ExprNodeColumnDesc @@ -260,7 +277,7 @@ public VectorExpression getVectorExpression(ExprNodeDesc exprDesc, Mode mode) th ve = getCustomUDFExpression(expr); } else { ve = getGenericUdfVectorExpression(expr.getGenericUDF(), - expr.getChildren(), mode); + expr.getChildren(), mode, exprDesc.getTypeInfo()); } } else if (exprDesc instanceof ExprNodeConstantDesc) { ve = getConstantVectorExpression((ExprNodeConstantDesc) exprDesc, mode); @@ -455,8 +472,8 @@ private VectorExpression getIdentityExpression(List childExprList) return expr; } - private VectorExpression getVectorExpressionForUdf(Class udf, List childExpr, Mode mode) - throws HiveException { + private VectorExpression getVectorExpressionForUdf(Class udf, List childExpr, Mode mode, + TypeInfo returnType) throws HiveException { int numChildren = (childExpr == null) ? 0 : childExpr.size(); if (numChildren > VectorExpressionDescriptor.MAX_NUM_ARGUMENTS) { return null; @@ -464,9 +481,14 @@ private VectorExpression getVectorExpressionForUdf(Class udf, List udf, List childExpr) { + if (childExpr == null) { + return false; + } + for (int i = 0; i < childExpr.size(); i++) { + ExprNodeDesc child = childExpr.get(i); + TypeInfo childTypeInfo = child.getTypeInfo(); + if (decimalTypePattern.matcher(childTypeInfo.getTypeName()).matches()) { + return true; + } + } + return false; } private VectorExpression createVectorExpression(Class vectorClass, List childExpr, - Mode childrenMode) throws HiveException { + Mode childrenMode, TypeInfo returnType) throws HiveException { int numChildren = childExpr == null ? 0: childExpr.size(); List children = new ArrayList(); Object[] arguments = new Object[numChildren]; try { + + // If any of the children is decimal type, cast the non-decimal output of other children + // to decimal. + boolean convertToDecimalType = isDecimalConversionNeeded(childExpr); + for (int i = 0; i < numChildren; i++) { ExprNodeDesc child = childExpr.get(i); if (child instanceof ExprNodeGenericFuncDesc) { VectorExpression vChild = getVectorExpression(child, childrenMode); - children.add(vChild); - arguments[i] = vChild.getOutputColumn(); + if ((convertToDecimalType) && + (!decimalTypePattern.matcher(child.getTypeInfo().getTypeName()).matches())) { + int inputColToCast = vChild.getOutputColumn(); + VectorExpression castExpression = getImplicitCastToDecimal(inputColToCast, child.getTypeInfo()); + ocm.freeOutputColumn(inputColToCast); + castExpression.setChildExpressions(new VectorExpression[] {vChild}); + children.add(castExpression); + arguments[i] = castExpression.getOutputColumn(); + } else { + children.add(vChild); + arguments[i] = vChild.getOutputColumn(); + } } else if (child instanceof ExprNodeColumnDesc) { int colIndex = getInputColumnIndex((ExprNodeColumnDesc) child); - if (childrenMode == Mode.FILTER) { - // In filter mode, the column must be a boolean - children.add(new SelectColumnIsTrue(colIndex)); + if ((convertToDecimalType) && + (!decimalTypePattern.matcher(child.getTypeInfo().getTypeName()).matches())) { + VectorExpression castExpression = getImplicitCastToDecimal(colIndex, child.getTypeInfo()); + children.add(castExpression); + arguments[i] = castExpression.getOutputColumn(); + } else { + if (childrenMode == Mode.FILTER) { + // In filter mode, the column must be a boolean + children.add(new SelectColumnIsTrue(colIndex)); + } + arguments[i] = colIndex; } - arguments[i] = colIndex; } else if (child instanceof ExprNodeConstantDesc) { - arguments[i] = getScalarValue((ExprNodeConstantDesc) child); + Object scalarValue = getScalarValue((ExprNodeConstantDesc) child); + if (convertToDecimalType) { + arguments[i] = castConstantToDecimal(scalarValue, child.getTypeInfo()); + } else { + arguments[i] = scalarValue; + } } else { - throw new HiveException("Cannot handle expression type: " - + child.getClass().getSimpleName()); + throw new HiveException("Cannot handle expression type: " + child.getClass().getSimpleName()); } } - VectorExpression vectorExpression = instantiateExpression(vectorClass, arguments); + VectorExpression vectorExpression = instantiateExpression(vectorClass, returnType, arguments); if ((vectorExpression != null) && !children.isEmpty()) { vectorExpression.setChildExpressions(children.toArray(new VectorExpression[0])); } @@ -533,32 +595,51 @@ private Mode getChildrenMode(Mode mode, Class udf) { return Mode.PROJECTION; } - private VectorExpression instantiateExpression(Class vclass, Object...args) + private VectorExpression instantiateExpression(Class vclass, TypeInfo returnType, Object...args) throws HiveException { + VectorExpression ve = null; Constructor ctor = getConstructor(vclass); int numParams = ctor.getParameterTypes().length; int argsLength = (args == null) ? 0 : args.length; try { if (numParams == 0) { - return (VectorExpression) ctor.newInstance(); + ve = (VectorExpression) ctor.newInstance(); } else if (numParams == argsLength) { - return (VectorExpression) ctor.newInstance(args); + ve = (VectorExpression) ctor.newInstance(args); } else if (numParams == argsLength + 1) { // Additional argument is needed, which is the outputcolumn. - String outType = ((VectorExpression) vclass.newInstance()).getOutputType(); + String outType; + boolean returnTypeDecimal = false; + + // Special handling for decimal because decimal types need scale and precision parameter. + // This special handling should be avoided by using returnType uniformly for all cases. + if ( (returnType != null) && + decimalTypePattern.matcher(returnType.getTypeName()).matches()) { + outType = returnType.getTypeName(); + returnTypeDecimal = true; + System.out.println("decimal type = "+returnType.getTypeName()); + } else { + outType = ((VectorExpression) vclass.newInstance()).getOutputType(); + System.out.println("return type = null or not decimal: "+outType); + } int outputCol = ocm.allocateOutputColumn(outType); Object [] newArgs = Arrays.copyOf(args, numParams); newArgs[numParams-1] = outputCol; - return (VectorExpression) ctor.newInstance(newArgs); + ve = (VectorExpression) ctor.newInstance(newArgs); + + // For decimal types, fix the outputType in expression to include scale and precision. + if (returnTypeDecimal) { + ve.setOutputType(outType); + } } } catch (Exception ex) { throw new HiveException("Could not instantiate " + vclass.getSimpleName(), ex); } - return null; + return ve; } private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, - List childExpr, Mode mode) throws HiveException { + List childExpr, Mode mode, TypeInfo returnType) throws HiveException { //First handle special cases if (udf instanceof GenericUDFBetween) { return getBetweenFilterExpression(childExpr, mode); @@ -567,7 +648,7 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, } else if (udf instanceof GenericUDFOPPositive) { return getIdentityExpression(childExpr); } else if (udf instanceof GenericUDFBridge) { - VectorExpression v = getGenericUDFBridgeVectorExpression((GenericUDFBridge) udf, childExpr, mode); + VectorExpression v = getGenericUDFBridgeVectorExpression((GenericUDFBridge) udf, childExpr, mode, returnType); if (v != null) { return v; } @@ -580,10 +661,12 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, } List constantFoldedChildren = foldConstantsForUnaryExprs(childExpr); - VectorExpression ve = getVectorExpressionForUdf(udfClass, constantFoldedChildren, mode); + VectorExpression ve = getVectorExpressionForUdf(udfClass, constantFoldedChildren, mode, returnType); + if (ve == null) { throw new HiveException("Udf: "+udf.getClass().getSimpleName()+", is not supported"); } + return ve; } @@ -593,6 +676,7 @@ private VectorExpression getGenericUdfVectorExpression(GenericUDF udf, private VectorExpression getInExpression(List childExpr, Mode mode) throws HiveException { ExprNodeDesc colExpr = childExpr.get(0); + TypeInfo colTypeInfo = colExpr.getTypeInfo(); String colType = colExpr.getTypeString(); // prepare arguments for createVectorExpression @@ -617,7 +701,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode for (int i = 0; i != inVals.length; i++) { inVals[i] = getIntFamilyScalarAsLong((ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION); + expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, colTypeInfo); ((ILongInExpr) expr).setInListValues(inVals); } else if (colType.equals("timestamp")) { cl = (mode == Mode.FILTER ? FilterLongColumnInList.class : LongColumnInList.class); @@ -625,7 +709,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode for (int i = 0; i != inVals.length; i++) { inVals[i] = getTimestampScalar(childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION); + expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, colTypeInfo); ((ILongInExpr) expr).setInListValues(inVals); } else if (colType.equals("string")) { cl = (mode == Mode.FILTER ? FilterStringColumnInList.class : StringColumnInList.class); @@ -633,7 +717,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode for (int i = 0; i != inVals.length; i++) { inVals[i] = getStringScalarAsByteArray((ExprNodeConstantDesc) childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION); + expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, colTypeInfo); ((IStringInExpr) expr).setInListValues(inVals); } else if (isFloatFamily(colType)) { cl = (mode == Mode.FILTER ? FilterDoubleColumnInList.class : DoubleColumnInList.class); @@ -641,7 +725,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode for (int i = 0; i != inValsD.length; i++) { inValsD[i] = getNumericScalarAsDouble(childrenForInList.get(i)); } - expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION); + expr = createVectorExpression(cl, childExpr.subList(0, 1), Mode.PROJECTION, colTypeInfo); ((IDoubleInExpr) expr).setInListValues(inValsD); } @@ -664,7 +748,7 @@ private VectorExpression getInExpression(List childExpr, Mode mode * descriptor based lookup. */ private VectorExpression getGenericUDFBridgeVectorExpression(GenericUDFBridge udf, - List childExpr, Mode mode) throws HiveException { + List childExpr, Mode mode, TypeInfo returnType) throws HiveException { Class cl = udf.getUdfClass(); if (isCastToIntFamily(cl)) { return getCastToLongExpression(childExpr); @@ -674,18 +758,93 @@ private VectorExpression getGenericUDFBridgeVectorExpression(GenericUDFBridge ud return getCastToDoubleExpression(cl, childExpr); } else if (cl.equals(UDFToString.class)) { return getCastToString(childExpr); + } else if (cl.equals(GenericUDFToDecimal.class)) { + return getCastToDecimal(childExpr, returnType); } return null; } + private VectorExpression getCastToDecimal(List childExpr, TypeInfo returnType) + throws HiveException { + String inputType = childExpr.get(0).getTypeString(); + DecimalTypeInfo dtInfo = (DecimalTypeInfo) returnType; + if (isIntFamily(inputType)) { + return createVectorExpression(CastLongToDecimal.class, childExpr, Mode.PROJECTION, returnType); + } else if (isFloatFamily(inputType)) { + return createVectorExpression(CastDoubleToDecimal.class, childExpr, Mode.PROJECTION, returnType); + } else if (decimalTypePattern.matcher(inputType).matches()) { + return createVectorExpression(CastDecimalToDecimal.class, childExpr, Mode.PROJECTION, returnType); + } + throw new HiveException("Unhandled cast input type: " + inputType); + } + + private VectorExpression getImplicitCastToDecimal(int inputColumnIndex, TypeInfo inputType) + throws HiveException { + VectorExpression castExpression = null; + int precision = HiveDecimalUtils.getPrecisionForType((PrimitiveTypeInfo) inputType); + int scale = HiveDecimalUtils.getScaleForType((PrimitiveTypeInfo) inputType); + String outDType = DecimalTypeInfo.getQualifiedName(precision, scale); + int outputColumn = ocm.allocateOutputColumn(outDType); + String inputTypeName = inputType.getTypeName(); + if (isIntFamily(inputTypeName)) { + castExpression = new CastLongToDecimal(inputColumnIndex, outputColumn); + } else if (isFloatFamily(inputTypeName)) { + castExpression = new CastDoubleToDecimal(inputColumnIndex, outputColumn); + } + if (castExpression == null) { + throw new HiveException("Cannot cast type "+inputTypeName+" to decimal"); + } + return castExpression; + } + + private Decimal128 castConstantToDecimal(Object scalar, TypeInfo type) throws HiveException { + PrimitiveTypeInfo ptinfo = (PrimitiveTypeInfo) type; + String typename = type.getTypeName(); + Decimal128 d = new Decimal128(); + int scale = HiveDecimalUtils.getScaleForType(ptinfo); + switch (ptinfo.getPrimitiveCategory()) { + case FLOAT: + float floatVal = ((Float) scalar).floatValue(); + d.update(floatVal, (short) scale); + break; + case DOUBLE: + double doubleVal = ((Double) scalar).doubleValue(); + d.update(doubleVal, (short) scale); + break; + case BYTE: + byte byteVal = ((Byte) scalar).byteValue(); + d.update(byteVal, (short) scale); + break; + case SHORT: + short shortVal = ((Short) scalar).shortValue(); + d.update(shortVal, (short) scale); + break; + case INT: + int intVal = ((Integer) scalar).intValue(); + d.update(intVal, (short) scale); + break; + case LONG: + long longVal = ((Long) scalar).longValue(); + d.update(longVal, (short) scale); + break; + case DECIMAL: + HiveDecimal decimalVal = (HiveDecimal) scalar; + d.update(decimalVal.unscaledValue(), (short) scale); + break; + default: + throw new HiveException("Unsupported type "+typename+" for cast to Decimal128"); + } + return d; + } + private VectorExpression getCastToString(List childExpr) throws HiveException { String inputType = childExpr.get(0).getTypeString(); if (inputType.equals("boolean")) { // Boolean must come before the integer family. It's a special case. - return createVectorExpression(CastBooleanToStringViaLongToString.class, childExpr, Mode.PROJECTION); + return createVectorExpression(CastBooleanToStringViaLongToString.class, childExpr, Mode.PROJECTION, null); } else if (isIntFamily(inputType)) { - return createVectorExpression(CastLongToString.class, childExpr, Mode.PROJECTION); + return createVectorExpression(CastLongToString.class, childExpr, Mode.PROJECTION, null); } /* The string type is deliberately omitted -- the planner removes string to string casts. * Timestamp, float, and double types are handled by the legacy code path. See isLegacyPathUDF. @@ -698,9 +857,9 @@ private VectorExpression getCastToDoubleExpression(Class udf, List childExpr) if (inputType.equals("string")) { // string casts to false if it is 0 characters long, otherwise true VectorExpression lenExpr = createVectorExpression(StringLength.class, childExpr, - Mode.PROJECTION); + Mode.PROJECTION, null); int outputCol = ocm.allocateOutputColumn("integer"); VectorExpression lenToBoolExpr = @@ -804,7 +963,7 @@ private VectorExpression getBetweenFilterExpression(List childExpr } } - return createVectorExpression(cl, childrenAfterNot, Mode.PROJECTION); + return createVectorExpression(cl, childrenAfterNot, Mode.PROJECTION, null); } /* @@ -855,6 +1014,7 @@ private VectorExpression getCustomUDFExpression(ExprNodeGenericFuncDesc expr) int outputCol = -1; String resultType = expr.getTypeInfo().getTypeName(); String resultColVectorType = getNormalizedTypeName(resultType); + outputCol = ocm.allocateOutputColumn(resultColVectorType); // Make vectorized operator @@ -901,21 +1061,6 @@ public static boolean isIntFamily(String resultType) { || resultType.equalsIgnoreCase("long"); } - public static String mapJavaTypeToVectorType(String javaType) - throws HiveException { - if (isStringFamily(javaType)) { - return "string"; - } - if (isFloatFamily(javaType)) { - return "double"; - } - if (isIntFamily(javaType) || - isDatetimeFamily(javaType)) { - return "bigint"; - } - throw new HiveException("Unsuported type for vectorization: " + javaType); - } - private Object getScalarValue(ExprNodeConstantDesc constDesc) throws HiveException { if (constDesc.getTypeString().equalsIgnoreCase("String")) { @@ -931,6 +1076,11 @@ private Object getScalarValue(ExprNodeConstantDesc constDesc) } else { return 0; } + } else if (decimalTypePattern.matcher(constDesc.getTypeString()).matches()) { + HiveDecimal hd = (HiveDecimal) constDesc.getValue(); + Decimal128 dvalue = new Decimal128(); + dvalue.update(hd.unscaledValue(), (short) hd.scale()); + return dvalue; } else { return constDesc.getValue(); } @@ -1029,6 +1179,10 @@ static String getNormalizedTypeName(String colType) { normalizedType = "Double"; } else if (colType.equalsIgnoreCase("String")) { normalizedType = "String"; + } else if (decimalTypePattern.matcher(colType.toLowerCase()).matches()) { + + //Return the decimal type as is, it includes scale and precision. + normalizedType = colType; } else { normalizedType = "Long"; } @@ -1110,16 +1264,11 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) "\" for type: \"" + inputType + ""); } - static Object[][] columnTypes = { - {"Double", DoubleColumnVector.class}, - {"Long", LongColumnVector.class}, - {"String", BytesColumnVector.class}, - }; - public Map getOutputColumnTypeMap() { Map map = new HashMap(); for (int i = 0; i < ocm.outputColCount; i++) { String type = ocm.outputColumnsTypes[i]; + System.out.println("index= "+(i+this.firstOutputColumnIndex)+", type="+type); map.put(i+this.firstOutputColumnIndex, type); } return map; @@ -1129,16 +1278,6 @@ public VectorAggregateExpression getAggregatorExpression(AggregationDesc desc) return columnMap; } - public static ColumnVector allocateColumnVector(String type, int defaultSize) { - if (isFloatFamily(type)) { - return new DoubleColumnVector(defaultSize); - } else if (isStringFamily(type)) { - return new BytesColumnVector(defaultSize); - } else { - return new LongColumnVector(defaultSize); - } - } - public void addToColumnMap(String columnName, int outputColumn) throws HiveException { if (columnMap.containsKey(columnName) && (columnMap.get(columnName) != outputColumn)) { throw new HiveException(String.format("Column %s is already mapped to %d. Cannot remap to %d.", @@ -1147,20 +1286,4 @@ public void addToColumnMap(String columnName, int outputColumn) throws HiveExcep columnMap.put(columnName, outputColumn); } - public Map getMapVectorExpressions( - Map> expressions) throws HiveException { - Map result = new HashMap(); - if (null != expressions) { - for(T key: expressions.keySet()) { - result.put(key, getVectorExpressions(expressions.get(key))); - } - } - return result; - } - - public void addOutputColumn(String columnName, String columnType) throws HiveException { - String vectorType = mapJavaTypeToVectorType(columnType); - int columnIndex = ocm.allocateOutputColumn(vectorType); - this.addToColumnMap(columnName, columnIndex); - } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java index 6e79979..7f76daf 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorizedRowBatchCtx.java @@ -24,6 +24,8 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants; @@ -370,11 +372,26 @@ private void addScratchColumnsToBatch(VectorizedRowBatch vrb) { } } + private int[] getScalePrecisionFromDecimalType(String decimalType) { + System.out.println("decimal type = "+decimalType); + Pattern p = Pattern.compile("\\d+"); + Matcher m = p.matcher(decimalType); + m.find(); + int precision = Integer.parseInt(m.group()); + m.find(); + int scale = Integer.parseInt(m.group()); + int [] precScale = { precision, scale }; + return precScale; + } + private ColumnVector allocateColumnVector(String type, int defaultSize) { if (type.equalsIgnoreCase("double")) { return new DoubleColumnVector(defaultSize); } else if (type.equalsIgnoreCase("string")) { return new BytesColumnVector(defaultSize); + } else if (VectorizationContext.decimalTypePattern.matcher(type.toLowerCase()).matches()){ + int [] precisionScale = getScalePrecisionFromDecimalType(type); + return new DecimalColumnVector(defaultSize, precisionScale[0], precisionScale[1]); } else { return new LongColumnVector(defaultSize); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpression.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpression.java index d00d99b..81eea7a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpression.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/VectorExpression.java @@ -53,6 +53,14 @@ public abstract String getOutputType(); /** + * Set type of the output column. Expressions that allow + * setting the output type should override this. + */ + public void setOutputType(String type) { + // Do nothing by default. + } + + /** * Initialize the child expressions. */ public void setChildExpressions(VectorExpression [] ve) { diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java index e694db1..d7b21d2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/physical/Vectorizer.java @@ -122,45 +122,7 @@ import org.apache.hadoop.hive.ql.udf.UDFToString; import org.apache.hadoop.hive.ql.udf.UDFWeekOfYear; import org.apache.hadoop.hive.ql.udf.UDFYear; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCeil; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFConcat; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFFloor; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLTrim; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFLower; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPDivide; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMinus; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMod; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPMultiply; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNegative; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPositive; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPower; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPPlus; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFPosMod; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFRTrim; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTrim; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen; +import org.apache.hadoop.hive.ql.udf.generic.*; public class Vectorizer implements PhysicalPlanResolver { @@ -277,6 +239,7 @@ public Vectorizer() { supportedGenericUDFs.add(UDFToDouble.class); supportedGenericUDFs.add(UDFToString.class); supportedGenericUDFs.add(GenericUDFTimestamp.class); + supportedGenericUDFs.add(GenericUDFToDecimal.class); // For conditional expressions supportedGenericUDFs.add(GenericUDFIf.class); @@ -712,6 +675,7 @@ private boolean validateExprNodeDesc(ExprNodeDesc desc) { } boolean validateExprNodeDesc(ExprNodeDesc desc, VectorExpressionDescriptor.Mode mode) { + System.out.println("Validating expr ="+desc.toString()); if (!validateExprNodeDescRecursive(desc)) { return false; } diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMinus.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMinus.java index 6ee6f39..3eb605a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMinus.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMinus.java @@ -21,18 +21,7 @@ import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColSubtractDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColSubtractDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColSubtractLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColSubtractLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarSubtractDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarSubtractLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColSubtractLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarSubtractLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.*; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -49,7 +38,9 @@ LongColSubtractLongScalar.class, LongColSubtractDoubleScalar.class, DoubleColSubtractLongScalar.class, DoubleColSubtractDoubleScalar.class, LongScalarSubtractLongColumn.class, LongScalarSubtractDoubleColumn.class, - DoubleScalarSubtractLongColumn.class, DoubleScalarSubtractDoubleColumn.class}) + DoubleScalarSubtractLongColumn.class, DoubleScalarSubtractDoubleColumn.class, + DecimalColSubtractDecimalColumn.class, DecimalColAddDecimalScalar.class, + DecimalScalarSubtractDecimalColumn.class}) public class GenericUDFOPMinus extends GenericUDFBaseNumeric { public GenericUDFOPMinus() { diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java index e7a2a8d..7dc1f83 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPMultiply.java @@ -21,18 +21,7 @@ import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColMultiplyDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColMultiplyDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColMultiplyLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColMultiplyLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarMultiplyDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarMultiplyLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColMultiplyLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarMultiplyDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarMultiplyLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.*; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -49,7 +38,9 @@ LongColMultiplyLongScalar.class, LongColMultiplyDoubleScalar.class, DoubleColMultiplyLongScalar.class, DoubleColMultiplyDoubleScalar.class, LongScalarMultiplyLongColumn.class, LongScalarMultiplyDoubleColumn.class, - DoubleScalarMultiplyLongColumn.class, DoubleScalarMultiplyDoubleColumn.class}) + DoubleScalarMultiplyLongColumn.class, DoubleScalarMultiplyDoubleColumn.class, + DecimalColMultiplyDecimalColumn.class, DecimalColMultiplyDecimalScalar.class, + DecimalScalarMultiplyDecimalColumn.class}) public class GenericUDFOPMultiply extends GenericUDFBaseNumeric { public GenericUDFOPMultiply() { diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPPlus.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPPlus.java index 26ac65c..2721e6b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPPlus.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPPlus.java @@ -21,18 +21,7 @@ import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColAddDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColAddDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColAddLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleColAddLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarAddDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.DoubleScalarAddLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddDoubleScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongColAddLongScalar; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarAddDoubleColumn; -import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.LongScalarAddLongColumn; +import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.*; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; @@ -57,7 +46,8 @@ DoubleColAddLongColumn.class, DoubleColAddDoubleColumn.class, LongColAddLongScalar.class, LongColAddDoubleScalar.class, DoubleColAddLongScalar.class, DoubleColAddDoubleScalar.class, LongScalarAddLongColumn.class, LongScalarAddDoubleColumn.class, DoubleScalarAddLongColumn.class, - DoubleScalarAddDoubleColumn.class}) + DoubleScalarAddDoubleColumn.class, DecimalScalarAddDecimalColumn.class, DecimalColAddDecimalColumn.class, + DecimalColAddDecimalScalar.class}) public class GenericUDFOPPlus extends GenericUDFBaseNumeric { public GenericUDFOPPlus() { diff --git ql/src/test/queries/clientpositive/vector_decimal_expressions.q ql/src/test/queries/clientpositive/vector_decimal_expressions.q new file mode 100644 index 0000000..f3b4c83 --- /dev/null +++ ql/src/test/queries/clientpositive/vector_decimal_expressions.q @@ -0,0 +1,4 @@ +CREATE TABLE decimal_test STORED AS ORC AS SELECT cdouble, CAST (((cdouble*22.1)/37) AS DECIMAL(20,10)) AS cdecimal1, CAST (((cdouble*9.3)/13) AS DECIMAL(23,14)) AS cdecimal2 FROM alltypesorc; +SET hive.vectorized.execution.enabled=true; +EXPLAIN SELECT cdecimal1 + cdecimal2, cdecimal1 - (2*cdecimal2), ((cdecimal1+2.34)/cdecimal2), (cdecimal1 * (cdecimal2/3.4)) from decimal_test where cdecimal1 > 0 AND cdecimal1 < 12345.5678 AND cdecimal2 != 0 AND cdouble IS NOT NULL LIMIT 10; +SELECT cdecimal1 + cdecimal2, cdecimal1 - (2*cdecimal2), ((cdecimal1+2.34)/cdecimal2), (cdecimal1 * (cdecimal2/3.4)) from decimal_test where cdecimal1 > 0 AND cdecimal1 < 12345.5678 AND cdecimal2 != 0 AND cdouble IS NOT NULL LIMIT 10;