diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java index 73c7cac..9fa71fd 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveExpandDistinctAggregatesRule.java @@ -16,18 +16,25 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.rules; +import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.core.RelFactories.AggregateFactory; import org.apache.calcite.rel.metadata.RelColumnOrigin; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.ImmutableIntList; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; @@ -37,12 +44,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import java.math.BigDecimal; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.TreeSet; /** * Planner rule that expands distinct aggregates @@ -70,16 +80,20 @@ /** The default instance of the rule; operates only on logical expressions. */ public static final HiveExpandDistinctAggregatesRule INSTANCE = new HiveExpandDistinctAggregatesRule(HiveAggregate.class, - HiveProject.DEFAULT_PROJECT_FACTORY); + HiveProject.DEFAULT_PROJECT_FACTORY, HiveAggregate.HIVE_AGGR_REL_FACTORY); private static RelFactories.ProjectFactory projFactory; + private static RelFactories.AggregateFactory aggregateFactory; + public static final BigDecimal TWO = BigDecimal.valueOf(2L); //~ Constructors ----------------------------------------------------------- public HiveExpandDistinctAggregatesRule( - Class clazz,RelFactories.ProjectFactory projectFactory) { + Class clazz,RelFactories.ProjectFactory projectFactory, + AggregateFactory aggFactory) { super(operand(clazz, any())); projFactory = projectFactory; + aggregateFactory = aggFactory; } //~ Methods ---------------------------------------------------------------- @@ -88,9 +102,16 @@ public HiveExpandDistinctAggregatesRule( public void onMatch(RelOptRuleCall call) { final Aggregate aggregate = call.rel(0); if (!aggregate.containsDistinctCall()) { + // if there are no distinct, then there is nothing to optimize for. return; } + if (aggregate.getGroupCount() > 0) { + // We are trying to lengthen the pipeline to avoid end up with 1 reducer + // processing lots of data. That will be the case only when there are no + // group-by's in query. + return; + } // Find all of the agg expressions. We use a LinkedHashSet to ensure // determinism. int nonDistinctCount = 0; @@ -108,28 +129,146 @@ public void onMatch(RelOptRuleCall call) { } Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied"); + if (allPartCols(argListSets, aggregate)) { + // All columns are partitioning column, this will be better handled by MetadataOnly optimizer. + return; + } + // If all of the agg expressions are distinct and have the same // arguments then we can use a more efficient form. if ((nonDistinctCount == 0) && (argListSets.size() == 1)) { - for (Integer arg : argListSets.iterator().next()) { + RelNode converted = convertMonopole(aggregate, argListSets.iterator().next()); + call.transformTo(converted); + return; + } else { + rewriteUsingGroupingSets(call, aggregate, argListSets); + return; + } + } + + private boolean allPartCols(Set> argListSets, Aggregate aggregate) { + for (List argList : argListSets) { + for (Integer arg : argList) { Set colOrigs = RelMetadataQuery.getColumnOrigins(aggregate, arg); if (null != colOrigs) { for (RelColumnOrigin colOrig : colOrigs) { RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable(); - if(hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) { - // Encountered partitioning column, this will be better handled by MetadataOnly optimizer. - return; + if(!hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) { + return false; } } } } - RelNode converted = - convertMonopole( - aggregate, - argListSets.iterator().next()); - call.transformTo(converted); - return; } + return true; + } + + private void rewriteUsingGroupingSets(RelOptRuleCall ruleCall, + Aggregate aggregate, Set> argLists) { + final Set groupSetTreeSet = + new TreeSet(ImmutableBitSet.ORDERING); + groupSetTreeSet.add(aggregate.getGroupSet()); + for (List argList : argLists) { + groupSetTreeSet.add(ImmutableBitSet.of(argList).union(aggregate.getGroupSet())); + } + + final ImmutableList groupSets = + ImmutableList.copyOf(groupSetTreeSet); + final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets); + + final List distinctAggCalls = new ArrayList<>(); + for (Pair call : Pair.zip(aggregate.getAggCallList(), Util.skip(aggregate.getRowType().getFieldNames(), aggregate.getGroupCount() + aggregate.getIndicatorCount()))) { + if (!call.left.isDistinct()) { + distinctAggCalls.add(call.left.rename(call.right)); + } + } + + final RelNode distinct = + aggregateFactory.createAggregate(aggregate.getInput(), + groupSets.size() > 1, fullGroupSet, groupSets, distinctAggCalls); + final int groupCount = fullGroupSet.cardinality(); + final int indicatorCount = groupSets.size() > 1 ? groupCount : 0; + + final RelOptCluster cluster = aggregate.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final List> predicates = new ArrayList<>(); + RelNode r = distinct; + if (!predicates.isEmpty()) { + List> nodes = new ArrayList<>(); + for (RelDataTypeField f : r.getRowType().getFieldList()) { + final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex()); + nodes.add(Pair.of(node, f.getName())); + } + nodes.addAll(predicates); + r = RelOptUtil.createProject(r, nodes, false); + } + + int x = groupCount + indicatorCount; + final List newCalls = new ArrayList<>(); + for (AggregateCall call : aggregate.getAggCallList()) { + + final List newArgList; + final SqlAggFunction aggregation; + if (!call.isDistinct()) { + aggregation = SqlStdOperatorTable.MIN; + newArgList = ImmutableIntList.of(x++); + } else { + aggregation = call.getAggregation(); + newArgList = remap(fullGroupSet, call.getArgList()); + + } + final AggregateCall newCall = + AggregateCall.create(aggregation, false, newArgList, + aggregate.getGroupCount(), distinct, null, call.name); + newCalls.add(newCall); + } + + final RelNode newAggregate = + aggregateFactory.createAggregate(r, aggregate.indicator, + remap(fullGroupSet, aggregate.getGroupSet()), + remap(fullGroupSet, aggregate.getGroupSets()), newCalls); + ruleCall.transformTo( + RelOptUtil.createCastRel(newAggregate, aggregate.getRowType(), true, + projFactory)); + } + + private static ImmutableBitSet remap(ImmutableBitSet groupSet, + ImmutableBitSet bitSet) { + final ImmutableBitSet.Builder builder = ImmutableBitSet.builder(); + for (Integer bit : bitSet) { + builder.set(remap(groupSet, bit)); + } + return builder.build(); + } + + private static ImmutableList remap(ImmutableBitSet groupSet, + Iterable bitSets) { + final ImmutableList.Builder builder = + ImmutableList.builder(); + for (ImmutableBitSet bitSet : bitSets) { + builder.add(remap(groupSet, bitSet)); + } + return builder.build(); + } + + private static List remap(ImmutableBitSet groupSet, + List argList) { + ImmutableIntList list = ImmutableIntList.of(); + for (int arg : argList) { + int remapped = remap(groupSet, arg); + if (list.isEmpty()) { + list = ImmutableIntList.of(remapped); + } else { + final int[] newInts = Arrays.copyOf(list.toIntArray(), list.size()+1); + newInts[list.size()] = remapped; + list = ImmutableIntList.of(newInts); + } + } + return list; + } + + private static int remap(ImmutableBitSet groupSet, int arg) { + return arg < 0 ? -1 : groupSet.indexOf(arg); } /** diff --git a/ql/src/test/results/clientpositive/tez/limit_pushdown.q.out b/ql/src/test/results/clientpositive/tez/limit_pushdown.q.out index 2a41aae..7038b4d 100644 --- a/ql/src/test/results/clientpositive/tez/limit_pushdown.q.out +++ b/ql/src/test/results/clientpositive/tez/limit_pushdown.q.out @@ -476,38 +476,35 @@ STAGE PLANS: outputColumnNames: _col0, _col1 Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE Group By Operator + aggregations: count(DISTINCT _col1) keys: _col0 (type: tinyint), _col1 (type: double) mode: hash - outputColumnNames: _col0, _col1 + outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: tinyint), _col1 (type: double) sort order: ++ Map-reduce partition columns: _col0 (type: tinyint) Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE + TopN Hash Memory Usage: 0.3 Reducer 2 Reduce Operator Tree: Group By Operator - keys: KEY._col0 (type: tinyint), KEY._col1 (type: double) + aggregations: count(DISTINCT KEY._col1:0._col0) + keys: KEY._col0 (type: tinyint) mode: mergepartial outputColumnNames: _col0, _col1 Statistics: Num rows: 6144 Data size: 1320982 Basic stats: COMPLETE Column stats: NONE - Group By Operator - aggregations: count(_col1) - keys: _col0 (type: tinyint) - mode: complete - outputColumnNames: _col0, _col1 - Statistics: Num rows: 3072 Data size: 660491 Basic stats: COMPLETE Column stats: NONE - Limit - Number of rows: 20 + Limit + Number of rows: 20 + Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE - File Output Operator - compressed: false - Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE - table: - input format: org.apache.hadoop.mapred.TextInputFormat - output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat - serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe Stage: Stage-0 Fetch Operator diff --git a/ql/src/test/results/clientpositive/tez/mrr.q.out b/ql/src/test/results/clientpositive/tez/mrr.q.out index d42f9b0..39a830b 100644 --- a/ql/src/test/results/clientpositive/tez/mrr.q.out +++ b/ql/src/test/results/clientpositive/tez/mrr.q.out @@ -456,9 +456,10 @@ STAGE PLANS: outputColumnNames: _col0, _col1 Statistics: Num rows: 275 Data size: 2921 Basic stats: COMPLETE Column stats: NONE Group By Operator + aggregations: count(DISTINCT _col1) keys: _col0 (type: string), _col1 (type: string) mode: hash - outputColumnNames: _col0, _col1 + outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 275 Data size: 2921 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: string), _col1 (type: string) @@ -468,30 +469,25 @@ STAGE PLANS: Reducer 3 Reduce Operator Tree: Group By Operator - keys: KEY._col0 (type: string), KEY._col1 (type: string) + aggregations: count(DISTINCT KEY._col1:0._col0) + keys: KEY._col0 (type: string) mode: mergepartial outputColumnNames: _col0, _col1 Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE - Group By Operator - aggregations: count(_col1) - keys: _col0 (type: string) - mode: complete - outputColumnNames: _col0, _col1 - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE - Reduce Output Operator - key expressions: _col1 (type: bigint) - sort order: + - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE - value expressions: _col0 (type: string) + Reduce Output Operator + key expressions: _col1 (type: bigint) + sort order: + + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE + value expressions: _col0 (type: string) Reducer 4 Reduce Operator Tree: Select Operator expressions: VALUE._col0 (type: string), KEY.reducesinkkey0 (type: bigint) outputColumnNames: _col0, _col1 - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE File Output Operator compressed: false - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE table: input format: org.apache.hadoop.mapred.TextInputFormat output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat @@ -863,9 +859,10 @@ STAGE PLANS: Statistics: Num rows: 275 Data size: 2921 Basic stats: COMPLETE Column stats: NONE HybridGraceHashJoin: true Group By Operator + aggregations: count(DISTINCT _col1) keys: _col0 (type: string), _col1 (type: string) mode: hash - outputColumnNames: _col0, _col1 + outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 275 Data size: 2921 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: string), _col1 (type: string) @@ -892,30 +889,25 @@ STAGE PLANS: Reducer 2 Reduce Operator Tree: Group By Operator - keys: KEY._col0 (type: string), KEY._col1 (type: string) + aggregations: count(DISTINCT KEY._col1:0._col0) + keys: KEY._col0 (type: string) mode: mergepartial outputColumnNames: _col0, _col1 Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE - Group By Operator - aggregations: count(_col1) - keys: _col0 (type: string) - mode: complete - outputColumnNames: _col0, _col1 - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE - Reduce Output Operator - key expressions: _col1 (type: bigint) - sort order: + - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE - value expressions: _col0 (type: string) + Reduce Output Operator + key expressions: _col1 (type: bigint) + sort order: + + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE + value expressions: _col0 (type: string) Reducer 3 Reduce Operator Tree: Select Operator expressions: VALUE._col0 (type: string), KEY.reducesinkkey0 (type: bigint) outputColumnNames: _col0, _col1 - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE File Output Operator compressed: false - Statistics: Num rows: 68 Data size: 722 Basic stats: COMPLETE Column stats: NONE + Statistics: Num rows: 137 Data size: 1455 Basic stats: COMPLETE Column stats: NONE table: input format: org.apache.hadoop.mapred.TextInputFormat output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat diff --git a/ql/src/test/results/clientpositive/tez/vectorization_limit.q.out b/ql/src/test/results/clientpositive/tez/vectorization_limit.q.out index 1c5b51f..d815938 100644 --- a/ql/src/test/results/clientpositive/tez/vectorization_limit.q.out +++ b/ql/src/test/results/clientpositive/tez/vectorization_limit.q.out @@ -345,39 +345,36 @@ STAGE PLANS: outputColumnNames: _col0, _col1 Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE Group By Operator + aggregations: count(DISTINCT _col1) keys: _col0 (type: tinyint), _col1 (type: double) mode: hash - outputColumnNames: _col0, _col1 + outputColumnNames: _col0, _col1, _col2 Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE Reduce Output Operator key expressions: _col0 (type: tinyint), _col1 (type: double) sort order: ++ Map-reduce partition columns: _col0 (type: tinyint) Statistics: Num rows: 12288 Data size: 2641964 Basic stats: COMPLETE Column stats: NONE + TopN Hash Memory Usage: 0.3 Execution mode: vectorized Reducer 2 Reduce Operator Tree: Group By Operator - keys: KEY._col0 (type: tinyint), KEY._col1 (type: double) + aggregations: count(DISTINCT KEY._col1:0._col0) + keys: KEY._col0 (type: tinyint) mode: mergepartial outputColumnNames: _col0, _col1 Statistics: Num rows: 6144 Data size: 1320982 Basic stats: COMPLETE Column stats: NONE - Group By Operator - aggregations: count(_col1) - keys: _col0 (type: tinyint) - mode: complete - outputColumnNames: _col0, _col1 - Statistics: Num rows: 3072 Data size: 660491 Basic stats: COMPLETE Column stats: NONE - Limit - Number of rows: 20 + Limit + Number of rows: 20 + Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE - File Output Operator - compressed: false - Statistics: Num rows: 20 Data size: 4300 Basic stats: COMPLETE Column stats: NONE - table: - input format: org.apache.hadoop.mapred.TextInputFormat - output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat - serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe Stage: Stage-0 Fetch Operator