diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java index 1b327fe..c079ed2 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveRexUtil.java @@ -318,12 +318,12 @@ private static RexNode simplifyCase(RexBuilder rexBuilder, RexCall call, if (newOperands.size() == 1 || values.size() == 1) { return rexBuilder.makeCast(call.getType(), newOperands.get(newOperands.size() - 1)); } - trueFalse: if (call.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) { // Optimize CASE where every branch returns constant true or constant // false. final List> pairs = casePairs(rexBuilder, newOperands); + RexNode result; // 1) Possible simplification if unknown is treated as false: // CASE // WHEN p1 THEN TRUE @@ -332,25 +332,9 @@ private static RexNode simplifyCase(RexBuilder rexBuilder, RexCall call, // END // can be rewritten to: (p1 or p2) if (unknownAsFalse) { - final List terms = new ArrayList<>(); - int pos = 0; - for (; pos < pairs.size(); pos++) { - // True block - Pair pair = pairs.get(pos); - if (!pair.getValue().isAlwaysTrue()) { - break; - } - terms.add(pair.getKey()); - } - for (; pos < pairs.size(); pos++) { - // False block - Pair pair = pairs.get(pos); - if (!pair.getValue().isAlwaysFalse() && !RexUtil.isNull(pair.getValue())) { - break; - } - } - if (pos == pairs.size()) { - return RexUtil.composeDisjunction(rexBuilder, terms, false); + result = simplifyBooleanCase1(rexBuilder, pairs); + if (result != null) { + return result; } } // 2) Another simplification @@ -360,27 +344,24 @@ private static RexNode simplifyCase(RexBuilder rexBuilder, RexCall call, // WHEN p3 THEN TRUE // ELSE FALSE // END + // (p1 or (p3 and not(p2))) // if p1...pn cannot be nullable - for (Ord> pair : Ord.zip(pairs)) { - if (pair.e.getKey().getType().isNullable()) { - break trueFalse; - } - if (!pair.e.getValue().isAlwaysTrue() - && !pair.e.getValue().isAlwaysFalse() - && (!unknownAsFalse || !RexUtil.isNull(pair.e.getValue()))) { - break trueFalse; - } + result = simplifyBooleanCase2(rexBuilder, pairs, unknownAsFalse); + if (result != null) { + return result; } - final List terms = new ArrayList<>(); - final List notTerms = new ArrayList<>(); - for (Ord> pair : Ord.zip(pairs)) { - if (pair.e.getValue().isAlwaysTrue()) { - terms.add(RexUtil.andNot(rexBuilder, pair.e.getKey(), notTerms)); - } else { - notTerms.add(pair.e.getKey()); - } + // 3) Another simplification + // CASE + // WHEN p1 THEN x + // WHEN p2 THEN y + // ELSE TRUE + // END + // (p1 and x) or (p2 and y and not(p1)) or (not(p1) and not(p2)) + // if p1...pn cannot be nullable + result = simplifyBooleanCase3(rexBuilder, pairs); + if (result != null) { + return result; } - return RexUtil.composeDisjunction(rexBuilder, terms, false); } if (newOperands.equals(operands)) { return call; @@ -402,6 +383,73 @@ private static RexNode simplifyCase(RexBuilder rexBuilder, RexCall call, return builder.build(); } + private static RexNode simplifyBooleanCase1(RexBuilder rexBuilder, + List> pairs) { + final List terms = new ArrayList<>(); + int pos = 0; + for (; pos < pairs.size(); pos++) { + // True block + Pair pair = pairs.get(pos); + if (!pair.getValue().isAlwaysTrue()) { + break; + } + terms.add(pair.getKey()); + } + for (; pos < pairs.size(); pos++) { + // False block + Pair pair = pairs.get(pos); + if (!pair.getValue().isAlwaysFalse() && !RexUtil.isNull(pair.getValue())) { + break; + } + } + if (pos == pairs.size()) { + return RexUtil.composeDisjunction(rexBuilder, terms, false); + } + return null; + } + + private static RexNode simplifyBooleanCase2(RexBuilder rexBuilder, + List> pairs, boolean unknownAsFalse) { + for (Ord> pair : Ord.zip(pairs)) { + if (pair.e.getKey().getType().isNullable()) { + return null; + } + if (!pair.e.getValue().isAlwaysTrue() + && !pair.e.getValue().isAlwaysFalse() + && (!unknownAsFalse || !RexUtil.isNull(pair.e.getValue()))) { + return null; + } + } + final List terms = new ArrayList<>(); + final List notTerms = new ArrayList<>(); + for (Ord> pair : Ord.zip(pairs)) { + if (pair.e.getValue().isAlwaysTrue()) { + terms.add(RexUtil.andNot(rexBuilder, pair.e.getKey(), notTerms)); + } else { + notTerms.add(pair.e.getKey()); + } + } + return RexUtil.composeDisjunction(rexBuilder, terms, false); + } + + private static RexNode simplifyBooleanCase3(RexBuilder rexBuilder, + List> pairs) { + for (Ord> pair : Ord.zip(pairs)) { + if (pair.e.getKey().getType().isNullable()) { + return null; + } + } + final List terms = new ArrayList<>(); + final List notTerms = new ArrayList<>(); + for (Ord> pair : Ord.zip(pairs)) { + terms.add(RexUtil.andNot(rexBuilder, + rexBuilder.makeCall(SqlStdOperatorTable.AND, pair.e.getKey(), pair.e.getValue()), + notTerms)); + notTerms.add(pair.e.getKey()); + } + return RexUtil.composeDisjunction(rexBuilder, terms, false); + } + public static RexNode simplifyAnd(RexBuilder rexBuilder, RexCall e, boolean unknownAsFalse) { final List terms = new ArrayList<>(); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java index 479070b..3287ec1 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java @@ -230,9 +230,13 @@ private RexNode convert(ExprNodeGenericFuncDesc func) throws SemanticException { retType = TypeConverter.convert(func.getTypeInfo(), cluster.getTypeFactory()); SqlOperator calciteOp = SqlFunctionConverter.getCalciteOperator(func.getFuncText(), func.getGenericUDF(), argTypeBldr.build(), retType); - // If it is a case operator, we need to rewrite it if (calciteOp.getKind() == SqlKind.CASE) { + // If it is a case operator, we need to rewrite it childRexNodeLst = rewriteCaseChildren(func, childRexNodeLst); + } else if (FunctionRegistry.getNormalizedFunctionName(func.getFuncText()).equals("coalesce")) { + // If it is a coalesce operator, we rewrite it into CASE + calciteOp = SqlStdOperatorTable.CASE; + childRexNodeLst = rewriteCoalesceChildren(func, childRexNodeLst); } expr = cluster.getRexBuilder().makeCall(calciteOp, childRexNodeLst); } else { @@ -340,6 +344,21 @@ private RexNode handleExplicitCast(ExprNodeGenericFuncDesc func, List c return newChildRexNodeLst; } + private List rewriteCoalesceChildren(ExprNodeGenericFuncDesc func, List childRexNodeLst) { + List newChildRexNodeLst = new ArrayList(); + // Convert first n-1 expressions + for (int i = 0; (i + 1) < childRexNodeLst.size(); ++i) { + RexNode child = childRexNodeLst.get(i); + RexNode childCond = cluster.getRexBuilder().makeCall( + SqlStdOperatorTable.IS_NOT_NULL, child); + newChildRexNodeLst.add(childCond); + newChildRexNodeLst.add(child); + } + // n expression is for the else clause + newChildRexNodeLst.add(childRexNodeLst.get(childRexNodeLst.size()-1)); + return newChildRexNodeLst; + } + private static boolean checkForStatefulFunctions(List list) { for (ExprNodeDesc node : list) { if (node instanceof ExprNodeGenericFuncDesc) {