diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RelNodeConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RelNodeConverter.java index cf65e10..a731f0b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RelNodeConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RelNodeConverter.java @@ -87,7 +87,6 @@ .put("count", (Aggregation) SqlStdOperatorTable.COUNT) .put("sum", SqlStdOperatorTable.SUM).put("min", SqlStdOperatorTable.MIN) .put("max", SqlStdOperatorTable.MAX).put("avg", SqlStdOperatorTable.AVG) - .put("stddev_samp", SqlFunctionConverter.hiveAggFunction("stddev_samp")) .build(); public static RelNode convert(Operator sinkOp, RelOptCluster cluster, diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RexNodeConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RexNodeConverter.java index 91708ce..507095b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RexNodeConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/RexNodeConverter.java @@ -18,6 +18,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseNumeric; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToBinary; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToChar; @@ -37,8 +38,11 @@ import org.eigenbase.rex.RexCall; import org.eigenbase.rex.RexNode; import org.eigenbase.sql.SqlOperator; +import org.eigenbase.sql.fun.SqlCastFunction; +import org.eigenbase.sql.type.SqlTypeName; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableMap; public class RexNodeConverter { @@ -49,9 +53,8 @@ private final RowResolver m_hiveRR; private final int m_offsetInOptiqSchema; - private InputCtx(RelDataType optiqInpDataType, - ImmutableMap hiveNameToPosMap, RowResolver hiveRR, - int offsetInOptiqSchema) { + private InputCtx(RelDataType optiqInpDataType, ImmutableMap hiveNameToPosMap, + RowResolver hiveRR, int offsetInOptiqSchema) { m_optiqInpDataType = optiqInpDataType; m_hiveNameToPosMap = hiveNameToPosMap; m_hiveRR = hiveRR; @@ -64,16 +67,13 @@ private InputCtx(RelDataType optiqInpDataType, private final boolean m_flattenExpr; public RexNodeConverter(RelOptCluster cluster, RelDataType inpDataType, - ImmutableMap nameToPosMap, int offset, - boolean flattenExpr) { + ImmutableMap nameToPosMap, int offset, boolean flattenExpr) { this.m_cluster = cluster; - m_inputCtxs = ImmutableList.of(new InputCtx(inpDataType, nameToPosMap, - null, offset)); + m_inputCtxs = ImmutableList.of(new InputCtx(inpDataType, nameToPosMap, null, offset)); m_flattenExpr = flattenExpr; } - public RexNodeConverter(RelOptCluster cluster, List inpCtxLst, - boolean flattenExpr) { + public RexNodeConverter(RelOptCluster cluster, List inpCtxLst, boolean flattenExpr) { this.m_cluster = cluster; m_inputCtxs = ImmutableList. builder().addAll(inpCtxLst).build(); m_flattenExpr = flattenExpr; @@ -93,72 +93,93 @@ public RexNode convert(ExprNodeDesc expr) throws SemanticException { // ExprNodeColumnListDesc } - private RexNode convert(final ExprNodeGenericFuncDesc func) - throws SemanticException { - + private RexNode convert(final ExprNodeGenericFuncDesc func) throws SemanticException { ExprNodeDesc tmpExprNode; RexNode tmpRN; TypeInfo tgtDT = null; - SqlOperator optiqOp = SqlFunctionConverter.getOptiqOperator(func - .getGenericUDF()); List childRexNodeLst = new LinkedList(); + Builder argTypeBldr = ImmutableList. builder(); - // TODO: 1) Expand to other functions as needed 2) What about types other + // TODO: 1) Expand to other functions as needed 2) What about types + // other // than primitive if (func.getGenericUDF() instanceof GenericUDFBaseNumeric) { tgtDT = func.getTypeInfo(); } else if (func.getGenericUDF() instanceof GenericUDFBaseCompare) { if (func.getChildren().size() == 2) { - tgtDT = FunctionRegistry.getCommonClassForComparison(func.getChildren() - .get(0).getTypeInfo(), func.getChildren().get(1).getTypeInfo()); + tgtDT = FunctionRegistry.getCommonClassForComparison(func.getChildren().get(0) + .getTypeInfo(), func.getChildren().get(1).getTypeInfo()); } } for (ExprNodeDesc childExpr : func.getChildren()) { tmpExprNode = childExpr; if (tgtDT != null - && TypeInfoUtils.isConversionRequiredForComparison(tgtDT, - childExpr.getTypeInfo())) { - tmpExprNode = ParseUtils.createConversionCast(childExpr, - (PrimitiveTypeInfo) tgtDT); + && TypeInfoUtils.isConversionRequiredForComparison(tgtDT, childExpr.getTypeInfo())) { + tmpExprNode = ParseUtils.createConversionCast(childExpr, (PrimitiveTypeInfo) tgtDT); } + argTypeBldr.add(TypeConverter.convert(tmpExprNode.getTypeInfo(), m_cluster.getTypeFactory())); tmpRN = convert(tmpExprNode); childRexNodeLst.add(tmpRN); } - RexNode expr = null; - // This is an explicit cast + RexNode expr = null; expr = handleExplicitCast(func, childRexNodeLst); - if (expr == null) + if (expr == null) { + RelDataType retType = (expr != null) ? expr.getType() : TypeConverter.convert( + func.getTypeInfo(), m_cluster.getTypeFactory()); + SqlOperator optiqOp = SqlFunctionConverter.getOptiqOperator(func.getGenericUDF(), + argTypeBldr.build(), retType); expr = m_cluster.getRexBuilder().makeCall(optiqOp, childRexNodeLst); + } - if (m_flattenExpr && expr instanceof RexCall) { + // TODO: Cast Function in Optiq have a bug where it infertype on cast throws + // an exception + if (m_flattenExpr && (expr instanceof RexCall) + && !(((RexCall) expr).getOperator() instanceof SqlCastFunction)) { RexCall call = (RexCall) expr; - expr = m_cluster.getRexBuilder().makeFlatCall(call.getOperator(), - call.getOperands()); + expr = m_cluster.getRexBuilder().makeFlatCall(call.getOperator(), call.getOperands()); } return expr; } - private RexNode handleExplicitCast(ExprNodeGenericFuncDesc func, - List childRexNodeLst) { + private boolean castExprUsingUDFBridge(GenericUDF gUDF) { + boolean castExpr = false; + if (gUDF != null && gUDF instanceof GenericUDFBridge) { + String udfClassName = ((GenericUDFBridge) gUDF).getUdfClassName(); + if (udfClassName != null) { + int sp = udfClassName.lastIndexOf('.'); + // TODO: add method to UDFBridge to say if it is a cast func + if (sp >= 0 & (sp + 1) < udfClassName.length()) { + udfClassName = udfClassName.substring(sp + 1); + if (udfClassName.equals("UDFToBoolean") || udfClassName.equals("UDFToByte") + || udfClassName.equals("UDFToDouble") || udfClassName.equals("UDFToInteger") + || udfClassName.equals("UDFToLong") || udfClassName.equals("UDFToShort") + || udfClassName.equals("UDFToFloat") || udfClassName.equals("UDFToString")) + castExpr = true; + } + } + } + + return castExpr; + } + + private RexNode handleExplicitCast(ExprNodeGenericFuncDesc func, List childRexNodeLst) { RexNode castExpr = null; if (childRexNodeLst != null && childRexNodeLst.size() == 1) { GenericUDF udf = func.getGenericUDF(); - if ((udf instanceof GenericUDFToChar) - || (udf instanceof GenericUDFToVarchar) - || (udf instanceof GenericUDFToDecimal) - || (udf instanceof GenericUDFToDate) - || (udf instanceof GenericUDFToBinary) - || (udf instanceof GenericUDFToUnixTimeStamp)) { + if ((udf instanceof GenericUDFToChar) || (udf instanceof GenericUDFToVarchar) + || (udf instanceof GenericUDFToDecimal) || (udf instanceof GenericUDFToDate) + || (udf instanceof GenericUDFToBinary) || (udf instanceof GenericUDFToUnixTimeStamp) + || castExprUsingUDFBridge(udf)) { castExpr = m_cluster.getRexBuilder().makeCast( - TypeConverter.convert(func.getTypeInfo(), - m_cluster.getTypeFactory()), childRexNodeLst.get(0)); + TypeConverter.convert(func.getTypeInfo(), m_cluster.getTypeFactory()), + childRexNodeLst.get(0)); } } @@ -194,8 +215,7 @@ protected RexNode convert(ExprNodeColumnDesc col) throws SemanticException { InputCtx ic = getInputCtx(col); int pos = ic.m_hiveNameToPosMap.get(col.getColumn()); return m_cluster.getRexBuilder().makeInputRef( - ic.m_optiqInpDataType.getFieldList().get(pos).getType(), - pos + ic.m_offsetInOptiqSchema); + ic.m_optiqInpDataType.getFieldList().get(pos).getType(), pos + ic.m_offsetInOptiqSchema); } protected RexNode convert(ExprNodeConstantDesc literal) { @@ -220,8 +240,7 @@ protected RexNode convert(ExprNodeConstantDesc literal) { optiqLiteral = rexBuilder.makeExactLiteral(new BigDecimal((Short) value)); break; case INT: - optiqLiteral = rexBuilder - .makeExactLiteral(new BigDecimal((Integer) value)); + optiqLiteral = rexBuilder.makeExactLiteral(new BigDecimal((Integer) value)); break; case LONG: optiqLiteral = rexBuilder.makeBigintLiteral(new BigDecimal((Long) value)); @@ -231,12 +250,10 @@ protected RexNode convert(ExprNodeConstantDesc literal) { optiqLiteral = rexBuilder.makeExactLiteral((BigDecimal) value); break; case FLOAT: - optiqLiteral = rexBuilder.makeApproxLiteral( - new BigDecimal((Float) value), optiqDataType); + optiqLiteral = rexBuilder.makeApproxLiteral(new BigDecimal((Float) value), optiqDataType); break; case DOUBLE: - optiqLiteral = rexBuilder.makeApproxLiteral( - new BigDecimal((Double) value), optiqDataType); + optiqLiteral = rexBuilder.makeApproxLiteral(new BigDecimal((Double) value), optiqDataType); break; case STRING: optiqLiteral = rexBuilder.makeLiteral((String) value); @@ -254,8 +271,9 @@ protected RexNode convert(ExprNodeConstantDesc literal) { } public static RexNode getAlwaysTruePredicate(RelOptCluster cluster) { - SqlOperator optiqOp = SqlFunctionConverter - .getOptiqOperator(new GenericUDFOPEqual()); + RelDataType dt = cluster.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN); + SqlOperator optiqOp = SqlFunctionConverter.getOptiqOperator(new GenericUDFOPEqual(), + ImmutableList. of(dt), dt); List childRexNodeLst = new LinkedList(); childRexNodeLst.add(cluster.getRexBuilder().makeLiteral(true)); childRexNodeLst.add(cluster.getRexBuilder().makeLiteral(true)); @@ -263,21 +281,19 @@ public static RexNode getAlwaysTruePredicate(RelOptCluster cluster) { return cluster.getRexBuilder().makeCall(optiqOp, childRexNodeLst); } - public static RexNode convert(RelOptCluster cluster, - ExprNodeDesc joinCondnExprNode, List inputRels, - LinkedHashMap relToHiveRR, - Map> relToHiveColNameOptiqPosMap, - boolean flattenExpr) throws SemanticException { + public static RexNode convert(RelOptCluster cluster, ExprNodeDesc joinCondnExprNode, + List inputRels, LinkedHashMap relToHiveRR, + Map> relToHiveColNameOptiqPosMap, boolean flattenExpr) + throws SemanticException { List inputCtxLst = new ArrayList(); int offSet = 0; for (RelNode r : inputRels) { - inputCtxLst.add(new InputCtx(r.getRowType(), relToHiveColNameOptiqPosMap - .get(r), relToHiveRR.get(r), offSet)); + inputCtxLst.add(new InputCtx(r.getRowType(), relToHiveColNameOptiqPosMap.get(r), relToHiveRR + .get(r), offSet)); offSet += r.getRowType().getFieldCount(); } - return (new RexNodeConverter(cluster, inputCtxLst, flattenExpr)) - .convert(joinCondnExprNode); + return (new RexNodeConverter(cluster, inputCtxLst, flattenExpr)).convert(joinCondnExprNode); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/SqlFunctionConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/SqlFunctionConverter.java index 84540c4..15ebdc7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/SqlFunctionConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/SqlFunctionConverter.java @@ -10,7 +10,6 @@ import org.apache.hadoop.hive.ql.parse.ParseDriver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; -import org.apache.hadoop.hive.serde.serdeConstants; import org.eigenbase.reltype.RelDataType; import org.eigenbase.reltype.RelDataTypeFactory; import org.eigenbase.sql.SqlAggFunction; @@ -19,31 +18,35 @@ import org.eigenbase.sql.SqlKind; import org.eigenbase.sql.SqlOperator; import org.eigenbase.sql.fun.SqlStdOperatorTable; +import org.eigenbase.sql.type.InferTypes; import org.eigenbase.sql.type.OperandTypes; import org.eigenbase.sql.type.ReturnTypes; +import org.eigenbase.sql.type.SqlOperandTypeChecker; +import org.eigenbase.sql.type.SqlOperandTypeInference; import org.eigenbase.sql.type.SqlReturnTypeInference; -import org.eigenbase.sql.type.SqlTypeName; +import org.eigenbase.sql.type.SqlTypeFamily; +import org.eigenbase.util.Util; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; public class SqlFunctionConverter { - static final Map operatorMap; static final Map hiveToOptiq; static final Map optiqToHiveToken; static { Builder builder = new Builder(); - operatorMap = ImmutableMap.copyOf(builder.operatorMap); - hiveToOptiq = ImmutableMap.copyOf(builder.hiveToOptiq); - optiqToHiveToken = ImmutableMap.copyOf(builder.optiqToHiveToken); + hiveToOptiq = builder.hiveToOptiq; + optiqToHiveToken = builder.optiqToHiveToken; } - public static SqlOperator getOptiqOperator(GenericUDF hiveUDF) { - return hiveToOptiq.get(getName(hiveUDF)); + public static SqlOperator getOptiqOperator(GenericUDF hiveUDF, + ImmutableList optiqArgTypes, RelDataType retType) { + return getOptiqFn(getName(hiveUDF), optiqArgTypes, retType); } + // TODO: 1) handle Agg Func Name translation 2) is it correct to add func args + // as child of func? public static ASTNode buildAST(SqlOperator op, List children) { HiveToken hToken = optiqToHiveToken.get(op); ASTNode node; @@ -52,8 +55,7 @@ public static ASTNode buildAST(SqlOperator op, List children) { } else { node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTION, "TOK_FUNCTION"); if (op.kind != SqlKind.CAST) - node.addChild((ASTNode) ParseDriver.adaptor.create( - HiveParser.Identifier, op.getName())); + node.addChild((ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, op.getName())); } for (ASTNode c : children) { @@ -89,135 +91,28 @@ private static String getName(GenericUDF hiveUDF) { } private static class Builder { - final Map operatorMap = Maps.newHashMap(); final Map hiveToOptiq = Maps.newHashMap(); final Map optiqToHiveToken = Maps.newHashMap(); Builder() { - registerFunction("concat", SqlStdOperatorTable.CONCAT, null); - registerFunction("substr", SqlStdOperatorTable.SUBSTRING, null); - registerFunction("substring", SqlStdOperatorTable.SUBSTRING, null); - stringFunction("space"); - stringFunction("repeat"); - numericFunction("ascii"); - stringFunction("repeat"); - - numericFunction("size"); - - numericFunction("round"); - registerFunction("floor", SqlStdOperatorTable.FLOOR, null); - registerFunction("sqrt", SqlStdOperatorTable.SQRT, null); - registerFunction("ceil", SqlStdOperatorTable.CEIL, null); - registerFunction("ceiling", SqlStdOperatorTable.CEIL, null); - numericFunction("rand"); - operatorMap.put("abs", SqlStdOperatorTable.ABS); - numericFunction("pmod"); - - numericFunction("ln"); - numericFunction("log2"); - numericFunction("sin"); - numericFunction("asin"); - numericFunction("cos"); - numericFunction("acos"); - registerFunction("log10", SqlStdOperatorTable.LOG10, null); - numericFunction("log"); - numericFunction("exp"); - numericFunction("power"); - numericFunction("pow"); - numericFunction("sign"); - numericFunction("pi"); - numericFunction("degrees"); - numericFunction("atan"); - numericFunction("tan"); - numericFunction("e"); - - registerFunction("upper", SqlStdOperatorTable.UPPER, null); - registerFunction("lower", SqlStdOperatorTable.LOWER, null); - registerFunction("ucase", SqlStdOperatorTable.UPPER, null); - registerFunction("lcase", SqlStdOperatorTable.LOWER, null); - registerFunction("trim", SqlStdOperatorTable.TRIM, null); - stringFunction("ltrim"); - stringFunction("rtrim"); - numericFunction("length"); - - stringFunction("like"); - stringFunction("rlike"); - stringFunction("regexp"); - stringFunction("regexp_replace"); - - stringFunction("regexp_extract"); - stringFunction("parse_url"); - - numericFunction("day"); - numericFunction("dayofmonth"); - numericFunction("month"); - numericFunction("year"); - numericFunction("hour"); - numericFunction("minute"); - numericFunction("second"); - registerFunction("+", SqlStdOperatorTable.PLUS, hToken(HiveParser.PLUS, "+")); registerFunction("-", SqlStdOperatorTable.MINUS, hToken(HiveParser.MINUS, "-")); registerFunction("*", SqlStdOperatorTable.MULTIPLY, hToken(HiveParser.STAR, "*")); registerFunction("/", SqlStdOperatorTable.DIVIDE, hToken(HiveParser.STAR, "/")); registerFunction("%", SqlStdOperatorTable.MOD, hToken(HiveParser.STAR, "%")); - numericFunction("div"); - - numericFunction("isnull"); - numericFunction("isnotnull"); - - numericFunction("if"); - numericFunction("in"); registerFunction("and", SqlStdOperatorTable.AND, hToken(HiveParser.KW_AND, "and")); registerFunction("or", SqlStdOperatorTable.OR, hToken(HiveParser.KW_OR, "or")); registerFunction("=", SqlStdOperatorTable.EQUALS, hToken(HiveParser.EQUAL, "=")); -// numericFunction("=="); - numericFunction("<=>"); - numericFunction("!="); - - numericFunction("<>"); registerFunction("<", SqlStdOperatorTable.LESS_THAN, hToken(HiveParser.LESSTHAN, "<")); registerFunction("<=", SqlStdOperatorTable.LESS_THAN_OR_EQUAL, hToken(HiveParser.LESSTHANOREQUALTO, "<=")); registerFunction(">", SqlStdOperatorTable.GREATER_THAN, hToken(HiveParser.GREATERTHAN, ">")); registerFunction(">=", SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, hToken(HiveParser.GREATERTHANOREQUALTO, ">=")); - numericFunction("not"); registerFunction("!", SqlStdOperatorTable.NOT, hToken(HiveParser.KW_NOT, "not")); - numericFunction("between"); - - registerFunction("case", SqlStdOperatorTable.CASE, null); - numericFunction("when"); - - // implicit convert methods - numericFunction(serdeConstants.BOOLEAN_TYPE_NAME); - numericFunction(serdeConstants.TINYINT_TYPE_NAME); - numericFunction(serdeConstants.SMALLINT_TYPE_NAME); - numericFunction(serdeConstants.INT_TYPE_NAME); - numericFunction(serdeConstants.BIGINT_TYPE_NAME); - numericFunction(serdeConstants.FLOAT_TYPE_NAME); - numericFunction(serdeConstants.DOUBLE_TYPE_NAME); - stringFunction(serdeConstants.STRING_TYPE_NAME); - } - - private void stringFunction(String name) { - registerFunction(name, SqlFunctionCategory.STRING, ReturnTypes.explicit(SqlTypeName.VARCHAR)); - } - - private void numericFunction(String name) { - registerFunction(name, SqlFunctionCategory.NUMERIC, ReturnTypes.explicit(SqlTypeName.DECIMAL)); - } - - private void registerFunction(String name, SqlFunctionCategory cat, SqlReturnTypeInference rti) { - SqlOperator optiqFn = new SqlFunction(name.toUpperCase(), SqlKind.OTHER_FUNCTION, rti, null, - null, cat); - registerFunction(name, optiqFn, null); } - private void registerFunction(String name, SqlOperator optiqFn, - HiveToken hiveToken) { - operatorMap.put(name, optiqFn); - + private void registerFunction(String name, SqlOperator optiqFn, HiveToken hiveToken) { FunctionInfo hFn = FunctionRegistry.getFunctionInfo(name); if (hFn != null) { String hFnName = getName(hFn.getGenericUDF()); @@ -234,30 +129,90 @@ private static HiveToken hToken(int type, String text) { return new HiveToken(type, text); } - public static SqlAggFunction hiveAggFunction(String name) { - return new HiveAggFunction(name); + public static class OptiqUDAF extends SqlAggFunction { + final ImmutableList m_argTypes; + final RelDataType m_retType; + + public OptiqUDAF(String opName, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker, + ImmutableList argTypes, RelDataType retType) { + super(opName, SqlKind.OTHER_FUNCTION, returnTypeInference, operandTypeInference, + operandTypeChecker, SqlFunctionCategory.USER_DEFINED_FUNCTION); + m_argTypes = argTypes; + m_retType = retType; + } + + public List getParameterTypes(final RelDataTypeFactory typeFactory) { + return m_argTypes; + } + + public RelDataType getReturnType(final RelDataTypeFactory typeFactory) { + return m_retType; + } } - static class HiveAggFunction extends SqlAggFunction { + private static class OptiqUDFInfo { + private String m_udfName; + private SqlReturnTypeInference m_returnTypeInference; + private SqlOperandTypeInference m_operandTypeInference; + private SqlOperandTypeChecker m_operandTypeChecker; + private ImmutableList m_argTypes; + private RelDataType m_retType; + } - public HiveAggFunction(String name) { - super(name, SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, null, - OperandTypes.ANY, SqlFunctionCategory.NUMERIC); + private static OptiqUDFInfo getUDFInfo(String hiveUdfName, + ImmutableList optiqArgTypes, RelDataType optiqRetType) { + OptiqUDFInfo udfInfo = new OptiqUDFInfo(); + udfInfo.m_udfName = hiveUdfName; + udfInfo.m_returnTypeInference = ReturnTypes.explicit(optiqRetType); + udfInfo.m_operandTypeInference = InferTypes.explicit(optiqArgTypes); + ImmutableList.Builder typeFamilyBuilder = new ImmutableList.Builder(); + for (RelDataType at : optiqArgTypes) { + typeFamilyBuilder.add(Util.first(at.getSqlTypeName().getFamily(), SqlTypeFamily.ANY)); } + udfInfo.m_operandTypeChecker = OperandTypes.family(typeFamilyBuilder.build()); + + udfInfo.m_argTypes = ImmutableList. copyOf(optiqArgTypes); + udfInfo.m_retType = optiqRetType; + + return udfInfo; + } - public List getParameterTypes(RelDataTypeFactory typeFactory) { - return ImmutableList.of(typeFactory.createSqlType(SqlTypeName.ANY)); + public static SqlOperator getOptiqFn(String hiveUdfName, + ImmutableList optiqArgTypes, RelDataType optiqRetType) { + SqlOperator optiqOp = hiveToOptiq.get(hiveUdfName); + if (optiqOp == null) { + OptiqUDFInfo uInf = getUDFInfo(hiveUdfName, optiqArgTypes, optiqRetType); + optiqOp = new SqlFunction(uInf.m_udfName, SqlKind.OTHER_FUNCTION, uInf.m_returnTypeInference, + uInf.m_operandTypeInference, uInf.m_operandTypeChecker, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + hiveToOptiq.put(hiveUdfName, optiqOp); + HiveToken ht = hToken(HiveParser.TOK_FUNCTION, "TOK_FUNCTION"); + optiqToHiveToken.put(optiqOp, ht); } - public RelDataType getReturnType(RelDataTypeFactory typeFactory) { - return typeFactory.createSqlType(SqlTypeName.BIGINT); + return optiqOp; + } + + public static SqlAggFunction getOptiqAggFn(String hiveUdfName, + ImmutableList optiqArgTypes, RelDataType optiqRetType) { + SqlAggFunction optiqAggFn = (SqlAggFunction) hiveToOptiq.get(hiveUdfName); + if (optiqAggFn == null) { + OptiqUDFInfo uInf = getUDFInfo(hiveUdfName, optiqArgTypes, optiqRetType); + + optiqAggFn = new OptiqUDAF(uInf.m_udfName, uInf.m_returnTypeInference, + uInf.m_operandTypeInference, uInf.m_operandTypeChecker, uInf.m_argTypes, uInf.m_retType); + hiveToOptiq.put(hiveUdfName, optiqAggFn); + HiveToken ht = hToken(HiveParser.TOK_FUNCTION, "TOK_FUNCTION"); + optiqToHiveToken.put(optiqAggFn, ht); } + return optiqAggFn; } static class HiveToken { - int type; - String text; + int type; + String text; String[] args; HiveToken(int type, String text, String... args) { @@ -265,5 +220,5 @@ public RelDataType getReturnType(RelDataTypeFactory typeFactory) { this.text = text; this.args = args; } - } + } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index 03f134d..07bd138 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -242,6 +242,7 @@ import org.eigenbase.relopt.hep.HepPlanner; import org.eigenbase.relopt.hep.HepProgramBuilder; import org.eigenbase.reltype.RelDataType; +import org.eigenbase.reltype.RelDataTypeFactory; import org.eigenbase.reltype.RelDataTypeField; import org.eigenbase.rex.RexBuilder; import org.eigenbase.rex.RexInputRef; @@ -250,6 +251,7 @@ import org.eigenbase.util.CompositeList; import com.google.common.base.Function; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; @@ -12170,29 +12172,6 @@ private RelNode genFilterLogicalPlan(QB qb, RelNode srcRel) return filterRel; } - private final Map AGG_MAP = ImmutableMap - . builder() - .put( - "count", - (Aggregation) SqlStdOperatorTable.COUNT) - .put( - "sum", - SqlStdOperatorTable.SUM) - .put( - "min", - SqlStdOperatorTable.MIN) - .put( - "max", - SqlStdOperatorTable.MAX) - .put( - "avg", - SqlStdOperatorTable.AVG) - .put( - "stddev_samp", - SqlFunctionConverter - .hiveAggFunction("stddev_samp")) - .build(); - /** * Class to store GenericUDAF related information. */ @@ -12211,23 +12190,22 @@ private AggInfo(List aggParams, TypeInfo returnType, } } - private AggregateCall convertAgg(AggInfo agg, RelNode input, - List gbChildProjLst, RexNodeConverter converter, - HashMap rexNodeToPosMap, Integer childProjLstIndx) - throws SemanticException { - final Aggregation aggregation = AGG_MAP.get(agg.m_udfName); - if (aggregation == null) { - throw new AssertionError("agg not found: " + agg.m_udfName); - } + private AggregateCall convertGBAgg(AggInfo agg, RelNode input, List gbChildProjLst, + RexNodeConverter converter, HashMap rexNodeToPosMap, + Integer childProjLstIndx) throws SemanticException { - List argList = new ArrayList(); - RelDataType type = TypeConverter.convert(agg.m_returnType, + // 1. Get agg fn ret type in Optiq + RelDataType aggFnRetType = TypeConverter.convert(agg.m_returnType, this.m_cluster.getTypeFactory()); + // 2. Convert Agg Fn args and type of args to Optiq // TODO: Does HQL allows expressions as aggregate args or can it only be // projections from child? Integer inputIndx; + List argList = new ArrayList(); RexNode rexNd = null; + RelDataTypeFactory dtFactory = this.m_cluster.getTypeFactory(); + ImmutableList.Builder aggArgRelDTBldr = new ImmutableList.Builder(); for (ExprNodeDesc expr : agg.m_aggParams) { rexNd = converter.convert(expr); inputIndx = rexNodeToPosMap.get(rexNd.toString()); @@ -12238,17 +12216,17 @@ private AggregateCall convertAgg(AggInfo agg, RelNode input, childProjLstIndx++; } argList.add(inputIndx); - } - /* - * set the type to the first arg, it there is one; because the RTi set on - * Aggregation call assumes this is the output type. - */ - if (argList.size() > 0) { - RexNode rex = converter.convert(agg.m_aggParams.get(0)); - type = rex.getType(); + // TODO: does arg need type cast? + aggArgRelDTBldr.add(TypeConverter.convert(expr.getTypeInfo(), dtFactory)); } - return new AggregateCall(aggregation, agg.m_distinct, argList, type, null); + + // 3. Get Aggregation FN from Optiq given name, ret type and input arg + // type + final Aggregation aggregation = SqlFunctionConverter.getOptiqAggFn(agg.m_udfName, + aggArgRelDTBldr.build(), aggFnRetType); + + return new AggregateCall(aggregation, agg.m_distinct, argList, aggFnRetType, null); } private RelNode genGBRelNode(List gbExprs, @@ -12276,7 +12254,7 @@ private RelNode genGBRelNode(List gbExprs, List aggregateCalls = Lists.newArrayList(); int i = aggInfoLst.size(); for (AggInfo agg : aggInfoLst) { - aggregateCalls.add(convertAgg(agg, srcRel, gbChildProjLst, converter, + aggregateCalls.add(convertGBAgg(agg, srcRel, gbChildProjLst, converter, rexNodeToPosMap, gbChildProjLst.size())); } @@ -12341,6 +12319,39 @@ private void addToGBExpr(RowResolver groupByOutputRowResolver, groupByOutputRowResolver); } + private AggInfo getHiveAggInfo(ASTNode aggAst, int aggFnLstArgIndx, RowResolver inputRR) + throws SemanticException { + AggInfo aInfo = null; + + // 1 Convert UDAF Params to ExprNodeDesc + ArrayList aggParameters = new ArrayList(); + for (int i = 1; i <= aggFnLstArgIndx; i++) { + ASTNode paraExpr = (ASTNode) aggAst.getChild(i); + ExprNodeDesc paraExprNode = genExprNodeDesc(paraExpr, inputRR); + aggParameters.add(paraExprNode); + } + + // 2 Determine type of UDAF + // This is the GenericUDAF name + String aggName = unescapeIdentifier(aggAst.getChild(0).getText()); + boolean isDistinct = aggAst.getType() == HiveParser.TOK_FUNCTIONDI; + boolean isAllColumns = aggAst.getType() == HiveParser.TOK_FUNCTIONSTAR; + + // 3 Get UDAF Evaluator + Mode amode = groupByDescModeToUDAFMode(GroupByDesc.Mode.COMPLETE, isDistinct); + GenericUDAFEvaluator genericUDAFEvaluator = getGenericUDAFEvaluator(aggName, aggParameters, + aggAst, isDistinct, isAllColumns); + assert (genericUDAFEvaluator != null); + + // 4. Get UDAF Info using UDAF Evaluator + GenericUDAFInfo udaf = getGenericUDAFInfo(genericUDAFEvaluator, amode, aggParameters); + + // 5. Construct AggInfo + aInfo = new AggInfo(aggParameters, udaf.returnType, aggName, isDistinct); + + return aInfo; + } + /** * Generate GB plan. * @@ -12398,7 +12409,6 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException boolean isDistinct = value.getType() == HiveParser.TOK_FUNCTIONDI; boolean isAllColumns = value.getType() == HiveParser.TOK_FUNCTIONSTAR; if (isDistinct) { -// continue; numDistinctUDFs++; } diff --git a/ql/src/test/queries/clientpositive/cbo_correctness.q b/ql/src/test/queries/clientpositive/cbo_correctness.q index f10f302..65ed130 100644 --- a/ql/src/test/queries/clientpositive/cbo_correctness.q +++ b/ql/src/test/queries/clientpositive/cbo_correctness.q @@ -162,3 +162,13 @@ select * from (select key as a, c_int+1 as b, sum(c_int) as c from t1 where (t1. select * from (select key as a, c_int+1 as b, sum(c_int) as c from t1 where (t1.c_int + 1 >= 0) and (t1.c_int > 0 or t1.c_float >= 0) group by c_float, t1.c_int, key having t1.c_float > 0 and (c_int >=1 or c_float >= 1) and (c_int + c_float) >= 0 order by b % c asc, b desc limit 5) t1 left outer join (select key as p, c_int+1 as q, sum(c_int) as r from t2 where (t2.c_int + 1 >= 0) and (t2.c_int > 0 or t2.c_float >= 0) group by c_float, t2.c_int, key having t2.c_float > 0 and (c_int >=1 or c_float >= 1) and (c_int + c_float) >= 0 limit 5) t2 on t1.a=p left outer join t3 on t1.a=key where (b + t2.q >= 0) and (b > 0 or c_int >= 0) group by t3.c_int, c having t3.c_int > 0 and (c_int >=1 or c >= 1) and (c_int + c) >= 0 order by t3.c_int % c asc, t3.c_int desc limit 5; +-- 8. Test UDAF +select count(*), count(c_int), sum(c_int), avg(c_int), max(c_int), min(c_int) from t1; +select * from (select count(*) as a, count(distinct c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1; +select f,a,e,b from (select count(*) as a, count(c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1; +select f,a,e,b from (select count(*) as a, count(distinct c_int) as b, sum(distinct c_int) as c, avg(distinct c_int) as d, max(distinct c_int) as e, min(distinct c_int) as f from t1) t1; +select count(c_int) as a, avg(c_float), key from t1 group by key; +select count(distinct c_int) as a, avg(c_float) from t1 group by c_float; +select count(distinct c_int) as a, avg(c_float) from t1 group by c_int; +select count(distinct c_int) as a, avg(c_float) from t1 group by c_float, c_int; + diff --git a/ql/src/test/results/clientpositive/cbo_correctness.q.out b/ql/src/test/results/clientpositive/cbo_correctness.q.out index fed415b..49f2661 100644 --- a/ql/src/test/results/clientpositive/cbo_correctness.q.out +++ b/ql/src/test/results/clientpositive/cbo_correctness.q.out @@ -15695,3 +15695,84 @@ POSTHOOK: Input: default@t3 #### A masked pattern was here #### 1 12 1 2 +PREHOOK: query: -- 8. Test UDAF +select count(*), count(c_int), sum(c_int), avg(c_int), max(c_int), min(c_int) from t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: -- 8. Test UDAF +select count(*), count(c_int), sum(c_int), avg(c_int), max(c_int), min(c_int) from t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +20 18 18 1.0 1 1 +PREHOOK: query: select * from (select count(*) as a, count(distinct c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select * from (select count(*) as a, count(distinct c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +20 1 18 1.0 1 1 +PREHOOK: query: select f,a,e,b from (select count(*) as a, count(c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select f,a,e,b from (select count(*) as a, count(c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +1 20 1 18 +PREHOOK: query: select f,a,e,b from (select count(*) as a, count(distinct c_int) as b, sum(distinct c_int) as c, avg(distinct c_int) as d, max(distinct c_int) as e, min(distinct c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select f,a,e,b from (select count(*) as a, count(distinct c_int) as b, sum(distinct c_int) as c, avg(distinct c_int) as d, max(distinct c_int) as e, min(distinct c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +1 20 1 1 +PREHOOK: query: select count(c_int) as a, avg(c_float), key from t1 group by key +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(c_int) as a, avg(c_float), key from t1 group by key +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +2 1.0 1 +2 1.0 1 +12 1.0 1 +2 1.0 1 +0 NULL null +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0 +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_int +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_int +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0 +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float, c_int +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float, c_int +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0 diff --git a/ql/src/test/results/clientpositive/tez/cbo_correctness.q.out b/ql/src/test/results/clientpositive/tez/cbo_correctness.q.out index fed415b..49f2661 100644 --- a/ql/src/test/results/clientpositive/tez/cbo_correctness.q.out +++ b/ql/src/test/results/clientpositive/tez/cbo_correctness.q.out @@ -15695,3 +15695,84 @@ POSTHOOK: Input: default@t3 #### A masked pattern was here #### 1 12 1 2 +PREHOOK: query: -- 8. Test UDAF +select count(*), count(c_int), sum(c_int), avg(c_int), max(c_int), min(c_int) from t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: -- 8. Test UDAF +select count(*), count(c_int), sum(c_int), avg(c_int), max(c_int), min(c_int) from t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +20 18 18 1.0 1 1 +PREHOOK: query: select * from (select count(*) as a, count(distinct c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select * from (select count(*) as a, count(distinct c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +20 1 18 1.0 1 1 +PREHOOK: query: select f,a,e,b from (select count(*) as a, count(c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select f,a,e,b from (select count(*) as a, count(c_int) as b, sum(c_int) as c, avg(c_int) as d, max(c_int) as e, min(c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +1 20 1 18 +PREHOOK: query: select f,a,e,b from (select count(*) as a, count(distinct c_int) as b, sum(distinct c_int) as c, avg(distinct c_int) as d, max(distinct c_int) as e, min(distinct c_int) as f from t1) t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select f,a,e,b from (select count(*) as a, count(distinct c_int) as b, sum(distinct c_int) as c, avg(distinct c_int) as d, max(distinct c_int) as e, min(distinct c_int) as f from t1) t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +1 20 1 1 +PREHOOK: query: select count(c_int) as a, avg(c_float), key from t1 group by key +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(c_int) as a, avg(c_float), key from t1 group by key +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +2 1.0 1 +2 1.0 1 +12 1.0 1 +2 1.0 1 +0 NULL null +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0 +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_int +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_int +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0 +PREHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float, c_int +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: select count(distinct c_int) as a, avg(c_float) from t1 group by c_float, c_int +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +0 NULL +1 1.0