diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index f3930a118c..6ac69c8004 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -22,16 +22,17 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.NavigableMap; import java.util.Set; +import java.util.TreeMap; -import com.google.common.collect.ImmutableList; import org.apache.calcite.adapter.druid.DruidQuery; import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.rel.RelCollation; -import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.TableScan; @@ -41,6 +42,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexPermuteInputsShuttle; import org.apache.calcite.rex.RexVisitor; @@ -56,6 +58,7 @@ import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; @@ -296,6 +299,63 @@ private static RelNode project(DruidQuery dq, ImmutableBitSet fieldsUsed, return hp; } + /** + * Variant of {@link #trimFields(Aggregate, ImmutableBitSet, Set)} for + * {@link org.apache.calcite.rel.logical.LogicalAggregate}. + * This method replaces group by 'constant key' with group by true (boolean) + * if and only if + * group by doesn't have grouping sets + * all keys in group by are constant + * none of the relnode above aggregate refers to these keys + * + * If all of above is true then group by is rewritten and a new project is introduced + * underneath aggregate + * + * This is mainly done so that hive is able to push down queries with + * group by 'constant key with type not supported by druid' into druid + */ + public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, + Set extraFields) { + + Aggregate newAggregate = aggregate; + if (!(aggregate.getIndicatorCount() > 0) + && !fieldsUsed.contains(aggregate.getGroupSet())) { + final RelNode input = aggregate.getInput(); + final RelDataType rowType = input.getRowType(); + RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + final List newProjects = new ArrayList<>(); + + final RelMetadataQuery mq = aggregate.getCluster().getMetadataQuery(); + final RelOptPredicateList predicates = + mq.getPulledUpPredicates(input); + if (predicates != null) { + final NavigableMap map = new TreeMap<>(); + for (int key : aggregate.getGroupSet()) { + final RexInputRef ref = + rexBuilder.makeInputRef(aggregate.getInput(), key); + if (predicates.constantMap.containsKey(ref)) { + map.put(key, predicates.constantMap.get(ref)); + } + } + + // only if all gby keys are constant + if (aggregate.getGroupCount() == map.size()) { + for (int i = 0; i < rowType.getFieldCount(); i++) { + if (aggregate.getGroupSet().get(i)) { + newProjects.add(rexBuilder.makeLiteral(true)); + } else { + newProjects.add(rexBuilder.makeInputRef(input, i)); + } + } + relBuilder.push(input); + relBuilder.project(newProjects); + newAggregate = new HiveAggregate(aggregate.getCluster(), aggregate.getTraitSet(), relBuilder.build(), + aggregate.getGroupSet(), null, aggregate.getAggCallList()); + } + } + } + return super.trimFields(newAggregate, fieldsUsed, extraFields); + } /** * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for * {@link org.apache.calcite.rel.logical.LogicalProject}. diff --git a/ql/src/test/queries/clientpositive/druidmini_expressions.q b/ql/src/test/queries/clientpositive/druidmini_expressions.q index 273c803154..fc921c37e5 100644 --- a/ql/src/test/queries/clientpositive/druidmini_expressions.q +++ b/ql/src/test/queries/clientpositive/druidmini_expressions.q @@ -138,4 +138,21 @@ SELECT DATE_ADD(cast(`__time` as date), CAST((cdouble / 1000) AS INT)) as date_1 EXPLAIN SELECT ctinyint > 2, count(*) from druid_table_n0 GROUP BY ctinyint > 2; +-- group by should be rewitten and pushed into druid +-- simple gby with single constant key +EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011; +SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011; + +-- gby with multiple constant keys +EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40; +SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40; + +-- group by with constant folded key +EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; +SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; + +-- group by key is referred in select +EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; +SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; + DROP TABLE druid_table_n0; diff --git a/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out b/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out index 9ffcdd86b2..737cfd85b8 100644 --- a/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out +++ b/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out @@ -1396,6 +1396,136 @@ STAGE PLANS: outputColumnNames: _col0, _col1 ListSink +PREHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,double + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"not","field":{"type":"selector","dimension":"cstring1","value":"en"}},"aggregations":[{"type":"doubleSum","name":"$f1","fieldName":"cfloat"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: double) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +-39590.24724686146 +PREHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,double + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"not","field":{"type":"selector","dimension":"cstring1","value":"en"}},"aggregations":[{"type":"doubleSum","name":"$f1","fieldName":"cfloat"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: double) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +-39590.24724686146 +PREHOOK: query: EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,bigint + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"bound","dimension":"cfloat","lower":"0.011","lowerStrict":false,"upper":"0.011","upperStrict":false,"ordering":"numeric"},"aggregations":[{"type":"longSum","name":"$f1","fieldName":"cint"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: bigint) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +PREHOOK: query: EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames cfloat,$f1 + druid.fieldTypes float,bigint + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"bound","dimension":"cfloat","lower":"0.011","lowerStrict":false,"upper":"0.011","upperStrict":false,"ordering":"numeric"},"aggregations":[{"type":"longSum","name":"$f1","fieldName":"cint"}],"postAggregations":[{"type":"expression","name":"cfloat","expression":"0.011"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: cfloat (type: float), $f1 (type: bigint) + outputColumnNames: _col0, _col1 + ListSink + +PREHOOK: query: SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### PREHOOK: query: DROP TABLE druid_table_n0 PREHOOK: type: DROPTABLE PREHOOK: Input: default@druid_table_n0