diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java index b94f790..08e1136 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/FunctionRegistry.java @@ -500,7 +500,7 @@ public static void registerGenericUDF(boolean isNative, String functionName, Class genericUDFClass) { if (GenericUDF.class.isAssignableFrom(genericUDFClass)) { FunctionInfo fI = new FunctionInfo(isNative, functionName, - (GenericUDF) ReflectionUtils.newInstance(genericUDFClass, null)); + ReflectionUtils.newInstance(genericUDFClass, null)); mFunctions.put(functionName.toLowerCase(), fI); registerNativeStatus(fI); } else { @@ -523,7 +523,7 @@ public static void registerGenericUDTF(boolean isNative, String functionName, Class genericUDTFClass) { if (GenericUDTF.class.isAssignableFrom(genericUDTFClass)) { FunctionInfo fI = new FunctionInfo(isNative, functionName, - (GenericUDTF) ReflectionUtils.newInstance(genericUDTFClass, null)); + ReflectionUtils.newInstance(genericUDTFClass, null)); mFunctions.put(functionName.toLowerCase(), fI); registerNativeStatus(fI); } else { @@ -534,7 +534,7 @@ public static void registerGenericUDTF(boolean isNative, String functionName, private static FunctionInfo getFunctionInfoFromMetastore(String functionName) { FunctionInfo ret = null; - + try { String dbName; String fName; @@ -577,7 +577,7 @@ private static FunctionInfo getFunctionInfoFromMetastore(String functionName) { // Lookup of UDf class failed LOG.error("Unable to load UDF class: " + e); } - + return ret; } @@ -599,7 +599,7 @@ private static FunctionInfo getFunctionInfoFromMetastore(String functionName) { if (functionInfo != null) { loadFunctionResourcesIfNecessary(functionName, functionInfo); } - + return functionInfo; } @@ -1018,7 +1018,7 @@ public static PrimitiveCategory getCommonCategory(TypeInfo a, TypeInfo b) { // If either is not a numeric type, return null. return null; } - + return (ai > bi) ? pcA : pcB; } @@ -1223,7 +1223,7 @@ public static void registerUDAF(boolean isNative, String functionName, Class udafClass) { FunctionInfo fi = new FunctionInfo(isNative, functionName.toLowerCase(), new GenericUDAFBridge( - (UDAF) ReflectionUtils.newInstance(udafClass, null))); + ReflectionUtils.newInstance(udafClass, null))); mFunctions.put(functionName.toLowerCase(), fi); // All aggregate functions should also be usable as window functions @@ -1571,7 +1571,7 @@ public static GenericUDF cloneGenericUDF(GenericUDF genericUDF) { clonedUDF = new GenericUDFMacro(bridge.getMacroName(), bridge.getBody(), bridge.getColNames(), bridge.getColTypes()); } else { - clonedUDF = (GenericUDF) ReflectionUtils + clonedUDF = ReflectionUtils .newInstance(genericUDF.getClass(), null); } @@ -1610,7 +1610,7 @@ public static GenericUDTF cloneGenericUDTF(GenericUDTF genericUDTF) { if (null == genericUDTF) { return null; } - return (GenericUDTF) ReflectionUtils.newInstance(genericUDTF.getClass(), + return ReflectionUtils.newInstance(genericUDTF.getClass(), null); } @@ -1735,7 +1735,7 @@ public static boolean isOpPositive(ExprNodeDesc desc) { /** * Returns whether the exprNodeDesc is node of "cast". */ - private static boolean isOpCast(ExprNodeDesc desc) { + public static boolean isOpCast(ExprNodeDesc desc) { if (!(desc instanceof ExprNodeGenericFuncDesc)) { return false; } @@ -1968,7 +1968,7 @@ public static TableFunctionResolver getWindowingTableFunction() { return getTableFunctionResolver(WINDOWING_TABLE_FUNCTION); } - + public static boolean isNoopFunction(String fnName) { fnName = fnName.toLowerCase(); return fnName.equals(NOOP_MAP_TABLE_FUNCTION) || diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java index b1d2ba4..9ca194a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.ql.exec.ColumnInfo; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.FilterOperator; +import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.exec.GroupByOperator; import org.apache.hadoop.hive.ql.exec.JoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; @@ -58,7 +59,6 @@ import org.apache.hadoop.hive.ql.plan.GroupByDesc; import org.apache.hadoop.hive.ql.plan.JoinCondDesc; import org.apache.hadoop.hive.ql.plan.JoinDesc; -import org.apache.hadoop.hive.ql.plan.PlanUtils; import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; import org.apache.hadoop.hive.ql.udf.UDFType; @@ -317,15 +317,22 @@ private static void propagate(GenericUDF udf, List newExprs, RowRe if (udf instanceof GenericUDFOPEqual) { ExprNodeDesc lOperand = newExprs.get(0); ExprNodeDesc rOperand = newExprs.get(1); - ExprNodeColumnDesc c; ExprNodeConstantDesc v; - if (lOperand instanceof ExprNodeColumnDesc && rOperand instanceof ExprNodeConstantDesc) { - c = (ExprNodeColumnDesc) lOperand; - v = (ExprNodeConstantDesc) rOperand; - } else if (rOperand instanceof ExprNodeColumnDesc && lOperand instanceof ExprNodeConstantDesc) { - c = (ExprNodeColumnDesc) rOperand; + if (lOperand instanceof ExprNodeConstantDesc) { v = (ExprNodeConstantDesc) lOperand; + } else if (rOperand instanceof ExprNodeConstantDesc) { + v = (ExprNodeConstantDesc) rOperand; } else { + // we need a constant on one side. + return; + } + // If both sides are constants, there is nothing to propagate + ExprNodeColumnDesc c = getColumnExpr(lOperand); + if (null == c) { + c = getColumnExpr(rOperand); + } + if (null == c) { + // we need a column expression on other side. return; } ColumnInfo ci = resolveColumn(rr, c); @@ -351,6 +358,16 @@ private static void propagate(GenericUDF udf, List newExprs, RowRe } } + private static ExprNodeColumnDesc getColumnExpr(ExprNodeDesc expr) { + if (expr instanceof ExprNodeColumnDesc) { + return (ExprNodeColumnDesc)expr; + } + if (FunctionRegistry.isOpCast(expr)) { + return (ExprNodeColumnDesc)expr.getChildren().get(0); + } + return null; + } + private static ExprNodeDesc shortcutFunction(GenericUDF udf, List newExprs) { if (udf instanceof GenericUDFOPAnd) { for (int i = 0; i < 2; i++) { diff --git a/ql/src/test/queries/clientpositive/constprog2.q b/ql/src/test/queries/clientpositive/constprog2.q index 72ce5a3..6001668 100644 --- a/ql/src/test/queries/clientpositive/constprog2.q +++ b/ql/src/test/queries/clientpositive/constprog2.q @@ -7,4 +7,10 @@ SELECT src1.key, src1.key + 1, src2.value SELECT src1.key, src1.key + 1, src2.value FROM src src1 join src src2 ON src1.key = src2.key AND src1.key = 86; +EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86; + +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86; diff --git a/ql/src/test/results/clientpositive/constprog2.q.out b/ql/src/test/results/clientpositive/constprog2.q.out index 148d95b..50ff890 100644 --- a/ql/src/test/results/clientpositive/constprog2.q.out +++ b/ql/src/test/results/clientpositive/constprog2.q.out @@ -73,3 +73,78 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### 86 87.0 val_86 +PREHOOK: query: EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Map Operator Tree: + TableScan + alias: src2 + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: ((UDFToDouble(key) = 86) and key is not null) (type: boolean) + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: '86' (type: string) + sort order: + + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + value expressions: value (type: string) + TableScan + alias: src1 + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: ((UDFToDouble(key) = 86) and key is not null) (type: boolean) + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: '86' (type: string) + sort order: + + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Operator Tree: + Join Operator + condition map: + Inner Join 0 to 1 + condition expressions: + 0 + 1 {VALUE._col0} + outputColumnNames: _col6 + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: '86' (type: string), 87.0 (type: double), _col6 (type: string) + outputColumnNames: _col0, _col1, _col2 + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + ListSink + +PREHOOK: query: SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +86 87.0 val_86