diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java index f3930a118c..97c97ac6a7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveRelFieldTrimmer.java @@ -32,9 +32,7 @@ import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.CorrelationId; -import org.apache.calcite.rel.core.Project; -import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.core.*; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; @@ -56,6 +54,7 @@ import org.apache.hadoop.hive.ql.metadata.Table; import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; @@ -296,6 +295,38 @@ private static RelNode project(DruidQuery dq, ImmutableBitSet fieldsUsed, return hp; } + public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, + Set extraFields) { + + Aggregate newAggregate = aggregate; + if(!(aggregate.getGroupSet().cardinality() > 1) + && !(aggregate.getIndicatorCount() > 0) + && !fieldsUsed.contains(aggregate.getGroupSet())){ + // skip if there are more than one group by key + // or there is groupging sets + //if group by key is constant + final RelNode input = aggregate.getInput(); + final RelDataType rowType = input.getRowType(); + RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + final List newProjects = new ArrayList<>(); + + for(int i=0; i