diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java index b1d2ba4..e69873c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConstantPropagateProcFactory.java @@ -69,6 +69,13 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToBinary; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToChar; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDate; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUtcTimestamp; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToVarchar; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; @@ -317,15 +324,21 @@ private static void propagate(GenericUDF udf, List newExprs, RowRe if (udf instanceof GenericUDFOPEqual) { ExprNodeDesc lOperand = newExprs.get(0); ExprNodeDesc rOperand = newExprs.get(1); - ExprNodeColumnDesc c; ExprNodeConstantDesc v; - if (lOperand instanceof ExprNodeColumnDesc && rOperand instanceof ExprNodeConstantDesc) { - c = (ExprNodeColumnDesc) lOperand; - v = (ExprNodeConstantDesc) rOperand; - } else if (rOperand instanceof ExprNodeColumnDesc && lOperand instanceof ExprNodeConstantDesc) { - c = (ExprNodeColumnDesc) rOperand; + if (lOperand instanceof ExprNodeConstantDesc) { v = (ExprNodeConstantDesc) lOperand; + } else if (rOperand instanceof ExprNodeConstantDesc) { + v = (ExprNodeConstantDesc) rOperand; } else { + // we need a constant on one side. + return; + } + ExprNodeColumnDesc c = getColumnExpr(lOperand); + if (null == c) { + c = getColumnExpr(rOperand); + } + if (null == c) { + // we need a column expression on other side. return; } ColumnInfo ci = resolveColumn(rr, c); @@ -351,6 +364,28 @@ private static void propagate(GenericUDF udf, List newExprs, RowRe } } + private static ExprNodeColumnDesc getColumnExpr(ExprNodeDesc expr) { + if (expr instanceof ExprNodeColumnDesc) { + return (ExprNodeColumnDesc)expr; + } + ExprNodeGenericFuncDesc funcDesc = null; + if (expr instanceof ExprNodeGenericFuncDesc) { + funcDesc = (ExprNodeGenericFuncDesc)expr; + } + if (null == funcDesc) { + return null; + } + GenericUDF udf = funcDesc.getGenericUDF(); + // check if its a simple cast expression. + if ((udf instanceof GenericUDFBridge || udf instanceof GenericUDFToBinary || udf instanceof GenericUDFToChar || + udf instanceof GenericUDFToVarchar || udf instanceof GenericUDFToDecimal || udf instanceof GenericUDFToDate + || udf instanceof GenericUDFToUnixTimeStamp || udf instanceof GenericUDFToUtcTimestamp) && + funcDesc.getChildren().size() == 1 && funcDesc.getChildren().get(0) instanceof ExprNodeColumnDesc) { + return (ExprNodeColumnDesc)expr.getChildren().get(0); + } + return null; + } + private static ExprNodeDesc shortcutFunction(GenericUDF udf, List newExprs) { if (udf instanceof GenericUDFOPAnd) { for (int i = 0; i < 2; i++) { diff --git a/ql/src/test/queries/clientpositive/constprog2.q b/ql/src/test/queries/clientpositive/constprog2.q index 72ce5a3..6001668 100644 --- a/ql/src/test/queries/clientpositive/constprog2.q +++ b/ql/src/test/queries/clientpositive/constprog2.q @@ -7,4 +7,10 @@ SELECT src1.key, src1.key + 1, src2.value SELECT src1.key, src1.key + 1, src2.value FROM src src1 join src src2 ON src1.key = src2.key AND src1.key = 86; +EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86; + +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86; diff --git a/ql/src/test/results/clientpositive/constprog2.q.out b/ql/src/test/results/clientpositive/constprog2.q.out index 148d95b..50ff890 100644 --- a/ql/src/test/results/clientpositive/constprog2.q.out +++ b/ql/src/test/results/clientpositive/constprog2.q.out @@ -73,3 +73,78 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### 86 87.0 val_86 +PREHOOK: query: EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN +SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Map Operator Tree: + TableScan + alias: src2 + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: ((UDFToDouble(key) = 86) and key is not null) (type: boolean) + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: '86' (type: string) + sort order: + + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + value expressions: value (type: string) + TableScan + alias: src1 + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Filter Operator + predicate: ((UDFToDouble(key) = 86) and key is not null) (type: boolean) + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: '86' (type: string) + sort order: + + Statistics: Num rows: 125 Data size: 1328 Basic stats: COMPLETE Column stats: NONE + Reduce Operator Tree: + Join Operator + condition map: + Inner Join 0 to 1 + condition expressions: + 0 + 1 {VALUE._col0} + outputColumnNames: _col6 + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: '86' (type: string), 87.0 (type: double), _col6 (type: string) + outputColumnNames: _col0, _col1, _col2 + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 137 Data size: 1460 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + ListSink + +PREHOOK: query: SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT src1.key, src1.key + 1, src2.value + FROM src src1 join src src2 ON src1.key = src2.key AND cast(src1.key as double) = 86 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +86 87.0 val_86