diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java index 103d5e157e..9013a8f1b7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java @@ -16,6 +16,8 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Collections; @@ -26,11 +28,13 @@ import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Aggregate.Group; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.metadata.RelColumnOrigin; @@ -112,7 +116,7 @@ public HiveExpandDistinctAggregatesRule( public void onMatch(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); int numCountDistinct = getNumCountDistinctCall(aggregate); - if (numCountDistinct == 0) { + if (numCountDistinct == 0 || aggregate.getGroupType() != Group.SIMPLE) { return; } @@ -121,7 +125,8 @@ public void onMatch(RelOptRuleCall call) { int nonDistinctCount = 0; List> argListList = new ArrayList>(); Set> argListSets = new LinkedHashSet>(); - Set positions = new HashSet<>(); + ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder(); + newGroupSet.addAll(aggregate.getGroupSet()); for (AggregateCall aggCall : aggregate.getAggCallList()) { if (!aggCall.isDistinct()) { ++nonDistinctCount; @@ -130,33 +135,27 @@ public void onMatch(RelOptRuleCall call) { ArrayList argList = new ArrayList(); for (Integer arg : aggCall.getArgList()) { argList.add(arg); - positions.add(arg); + newGroupSet.set(arg); } // Aggr checks for sorted argList. argListList.add(argList); argListSets.add(argList); } - Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied"); + Preconditions.checkArgument(argListSets.size() > 0, "containsDistinctCall lied"); - if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size() - && aggregate.getGroupSet().isEmpty()) { + if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()) { LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct); // now positions contains all the distinct positions, i.e., $5, $4, $6 // we need to first sort them as group by set // and then get their position later, i.e., $4->1, $5->2, $6->3 cluster = aggregate.getCluster(); rexBuilder = cluster.getRexBuilder(); - RelNode converted = null; - List sourceOfForCountDistinct = new ArrayList<>(); - sourceOfForCountDistinct.addAll(positions); - Collections.sort(sourceOfForCountDistinct); try { - converted = convert(aggregate, argListList, sourceOfForCountDistinct); + call.transformTo(convert(aggregate, argListList, newGroupSet.build())); } catch (CalciteSemanticException e) { LOG.debug(e.toString()); throw new RuntimeException(e); } - call.transformTo(converted); return; } @@ -200,19 +199,20 @@ public void onMatch(RelOptRuleCall call) { * (department_id, gender, education_level))subq; * @throws CalciteSemanticException */ - private RelNode convert(Aggregate aggregate, List> argList, List sourceOfForCountDistinct) throws CalciteSemanticException { + private RelNode convert(Aggregate aggregate, List> argList, ImmutableBitSet newGroupSet) + throws CalciteSemanticException { // we use this map to map the position of argList to the position of grouping set Map map = new HashMap<>(); List> cleanArgList = new ArrayList<>(); - final Aggregate groupingSets = createGroupingSets(aggregate, argList, cleanArgList, map, sourceOfForCountDistinct); - return createCount(groupingSets, argList, cleanArgList, map, sourceOfForCountDistinct); + final Aggregate groupingSets = createGroupingSets(aggregate, argList, cleanArgList, map, newGroupSet); + return createCount(groupingSets, argList, cleanArgList, map, aggregate.getGroupSet(), newGroupSet); } - private int getGroupingIdValue(List list, List sourceOfForCountDistinct, + private int getGroupingIdValue(List list, ImmutableBitSet newGroupSet, int groupCount) { int ind = IntMath.pow(2, groupCount) - 1; for (int i : list) { - ind &= ~(1 << groupCount - sourceOfForCountDistinct.indexOf(i) - 1); + ind &= ~(1 << groupCount - newGroupSet.indexOf(i) - 1); } return ind; } @@ -222,28 +222,28 @@ private int getGroupingIdValue(List list, List sourceOfForCoun * @param argList: the original argList in aggregate * @param cleanArgList: the new argList without duplicates * @param map: the mapping from the original argList to the new argList - * @param sourceOfForCountDistinct: the sorted positions of groupset + * @param newGroupSet: the sorted positions of groupset * @return * @throws CalciteSemanticException */ private RelNode createCount(Aggregate aggr, List> argList, List> cleanArgList, Map map, - List sourceOfForCountDistinct) throws CalciteSemanticException { - List originalInputRefs = Lists.transform(aggr.getRowType().getFieldList(), - new Function() { - @Override - public RexNode apply(RelDataTypeField input) { - return new RexInputRef(input.getIndex(), input.getType()); - } - }); + ImmutableBitSet originalGroupSet, ImmutableBitSet newGroupSet) throws CalciteSemanticException { + final List originalInputRefs = aggr.getRowType().getFieldList() + .stream() + .map(input -> new RexInputRef(input.getIndex(), input.getType())) + .collect(Collectors.toList()); final List gbChildProjLst = Lists.newArrayList(); // for singular arg, count should not include null // e.g., count(case when i=1 and department_id is not null then 1 else null end) as c0, // for non-singular args, count can include null, i.e. (,) is counted as 1 for (List list : cleanArgList) { - RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs - .get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal( - getGroupingIdValue(list, sourceOfForCountDistinct, aggr.getGroupCount())))); + RexNode condition = rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + originalInputRefs.get(originalInputRefs.size() - 1), + rexBuilder.makeExactLiteral( + new BigDecimal( + getGroupingIdValue(list, newGroupSet, aggr.getGroupCount())))); if (list.size() == 1) { int pos = list.get(0); RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, @@ -257,6 +257,10 @@ public RexNode apply(RelDataTypeField input) { gbChildProjLst.add(when); } + for (int pos : originalGroupSet) { + gbChildProjLst.add(originalInputRefs.get(newGroupSet.indexOf(pos))); + } + // create the project before GB RelNode gbInputRel = HiveProject.create(aggr, gbChildProjLst, null); @@ -269,23 +273,25 @@ public RexNode apply(RelDataTypeField input) { TypeInfoFactory.longTypeInfo, i, aggFnRetType); aggregateCalls.add(aggregateCall); } + ImmutableBitSet groupSet = + ImmutableBitSet.range(cleanArgList.size(), cleanArgList.size() + originalGroupSet.cardinality()); Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, - ImmutableBitSet.of(), null, aggregateCalls); + groupSet, null, aggregateCalls); // create the project after GB. For those repeated values, e.g., select // count(distinct x, y), count(distinct y, x), we find the correct mapping. if (map.isEmpty()) { return aggregate; } else { - List originalAggrRefs = Lists.transform(aggregate.getRowType().getFieldList(), - new Function() { - @Override - public RexNode apply(RelDataTypeField input) { - return new RexInputRef(input.getIndex(), input.getType()); - } - }); + final List originalAggrRefs = aggregate.getRowType().getFieldList() + .stream() + .map(input -> new RexInputRef(input.getIndex(), input.getType())) + .collect(Collectors.toList()); final List projLst = Lists.newArrayList(); int index = 0; + for (int i = 0; i < groupSet.cardinality(); i++) { + projLst.add(originalAggrRefs.get(index++)); + } for (int i = 0; i < argList.size(); i++) { if (map.containsKey(i)) { projLst.add(originalAggrRefs.get(map.get(i))); @@ -302,18 +308,18 @@ public RexNode apply(RelDataTypeField input) { * @param argList: the original argList in aggregate * @param cleanArgList: the new argList without duplicates * @param map: the mapping from the original argList to the new argList - * @param sourceOfForCountDistinct: the sorted positions of groupset + * @param groupSet: new group set * @return */ private Aggregate createGroupingSets(Aggregate aggregate, List> argList, List> cleanArgList, Map map, - List sourceOfForCountDistinct) { - final ImmutableBitSet groupSet = ImmutableBitSet.of(sourceOfForCountDistinct); + ImmutableBitSet groupSet) { final List origGroupSets = new ArrayList<>(); for (int i = 0; i < argList.size(); i++) { List list = argList.get(i); - ImmutableBitSet bitSet = ImmutableBitSet.of(list); + ImmutableBitSet bitSet = aggregate.getGroupSet().union( + ImmutableBitSet.of(list)); int prev = origGroupSets.indexOf(bitSet); if (prev == -1) { origGroupSets.add(bitSet); @@ -323,7 +329,7 @@ private Aggregate createGroupingSets(Aggregate aggregate, List> ar } } // Calcite expects the grouping sets sorted and without duplicates - Collections.sort(origGroupSets, ImmutableBitSet.COMPARATOR); + origGroupSets.sort(ImmutableBitSet.COMPARATOR); List aggregateCalls = new ArrayList(); // Create GroupingID column diff --git a/ql/src/test/queries/clientpositive/multigroupbydistinct.q b/ql/src/test/queries/clientpositive/multigroupbydistinct.q new file mode 100644 index 0000000000..045c4378bd --- /dev/null +++ b/ql/src/test/queries/clientpositive/multigroupbydistinct.q @@ -0,0 +1,7 @@ +create table mytable1 (x integer, y integer, z integer, a integer); + +explain cbo +select z, x, count(distinct y), count(distinct a) +from mytable1 +group by z, x; + diff --git a/ql/src/test/results/clientpositive/llap/multigroupbydistinct.q.out b/ql/src/test/results/clientpositive/llap/multigroupbydistinct.q.out new file mode 100644 index 0000000000..2f9fec9ca7 --- /dev/null +++ b/ql/src/test/results/clientpositive/llap/multigroupbydistinct.q.out @@ -0,0 +1,29 @@ +PREHOOK: query: create table mytable1 (x integer, y integer, z integer, a integer) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@mytable1 +POSTHOOK: query: create table mytable1 (x integer, y integer, z integer, a integer) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@mytable1 +PREHOOK: query: explain cbo +select z, x, count(distinct y), count(distinct a) +from mytable1 +group by z, x +PREHOOK: type: QUERY +PREHOOK: Input: default@mytable1 +#### A masked pattern was here #### +POSTHOOK: query: explain cbo +select z, x, count(distinct y), count(distinct a) +from mytable1 +group by z, x +POSTHOOK: type: QUERY +POSTHOOK: Input: default@mytable1 +#### A masked pattern was here #### +CBO PLAN: +HiveAggregate(group=[{2, 3}], agg#0=[count($0)], agg#1=[count($1)]) + HiveProject($f0=[CASE(AND(=($4, 13), IS NOT NULL($2)), 1, null:INTEGER)], $f1=[CASE(AND(=($4, 14), IS NOT NULL($3)), 1, null:INTEGER)], $f2=[$0], $f3=[$1]) + HiveAggregate(group=[{0, 1, 2, 3}], groups=[[{0, 1, 2}, {0, 1, 3}]], GROUPING__ID=[GROUPING__ID()]) + HiveProject($f0=[$2], $f1=[$0], $f2=[$1], $f3=[$3]) + HiveTableScan(table=[[default, mytable1]], table:alias=[mytable1]) +