diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveSubQRemoveRelBuilder.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveSubQRemoveRelBuilder.java index c6a5ce261a..a8b408a633 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveSubQRemoveRelBuilder.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveSubQRemoveRelBuilder.java @@ -100,6 +100,9 @@ * because CALCITE-1493 hasn't been fixed yet * This should be deleted and replaced with RelBuilder in SubqueryRemoveRule * once CALCITE-1493 is fixed. + * EDIT: Although CALCITE-1493 has been fixed and released but HIVE now has special handling + * in join (it gets a flag to see if semi join is to be created or not). So we still can not + * replace this with Calcite's RelBuilder * *

{@code RelBuilder} does not make possible anything that you could not * also accomplish by calling the factory methods of the particular relational @@ -116,14 +119,14 @@ */ public class HiveSubQRemoveRelBuilder { private static final Function FN_TYPE = - new Function() { - public String apply(RexNode input) { - return input + ": " + input.getType(); - } - }; + new Function() { + public String apply(RexNode input) { + return input + ": " + input.getType(); + } + }; - protected final RelOptCluster cluster; - protected final RelOptSchema relOptSchema; + private final RelOptCluster cluster; + private final RelOptSchema relOptSchema; private final RelFactories.FilterFactory filterFactory; private final RelFactories.ProjectFactory projectFactory; private final RelFactories.AggregateFactory aggregateFactory; @@ -137,57 +140,57 @@ public String apply(RexNode input) { private final Deque stack = new ArrayDeque<>(); public HiveSubQRemoveRelBuilder(Context context, RelOptCluster cluster, - RelOptSchema relOptSchema) { + RelOptSchema relOptSchema) { this.cluster = cluster; this.relOptSchema = relOptSchema; if (context == null) { context = Contexts.EMPTY_CONTEXT; } this.aggregateFactory = - Util.first(context.unwrap(RelFactories.AggregateFactory.class), - HiveRelFactories.HIVE_AGGREGATE_FACTORY); + Util.first(context.unwrap(RelFactories.AggregateFactory.class), + HiveRelFactories.HIVE_AGGREGATE_FACTORY); this.filterFactory = - Util.first(context.unwrap(RelFactories.FilterFactory.class), - HiveRelFactories.HIVE_FILTER_FACTORY); + Util.first(context.unwrap(RelFactories.FilterFactory.class), + HiveRelFactories.HIVE_FILTER_FACTORY); this.projectFactory = - Util.first(context.unwrap(RelFactories.ProjectFactory.class), - HiveRelFactories.HIVE_PROJECT_FACTORY); + Util.first(context.unwrap(RelFactories.ProjectFactory.class), + HiveRelFactories.HIVE_PROJECT_FACTORY); this.sortFactory = - Util.first(context.unwrap(RelFactories.SortFactory.class), - HiveRelFactories.HIVE_SORT_FACTORY); + Util.first(context.unwrap(RelFactories.SortFactory.class), + HiveRelFactories.HIVE_SORT_FACTORY); this.setOpFactory = - Util.first(context.unwrap(RelFactories.SetOpFactory.class), - HiveRelFactories.HIVE_SET_OP_FACTORY); + Util.first(context.unwrap(RelFactories.SetOpFactory.class), + HiveRelFactories.HIVE_SET_OP_FACTORY); this.joinFactory = - Util.first(context.unwrap(RelFactories.JoinFactory.class), - HiveRelFactories.HIVE_JOIN_FACTORY); + Util.first(context.unwrap(RelFactories.JoinFactory.class), + HiveRelFactories.HIVE_JOIN_FACTORY); this.semiJoinFactory = - Util.first(context.unwrap(RelFactories.SemiJoinFactory.class), - HiveRelFactories.HIVE_SEMI_JOIN_FACTORY); + Util.first(context.unwrap(RelFactories.SemiJoinFactory.class), + HiveRelFactories.HIVE_SEMI_JOIN_FACTORY); this.correlateFactory = - Util.first(context.unwrap(RelFactories.CorrelateFactory.class), - RelFactories.DEFAULT_CORRELATE_FACTORY); + Util.first(context.unwrap(RelFactories.CorrelateFactory.class), + RelFactories.DEFAULT_CORRELATE_FACTORY); this.valuesFactory = - Util.first(context.unwrap(RelFactories.ValuesFactory.class), - RelFactories.DEFAULT_VALUES_FACTORY); + Util.first(context.unwrap(RelFactories.ValuesFactory.class), + RelFactories.DEFAULT_VALUES_FACTORY); this.scanFactory = - Util.first(context.unwrap(RelFactories.TableScanFactory.class), - RelFactories.DEFAULT_TABLE_SCAN_FACTORY); + Util.first(context.unwrap(RelFactories.TableScanFactory.class), + RelFactories.DEFAULT_TABLE_SCAN_FACTORY); } - /** Creates a RelBuilder. */ + /** Creates a RelBuilder. */ public static HiveSubQRemoveRelBuilder create(FrameworkConfig config) { final RelOptCluster[] clusters = {null}; final RelOptSchema[] relOptSchemas = {null}; Frameworks.withPrepare( - new Frameworks.PrepareAction(config) { - public Void apply(RelOptCluster cluster, RelOptSchema relOptSchema, - SchemaPlus rootSchema, CalciteServerStatement statement) { - clusters[0] = cluster; - relOptSchemas[0] = relOptSchema; - return null; - } - }); + new Frameworks.PrepareAction(config) { + public Void apply(RelOptCluster cluster, RelOptSchema relOptSchema, + SchemaPlus rootSchema, CalciteServerStatement statement) { + clusters[0] = cluster; + relOptSchemas[0] = relOptSchema; + return null; + } + }); return new HiveSubQRemoveRelBuilder(config.getContext(), clusters[0], relOptSchemas[0]); } @@ -286,15 +289,15 @@ public RexNode literal(Object value) { return rexBuilder.makeExactLiteral((BigDecimal) value); } else if (value instanceof Float || value instanceof Double) { return rexBuilder.makeApproxLiteral( - BigDecimal.valueOf(((Number) value).doubleValue())); + BigDecimal.valueOf(((Number) value).doubleValue())); } else if (value instanceof Number) { return rexBuilder.makeExactLiteral( - BigDecimal.valueOf(((Number) value).longValue())); + BigDecimal.valueOf(((Number) value).longValue())); } else if (value instanceof String) { return rexBuilder.makeLiteral((String) value); } else { throw new IllegalArgumentException("cannot convert " + value - + " (" + value.getClass() + ") to a constant"); + + " (" + value.getClass() + ") to a constant"); } } @@ -323,7 +326,7 @@ public RexInputRef field(int inputCount, int inputOrdinal, String fieldName) { return field(inputCount, inputOrdinal, i); } else { throw new IllegalArgumentException("field [" + fieldName - + "] not found; input fields are: " + fieldNames); + + "] not found; input fields are: " + fieldNames); } } @@ -359,12 +362,12 @@ private RexNode field(int inputCount, int inputOrdinal, int fieldOrdinal, final RelDataType rowType = input.getRowType(); if (fieldOrdinal < 0 || fieldOrdinal > rowType.getFieldCount()) { throw new IllegalArgumentException("field ordinal [" + fieldOrdinal - + "] out of range; input fields are: " + rowType.getFieldNames()); + + "] out of range; input fields are: " + rowType.getFieldNames()); } final RelDataTypeField field = rowType.getFieldList().get(fieldOrdinal); final int offset = inputOffset(inputCount, inputOrdinal); final RexInputRef ref = cluster.getRexBuilder() - .makeInputRef(field.getType(), offset + fieldOrdinal); + .makeInputRef(field.getType(), offset + fieldOrdinal); final RelDataTypeField aliasField = frame.fields().get(fieldOrdinal); if (!alias || field.getName().equals(aliasField.getName())) { return ref; @@ -388,15 +391,15 @@ public RexNode field(String alias, String fieldName) { return field(offset + i); } else { throw new IllegalArgumentException("no field '" + fieldName - + "' in relation '" + alias - + "'; fields are: " + pair.right.getFieldNames()); + + "' in relation '" + alias + + "'; fields are: " + pair.right.getFieldNames()); } } aliases.add(pair.left); offset += pair.right.getFieldCount(); } throw new IllegalArgumentException("no relation wtih alias '" + alias - + "'; aliases are: " + aliases); + + "'; aliases are: " + aliases); } /** Returns references to the fields of the top input. */ @@ -421,16 +424,16 @@ public RexNode field(String alias, String fieldName) { for (RelFieldCollation fieldCollation : collation.getFieldCollations()) { RexNode node = field(fieldCollation.getFieldIndex()); switch (fieldCollation.direction) { - case DESCENDING: - node = desc(node); + case DESCENDING: + node = desc(node); } switch (fieldCollation.nullDirection) { - case FIRST: - node = nullsFirst(node); - break; - case LAST: - node = nullsLast(node); - break; + case FIRST: + node = nullsFirst(node); + break; + case LAST: + node = nullsLast(node); + break; } nodes.add(node); } @@ -480,7 +483,7 @@ public RexNode call(SqlOperator operator, RexNode... operands) { final RelDataType type = builder.deriveReturnType(operator, operandList); if (type == null) { throw new IllegalArgumentException("cannot derive type: " + operator - + "; operands: " + Lists.transform(operandList, FN_TYPE)); + + "; operands: " + Lists.transform(operandList, FN_TYPE)); } return builder.makeCall(type, operator, operandList); } @@ -489,7 +492,7 @@ public RexNode call(SqlOperator operator, RexNode... operands) { public RexNode call(SqlOperator operator, Iterable operands) { return cluster.getRexBuilder().makeCall(operator, - ImmutableList.copyOf(operands)); + ImmutableList.copyOf(operands)); } /** Creates an AND. */ @@ -546,7 +549,7 @@ public RexNode cast(RexNode expr, SqlTypeName typeName) { * and precision or length. */ public RexNode cast(RexNode expr, SqlTypeName typeName, int precision) { final RelDataType type = - cluster.getTypeFactory().createSqlType(typeName, precision); + cluster.getTypeFactory().createSqlType(typeName, precision); return cluster.getRexBuilder().makeCast(type, expr); } @@ -555,7 +558,7 @@ public RexNode cast(RexNode expr, SqlTypeName typeName, int precision) { public RexNode cast(RexNode expr, SqlTypeName typeName, int precision, int scale) { final RelDataType type = - cluster.getTypeFactory().createSqlType(typeName, precision, scale); + cluster.getTypeFactory().createSqlType(typeName, precision, scale); return cluster.getRexBuilder().makeCast(type, expr); } @@ -604,7 +607,7 @@ public GroupKey groupKey(Iterable nodes) { public GroupKey groupKey(Iterable nodes, boolean indicator, Iterable> nodeLists) { final ImmutableList.Builder> builder = - ImmutableList.builder(); + ImmutableList.builder(); for (Iterable nodeList : nodeLists) { builder.add(ImmutableList.copyOf(nodeList)); } @@ -636,14 +639,14 @@ public GroupKey groupKey(ImmutableBitSet groupSet, boolean indicator, groupSets = ImmutableList.of(groupSet); } final ImmutableList nodes = - fields(ImmutableIntList.of(groupSet.toArray())); + fields(ImmutableIntList.of(groupSet.toArray())); final List> nodeLists = - Lists.transform(groupSets, - new Function>() { - public ImmutableList apply(ImmutableBitSet input) { - return fields(ImmutableIntList.of(input.toArray())); - } - }); + Lists.transform(groupSets, + new Function>() { + public ImmutableList apply(ImmutableBitSet input) { + return fields(ImmutableIntList.of(input.toArray())); + } + }); return groupKey(nodes, indicator, nodeLists); } @@ -651,7 +654,7 @@ public GroupKey groupKey(ImmutableBitSet groupSet, boolean indicator, public AggCall aggregateCall(SqlAggFunction aggFunction, boolean distinct, RexNode filter, String alias, RexNode... operands) { return aggregateCall(aggFunction, distinct, filter, alias, - ImmutableList.copyOf(operands)); + ImmutableList.copyOf(operands)); } /** Creates a call to an aggregate function. */ @@ -666,13 +669,13 @@ public AggCall aggregateCall(SqlAggFunction aggFunction, boolean distinct, } } return new AggCallImpl(aggFunction, distinct, filter, alias, - ImmutableList.copyOf(operands)); + ImmutableList.copyOf(operands)); } /** Creates a call to the COUNT aggregate function. */ public AggCall count(boolean distinct, String alias, RexNode... operands) { return aggregateCall(SqlStdOperatorTable.COUNT, distinct, null, alias, - operands); + operands); } /** Creates a call to the COUNT(*) aggregate function. */ @@ -683,13 +686,13 @@ public AggCall countStar(String alias) { /** Creates a call to the SUM aggregate function. */ public AggCall sum(boolean distinct, String alias, RexNode operand) { return aggregateCall(SqlStdOperatorTable.SUM, distinct, null, alias, - operand); + operand); } /** Creates a call to the AVG aggregate function. */ public AggCall avg(boolean distinct, String alias, RexNode operand) { return aggregateCall( - SqlStdOperatorTable.AVG, distinct, null, alias, operand); + SqlStdOperatorTable.AVG, distinct, null, alias, operand); } /** Creates a call to the MIN aggregate function. */ @@ -789,7 +792,7 @@ public HiveSubQRemoveRelBuilder project(Iterable nodes) { * @param fieldNames field names for expressions */ public HiveSubQRemoveRelBuilder project(Iterable nodes, - Iterable fieldNames) { + Iterable fieldNames) { return project(nodes, fieldNames, false); } @@ -817,9 +820,9 @@ public HiveSubQRemoveRelBuilder project(Iterable nodes, * @param force create project even if it is identity */ public HiveSubQRemoveRelBuilder project( - Iterable nodes, - Iterable fieldNames, - boolean force) { + Iterable nodes, + Iterable fieldNames, + boolean force) { final List names = new ArrayList<>(); final List exprList = Lists.newArrayList(nodes); final Iterator nameIterator = fieldNames.iterator(); @@ -837,17 +840,17 @@ public HiveSubQRemoveRelBuilder project( // create "virtual" row type for project only rename fields final Frame frame = stack.pop(); final RelDataType rowType = - RexUtil.createStructType(cluster.getTypeFactory(), exprList, - names, SqlValidatorUtil.F_SUGGESTER); + RexUtil.createStructType(cluster.getTypeFactory(), exprList, + names, SqlValidatorUtil.F_SUGGESTER); stack.push( - new Frame(frame.rel, - ImmutableList.of(Pair.of(frame.right.get(0).left, rowType)))); + new Frame(frame.rel, + ImmutableList.of(Pair.of(frame.right.get(0).left, rowType)))); return this; } } final RelNode project = - projectFactory.createProject(build(), ImmutableList.copyOf(exprList), - names); + projectFactory.createProject(build(), ImmutableList.copyOf(exprList), + names); push(project); return this; } @@ -865,24 +868,24 @@ public HiveSubQRemoveRelBuilder project(RexNode... nodes) { */ private String inferAlias(List exprList, RexNode expr) { switch (expr.getKind()) { - case INPUT_REF: - final RexInputRef ref = (RexInputRef) expr; - return peek(0).getRowType().getFieldNames().get(ref.getIndex()); - case CAST: - return inferAlias(exprList, ((RexCall) expr).getOperands().get(0)); - case AS: - final RexCall call = (RexCall) expr; - for (;;) { - final int i = exprList.indexOf(expr); - if (i < 0) { - break; - } - exprList.set(i, call.getOperands().get(0)); + case INPUT_REF: + final RexInputRef ref = (RexInputRef) expr; + return peek(0).getRowType().getFieldNames().get(ref.getIndex()); + case CAST: + return inferAlias(exprList, ((RexCall) expr).getOperands().get(0)); + case AS: + final RexCall call = (RexCall) expr; + for (;;) { + final int i = exprList.indexOf(expr); + if (i < 0) { + break; } - return ((NlsString) ((RexLiteral) call.getOperands().get(1)).getValue()) - .getValue(); - default: - return null; + exprList.set(i, call.getOperands().get(0)); + } + return ((NlsString) ((RexLiteral) call.getOperands().get(1)).getValue()) + .getValue(); + default: + return null; } } @@ -905,26 +908,26 @@ public HiveSubQRemoveRelBuilder aggregate(GroupKey groupKey, Iterable a final List extraNodes = projects(inputRowType); final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey; final ImmutableBitSet groupSet = - ImmutableBitSet.of(registerExpressions(extraNodes, groupKey_.nodes)); + ImmutableBitSet.of(registerExpressions(extraNodes, groupKey_.nodes)); final ImmutableList groupSets; if (groupKey_.nodeLists != null) { final int sizeBefore = extraNodes.size(); final SortedSet groupSetSet = - new TreeSet<>(ImmutableBitSet.ORDERING); + new TreeSet<>(ImmutableBitSet.ORDERING); for (ImmutableList nodeList : groupKey_.nodeLists) { final ImmutableBitSet groupSet2 = - ImmutableBitSet.of(registerExpressions(extraNodes, nodeList)); + ImmutableBitSet.of(registerExpressions(extraNodes, nodeList)); if (!groupSet.contains(groupSet2)) { throw new IllegalArgumentException("group set element " + nodeList - + " must be a subset of group key"); + + " must be a subset of group key"); } groupSetSet.add(groupSet2); } groupSets = ImmutableList.copyOf(groupSetSet); if (extraNodes.size() > sizeBefore) { throw new IllegalArgumentException( - "group sets contained expressions not in group key: " - + extraNodes.subList(sizeBefore, extraNodes.size())); + "group sets contained expressions not in group key: " + + extraNodes.subList(sizeBefore, extraNodes.size())); } } else { groupSets = ImmutableList.of(groupSet); @@ -949,10 +952,10 @@ public HiveSubQRemoveRelBuilder aggregate(GroupKey groupKey, Iterable a final AggCallImpl aggCall1 = (AggCallImpl) aggCall; final List args = registerExpressions(extraNodes, aggCall1.operands); final int filterArg = aggCall1.filter == null ? -1 - : registerExpression(extraNodes, aggCall1.filter); + : registerExpression(extraNodes, aggCall1.filter); aggregateCall = - AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct, args, - filterArg, groupSet.cardinality(), r, null, aggCall1.alias); + AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct, args, + filterArg, groupSet.cardinality(), r, null, aggCall1.alias); } else { aggregateCall = ((AggCallImpl2) aggCall).aggregateCall; } @@ -964,7 +967,7 @@ public HiveSubQRemoveRelBuilder aggregate(GroupKey groupKey, Iterable a assert groupSet.contains(set); } RelNode aggregate = aggregateFactory.createAggregate(r, - groupKey_.indicator, groupSet, groupSets, aggregateCalls); + groupKey_.indicator, groupSet, groupSets, aggregateCalls); push(aggregate); return this; } @@ -1002,22 +1005,22 @@ private HiveSubQRemoveRelBuilder setOp(boolean all, SqlKind kind, int n) { inputs.add(0, build()); } switch (kind) { - case UNION: - case INTERSECT: - case EXCEPT: + case UNION: + case INTERSECT: + case EXCEPT: if (n < 1) { throw new IllegalArgumentException( "bad INTERSECT/UNION/EXCEPT input count"); - } - break; - default: - throw new AssertionError("bad setOp " + kind); + } + break; + default: + throw new AssertionError("bad setOp " + kind); } switch (n) { - case 1: - return push(inputs.get(0)); - default: - return push(setOpFactory.createSetOp(kind, inputs, all)); + case 1: + return push(inputs.get(0)); + default: + return push(setOpFactory.createSetOp(kind, inputs, all)); } } @@ -1079,16 +1082,16 @@ public HiveSubQRemoveRelBuilder minus(boolean all, int n) { /** Creates a {@link org.apache.calcite.rel.core.Join}. */ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition0, - RexNode... conditions) { + RexNode... conditions) { return join(joinType, Lists.asList(condition0, conditions)); } /** Creates a {@link org.apache.calcite.rel.core.Join} with multiple * conditions. */ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, - Iterable conditions) { + Iterable conditions) { return join(joinType, and(conditions), - ImmutableSet.of()); + ImmutableSet.of()); } public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition) { @@ -1099,8 +1102,8 @@ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition) { * a Holder. */ public HiveSubQRemoveRelBuilder variable(Holder v) { v.set((RexCorrelVariable) - getRexBuilder().makeCorrel(peek().getRowType(), - cluster.createCorrel())); + getRexBuilder().makeCorrel(peek().getRowType(), + cluster.createCorrel())); return this; } @@ -1125,22 +1128,21 @@ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition, + " must not be used by left input to correlation"); } switch (joinType) { - case LEFT: - // Correlate does not have an ON clause. - // For a LEFT correlate, predicate must be evaluated first. - // For INNER, we can defer. - stack.push(right); - filter(condition.accept(new Shifter(left.rel, id, right.rel))); - right = stack.pop(); - break; - default: - postCondition = condition; + case LEFT: + // Correlate does not have an ON clause. + // For a LEFT correlate, predicate must be evaluated first. + // For INNER, we can defer. + stack.push(right); + filter(condition.accept(new Shifter(left.rel, id, right.rel))); + right = stack.pop(); + break; + default: + postCondition = condition; } if(createSemiJoin) { join = correlateFactory.createCorrelate(left.rel, right.rel, id, requiredColumns, SemiJoinType.SEMI); - } - else { + } else { join = correlateFactory.createCorrelate(left.rel, right.rel, id, requiredColumns, SemiJoinType.of(joinType)); @@ -1160,8 +1162,8 @@ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition, /** Creates a {@link org.apache.calcite.rel.core.Join} with correlating * variables. */ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition, - Set variablesSet) { - return join(joinType, condition, variablesSet, false) ; + Set variablesSet) { + return join(joinType, condition, variablesSet, false); } /** Creates a {@link org.apache.calcite.rel.core.Join} using USING syntax. @@ -1177,9 +1179,9 @@ public HiveSubQRemoveRelBuilder join(JoinRelType joinType, String... fieldNames) final List conditions = new ArrayList<>(); for (String fieldName : fieldNames) { conditions.add( - call(SqlStdOperatorTable.EQUALS, - field(2, 0, fieldName), - field(2, 1, fieldName))); + call(SqlStdOperatorTable.EQUALS, + field(2, 0, fieldName), + field(2, 1, fieldName))); } return join(joinType, conditions); } @@ -1189,7 +1191,7 @@ public HiveSubQRemoveRelBuilder semiJoin(Iterable conditions) final Frame right = stack.pop(); final Frame left = stack.pop(); final RelNode semiJoin = - semiJoinFactory.createSemiJoin(left.rel, right.rel, and(conditions)); + semiJoinFactory.createSemiJoin(left.rel, right.rel, and(conditions)); stack.push(new Frame(semiJoin, left.right)); return this; } @@ -1203,8 +1205,8 @@ public HiveSubQRemoveRelBuilder semiJoin(RexNode... conditions) { public HiveSubQRemoveRelBuilder as(String alias) { final Frame pair = stack.pop(); stack.push( - new Frame(pair.rel, - ImmutableList.of(Pair.of(alias, pair.right.get(0).right)))); + new Frame(pair.rel, + ImmutableList.of(Pair.of(alias, pair.right.get(0).right)))); return this; } @@ -1223,36 +1225,36 @@ public HiveSubQRemoveRelBuilder as(String alias) { */ public HiveSubQRemoveRelBuilder values(String[] fieldNames, Object... values) { if (fieldNames == null - || fieldNames.length == 0 - || values.length % fieldNames.length != 0 - || values.length < fieldNames.length) { + || fieldNames.length == 0 + || values.length % fieldNames.length != 0 + || values.length < fieldNames.length) { throw new IllegalArgumentException( - "Value count must be a positive multiple of field count"); + "Value count must be a positive multiple of field count"); } final int rowCount = values.length / fieldNames.length; for (Ord fieldName : Ord.zip(fieldNames)) { if (allNull(values, fieldName.i, fieldNames.length)) { throw new IllegalArgumentException("All values of field '" + fieldName.e - + "' are null; cannot deduce type"); + + "' are null; cannot deduce type"); } } final ImmutableList> tupleList = - tupleList(fieldNames.length, values); + tupleList(fieldNames.length, values); final RelDataTypeFactory.FieldInfoBuilder rowTypeBuilder = - cluster.getTypeFactory().builder(); + cluster.getTypeFactory().builder(); for (final Ord fieldName : Ord.zip(fieldNames)) { final String name = - fieldName.e != null ? fieldName.e : "expr$" + fieldName.i; + fieldName.e != null ? fieldName.e : "expr$" + fieldName.i; final RelDataType type = cluster.getTypeFactory().leastRestrictive( - new AbstractList() { - public RelDataType get(int index) { - return tupleList.get(index).get(fieldName.i).getType(); - } - - public int size() { - return rowCount; - } - }); + new AbstractList() { + public RelDataType get(int index) { + return tupleList.get(index).get(fieldName.i).getType(); + } + + public int size() { + return rowCount; + } + }); rowTypeBuilder.add(name, type); } final RelDataType rowType = rowTypeBuilder.build(); @@ -1262,7 +1264,7 @@ public int size() { private ImmutableList> tupleList(int columnCount, Object[] values) { final ImmutableList.Builder> listBuilder = - ImmutableList.builder(); + ImmutableList.builder(); final List valueList = new ArrayList<>(); for (int i = 0; i < values.length; i++) { Object value = values[i]; @@ -1296,7 +1298,7 @@ private boolean allNull(Object[] values, int column, int columnCount) { public HiveSubQRemoveRelBuilder empty() { final RelNode input = build(); final RelNode sort = HiveRelFactories.HIVE_SORT_FACTORY.createSort( - input, RelCollations.of(), null, literal(0)); + input, RelCollations.of(), null, literal(0)); return this.push(sort); } @@ -1312,9 +1314,9 @@ public HiveSubQRemoveRelBuilder empty() { */ public HiveSubQRemoveRelBuilder values(RelDataType rowType, Object... columnValues) { final ImmutableList> tupleList = - tupleList(rowType.getFieldCount(), columnValues); + tupleList(rowType.getFieldCount(), columnValues); RelNode values = valuesFactory.createValues(cluster, rowType, - ImmutableList.copyOf(tupleList)); + ImmutableList.copyOf(tupleList)); push(values); return this; } @@ -1329,9 +1331,9 @@ public HiveSubQRemoveRelBuilder values(RelDataType rowType, Object... columnValu * @param rowType Row type */ public HiveSubQRemoveRelBuilder values(Iterable> tupleList, - RelDataType rowType) { + RelDataType rowType) { RelNode values = - valuesFactory.createValues(cluster, rowType, copy(tupleList)); + valuesFactory.createValues(cluster, rowType, copy(tupleList)); push(values); return this; } @@ -1347,14 +1349,13 @@ public HiveSubQRemoveRelBuilder values(RelDataType rowType) { /** Converts an iterable of lists into an immutable list of immutable lists * with the same contents. Returns the same object if possible. */ - private static ImmutableList> - copy(Iterable> tupleList) { + private static ImmutableList> copy(Iterable> tupleList) { final ImmutableList.Builder> builder = - ImmutableList.builder(); + ImmutableList.builder(); int changeCount = 0; for (List literals : tupleList) { final ImmutableList literals2 = - ImmutableList.copyOf(literals); + ImmutableList.copyOf(literals); builder.add(literals2); if (literals != literals2) { ++changeCount; @@ -1408,15 +1409,15 @@ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, RexNode... node * @param nodes Sort expressions */ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, - Iterable nodes) { + Iterable nodes) { final List fieldCollations = new ArrayList<>(); final RelDataType inputRowType = peek().getRowType(); final List extraNodes = projects(inputRowType); final List originalExtraNodes = ImmutableList.copyOf(extraNodes); for (RexNode node : nodes) { fieldCollations.add( - collation(node, RelFieldCollation.Direction.ASCENDING, null, - extraNodes)); + collation(node, RelFieldCollation.Direction.ASCENDING, null, + extraNodes)); } final RexNode offsetNode = offset <= 0 ? null : literal(offset); final RexNode fetchNode = fetch < 0 ? null : literal(fetch); @@ -1437,8 +1438,8 @@ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, stack.pop(); push(sort2.getInput()); final RelNode sort = - sortFactory.createSort(build(), sort2.collation, - offsetNode, fetchNode); + sortFactory.createSort(build(), sort2.collation, + offsetNode, fetchNode); push(sort); return this; } @@ -1451,8 +1452,8 @@ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, stack.pop(); push(sort2.getInput()); final RelNode sort = - sortFactory.createSort(build(), sort2.collation, - offsetNode, fetchNode); + sortFactory.createSort(build(), sort2.collation, + offsetNode, fetchNode); push(sort); project(project.getProjects()); return this; @@ -1464,8 +1465,8 @@ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, project(extraNodes); } final RelNode sort = - sortFactory.createSort(build(), RelCollations.of(fieldCollations), - offsetNode, fetchNode); + sortFactory.createSort(build(), RelCollations.of(fieldCollations), + offsetNode, fetchNode); push(sort); if (addedFields) { project(originalExtraNodes); @@ -1475,26 +1476,27 @@ public HiveSubQRemoveRelBuilder sortLimit(int offset, int fetch, private static RelFieldCollation collation(RexNode node, RelFieldCollation.Direction direction, - RelFieldCollation.NullDirection nullDirection, List extraNodes) { + RelFieldCollation.NullDirection nullDirection, + List extraNodes) { switch (node.getKind()) { - case INPUT_REF: - return new RelFieldCollation(((RexInputRef) node).getIndex(), direction, - Util.first(nullDirection, direction.defaultNullDirection())); - case DESCENDING: - return collation(((RexCall) node).getOperands().get(0), - RelFieldCollation.Direction.DESCENDING, - nullDirection, extraNodes); - case NULLS_FIRST: - return collation(((RexCall) node).getOperands().get(0), direction, - RelFieldCollation.NullDirection.FIRST, extraNodes); - case NULLS_LAST: - return collation(((RexCall) node).getOperands().get(0), direction, - RelFieldCollation.NullDirection.LAST, extraNodes); - default: - final int fieldIndex = extraNodes.size(); - extraNodes.add(node); - return new RelFieldCollation(fieldIndex, direction, - Util.first(nullDirection, direction.defaultNullDirection())); + case INPUT_REF: + return new RelFieldCollation(((RexInputRef) node).getIndex(), direction, + Util.first(nullDirection, direction.defaultNullDirection())); + case DESCENDING: + return collation(((RexCall) node).getOperands().get(0), + RelFieldCollation.Direction.DESCENDING, + nullDirection, extraNodes); + case NULLS_FIRST: + return collation(((RexCall) node).getOperands().get(0), direction, + RelFieldCollation.NullDirection.FIRST, extraNodes); + case NULLS_LAST: + return collation(((RexCall) node).getOperands().get(0), direction, + RelFieldCollation.NullDirection.LAST, extraNodes); + default: + final int fieldIndex = extraNodes.size(); + extraNodes.add(node); + return new RelFieldCollation(fieldIndex, direction, + Util.first(nullDirection, direction.defaultNullDirection())); } } @@ -1509,7 +1511,7 @@ private static RelFieldCollation collation(RexNode node, public HiveSubQRemoveRelBuilder convert(RelDataType castRowType, boolean rename) { final RelNode r = build(); final RelNode r2 = - RelOptUtil.createCastRel(r, castRowType, rename, projectFactory); + RelOptUtil.createCastRel(r, castRowType, rename, projectFactory); push(r2); return this; } @@ -1528,14 +1530,14 @@ public HiveSubQRemoveRelBuilder permute(Mapping mapping) { } public HiveSubQRemoveRelBuilder aggregate(GroupKey groupKey, - List aggregateCalls) { + List aggregateCalls) { return aggregate(groupKey, - Lists.transform( - aggregateCalls, new Function() { - public AggCall apply(AggregateCall input) { - return new AggCallImpl2(input); - } - })); + Lists.transform( + aggregateCalls, new Function() { + public AggCall apply(AggregateCall input) { + return new AggCallImpl2(input); + } + })); } /** Clears the stack. @@ -1548,8 +1550,8 @@ public void clear() { protected String getAlias() { final Frame frame = stack.peek(); return frame.right.size() == 1 - ? frame.right.get(0).left - : null; + ? frame.right.get(0).left + : null; } /** Information necessary to create a call to an aggregate function. @@ -1570,10 +1572,10 @@ protected String getAlias() { /** Implementation of {@link RelBuilder.GroupKey}. */ protected static class GroupKeyImpl implements GroupKey { - final ImmutableList nodes; - final boolean indicator; - final ImmutableList> nodeLists; - final String alias; + private final ImmutableList nodes; + private final boolean indicator; + private final ImmutableList> nodeLists; + private final String alias; GroupKeyImpl(ImmutableList nodes, boolean indicator, ImmutableList> nodeLists, String alias) { @@ -1589,8 +1591,8 @@ protected String getAlias() { public GroupKey alias(String alias) { return Objects.equals(this.alias, alias) - ? this - : new GroupKeyImpl(nodes, indicator, nodeLists, alias); + ? this + : new GroupKeyImpl(nodes, indicator, nodeLists, alias); } } @@ -1626,16 +1628,16 @@ public GroupKey alias(String alias) { * *

Describes a previously created relational expression and * information about how table aliases map into its row type. */ - private static class Frame { + private static final class Frame { static final Function, List> FN = - new Function, List>() { - public List apply(Pair input) { - return input.right.getFieldList(); - } - }; + new Function, List>() { + public List apply(Pair input) { + return input.right.getFieldList(); + } + }; - final RelNode rel; - final ImmutableList> right; + private final RelNode rel; + private final ImmutableList> right; private Frame(RelNode rel, ImmutableList> pairs) { this.rel = rel; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java index 98d140fc8b..c9e02ea6ea 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelDecorrelator.java @@ -136,10 +136,10 @@ /** * NOTE: this whole logic is replicated from Calcite's RelDecorrelator * and is exteneded to make it suitable for HIVE - * TODO: * We should get rid of this and replace it with Calcite's RelDecorrelator * once that works with Join, Project etc instead of LogicalJoin, LogicalProject. - * Also we need to have CALCITE-1511 fixed + * At this point this has differed from Calcite's version significantly so cannot + * get rid of this. * * RelDecorrelator replaces all correlated expressions (corExp) in a relational * expression (RelNode) tree with non-correlated expressions that are produced @@ -156,7 +156,7 @@ * de-correlator * */ -public class HiveRelDecorrelator implements ReflectiveVisitor { +public final class HiveRelDecorrelator implements ReflectiveVisitor { //~ Static fields/initializers --------------------------------------------- protected static final Logger LOG = LoggerFactory.getLogger( @@ -191,7 +191,7 @@ //~ Constructors ----------------------------------------------------------- - private HiveRelDecorrelator ( + private HiveRelDecorrelator( RelOptCluster cluster, CorelMap cm, Context context) { @@ -698,225 +698,223 @@ private static RexLiteral projectedLiteral(RelNode rel, int i) { } public Frame decorrelateRel(HiveAggregate rel) throws SemanticException{ - { - if (rel.getGroupType() != Aggregate.Group.SIMPLE) { - throw new AssertionError(Bug.CALCITE_461_FIXED); - } - // - // Rewrite logic: - // - // 1. Permute the group by keys to the front. - // 2. If the input of an aggregate produces correlated variables, - // add them to the group list. - // 3. Change aggCalls to reference the new project. - // + if (rel.getGroupType() != Aggregate.Group.SIMPLE) { + throw new AssertionError(Bug.CALCITE_461_FIXED); + } + // + // Rewrite logic: + // + // 1. Permute the group by keys to the front. + // 2. If the input of an aggregate produces correlated variables, + // add them to the group list. + // 3. Change aggCalls to reference the new project. + // - // Aggregate itself should not reference cor vars. - assert !cm.mapRefRelToCorRef.containsKey(rel); + // Aggregate itself should not reference cor vars. + assert !cm.mapRefRelToCorRef.containsKey(rel); - final RelNode oldInput = rel.getInput(); - final Frame frame = getInvoke(oldInput, rel); - if (frame == null) { - // If input has not been rewritten, do not rewrite this rel. - return null; - } - //assert !frame.corVarOutputPos.isEmpty(); - final RelNode newInput = frame.r; + final RelNode oldInput = rel.getInput(); + final Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + //assert !frame.corVarOutputPos.isEmpty(); + final RelNode newInput = frame.r; - // map from newInput - Map mapNewInputToProjOutputs = new HashMap<>(); - final int oldGroupKeyCount = rel.getGroupSet().cardinality(); + // map from newInput + Map mapNewInputToProjOutputs = new HashMap<>(); + final int oldGroupKeyCount = rel.getGroupSet().cardinality(); - // Project projects the original expressions, - // plus any correlated variables the input wants to pass along. - final List> projects = Lists.newArrayList(); + // Project projects the original expressions, + // plus any correlated variables the input wants to pass along. + final List> projects = Lists.newArrayList(); - List newInputOutput = - newInput.getRowType().getFieldList(); + List newInputOutput = + newInput.getRowType().getFieldList(); - int newPos = 0; + int newPos = 0; - // oldInput has the original group by keys in the front. - final NavigableMap omittedConstants = new TreeMap<>(); - for (int i = 0; i < oldGroupKeyCount; i++) { - final RexLiteral constant = projectedLiteral(newInput, i); - if (constant != null) { - // Exclude constants. Aggregate({true}) occurs because Aggregate({}) - // would generate 1 row even when applied to an empty table. - omittedConstants.put(i, constant); - continue; - } - int newInputPos = frame.oldToNewOutputs.get(i); - projects.add(RexInputRef.of2(newInputPos, newInputOutput)); - mapNewInputToProjOutputs.put(newInputPos, newPos); - newPos++; + // oldInput has the original group by keys in the front. + final NavigableMap omittedConstants = new TreeMap<>(); + for (int i = 0; i < oldGroupKeyCount; i++) { + final RexLiteral constant = projectedLiteral(newInput, i); + if (constant != null) { + // Exclude constants. Aggregate({true}) occurs because Aggregate({}) + // would generate 1 row even when applied to an empty table. + omittedConstants.put(i, constant); + continue; } + int newInputPos = frame.oldToNewOutputs.get(i); + projects.add(RexInputRef.of2(newInputPos, newInputOutput)); + mapNewInputToProjOutputs.put(newInputPos, newPos); + newPos++; + } - final SortedMap corDefOutputs = new TreeMap<>(); - if (!frame.corDefOutputs.isEmpty()) { - // If input produces correlated variables, move them to the front, - // right after any existing GROUP BY fields. + final SortedMap corDefOutputs = new TreeMap<>(); + if (!frame.corDefOutputs.isEmpty()) { + // If input produces correlated variables, move them to the front, + // right after any existing GROUP BY fields. - // Now add the corVars from the input, starting from - // position oldGroupKeyCount. - for (Map.Entry entry - : frame.corDefOutputs.entrySet()) { - projects.add(RexInputRef.of2(entry.getValue(), newInputOutput)); + // Now add the corVars from the input, starting from + // position oldGroupKeyCount. + for (Map.Entry entry + : frame.corDefOutputs.entrySet()) { + projects.add(RexInputRef.of2(entry.getValue(), newInputOutput)); - corDefOutputs.put(entry.getKey(), newPos); - mapNewInputToProjOutputs.put(entry.getValue(), newPos); - newPos++; - } + corDefOutputs.put(entry.getKey(), newPos); + mapNewInputToProjOutputs.put(entry.getValue(), newPos); + newPos++; } + } - // add the remaining fields - final int newGroupKeyCount = newPos; - for (int i = 0; i < newInputOutput.size(); i++) { - if (!mapNewInputToProjOutputs.containsKey(i)) { - projects.add(RexInputRef.of2(i, newInputOutput)); - mapNewInputToProjOutputs.put(i, newPos); - newPos++; - } + // add the remaining fields + final int newGroupKeyCount = newPos; + for (int i = 0; i < newInputOutput.size(); i++) { + if (!mapNewInputToProjOutputs.containsKey(i)) { + projects.add(RexInputRef.of2(i, newInputOutput)); + mapNewInputToProjOutputs.put(i, newPos); + newPos++; } + } - assert newPos == newInputOutput.size(); + assert newPos == newInputOutput.size(); - // This Project will be what the old input maps to, - // replacing any previous mapping from old input). - RelNode newProject = HiveProject.create(newInput, Pair.left(projects), Pair.right(projects)); + // This Project will be what the old input maps to, + // replacing any previous mapping from old input). + RelNode newProject = HiveProject.create(newInput, Pair.left(projects), Pair.right(projects)); - // update mappings: - // oldInput ----> newInput - // - // newProject - // | - // oldInput ----> newInput - // - // is transformed to - // - // oldInput ----> newProject - // | - // newInput - Map combinedMap = Maps.newHashMap(); + // update mappings: + // oldInput ----> newInput + // + // newProject + // | + // oldInput ----> newInput + // + // is transformed to + // + // oldInput ----> newProject + // | + // newInput + Map combinedMap = Maps.newHashMap(); - for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) { - combinedMap.put(oldInputPos, - mapNewInputToProjOutputs.get( - frame.oldToNewOutputs.get(oldInputPos))); - } + for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) { + combinedMap.put(oldInputPos, + mapNewInputToProjOutputs.get( + frame.oldToNewOutputs.get(oldInputPos))); + } - register(oldInput, newProject, combinedMap, corDefOutputs); + register(oldInput, newProject, combinedMap, corDefOutputs); - // now it's time to rewrite the Aggregate - final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount); - List newAggCalls = Lists.newArrayList(); - List oldAggCalls = rel.getAggCallList(); + // now it's time to rewrite the Aggregate + final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount); + List newAggCalls = Lists.newArrayList(); + List oldAggCalls = rel.getAggCallList(); - int oldInputOutputFieldCount = rel.getGroupSet().cardinality(); - int newInputOutputFieldCount = newGroupSet.cardinality(); + int oldInputOutputFieldCount = rel.getGroupSet().cardinality(); + int newInputOutputFieldCount = newGroupSet.cardinality(); - int i = -1; - for (AggregateCall oldAggCall : oldAggCalls) { - ++i; - List oldAggArgs = oldAggCall.getArgList(); + int i = -1; + for (AggregateCall oldAggCall : oldAggCalls) { + ++i; + List oldAggArgs = oldAggCall.getArgList(); - List aggArgs = Lists.newArrayList(); + List aggArgs = Lists.newArrayList(); - // Adjust the aggregator argument positions. - // Note aggregator does not change input ordering, so the input - // output position mapping can be used to derive the new positions - // for the argument. - for (int oldPos : oldAggArgs) { - aggArgs.add(combinedMap.get(oldPos)); - } - final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg - : combinedMap.get(oldAggCall.filterArg); + // Adjust the aggregator argument positions. + // Note aggregator does not change input ordering, so the input + // output position mapping can be used to derive the new positions + // for the argument. + for (int oldPos : oldAggArgs) { + aggArgs.add(combinedMap.get(oldPos)); + } + final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg + : combinedMap.get(oldAggCall.filterArg); - newAggCalls.add( - oldAggCall.adaptTo(newProject, aggArgs, filterArg, - oldGroupKeyCount, newGroupKeyCount)); + newAggCalls.add( + oldAggCall.adaptTo(newProject, aggArgs, filterArg, + oldGroupKeyCount, newGroupKeyCount)); - // The old to new output position mapping will be the same as that - // of newProject, plus any aggregates that the oldAgg produces. - combinedMap.put( - oldInputOutputFieldCount + i, - newInputOutputFieldCount + i); - } + // The old to new output position mapping will be the same as that + // of newProject, plus any aggregates that the oldAgg produces. + combinedMap.put( + oldInputOutputFieldCount + i, + newInputOutputFieldCount + i); + } - relBuilder.push( - new HiveAggregate(rel.getCluster(), rel.getTraitSet(), newProject, newGroupSet, null, newAggCalls) ); + relBuilder.push( + new HiveAggregate(rel.getCluster(), rel.getTraitSet(), newProject, + newGroupSet, null, newAggCalls)); - if (!omittedConstants.isEmpty()) { - final List postProjects = new ArrayList<>(relBuilder.fields()); - for (Map.Entry entry - : omittedConstants.descendingMap().entrySet()) { - postProjects.add(entry.getKey() + frame.corDefOutputs.size(), - entry.getValue()); - } - relBuilder.project(postProjects); + if (!omittedConstants.isEmpty()) { + final List postProjects = new ArrayList<>(relBuilder.fields()); + for (Map.Entry entry + : omittedConstants.descendingMap().entrySet()) { + postProjects.add(entry.getKey() + frame.corDefOutputs.size(), + entry.getValue()); } - - // Aggregate does not change input ordering so corVars will be - // located at the same position as the input newProject. - return register(rel, relBuilder.build(), combinedMap, corDefOutputs); + relBuilder.project(postProjects); } + + // Aggregate does not change input ordering so corVars will be + // located at the same position as the input newProject. + return register(rel, relBuilder.build(), combinedMap, corDefOutputs); } public Frame decorrelateRel(HiveProject rel) throws SemanticException{ - { - // - // Rewrite logic: - // - // 1. Pass along any correlated variables coming from the input. - // + // + // Rewrite logic: + // + // 1. Pass along any correlated variables coming from the input. + // - final RelNode oldInput = rel.getInput(); - Frame frame = getInvoke(oldInput, rel); - if (frame == null) { - // If input has not been rewritten, do not rewrite this rel. - return null; - } - final List oldProjects = rel.getProjects(); - final List relOutput = rel.getRowType().getFieldList(); + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + final List oldProjects = rel.getProjects(); + final List relOutput = rel.getRowType().getFieldList(); - // LogicalProject projects the original expressions, - // plus any correlated variables the input wants to pass along. - final List> projects = Lists.newArrayList(); + // LogicalProject projects the original expressions, + // plus any correlated variables the input wants to pass along. + final List> projects = Lists.newArrayList(); - // If this LogicalProject has correlated reference, create value generator - // and produce the correlated variables in the new output. - if (cm.mapRefRelToCorRef.containsKey(rel)) { - frame = decorrelateInputWithValueGenerator(rel); - } + // If this LogicalProject has correlated reference, create value generator + // and produce the correlated variables in the new output. + if (cm.mapRefRelToCorRef.containsKey(rel)) { + frame = decorrelateInputWithValueGenerator(rel); + } - // LogicalProject projects the original expressions - final Map mapOldToNewOutputs = new HashMap<>(); - int newPos; - for (newPos = 0; newPos < oldProjects.size(); newPos++) { - projects.add( - newPos, - Pair.of( - decorrelateExpr(oldProjects.get(newPos)), - relOutput.get(newPos).getName())); - mapOldToNewOutputs.put(newPos, newPos); - } + // LogicalProject projects the original expressions + final Map mapOldToNewOutputs = new HashMap<>(); + int newPos; + for (newPos = 0; newPos < oldProjects.size(); newPos++) { + projects.add( + newPos, + Pair.of( + decorrelateExpr(oldProjects.get(newPos)), + relOutput.get(newPos).getName())); + mapOldToNewOutputs.put(newPos, newPos); + } - // Project any correlated variables the input wants to pass along. - final SortedMap corDefOutputs = new TreeMap<>(); - for (Map.Entry entry : frame.corDefOutputs.entrySet()) { - projects.add( - RexInputRef.of2(entry.getValue(), - frame.r.getRowType().getFieldList())); - corDefOutputs.put(entry.getKey(), newPos); - newPos++; - } + // Project any correlated variables the input wants to pass along. + final SortedMap corDefOutputs = new TreeMap<>(); + for (Map.Entry entry : frame.corDefOutputs.entrySet()) { + projects.add( + RexInputRef.of2(entry.getValue(), + frame.r.getRowType().getFieldList())); + corDefOutputs.put(entry.getKey(), newPos); + newPos++; + } - RelNode newProject = HiveProject.create(frame.r, Pair.left(projects), SqlValidatorUtil.uniquify(Pair.right(projects))); + RelNode newProject = HiveProject.create(frame.r, Pair.left(projects), + SqlValidatorUtil.uniquify(Pair.right(projects))); - return register(rel, newProject, mapOldToNewOutputs, - corDefOutputs); - } + return register(rel, newProject, mapOldToNewOutputs, + corDefOutputs); } /** * Rewrite LogicalProject. @@ -1118,10 +1116,10 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel) { // Try to populate correlation variables using local fields. // This means that we do not need a value generator. if (rel instanceof Filter) { - SortedMap map = new TreeMap<>(); + SortedMap coreMap = new TreeMap<>(); for (CorRef correlation : corVarList) { final CorDef def = correlation.def(); - if (corDefOutputs.containsKey(def) || map.containsKey(def)) { + if (corDefOutputs.containsKey(def) || coreMap.containsKey(def)) { continue; } try { @@ -1132,15 +1130,15 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel) { // is generated def.setPredicateKind((SqlOperator) ((Pair)((Pair)e.getNode()).getValue()).getKey()); def.setIsLeft((boolean)((Pair)((Pair) e.getNode()).getValue()).getValue()); - map.put(def, (Integer)((Pair) e.getNode()).getKey()); + coreMap.put(def, (Integer)((Pair) e.getNode()).getKey()); } } // If all correlation variables are now satisfied, skip creating a value // generator. - if (map.size() == corVarList.size()) { - map.putAll(frame.corDefOutputs); + if (coreMap.size() == corVarList.size()) { + coreMap.putAll(frame.corDefOutputs); return register(oldInput, frame.r, - frame.oldToNewOutputs, map); + frame.oldToNewOutputs, coreMap); } } @@ -1149,14 +1147,14 @@ private Frame decorrelateInputWithValueGenerator(RelNode rel) { // can directly add positions into corDefOutputs since join // does not change the output ordering from the inputs. - RelNode valueGen = + RelNode valueGenRel = createValueGenerator( corVarList, leftInputOutputCount, corDefOutputs); RelNode join = - LogicalJoin.create(frame.r, valueGen, rexBuilder.makeLiteral(true), + LogicalJoin.create(frame.r, valueGenRel, rexBuilder.makeLiteral(true), ImmutableSet.of(), JoinRelType.INNER); // LogicalJoin or LogicalFilter does not change the old input ordering. All @@ -1208,23 +1206,23 @@ private void findCorrelationEquivalent(CorRef correlation, RexNode e) private boolean references(RexNode e, CorRef correlation) { switch (e.getKind()) { - case CAST: - final RexNode operand = ((RexCall) e).getOperands().get(0); - if (isWidening(e.getType(), operand.getType())) { - return references(operand, correlation); - } - return false; - case FIELD_ACCESS: - final RexFieldAccess f = (RexFieldAccess) e; - if (f.getField().getIndex() == correlation.field - && f.getReferenceExpr() instanceof RexCorrelVariable) { - if (((RexCorrelVariable) f.getReferenceExpr()).id == correlation.corr) { - return true; - } + case CAST: + final RexNode operand = ((RexCall) e).getOperands().get(0); + if (isWidening(e.getType(), operand.getType())) { + return references(operand, correlation); + } + return false; + case FIELD_ACCESS: + final RexFieldAccess f = (RexFieldAccess) e; + if (f.getField().getIndex() == correlation.field + && f.getReferenceExpr() instanceof RexCorrelVariable) { + if (((RexCorrelVariable) f.getReferenceExpr()).id == correlation.corr) { + return true; } - // fall through - default: - return false; + } + // fall through + default: + return false; } } @@ -1241,69 +1239,70 @@ private boolean isWidening(RelDataType type, RelDataType type1) { } public Frame decorrelateRel(HiveFilter rel) throws SemanticException { - { - // - // Rewrite logic: - // - // 1. If a LogicalFilter references a correlated field in its filter - // condition, rewrite the LogicalFilter to be - // LogicalFilter - // LogicalJoin(cross product) - // OriginalFilterInput - // ValueGenerator(produces distinct sets of correlated variables) - // and rewrite the correlated fieldAccess in the filter condition to - // reference the LogicalJoin output. - // - // 2. If LogicalFilter does not reference correlated variables, simply - // rewrite the filter condition using new input. - // + // + // Rewrite logic: + // + // 1. If a LogicalFilter references a correlated field in its filter + // condition, rewrite the LogicalFilter to be + // LogicalFilter + // LogicalJoin(cross product) + // OriginalFilterInput + // ValueGenerator(produces distinct sets of correlated variables) + // and rewrite the correlated fieldAccess in the filter condition to + // reference the LogicalJoin output. + // + // 2. If LogicalFilter does not reference correlated variables, simply + // rewrite the filter condition using new input. + // - final RelNode oldInput = rel.getInput(); - Frame frame = getInvoke(oldInput, rel); - if (frame == null) { - // If input has not been rewritten, do not rewrite this rel. - return null; - } + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput, rel); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + + Frame oldInputFrame = frame; + // If this LogicalFilter has correlated reference, create value generator + // and produce the correlated variables in the new output. + if (cm.mapRefRelToCorRef.containsKey(rel)) { + frame = decorrelateInputWithValueGenerator(rel); + } + + boolean valueGenerator = true; + if(frame.r == oldInputFrame.r) { + // this means correated value generator wasn't generated + valueGenerator = false; + } - Frame oldInputFrame = frame; - // If this LogicalFilter has correlated reference, create value generator - // and produce the correlated variables in the new output. - if (cm.mapRefRelToCorRef.containsKey(rel)) { - frame = decorrelateInputWithValueGenerator(rel); - } - - boolean valueGenerator = true; - if(frame.r == oldInputFrame.r) { - // this means correated value generator wasn't generated - valueGenerator = false; - } - - if(oldInput instanceof LogicalCorrelate && ((LogicalCorrelate) oldInput).getJoinType() == SemiJoinType.SEMI - && !cm.mapRefRelToCorRef.containsKey(rel)) { - // this conditions need to be pushed into semi-join since this condition - // corresponds to IN - HiveSemiJoin join = ((HiveSemiJoin)frame.r); - final List conditions = new ArrayList<>(); - RexNode joinCond = join.getCondition(); - conditions.add(joinCond); - conditions.add(decorrelateExpr(rel.getCondition(), valueGenerator)); - final RexNode condition = - RexUtil.composeConjunction(rexBuilder, conditions, false); - RelNode newRel = HiveSemiJoin.getSemiJoin(frame.r.getCluster(), frame.r.getTraitSet(), join.getLeft(), join.getRight(), - condition,join.getLeftKeys(), join.getRightKeys()); - return register(rel, newRel, frame.oldToNewOutputs, frame.corDefOutputs); - } - // Replace the filter expression to reference output of the join - // Map filter to the new filter over join - relBuilder.push(frame.r).filter( - (decorrelateExpr(rel.getCondition(), valueGenerator))); - // Filter does not change the input ordering. - // Filter rel does not permute the input. - // All corvars produced by filter will have the same output positions in the - // input rel. - return register(rel, relBuilder.build(), frame.oldToNewOutputs, - frame.corDefOutputs); + if(oldInput instanceof LogicalCorrelate + && ((LogicalCorrelate) oldInput).getJoinType() == SemiJoinType.SEMI + && !cm.mapRefRelToCorRef.containsKey(rel)) { + // this conditions need to be pushed into semi-join since this condition + // corresponds to IN + HiveSemiJoin join = ((HiveSemiJoin)frame.r); + final List conditions = new ArrayList<>(); + RexNode joinCond = join.getCondition(); + conditions.add(joinCond); + conditions.add(decorrelateExpr(rel.getCondition(), valueGenerator)); + final RexNode condition = + RexUtil.composeConjunction(rexBuilder, conditions, false); + + RelNode newRel = HiveSemiJoin.getSemiJoin(frame.r.getCluster(), frame.r.getTraitSet(), + join.getLeft(), join.getRight(), condition, join.getLeftKeys(), join.getRightKeys()); + + return register(rel, newRel, frame.oldToNewOutputs, frame.corDefOutputs); } + // Replace the filter expression to reference output of the join + // Map filter to the new filter over join + relBuilder.push(frame.r).filter( + (decorrelateExpr(rel.getCondition(), valueGenerator))); + // Filter does not change the input ordering. + // Filter rel does not permute the input. + // All corvars produced by filter will have the same output positions in the + // input rel. + return register(rel, relBuilder.build(), frame.oldToNewOutputs, + frame.corDefOutputs); } /** @@ -1348,7 +1347,8 @@ public Frame decorrelateRel(LogicalFilter rel) { valueGenerator = false; } - if(oldInput instanceof LogicalCorrelate && ((LogicalCorrelate) oldInput).getJoinType() == SemiJoinType.SEMI + if(oldInput instanceof LogicalCorrelate + && ((LogicalCorrelate) oldInput).getJoinType() == SemiJoinType.SEMI && !cm.mapRefRelToCorRef.containsKey(rel)) { // this conditions need to be pushed into semi-join since this condition // corresponds to IN @@ -1359,8 +1359,8 @@ public Frame decorrelateRel(LogicalFilter rel) { conditions.add(decorrelateExpr(rel.getCondition(), valueGenerator)); final RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions, false); - RelNode newRel = HiveSemiJoin.getSemiJoin(frame.r.getCluster(), frame.r.getTraitSet(), join.getLeft(), join.getRight(), - condition,join.getLeftKeys(), join.getRightKeys()); + RelNode newRel = HiveSemiJoin.getSemiJoin(frame.r.getCluster(), frame.r.getTraitSet(), + join.getLeft(), join.getRight(), condition, join.getLeftKeys(), join.getRightKeys()); return register(rel, newRel, frame.oldToNewOutputs, frame.corDefOutputs); } @@ -1443,8 +1443,7 @@ public Frame decorrelateRel(LogicalCorrelate rel) { RexInputRef.of(newLeftPos, newLeftOutput), new RexInputRef(newLeftFieldCount + newRightPos, newRightOutput.get(newRightPos).getType()))); - } - else { + } else { conditions.add( rexBuilder.makeCall(callOp, new RexInputRef(newLeftFieldCount + newRightPos, @@ -1488,13 +1487,12 @@ public Frame decorrelateRel(LogicalCorrelate rel) { final List leftKeys = new ArrayList(); final List rightKeys = new ArrayList(); - RelNode[] inputRels = new RelNode[] { leftFrame.r, rightFrame.r}; - newJoin = HiveSemiJoin.getSemiJoin(rel.getCluster(), rel.getCluster().traitSetOf(HiveRelNode.CONVENTION), - leftFrame.r, rightFrame.r, condition, ImmutableIntList.copyOf(leftKeys), - ImmutableIntList.copyOf(rightKeys)); + RelNode[] inputRels = new RelNode[] {leftFrame.r, rightFrame.r}; + newJoin = HiveSemiJoin.getSemiJoin(rel.getCluster(), + rel.getCluster().traitSetOf(HiveRelNode.CONVENTION), leftFrame.r, rightFrame.r, + condition, ImmutableIntList.copyOf(leftKeys), ImmutableIntList.copyOf(rightKeys)); - } - else { + } else { // Right input positions are shifted by newLeftFieldCount. for (int i = 0; i < oldRightFieldCount; i++) { mapOldToNewOutputs.put( @@ -1531,7 +1529,8 @@ public Frame decorrelateRel(HiveJoin rel) throws SemanticException{ return null; } - final RelNode newJoin = HiveJoin.getJoin(rel.getCluster(), leftFrame.r, rightFrame.r, decorrelateExpr(rel.getCondition()), rel.getJoinType() ); + final RelNode newJoin = HiveJoin.getJoin(rel.getCluster(), leftFrame.r, rightFrame.r, + decorrelateExpr(rel.getCondition()), rel.getJoinType()); // Create the mapping between the output of the old correlation rel // and the new join rel @@ -1589,7 +1588,7 @@ public Frame decorrelateRel(LogicalJoin rel) { } final RelNode newJoin = HiveJoin.getJoin(rel.getCluster(), leftFrame.r, - rightFrame.r, decorrelateExpr(rel.getCondition()), rel.getJoinType() ); + rightFrame.r, decorrelateExpr(rel.getCondition()), rel.getJoinType()); // Create the mapping between the output of the old correlation rel // and the new join rel @@ -1838,7 +1837,7 @@ private boolean checkCorVars( } /** - * Remove correlated variables from the tree at root corRel + * Remove correlated variables from the tree at root corRel. * * @param correlate Correlator */ @@ -1949,7 +1948,7 @@ public void setValueGenerator(boolean valueGenerator) { final List newOperands = new ArrayList<>(); newOperands.add(o0); newOperands.add(o1); - boolean[] update = { false }; + boolean[] update = {false}; List clonedOperands = visitList(newOperands, update); return relBuilder.call(call.getOperator(), clonedOperands); @@ -2003,13 +2002,13 @@ private RexNode decorrFieldAccess(RexFieldAccess fieldAccess) { /** Shuttle that removes correlations. */ private class RemoveCorrelationRexShuttle extends RexShuttle { - final RexBuilder rexBuilder; - final RelDataTypeFactory typeFactory; - final boolean projectPulledAboveLeftCorrelator; - final RexInputRef nullIndicator; - final ImmutableSet isCount; + private final RexBuilder rexBuilder; + private final RelDataTypeFactory typeFactory; + private final boolean projectPulledAboveLeftCorrelator; + private final RexInputRef nullIndicator; + private final ImmutableSet isCount; - public RemoveCorrelationRexShuttle( + RemoveCorrelationRexShuttle( RexBuilder rexBuilder, boolean projectPulledAboveLeftCorrelator, RexInputRef nullIndicator, @@ -2204,7 +2203,7 @@ private RexNode createCaseExpression( * AggRel single group */ private final class RemoveSingleAggregateRule extends RelOptRule { - public RemoveSingleAggregateRule() { + RemoveSingleAggregateRule() { super( operand( LogicalAggregate.class, @@ -2257,7 +2256,7 @@ public void onMatch(RelOptRuleCall call) { /** Planner rule that removes correlations for scalar projects. */ private final class RemoveCorrelationForScalarProjectRule extends RelOptRule { - public RemoveCorrelationForScalarProjectRule() { + RemoveCorrelationForScalarProjectRule() { super( operand(LogicalCorrelate.class, operand(RelNode.class, any()), @@ -2456,7 +2455,7 @@ public void onMatch(RelOptRuleCall call) { /** Planner rule that removes correlations for scalar aggregates. */ private final class RemoveCorrelationForScalarAggregateRule extends RelOptRule { - public RemoveCorrelationForScalarAggregateRule() { + RemoveCorrelationForScalarAggregateRule() { super( operand(LogicalCorrelate.class, operand(RelNode.class, any()), @@ -2838,9 +2837,9 @@ public void onMatch(RelOptRuleCall call) { /** Planner rule that adjusts projects when counts are added. */ private final class AdjustProjectForCountAggregateRule extends RelOptRule { - final boolean flavor; + private final boolean flavor; - public AdjustProjectForCountAggregateRule(boolean flavor) { + AdjustProjectForCountAggregateRule(boolean flavor) { super( flavor ? operand(LogicalCorrelate.class, @@ -2976,9 +2975,9 @@ private void onMatch2( * {@link CorRef#uniqueKey}. */ static class CorRef implements Comparable { - public final int uniqueKey; - public final CorrelationId corr; - public final int field; + private final int uniqueKey; + private final CorrelationId corr; + private final int field; CorRef(CorrelationId corr, int field, int uniqueKey) { this.corr = corr; @@ -3021,8 +3020,8 @@ public CorDef def() { /** A correlation and a field. */ static class CorDef implements Comparable { - public final CorrelationId corr; - public final int field; + private final CorrelationId corr; + private final int field; private SqlOperator predicateKind; // this indicates if corr var is left operand of rex call or not @@ -3100,7 +3099,7 @@ public void setIsLeft(boolean isLeft) { * updated. * * */ - private static class CorelMap { + private static final class CorelMap { private final Multimap mapRefRelToCorRef; private final SortedMap mapCorToCorRel; private final Map mapFieldAccessToCorRef; @@ -3155,8 +3154,10 @@ public boolean hasCorrelation() { } private static class findIfValueGenRequired extends HiveRelShuttleImpl { - private boolean mightRequireValueGen ; - findIfValueGenRequired() { this.mightRequireValueGen = true; } + private boolean mightRequireValueGen; + findIfValueGenRequired() { + this.mightRequireValueGen = true; + } private boolean hasRexOver(List projects) { for(RexNode expr : projects) { @@ -3200,8 +3201,7 @@ public RelNode visit(HiveIntersect rel) { if(!(hasRexOver(((HiveProject)rel).getProjects()))) { mightRequireValueGen = false; return super.visit(rel); - } - else { + } else { mightRequireValueGen = true; return rel; } @@ -3210,8 +3210,7 @@ public RelNode visit(HiveIntersect rel) { if(!(hasRexOver(((LogicalProject)rel).getProjects()))) { mightRequireValueGen = false; return super.visit(rel); - } - else { + } else { mightRequireValueGen = true; return rel; } @@ -3219,12 +3218,10 @@ public RelNode visit(HiveIntersect rel) { @Override public RelNode visit(HiveAggregate rel) { // if there are aggregate functions or grouping sets we will need // value generator - if((((HiveAggregate)rel).getAggCallList().isEmpty() == true - && ((HiveAggregate)rel).indicator == false)) { + if(rel.getAggCallList().isEmpty() && !rel.indicator) { this.mightRequireValueGen = false; return super.visit(rel); - } - else { + } else { // need to reset to true in case previous aggregate/project // has set it to false this.mightRequireValueGen = true; @@ -3232,12 +3229,10 @@ public RelNode visit(HiveIntersect rel) { } } @Override public RelNode visit(LogicalAggregate rel) { - if((((LogicalAggregate)rel).getAggCallList().isEmpty() == true - && ((LogicalAggregate)rel).indicator == false)) { + if(rel.getAggCallList().isEmpty() && !rel.indicator) { this.mightRequireValueGen = false; return super.visit(rel); - } - else { + } else { // need to reset to true in case previous aggregate/project // has set it to false this.mightRequireValueGen = true; @@ -3257,10 +3252,10 @@ public boolean traverse(RelNode root) { } /** Builds a {@link org.apache.calcite.sql2rel.RelDecorrelator.CorelMap}. */ private static class CorelMapBuilder extends HiveRelShuttleImpl { - final SortedMap mapCorToCorRel = + private final SortedMap mapCorToCorRel = new TreeMap<>(); - final SortedSetMultimap mapRefRelToCorRef = + private final SortedSetMultimap mapRefRelToCorRef = Multimaps.newSortedSetMultimap( new HashMap>(), new Supplier>() { @@ -3270,12 +3265,12 @@ public boolean traverse(RelNode root) { } }); - final Map mapFieldAccessToCorVar = new HashMap<>(); + private final Map mapFieldAccessToCorVar = new HashMap<>(); - final Holder offset = Holder.of(0); - int corrIdGenerator = 0; + private final Holder offset = Holder.of(0); + private int corrIdGenerator = 0; - final List stack = new ArrayList<>(); + private final List stack = new ArrayList<>(); /** Creates a CorelMap by iterating over a {@link RelNode} tree. */ CorelMap build(RelNode rel) { @@ -3401,9 +3396,9 @@ public RelNode visit(final HiveFilter filter) { * and where to find the output fields and correlation variables * among its output fields. */ static class Frame { - final RelNode r; - final ImmutableSortedMap corDefOutputs; - final ImmutableSortedMap oldToNewOutputs; + private final RelNode r; + private final ImmutableSortedMap corDefOutputs; + private final ImmutableSortedMap oldToNewOutputs; Frame(RelNode oldRel, RelNode r, SortedMap corDefOutputs, Map oldToNewOutputs) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSubQueryRemoveRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSubQueryRemoveRule.java index 90aab6e2d2..4758a37e1e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSubQueryRemoveRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSubQueryRemoveRule.java @@ -19,7 +19,6 @@ import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; -import org.apache.calcite.plan.RelOptRuleOperand; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; @@ -42,7 +41,6 @@ import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; @@ -67,7 +65,6 @@ * TODO: * Reason this is replicated instead of using Calcite's is * Calcite creates null literal with null type but hive needs it to be properly typed - * Need fix for Calcite-1493 * *

Sub-queries are represented by {@link RexSubQuery} expressions. * @@ -76,493 +73,491 @@ * the rewrite, and the product of the rewrite will be a {@link Correlate}. * The Correlate can be removed using {@link RelDecorrelator}. */ -public class HiveSubQueryRemoveRule extends RelOptRule{ - - private HiveConf conf; - - public HiveSubQueryRemoveRule(HiveConf conf) { - super(operand(RelNode.class, null, HiveSubQueryFinder.RELNODE_PREDICATE, - any()), - HiveRelFactories.HIVE_BUILDER, "SubQueryRemoveRule:Filter") ; - this.conf = conf; - +public class HiveSubQueryRemoveRule extends RelOptRule { + + private HiveConf conf; + + public HiveSubQueryRemoveRule(HiveConf conf) { + super(operand(RelNode.class, null, HiveSubQueryFinder.RELNODE_PREDICATE, + any()), + HiveRelFactories.HIVE_BUILDER, "SubQueryRemoveRule:Filter"); + this.conf = conf; + } + public void onMatch(RelOptRuleCall call) { + final RelNode relNode = call.rel(0); + final HiveSubQRemoveRelBuilder builder = + new HiveSubQRemoveRelBuilder(null, call.rel(0).getCluster(), null); + + // if subquery is in FILTER + if(relNode instanceof Filter) { + final Filter filter = call.rel(0); + final RexSubQuery e = + RexUtil.SubQueryFinder.find(filter.getCondition()); + assert e != null; + + final RelOptUtil.Logic logic = + LogicVisitor.find(RelOptUtil.Logic.TRUE, + ImmutableList.of(filter.getCondition()), e); + builder.push(filter.getInput()); + final int fieldCount = builder.peek().getRowType().getFieldCount(); + + assert(filter instanceof HiveFilter); + SubqueryConf subqueryConfig = filter.getCluster().getPlanner(). + getContext().unwrap(SubqueryConf.class); + boolean isCorrScalarQuery = subqueryConfig.getCorrScalarRexSQWithAgg().contains(e.rel); + boolean hasNoWindowingAndNoGby = + subqueryConfig.getScalarAggWithoutGbyWindowing().contains(e.rel); + + final RexNode target = apply(e, HiveFilter.getVariablesSet(e), logic, + builder, 1, fieldCount, isCorrScalarQuery, hasNoWindowingAndNoGby); + final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); + builder.filter(shuttle.apply(filter.getCondition())); + builder.project(fields(builder, filter.getRowType().getFieldCount())); + call.transformTo(builder.build()); + } else if(relNode instanceof Project) { + // if subquery is in PROJECT + final Project project = call.rel(0); + final RexSubQuery e = + RexUtil.SubQueryFinder.find(project.getProjects()); + assert e != null; + + final RelOptUtil.Logic logic = + LogicVisitor.find(RelOptUtil.Logic.TRUE_FALSE_UNKNOWN, + project.getProjects(), e); + builder.push(project.getInput()); + final int fieldCount = builder.peek().getRowType().getFieldCount(); + + SubqueryConf subqueryConfig = + project.getCluster().getPlanner().getContext().unwrap(SubqueryConf.class); + boolean isCorrScalarQuery = subqueryConfig.getCorrScalarRexSQWithAgg().contains(e.rel); + boolean hasNoWindowingAndNoGby = + subqueryConfig.getScalarAggWithoutGbyWindowing().contains(e.rel); + + final RexNode target = apply(e, HiveFilter.getVariablesSet(e), + logic, builder, 1, fieldCount, isCorrScalarQuery, hasNoWindowingAndNoGby); + final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); + builder.project(shuttle.apply(project.getProjects()), + project.getRowType().getFieldNames()); + call.transformTo(builder.build()); } - public void onMatch(RelOptRuleCall call) { - final RelNode relNode = call.rel(0); - //TODO: replace HiveSubQRemoveRelBuilder with calcite's once calcite 1.11.0 is released - final HiveSubQRemoveRelBuilder builder = new HiveSubQRemoveRelBuilder(null, call.rel(0).getCluster(), null); - - // if subquery is in FILTER - if(relNode instanceof Filter) { - final Filter filter = call.rel(0); - final RexSubQuery e = - RexUtil.SubQueryFinder.find(filter.getCondition()); - assert e != null; - - final RelOptUtil.Logic logic = - LogicVisitor.find(RelOptUtil.Logic.TRUE, - ImmutableList.of(filter.getCondition()), e); - builder.push(filter.getInput()); - final int fieldCount = builder.peek().getRowType().getFieldCount(); - - assert(filter instanceof HiveFilter); - SubqueryConf subqueryConfig = filter.getCluster().getPlanner().getContext().unwrap(SubqueryConf.class); - boolean isCorrScalarQuery = subqueryConfig.getCorrScalarRexSQWithAgg().contains(e.rel); - boolean hasNoWindowingAndNoGby = subqueryConfig.getScalarAggWithoutGbyWindowing().contains(e.rel); - - final RexNode target = apply(e, HiveFilter.getVariablesSet(e), logic, - builder, 1, fieldCount, isCorrScalarQuery, hasNoWindowingAndNoGby); - final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); - builder.filter(shuttle.apply(filter.getCondition())); - builder.project(fields(builder, filter.getRowType().getFieldCount())); - call.transformTo(builder.build()); - } - // if subquery is in PROJECT - else if(relNode instanceof Project) { - final Project project = call.rel(0); - final RexSubQuery e = - RexUtil.SubQueryFinder.find(project.getProjects()); - assert e != null; - - final RelOptUtil.Logic logic = - LogicVisitor.find(RelOptUtil.Logic.TRUE_FALSE_UNKNOWN, - project.getProjects(), e); - builder.push(project.getInput()); - final int fieldCount = builder.peek().getRowType().getFieldCount(); - - SubqueryConf subqueryConfig = project.getCluster().getPlanner().getContext().unwrap(SubqueryConf.class); - boolean isCorrScalarQuery = subqueryConfig.getCorrScalarRexSQWithAgg().contains(e.rel); - boolean hasNoWindowingAndNoGby = subqueryConfig.getScalarAggWithoutGbyWindowing().contains(e.rel); - - final RexNode target = apply(e, HiveFilter.getVariablesSet(e), - logic, builder, 1, fieldCount, isCorrScalarQuery, hasNoWindowingAndNoGby); - final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); - builder.project(shuttle.apply(project.getProjects()), - project.getRowType().getFieldNames()); - call.transformTo(builder.build()); - } + } + + // given a subquery it checks to see what is the aggegate function + /// if COUNT returns true since COUNT produces 0 on empty result set + private boolean isAggZeroOnEmpty(RexSubQuery e) { + //as this is corr scalar subquery with agg we expect one aggregate + assert(e.getKind() == SqlKind.SCALAR_QUERY); + assert(e.rel.getInputs().size() == 1); + Aggregate relAgg = (Aggregate)e.rel.getInput(0); + assert(relAgg.getAggCallList().size() == 1); //should only have one aggregate + if(relAgg.getAggCallList().get(0).getAggregation().getKind() == SqlKind.COUNT) { + return true; } - - /*private HiveSubQueryRemoveRule(RelOptRuleOperand operand, - RelBuilderFactory relBuilderFactory, - String description) { - super(operand, relBuilderFactory, description); - } */ - - // given a subquery it checks to see what is the aggegate function - /// if COUNT returns true since COUNT produces 0 on empty result set - private boolean isAggZeroOnEmpty(RexSubQuery e) { - //as this is corr scalar subquery with agg we expect one aggregate - assert(e.getKind() == SqlKind.SCALAR_QUERY); - assert(e.rel.getInputs().size() == 1); - Aggregate relAgg = (Aggregate)e.rel.getInput(0); - assert( relAgg.getAggCallList().size() == 1); //should only have one aggregate - if( relAgg.getAggCallList().get(0).getAggregation().getKind() == SqlKind.COUNT ) { - return true; + return false; + } + + private SqlTypeName getAggTypeForScalarSub(RexSubQuery e) { + assert(e.getKind() == SqlKind.SCALAR_QUERY); + assert(e.rel.getInputs().size() == 1); + Aggregate relAgg = (Aggregate)e.rel.getInput(0); + assert(relAgg.getAggCallList().size() == 1); //should only have one aggregate + return relAgg.getAggCallList().get(0).getType().getSqlTypeName(); + } + + protected RexNode apply(RexSubQuery e, Set variablesSet, + RelOptUtil.Logic logic, + HiveSubQRemoveRelBuilder builder, int inputCount, int offset, + boolean isCorrScalarAgg, + boolean hasNoWindowingAndNoGby) { + switch (e.getKind()) { + case SCALAR_QUERY: + // if scalar query has aggregate and no windowing and no gby avoid adding sq_count_check + // since it is guaranteed to produce at most one row + if(!hasNoWindowingAndNoGby) { + final List parentQueryFields = new ArrayList<>(); + if (conf.getBoolVar(ConfVars.HIVE_REMOVE_SQ_COUNT_CHECK)) { + // we want to have project after join since sq_count_check's count() expression wouldn't + // be needed further up + parentQueryFields.addAll(builder.fields()); } - return false; - } - private SqlTypeName getAggTypeForScalarSub(RexSubQuery e) { - assert(e.getKind() == SqlKind.SCALAR_QUERY); - assert(e.rel.getInputs().size() == 1); - Aggregate relAgg = (Aggregate)e.rel.getInput(0); - assert( relAgg.getAggCallList().size() == 1); //should only have one aggregate - return relAgg.getAggCallList().get(0).getType().getSqlTypeName(); - } - protected RexNode apply(RexSubQuery e, Set variablesSet, - RelOptUtil.Logic logic, - HiveSubQRemoveRelBuilder builder, int inputCount, int offset, - boolean isCorrScalarAgg, - boolean hasNoWindowingAndNoGby ) { - switch (e.getKind()) { - case SCALAR_QUERY: - // if scalar query has aggregate and no windowing and no gby avoid adding sq_count_check - // since it is guaranteed to produce at most one row - if(!hasNoWindowingAndNoGby) { - final List parentQueryFields = new ArrayList<>(); - if (conf.getBoolVar(ConfVars.HIVE_REMOVE_SQ_COUNT_CHECK)) { - // we want to have project after join since sq_count_check's count() expression wouldn't - // be needed further up - parentQueryFields.addAll(builder.fields()); - } - - builder.push(e.rel); - // returns single row/column - builder.aggregate(builder.groupKey(), builder.count(false, "cnt")); - - SqlFunction countCheck = - new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, - InferTypes.RETURN_TYPE, OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // we create FILTER (sq_count_check(count()) <= 1) instead of PROJECT because RelFieldTrimmer - // ends up getting rid of Project since it is not used further up the tree - builder.filter(builder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, - builder.call(countCheck, builder.field("cnt")), builder.literal(1))); - if (!variablesSet.isEmpty()) { - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - } else - builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); - - if (conf.getBoolVar(ConfVars.HIVE_REMOVE_SQ_COUNT_CHECK)) { - builder.project(parentQueryFields); - } - else { - offset++; - } - - } - if(isCorrScalarAgg) { - // Transformation : - // Outer Query Left Join (inner query) on correlated predicate and preserve rows only from left side. - builder.push(e.rel); - final List parentQueryFields = new ArrayList<>(); - parentQueryFields.addAll(builder.fields()); - - // id is appended since there could be multiple scalar subqueries and FILTER - // is created using field name - String indicator = "alwaysTrue" + e.rel.getId(); - parentQueryFields.add(builder.alias(builder.literal(true), indicator)); - builder.project(parentQueryFields); - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - - final ImmutableList.Builder operands = ImmutableList.builder(); - RexNode literal; - if(isAggZeroOnEmpty(e)) { - // since count has a return type of BIG INT we need to make a literal of type big int - // relbuilder's literal doesn't allow this - literal = e.rel.getCluster().getRexBuilder().makeBigintLiteral(new BigDecimal(0)); - } - else { - literal = e.rel.getCluster().getRexBuilder().makeNullLiteral(getAggTypeForScalarSub(e)); - } - operands.add((builder.isNull(builder.field(indicator))), literal); - operands.add(field(builder, 1, builder.fields().size()-2)); - return builder.call(SqlStdOperatorTable.CASE, operands.build()); - } - - //Transformation is to left join for correlated predicates and inner join otherwise, - // but do a count on inner side before that to make sure it generates atmost 1 row. - builder.push(e.rel); - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - return field(builder, inputCount, offset); - - case IN: - case EXISTS: - // Most general case, where the left and right keys might have nulls, and - // caller requires 3-valued logic return. - // - // select e.deptno, e.deptno in (select deptno from emp) - // - // becomes - // - // select e.deptno, - // case - // when ct.c = 0 then false - // when dt.i is not null then true - // when e.deptno is null then null - // when ct.ck < ct.c then null - // else false - // end - // from e - // left join ( - // (select count(*) as c, count(deptno) as ck from emp) as ct - // cross join (select distinct deptno, true as i from emp)) as dt - // on e.deptno = dt.deptno - // - // If keys are not null we can remove "ct" and simplify to - // - // select e.deptno, - // case - // when dt.i is not null then true - // else false - // end - // from e - // left join (select distinct deptno, true as i from emp) as dt - // on e.deptno = dt.deptno - // - // We could further simplify to - // - // select e.deptno, - // dt.i is not null - // from e - // left join (select distinct deptno, true as i from emp) as dt - // on e.deptno = dt.deptno - // - // but have not yet. - // - // If the logic is TRUE we can just kill the record if the condition - // evaluates to FALSE or UNKNOWN. Thus the query simplifies to an inner - // join: - // - // select e.deptno, - // true - // from e - // inner join (select distinct deptno from emp) as dt - // on e.deptno = dt.deptno - // - - builder.push(e.rel); - final List fields = new ArrayList<>(); - switch (e.getKind()) { - case IN: - fields.addAll(builder.fields()); - // Transformation: sq_count_check(count(*), true) FILTER is generated on top - // of subquery which is then joined (LEFT or INNER) with outer query - // This transformation is done to add run time check using sq_count_check to - // throw an error if subquery is producing zero row, since with aggregate this - // will produce wrong results (because we further rewrite such queries into JOIN) - if(isCorrScalarAgg) { - // returns single row/column - builder.aggregate(builder.groupKey(), - builder.count(false, "cnt_in")); - - if (!variablesSet.isEmpty()) { - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - } else { - builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); - } - - SqlFunction inCountCheck = new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, - InferTypes.RETURN_TYPE, OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION); - - // we create FILTER (sq_count_check(count()) > 0) instead of PROJECT because RelFieldTrimmer - // ends up getting rid of Project since it is not used further up the tree - builder.filter(builder.call(SqlStdOperatorTable.GREATER_THAN, - //true here indicates that sq_count_check is for IN/NOT IN subqueries - builder.call(inCountCheck, builder.field("cnt_in"), builder.literal(true)), - builder.literal(0))); - offset = offset + 1; - builder.push(e.rel); - } - } - - // First, the cross join - switch (logic) { - case TRUE_FALSE_UNKNOWN: - case UNKNOWN_AS_TRUE: - // Since EXISTS/NOT EXISTS are not affected by presence of - // null keys we do not need to generate count(*), count(c) - if (e.getKind() == SqlKind.EXISTS) { - logic = RelOptUtil.Logic.TRUE_FALSE; - break; - } - builder.aggregate(builder.groupKey(), - builder.count(false, "c"), - builder.aggregateCall(SqlStdOperatorTable.COUNT, false, null, "ck", - builder.fields())); - builder.as("ct"); - if( !variablesSet.isEmpty()) - { - //builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); - builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - } - else - builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); - - offset += 2; - builder.push(e.rel); - break; - } - - // Now the left join - switch (logic) { - case TRUE: - if (fields.isEmpty()) { - builder.project(builder.alias(builder.literal(true), "i" + e.rel.getId())); - if(!variablesSet.isEmpty() && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) { - // avoid adding group by for correlated IN/EXISTS queries - // since this is rewritting into semijoin - break; - } - else { - builder.aggregate(builder.groupKey(0)); - } - } else { - if(!variablesSet.isEmpty() && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) { - // avoid adding group by for correlated IN/EXISTS queries - // since this is rewritting into semijoin - break; - } - else { - builder.aggregate(builder.groupKey(fields)); - } - } - break; - default: - fields.add(builder.alias(builder.literal(true), "i" + e.rel.getId())); - builder.project(fields); - builder.distinct(); - } - builder.as("dt"); - final List conditions = new ArrayList<>(); - for (Pair pair - : Pair.zip(e.getOperands(), builder.fields())) { - conditions.add( - builder.equals(pair.left, RexUtil.shift(pair.right, offset))); - } - switch (logic) { - case TRUE: - builder.join(JoinRelType.INNER, builder.and(conditions), variablesSet, true); - return builder.literal(true); - } - builder.join(JoinRelType.LEFT, builder.and(conditions), variablesSet); - - final List keyIsNulls = new ArrayList<>(); - for (RexNode operand : e.getOperands()) { - if (operand.getType().isNullable()) { - keyIsNulls.add(builder.isNull(operand)); - } - } - final ImmutableList.Builder operands = ImmutableList.builder(); - switch (logic) { - case TRUE_FALSE_UNKNOWN: - case UNKNOWN_AS_TRUE: - operands.add( - builder.equals(builder.field("ct", "c"), builder.literal(0)), - builder.literal(false)); - //now that we are using LEFT OUTER JOIN to join inner count, count(*) - // with outer table, we wouldn't be able to tell if count is zero - // for inner table since inner join with correlated values will get rid - // of all values where join cond is not true (i.e where actual inner table - // will produce zero result). To handle this case we need to check both - // count is zero or count is null - operands.add((builder.isNull(builder.field("ct", "c"))), builder.literal(false)); - break; - } - operands.add(builder.isNotNull(builder.field("dt", "i" + e.rel.getId())), - builder.literal(true)); - if (!keyIsNulls.isEmpty()) { - //Calcite creates null literal with Null type here but because HIVE doesn't support null type - // it is appropriately typed boolean - operands.add(builder.or(keyIsNulls), e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN)); - // we are creating filter here so should not be returning NULL. Not sure why Calcite return NULL - //operands.add(builder.or(keyIsNulls), builder.literal(false)); - } - RexNode b = builder.literal(true); - switch (logic) { - case TRUE_FALSE_UNKNOWN: - b = e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN); - // fall through - case UNKNOWN_AS_TRUE: - operands.add( - builder.call(SqlStdOperatorTable.LESS_THAN, - builder.field("ct", "ck"), builder.field("ct", "c")), - b); - break; - } - operands.add(builder.literal(false)); - return builder.call(SqlStdOperatorTable.CASE, operands.build()); - - default: - throw new AssertionError(e.getKind()); + builder.push(e.rel); + // returns single row/column + builder.aggregate(builder.groupKey(), builder.count(false, "cnt")); + + SqlFunction countCheck = + new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT, + InferTypes.RETURN_TYPE, OperandTypes.NUMERIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + //we create FILTER (sq_count_check(count()) <= 1) instead of PROJECT because RelFieldTrimmer + // ends up getting rid of Project since it is not used further up the tree + builder.filter(builder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + builder.call(countCheck, builder.field("cnt")), builder.literal(1))); + if (!variablesSet.isEmpty()) { + builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + } else { + builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); } - } - /** Returns a reference to a particular field, by offset, across several - * inputs on a {@link RelBuilder}'s stack. */ - private RexInputRef field(HiveSubQRemoveRelBuilder builder, int inputCount, int offset) { - for (int inputOrdinal = 0;;) { - final RelNode r = builder.peek(inputCount, inputOrdinal); - if (offset < r.getRowType().getFieldCount()) { - return builder.field(inputCount, inputOrdinal, offset); - } - ++inputOrdinal; - offset -= r.getRowType().getFieldCount(); + if (conf.getBoolVar(ConfVars.HIVE_REMOVE_SQ_COUNT_CHECK)) { + builder.project(parentQueryFields); + } else { + offset++; } - } - - /** Returns a list of expressions that project the first {@code fieldCount} - * fields of the top input on a {@link RelBuilder}'s stack. */ - private static List fields(HiveSubQRemoveRelBuilder builder, int fieldCount) { - final List projects = new ArrayList<>(); - for (int i = 0; i < fieldCount; i++) { - projects.add(builder.field(i)); + } + if(isCorrScalarAgg) { + // Transformation : + // Outer Query Left Join (inner query) on correlated predicate + // and preserve rows only from left side. + builder.push(e.rel); + final List parentQueryFields = new ArrayList<>(); + parentQueryFields.addAll(builder.fields()); + + // id is appended since there could be multiple scalar subqueries and FILTER + // is created using field name + String indicator = "alwaysTrue" + e.rel.getId(); + parentQueryFields.add(builder.alias(builder.literal(true), indicator)); + builder.project(parentQueryFields); + builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + + final ImmutableList.Builder operands = ImmutableList.builder(); + RexNode literal; + if(isAggZeroOnEmpty(e)) { + // since count has a return type of BIG INT we need to make a literal of type big int + // relbuilder's literal doesn't allow this + literal = e.rel.getCluster().getRexBuilder().makeBigintLiteral(new BigDecimal(0)); + } else { + literal = e.rel.getCluster().getRexBuilder().makeNullLiteral(getAggTypeForScalarSub(e)); } - return projects; - } - - /** Shuttle that replaces occurrences of a given - * {@link org.apache.calcite.rex.RexSubQuery} with a replacement - * expression. */ - private static class ReplaceSubQueryShuttle extends RexShuttle { - private final RexSubQuery subQuery; - private final RexNode replacement; - - public ReplaceSubQueryShuttle(RexSubQuery subQuery, RexNode replacement) { - this.subQuery = subQuery; - this.replacement = replacement; + operands.add((builder.isNull(builder.field(indicator))), literal); + operands.add(field(builder, 1, builder.fields().size()-2)); + return builder.call(SqlStdOperatorTable.CASE, operands.build()); + } + + //Transformation is to left join for correlated predicates and inner join otherwise, + // but do a count on inner side before that to make sure it generates atmost 1 row. + builder.push(e.rel); + builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + return field(builder, inputCount, offset); + + case IN: + case EXISTS: + // Most general case, where the left and right keys might have nulls, and + // caller requires 3-valued logic return. + // + // select e.deptno, e.deptno in (select deptno from emp) + // + // becomes + // + // select e.deptno, + // case + // when ct.c = 0 then false + // when dt.i is not null then true + // when e.deptno is null then null + // when ct.ck < ct.c then null + // else false + // end + // from e + // left join ( + // (select count(*) as c, count(deptno) as ck from emp) as ct + // cross join (select distinct deptno, true as i from emp)) as dt + // on e.deptno = dt.deptno + // + // If keys are not null we can remove "ct" and simplify to + // + // select e.deptno, + // case + // when dt.i is not null then true + // else false + // end + // from e + // left join (select distinct deptno, true as i from emp) as dt + // on e.deptno = dt.deptno + // + // We could further simplify to + // + // select e.deptno, + // dt.i is not null + // from e + // left join (select distinct deptno, true as i from emp) as dt + // on e.deptno = dt.deptno + // + // but have not yet. + // + // If the logic is TRUE we can just kill the record if the condition + // evaluates to FALSE or UNKNOWN. Thus the query simplifies to an inner + // join: + // + // select e.deptno, + // true + // from e + // inner join (select distinct deptno from emp) as dt + // on e.deptno = dt.deptno + // + + builder.push(e.rel); + final List fields = new ArrayList<>(); + switch (e.getKind()) { + case IN: + fields.addAll(builder.fields()); + // Transformation: sq_count_check(count(*), true) FILTER is generated on top + // of subquery which is then joined (LEFT or INNER) with outer query + // This transformation is done to add run time check using sq_count_check to + // throw an error if subquery is producing zero row, since with aggregate this + // will produce wrong results (because we further rewrite such queries into JOIN) + if(isCorrScalarAgg) { + // returns single row/column + builder.aggregate(builder.groupKey(), + builder.count(false, "cnt_in")); + + if (!variablesSet.isEmpty()) { + builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + } else { + builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); + } + + SqlFunction inCountCheck = new SqlFunction("sq_count_check", SqlKind.OTHER_FUNCTION, + ReturnTypes.BIGINT, InferTypes.RETURN_TYPE, OperandTypes.NUMERIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + + // we create FILTER (sq_count_check(count()) > 0) instead of PROJECT + // because RelFieldTrimmer ends up getting rid of Project + // since it is not used further up the tree + builder.filter(builder.call(SqlStdOperatorTable.GREATER_THAN, + //true here indicates that sq_count_check is for IN/NOT IN subqueries + builder.call(inCountCheck, builder.field("cnt_in"), builder.literal(true)), + builder.literal(0))); + offset = offset + 1; + builder.push(e.rel); } - - @Override public RexNode visitSubQuery(RexSubQuery subQuery) { - return RexUtil.eq(subQuery, this.subQuery) ? replacement : subQuery; + } + + // First, the cross join + switch (logic) { + case TRUE_FALSE_UNKNOWN: + case UNKNOWN_AS_TRUE: + // Since EXISTS/NOT EXISTS are not affected by presence of + // null keys we do not need to generate count(*), count(c) + if (e.getKind() == SqlKind.EXISTS) { + logic = RelOptUtil.Logic.TRUE_FALSE; + break; } - } - - // TODO: - // Following HiveSubQueryFinder has been copied from RexUtil::SubQueryFinder - // since there is BUG in there (CALCITE-1726). - // Once CALCITE-1726 is fixed we should get rid of the following code - /** Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if - * applied to an expression that contains a {@link RexSubQuery}. */ - public static class HiveSubQueryFinder extends RexVisitorImpl { - public static final HiveSubQueryFinder INSTANCE = new HiveSubQueryFinder(); - - /** Returns whether a {@link Project} contains a sub-query. */ - public static final Predicate RELNODE_PREDICATE= - new Predicate() { - public boolean apply(RelNode relNode) { - if (relNode instanceof Project) { - Project project = (Project)relNode; - for (RexNode node : project.getProjects()) { - try { - node.accept(INSTANCE); - } catch (Util.FoundOne e) { - return true; - } - } - return false; - } - else if (relNode instanceof Filter) { - try { - ((Filter)relNode).getCondition().accept(INSTANCE); - return false; - } catch (Util.FoundOne e) { - return true; - } - } - return false; - } - }; - - private HiveSubQueryFinder() { - super(true); + builder.aggregate(builder.groupKey(), + builder.count(false, "c"), + builder.aggregateCall(SqlStdOperatorTable.COUNT, false, null, "ck", + builder.fields())); + builder.as("ct"); + if(!variablesSet.isEmpty()) { + //builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); + builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); + } else { + builder.join(JoinRelType.INNER, builder.literal(true), variablesSet); } - @Override public Void visitSubQuery(RexSubQuery subQuery) { - throw new Util.FoundOne(subQuery); + offset += 2; + builder.push(e.rel); + break; + } + + // Now the left join + switch (logic) { + case TRUE: + if (fields.isEmpty()) { + builder.project(builder.alias(builder.literal(true), "i" + e.rel.getId())); + if(!variablesSet.isEmpty() + && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) { + // avoid adding group by for correlated IN/EXISTS queries + // since this is rewritting into semijoin + break; + } else { + builder.aggregate(builder.groupKey(0)); + } + } else { + if(!variablesSet.isEmpty() + && (e.getKind() == SqlKind.EXISTS || e.getKind() == SqlKind.IN)) { + // avoid adding group by for correlated IN/EXISTS queries + // since this is rewritting into semijoin + break; + } else { + builder.aggregate(builder.groupKey(fields)); + } } + break; + default: + fields.add(builder.alias(builder.literal(true), "i" + e.rel.getId())); + builder.project(fields); + builder.distinct(); + } + builder.as("dt"); + final List conditions = new ArrayList<>(); + for (Pair pair + : Pair.zip(e.getOperands(), builder.fields())) { + conditions.add( + builder.equals(pair.left, RexUtil.shift(pair.right, offset))); + } + switch (logic) { + case TRUE: + builder.join(JoinRelType.INNER, builder.and(conditions), variablesSet, true); + return builder.literal(true); + } + builder.join(JoinRelType.LEFT, builder.and(conditions), variablesSet); + + final List keyIsNulls = new ArrayList<>(); + for (RexNode operand : e.getOperands()) { + if (operand.getType().isNullable()) { + keyIsNulls.add(builder.isNull(operand)); + } + } + final ImmutableList.Builder operands = ImmutableList.builder(); + switch (logic) { + case TRUE_FALSE_UNKNOWN: + case UNKNOWN_AS_TRUE: + operands.add( + builder.equals(builder.field("ct", "c"), builder.literal(0)), + builder.literal(false)); + //now that we are using LEFT OUTER JOIN to join inner count, count(*) + // with outer table, we wouldn't be able to tell if count is zero + // for inner table since inner join with correlated values will get rid + // of all values where join cond is not true (i.e where actual inner table + // will produce zero result). To handle this case we need to check both + // count is zero or count is null + operands.add((builder.isNull(builder.field("ct", "c"))), builder.literal(false)); + break; + } + operands.add(builder.isNotNull(builder.field("dt", "i" + e.rel.getId())), + builder.literal(true)); + if (!keyIsNulls.isEmpty()) { + //Calcite creates null literal with Null type here but + // because HIVE doesn't support null type it is appropriately typed boolean + operands.add(builder.or(keyIsNulls), + e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN)); + // we are creating filter here so should not be returning NULL. + // Not sure why Calcite return NULL + } + RexNode b = builder.literal(true); + switch (logic) { + case TRUE_FALSE_UNKNOWN: + b = e.rel.getCluster().getRexBuilder().makeNullLiteral(SqlTypeName.BOOLEAN); + // fall through + case UNKNOWN_AS_TRUE: + operands.add( + builder.call(SqlStdOperatorTable.LESS_THAN, + builder.field("ct", "ck"), builder.field("ct", "c")), + b); + break; + } + operands.add(builder.literal(false)); + return builder.call(SqlStdOperatorTable.CASE, operands.build()); + + default: + throw new AssertionError(e.getKind()); + } + } + + /** Returns a reference to a particular field, by offset, across several + * inputs on a {@link RelBuilder}'s stack. */ + private RexInputRef field(HiveSubQRemoveRelBuilder builder, int inputCount, int offset) { + for (int inputOrdinal = 0;;) { + final RelNode r = builder.peek(inputCount, inputOrdinal); + if (offset < r.getRowType().getFieldCount()) { + return builder.field(inputCount, inputOrdinal, offset); + } + ++inputOrdinal; + offset -= r.getRowType().getFieldCount(); + } + } + + /** Returns a list of expressions that project the first {@code fieldCount} + * fields of the top input on a {@link RelBuilder}'s stack. */ + private static List fields(HiveSubQRemoveRelBuilder builder, int fieldCount) { + final List projects = new ArrayList<>(); + for (int i = 0; i < fieldCount; i++) { + projects.add(builder.field(i)); + } + return projects; + } + + /** Shuttle that replaces occurrences of a given + * {@link org.apache.calcite.rex.RexSubQuery} with a replacement + * expression. */ + private static class ReplaceSubQueryShuttle extends RexShuttle { + private final RexSubQuery subQuery; + private final RexNode replacement; + + ReplaceSubQueryShuttle(RexSubQuery subQuery, RexNode replacement) { + this.subQuery = subQuery; + this.replacement = replacement; + } - public static RexSubQuery find(Iterable nodes) { - for (RexNode node : nodes) { + @Override public RexNode visitSubQuery(RexSubQuery subQuery) { + return RexUtil.eq(subQuery, this.subQuery) ? replacement : subQuery; + } + } + + // TODO: + // Following HiveSubQueryFinder has been copied from RexUtil::SubQueryFinder + // since there is BUG in there (CALCITE-1726). + // Once CALCITE-1726 is fixed we should get rid of the following code + /** Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if + * applied to an expression that contains a {@link RexSubQuery}. */ + public static final class HiveSubQueryFinder extends RexVisitorImpl { + public static final HiveSubQueryFinder INSTANCE = new HiveSubQueryFinder(); + + /** Returns whether a {@link Project} contains a sub-query. */ + public static final Predicate RELNODE_PREDICATE= + new Predicate() { + public boolean apply(RelNode relNode) { + if (relNode instanceof Project) { + Project project = (Project)relNode; + for (RexNode node : project.getProjects()) { try { - node.accept(INSTANCE); + node.accept(INSTANCE); } catch (Util.FoundOne e) { - return (RexSubQuery) e.getNode(); + return true; } + } + return false; + } else if (relNode instanceof Filter) { + try { + ((Filter)relNode).getCondition().accept(INSTANCE); + return false; + } catch (Util.FoundOne e) { + return true; + } } - return null; - } + return false; + } + }; - public static RexSubQuery find(RexNode node) { - try { - node.accept(INSTANCE); - return null; - } catch (Util.FoundOne e) { - return (RexSubQuery) e.getNode(); - } + private HiveSubQueryFinder() { + super(true); + } + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + throw new Util.FoundOne(subQuery); + } + + public static RexSubQuery find(Iterable nodes) { + for (RexNode node : nodes) { + try { + node.accept(INSTANCE); + } catch (Util.FoundOne e) { + return (RexSubQuery) e.getNode(); } + } + return null; + } + + public static RexSubQuery find(RexNode node) { + try { + node.accept(INSTANCE); + return null; + } catch (Util.FoundOne e) { + return (RexSubQuery) e.getNode(); + } } + } }