diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/JoinTypeCheckCtx.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/JoinTypeCheckCtx.java index dccd1d9..f166bb6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/JoinTypeCheckCtx.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/JoinTypeCheckCtx.java @@ -53,7 +53,7 @@ public JoinTypeCheckCtx(RowResolver leftRR, RowResolver rightRR, JoinType hiveJoinType) throws SemanticException { - super(RowResolver.getCombinedRR(leftRR, rightRR), true, false, false, false, false, false, false, + super(RowResolver.getCombinedRR(leftRR, rightRR), true, false, false, false, false, false, false, false, false, false); this.inputRRLst = ImmutableList.of(leftRR, rightRR); this.outerJoin = (hiveJoinType == JoinType.LEFTOUTER) || (hiveJoinType == JoinType.RIGHTOUTER) diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index 2983d38..f79a525 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -3143,8 +3143,8 @@ private Operator genFilterPlan(QB qb, ASTNode condn, Operator input, boolean use OpParseContext inputCtx = opParseCtx.get(input); RowResolver inputRR = inputCtx.getRowResolver(); Operator output = putOpInsertMap(OperatorFactory.getAndMakeChild( - new FilterDesc(genExprNodeDesc(condn, inputRR, useCaching), false), new RowSchema( - inputRR.getColumnInfos()), input), inputRR); + new FilterDesc(genExprNodeDesc(condn, inputRR, useCaching, isCBOExecuted()), false), + new RowSchema(inputRR.getColumnInfos()), input), inputRR); if (LOG.isDebugEnabled()) { LOG.debug("Created Filter Plan for " + qb.getId() + " row schema: " @@ -4146,7 +4146,7 @@ static boolean isRegex(String pattern, HiveConf conf) { expr, col_list, null, inputRR, starRR, pos, out_rwsch, qb.getAliases(), false); } else { // Case when this is an expression - TypeCheckCtx tcCtx = new TypeCheckCtx(inputRR); + TypeCheckCtx tcCtx = new TypeCheckCtx(inputRR, true, isCBOExecuted()); // We allow stateful functions in the SELECT list (but nowhere else) tcCtx.setAllowStatefulFunctions(true); tcCtx.setAllowDistinctFunctions(false); @@ -7777,7 +7777,7 @@ private Operator genJoinOperatorChildren(QBJoinTree join, Operator left, List expressions = joinTree.getExpressions().get(i); joinKeys[i] = new ExprNodeDesc[expressions.size()]; for (int j = 0; j < joinKeys[i].length; j++) { - joinKeys[i][j] = genExprNodeDesc(expressions.get(j), inputRR); + joinKeys[i][j] = genExprNodeDesc(expressions.get(j), inputRR, true, isCBOExecuted()); } } // Type checking and implicit type conversion for join keys @@ -10999,12 +10999,17 @@ public ExprNodeDesc genExprNodeDesc(ASTNode expr, RowResolver input) throws SemanticException { // Since the user didn't supply a customized type-checking context, // use default settings. - return genExprNodeDesc(expr, input, true); + return genExprNodeDesc(expr, input, true, false); } public ExprNodeDesc genExprNodeDesc(ASTNode expr, RowResolver input, boolean useCaching) throws SemanticException { - TypeCheckCtx tcCtx = new TypeCheckCtx(input, useCaching); + return genExprNodeDesc(expr, input, useCaching, false); + } + + public ExprNodeDesc genExprNodeDesc(ASTNode expr, RowResolver input, boolean useCaching, + boolean foldExpr) throws SemanticException { + TypeCheckCtx tcCtx = new TypeCheckCtx(input, useCaching, foldExpr); return genExprNodeDesc(expr, input, tcCtx); } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckCtx.java ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckCtx.java index de1c043..02896ff 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckCtx.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckCtx.java @@ -37,6 +37,8 @@ private final boolean useCaching; + private final boolean foldExpr; + /** * Receives translations which will need to be applied during unparse. */ @@ -79,20 +81,21 @@ * The input row resolver of the previous operator. */ public TypeCheckCtx(RowResolver inputRR) { - this(inputRR, true); + this(inputRR, true, false); } - public TypeCheckCtx(RowResolver inputRR, boolean useCaching) { - this(inputRR, useCaching, false, true, true, true, true, true, true, true); + public TypeCheckCtx(RowResolver inputRR, boolean useCaching, boolean foldExpr) { + this(inputRR, useCaching, foldExpr, false, true, true, true, true, true, true, true); } - public TypeCheckCtx(RowResolver inputRR, boolean useCaching, boolean allowStatefulFunctions, - boolean allowDistinctFunctions, boolean allowGBExprElimination, boolean allowAllColRef, - boolean allowFunctionStar, boolean allowWindowing, + public TypeCheckCtx(RowResolver inputRR, boolean useCaching, boolean foldExpr, + boolean allowStatefulFunctions, boolean allowDistinctFunctions, boolean allowGBExprElimination, + boolean allowAllColRef, boolean allowFunctionStar, boolean allowWindowing, boolean allowIndexExpr, boolean allowSubQueryExpr) { setInputRR(inputRR); error = null; this.useCaching = useCaching; + this.foldExpr = foldExpr; this.allowStatefulFunctions = allowStatefulFunctions; this.allowDistinctFunctions = allowDistinctFunctions; this.allowGBExprElimination = allowGBExprElimination; @@ -209,4 +212,8 @@ public boolean getallowSubQueryExpr() { public boolean isUseCaching() { return useCaching; } + + public boolean isFoldExpr() { + return foldExpr; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java index da236d5..ceeb9b4 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java @@ -61,9 +61,12 @@ import org.apache.hadoop.hive.ql.udf.SettableUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFNvl; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -1055,6 +1058,14 @@ protected ExprNodeDesc getXpathOrFuncExprNodeDesc(ASTNode expr, } desc = ExprNodeGenericFuncDesc.newInstance(genericUDF, funcText, childrenList); + } else if (ctx.isFoldExpr() && canConvertIntoNvl(genericUDF, children)) { + // Rewrite CASE into NVL + desc = ExprNodeGenericFuncDesc.newInstance(new GenericUDFNvl(), + Lists.newArrayList(children.get(0), new ExprNodeConstantDesc(false))); + if (Boolean.FALSE.equals(((ExprNodeConstantDesc) children.get(1)).getValue())) { + desc = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNot(), + Lists.newArrayList(desc)); + } } else { desc = ExprNodeGenericFuncDesc.newInstance(genericUDF, funcText, children); @@ -1072,6 +1083,21 @@ protected ExprNodeDesc getXpathOrFuncExprNodeDesc(ASTNode expr, return desc; } + private boolean canConvertIntoNvl(GenericUDF genericUDF, ArrayList children) { + if (genericUDF instanceof GenericUDFWhen && children.size() == 3 && + children.get(1) instanceof ExprNodeConstantDesc && + children.get(2) instanceof ExprNodeConstantDesc) { + ExprNodeConstantDesc constThen = (ExprNodeConstantDesc) children.get(1); + ExprNodeConstantDesc constElse = (ExprNodeConstantDesc) children.get(2); + Object thenVal = constThen.getValue(); + Object elseVal = constElse.getValue(); + if (thenVal instanceof Boolean && elseVal instanceof Boolean) { + return true; + } + } + return false; + } + /** * Returns true if des is a descendant of ans (ancestor) */ diff --git ql/src/test/queries/clientpositive/constantPropWhen.q ql/src/test/queries/clientpositive/constantPropWhen.q index c1d4885..03bfd54 100644 --- ql/src/test/queries/clientpositive/constantPropWhen.q +++ ql/src/test/queries/clientpositive/constantPropWhen.q @@ -1,4 +1,5 @@ set hive.mapred.mode=nonstrict; +set hive.optimize.constant.propagation=false; drop table test_1; @@ -24,6 +25,7 @@ SELECT cast(CASE id when id2 THEN TRUE ELSE FALSE END AS BOOLEAN) AS b FROM test set hive.cbo.enable=false; +set hive.optimize.constant.propagation=true; explain SELECT cast(CASE WHEN id = id2 THEN FALSE ELSE TRUE END AS BOOLEAN) AS b FROM test_1;