diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index 3759ed6c50..c647f44971 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -473,7 +473,7 @@ private ImmutableBitSet generateNewGroupset(Aggregate aggregate, ImmutableBitSet * */ private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fieldsUsed, - Set extraFields) { + ImmutableBitSet aggCallFields ) { if ((aggregate.getIndicatorCount() > 0) || (aggregate.getGroupSet().isEmpty()) || fieldsUsed.contains(aggregate.getGroupSet())) { @@ -503,7 +503,7 @@ private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fie if (allConstants) { for (int i = 0; i < rowType.getFieldCount(); i++) { - if (aggregate.getGroupSet().get(i)) { + if (aggregate.getGroupSet().get(i) && !aggCallFields.get(i)) { newProjects.add(rexBuilder.makeLiteral(true)); } else { newProjects.add(rexBuilder.makeInputRef(input, i)); @@ -516,7 +516,7 @@ private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fie return newAggregate; } return aggregate; - } + } @Override public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Set extraFields) { @@ -533,24 +533,33 @@ public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Se // // But group and indicator fields stay, even if they are not used. - aggregate = rewriteGBConstantKeys(aggregate, fieldsUsed, extraFields); + // Compute which input fields are used. - final RelDataType rowType = aggregate.getRowType(); - // Compute which input fields are used. - // 1. group fields are always used - final ImmutableBitSet.Builder inputFieldsUsed = - aggregate.getGroupSet().rebuild(); - // 2. agg functions + // agg functions + // agg functions are added first (before group sets) because rewriteGBConstantsKeys + // needs it + final ImmutableBitSet.Builder aggCallFieldsUsedBuilder = ImmutableBitSet.builder(); for (AggregateCall aggCall : aggregate.getAggCallList()) { for (int i : aggCall.getArgList()) { - inputFieldsUsed.set(i); + aggCallFieldsUsedBuilder.set(i); } if (aggCall.filterArg >= 0) { - inputFieldsUsed.set(aggCall.filterArg); + aggCallFieldsUsedBuilder.set(aggCall.filterArg); } } + // transform if group by contain constant keys + ImmutableBitSet aggCallFieldsUsed = aggCallFieldsUsedBuilder.build(); + aggregate = rewriteGBConstantKeys(aggregate, fieldsUsed, aggCallFieldsUsed); + + // add group fields + final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild(); + inputFieldsUsed.addAll(aggCallFieldsUsed); + + + final RelDataType rowType = aggregate.getRowType(); + // Create input with trimmed columns. final RelNode input = aggregate.getInput(); final Set inputExtraFields = Collections.emptySet(); diff --git a/ql/src/test/queries/clientpositive/groupby13.q b/ql/src/test/queries/clientpositive/groupby13.q index 53feaedd3a..900f557489 100644 --- a/ql/src/test/queries/clientpositive/groupby13.q +++ b/ql/src/test/queries/clientpositive/groupby13.q @@ -14,3 +14,9 @@ int_col_7, int_col_7, LEAST(COALESCE(int_col_5, -279), COALESCE(int_col_7, 476)); + +create table aGBY (i int, j string); +insert into aGBY values ( 1, 'a'),(2,'b'); +explain cbo select min(j) from aGBY where j='a' group by j; +select min(j) from aGBY where j='a' group by j; +drop table aGBY; diff --git a/ql/src/test/results/clientpositive/groupby13.q.out b/ql/src/test/results/clientpositive/groupby13.q.out index d7fcc6806c..0af8530fb3 100644 --- a/ql/src/test/results/clientpositive/groupby13.q.out +++ b/ql/src/test/results/clientpositive/groupby13.q.out @@ -90,3 +90,53 @@ STAGE PLANS: Processor Tree: ListSink +PREHOOK: query: create table aGBY (i int, j string) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@aGBY +POSTHOOK: query: create table aGBY (i int, j string) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@aGBY +PREHOOK: query: insert into aGBY values ( 1, 'a'),(2,'b') +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@agby +POSTHOOK: query: insert into aGBY values ( 1, 'a'),(2,'b') +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@agby +POSTHOOK: Lineage: agby.i SCRIPT [] +POSTHOOK: Lineage: agby.j SCRIPT [] +PREHOOK: query: explain cbo select min(j) from aGBY where j='a' group by j +PREHOOK: type: QUERY +PREHOOK: Input: default@agby +#### A masked pattern was here #### +POSTHOOK: query: explain cbo select min(j) from aGBY where j='a' group by j +POSTHOOK: type: QUERY +POSTHOOK: Input: default@agby +#### A masked pattern was here #### +CBO PLAN: +HiveProject(_o__c0=[$1]) + HiveAggregate(group=[{0}], agg#0=[min($0)]) + HiveProject($f0=[CAST(_UTF-16LE'a':VARCHAR(2147483647) CHARACTER SET "UTF-16LE"):VARCHAR(2147483647) CHARACTER SET "UTF-16LE"]) + HiveFilter(condition=[=($1, _UTF-16LE'a')]) + HiveTableScan(table=[[default, agby]], table:alias=[agby]) + +PREHOOK: query: select min(j) from aGBY where j='a' group by j +PREHOOK: type: QUERY +PREHOOK: Input: default@agby +#### A masked pattern was here #### +POSTHOOK: query: select min(j) from aGBY where j='a' group by j +POSTHOOK: type: QUERY +POSTHOOK: Input: default@agby +#### A masked pattern was here #### +a +PREHOOK: query: drop table aGBY +PREHOOK: type: DROPTABLE +PREHOOK: Input: default@agby +PREHOOK: Output: default@agby +POSTHOOK: query: drop table aGBY +POSTHOOK: type: DROPTABLE +POSTHOOK: Input: default@agby +POSTHOOK: Output: default@agby