diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index f3930a118c..4e7512ffc4 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -22,28 +22,25 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.NavigableMap; import java.util.Set; +import java.util.TreeMap; -import com.google.common.collect.ImmutableList; import org.apache.calcite.adapter.druid.DruidQuery; import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.rel.RelCollation; -import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexCorrelVariable; -import org.apache.calcite.rex.RexFieldAccess; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexPermuteInputsShuttle; -import org.apache.calcite.rex.RexVisitor; +import org.apache.calcite.rex.*; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql2rel.CorrelationReferenceFinder; import org.apache.calcite.sql2rel.RelFieldTrimmer; @@ -56,6 +53,7 @@ import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; @@ -296,6 +294,74 @@ private static RelNode project(DruidQuery dq, ImmutableBitSet fieldsUsed, return hp; } + private boolean isRexLiteral(final RexNode rexNode) { + if(rexNode instanceof RexLiteral) { + return true; + } + else if(rexNode instanceof RexCall + && ((RexCall)rexNode).getOperator().getKind() == SqlKind.CAST){ + return isRexLiteral(((RexCall)(rexNode)).getOperands().get(0)); + } + else { + return false; + } + } + /** + * Variant of {@link #trimFields(Aggregate, ImmutableBitSet, Set)} for + * {@link org.apache.calcite.rel.logical.LogicalAggregate}. + * This method replaces group by 'constant key' with group by true (boolean) + * if and only if + * group by doesn't have grouping sets + * all keys in group by are constant + * none of the relnode above aggregate refers to these keys + * + * If all of above is true then group by is rewritten and a new project is introduced + * underneath aggregate + * + * This is mainly done so that hive is able to push down queries with + * group by 'constant key with type not supported by druid' into druid + */ + public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, + Set extraFields) { + + Aggregate newAggregate = aggregate; + if (!(aggregate.getIndicatorCount() > 0) + && !(aggregate.getGroupSet().isEmpty()) + && !fieldsUsed.contains(aggregate.getGroupSet())) { + final RelNode input = aggregate.getInput(); + final RelDataType rowType = input.getRowType(); + RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + final List newProjects = new ArrayList<>(); + + final List inputExprs = input.getChildExps(); + if(inputExprs == null || inputExprs.isEmpty()) { + return super.trimFields(newAggregate, fieldsUsed, extraFields); + } + + boolean allConstants = true; + for(int key : aggregate.getGroupSet()) { + if(!isRexLiteral(inputExprs.get(key))){ + allConstants = false; + break; + } + } + + if (allConstants) { + for (int i = 0; i < rowType.getFieldCount(); i++) { + if (aggregate.getGroupSet().get(i)) { + newProjects.add(rexBuilder.makeLiteral(true)); + } else { + newProjects.add(rexBuilder.makeInputRef(input, i)); + } + } + relBuilder.push(input); + relBuilder.project(newProjects); + newAggregate = new HiveAggregate(aggregate.getCluster(), aggregate.getTraitSet(), relBuilder.build(), + aggregate.getGroupSet(), null, aggregate.getAggCallList()); + } + } + return super.trimFields(newAggregate, fieldsUsed, extraFields); + } /** * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for * {@link org.apache.calcite.rel.logical.LogicalProject}. diff --git a/ql/src/test/queries/clientpositive/druidmini_expressions.q b/ql/src/test/queries/clientpositive/druidmini_expressions.q index 273c803154..fc921c37e5 100644 --- a/ql/src/test/queries/clientpositive/druidmini_expressions.q +++ b/ql/src/test/queries/clientpositive/druidmini_expressions.q @@ -138,4 +138,21 @@ SELECT DATE_ADD(cast(`__time` as date), CAST((cdouble / 1000) AS INT)) as date_1 EXPLAIN SELECT ctinyint > 2, count(*) from druid_table_n0 GROUP BY ctinyint > 2; +-- group by should be rewitten and pushed into druid +-- simple gby with single constant key +EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011; +SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011; + +-- gby with multiple constant keys +EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40; +SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40; + +-- group by with constant folded key +EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; +SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; + +-- group by key is referred in select +EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; +SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat; + DROP TABLE druid_table_n0; diff --git a/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out b/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out index 9ffcdd86b2..737cfd85b8 100644 --- a/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out +++ b/ql/src/test/results/clientpositive/druid/druidmini_expressions.q.out @@ -1396,6 +1396,136 @@ STAGE PLANS: outputColumnNames: _col0, _col1 ListSink +PREHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,double + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"not","field":{"type":"selector","dimension":"cstring1","value":"en"}},"aggregations":[{"type":"doubleSum","name":"$f1","fieldName":"cfloat"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: double) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +-39590.24724686146 +PREHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,double + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"not","field":{"type":"selector","dimension":"cstring1","value":"en"}},"aggregations":[{"type":"doubleSum","name":"$f1","fieldName":"cfloat"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: double) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cfloat) FROM druid_table_n0 WHERE cstring1 != 'en' group by 1.011, 3.40 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +-39590.24724686146 +PREHOOK: query: EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames vc,$f1 + druid.fieldTypes boolean,bigint + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"bound","dimension":"cfloat","lower":"0.011","lowerStrict":false,"upper":"0.011","upperStrict":false,"ordering":"numeric"},"aggregations":[{"type":"longSum","name":"$f1","fieldName":"cint"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: $f1 (type: bigint) + outputColumnNames: _col0 + ListSink + +PREHOOK: query: SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### +PREHOOK: query: EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-0 is a root stage + +STAGE PLANS: + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + TableScan + alias: druid_table_n0 + properties: + druid.fieldNames cfloat,$f1 + druid.fieldTypes float,bigint + druid.query.json {"queryType":"groupBy","dataSource":"default.druid_table_n0","granularity":"all","dimensions":[{"type":"default","dimension":"vc","outputName":"vc","outputType":"LONG"}],"virtualColumns":[{"type":"expression","name":"vc","expression":"1","outputType":"LONG"}],"limitSpec":{"type":"default"},"filter":{"type":"bound","dimension":"cfloat","lower":"0.011","lowerStrict":false,"upper":"0.011","upperStrict":false,"ordering":"numeric"},"aggregations":[{"type":"longSum","name":"$f1","fieldName":"cint"}],"postAggregations":[{"type":"expression","name":"cfloat","expression":"0.011"}],"intervals":["1900-01-01T00:00:00.000Z/3000-01-01T00:00:00.000Z"]} + druid.query.type groupBy + Select Operator + expressions: cfloat (type: float), $f1 (type: bigint) + outputColumnNames: _col0, _col1 + ListSink + +PREHOOK: query: SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +PREHOOK: type: QUERY +PREHOOK: Input: default@druid_table_n0 +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: SELECT cfloat, sum(cint) FROM druid_table_n0 WHERE cfloat= 0.011 group by cfloat +POSTHOOK: type: QUERY +POSTHOOK: Input: default@druid_table_n0 +POSTHOOK: Output: hdfs://### HDFS PATH ### PREHOOK: query: DROP TABLE druid_table_n0 PREHOOK: type: DROPTABLE PREHOOK: Input: default@druid_table_n0