diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveAggregate.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveAggregate.java index bea5eec..903cc19 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveAggregate.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveAggregate.java @@ -18,7 +18,9 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators; import java.util.List; +import java.util.Set; +import org.apache.calcite.linq4j.Ord; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptPlanner; @@ -29,10 +31,16 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.RelFactories.AggregateFactory; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.IntList; import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; public class HiveAggregate extends Aggregate implements HiveRelNode { @@ -81,6 +89,56 @@ public boolean isBucketedInput() { containsAll(groupSet.asList()); } + @Override + protected RelDataType deriveRowType() { + return deriveRowType(getCluster().getTypeFactory(), getInput().getRowType(), + indicator, groupSet, groupSets, aggCalls); + } + + public static RelDataType deriveRowType(RelDataTypeFactory typeFactory, + final RelDataType inputRowType, boolean indicator, + ImmutableBitSet groupSet, List groupSets, + final List aggCalls) { + final IntList groupList = groupSet.toList(); + assert groupList.size() == groupSet.cardinality(); + final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder(); + final List fieldList = inputRowType.getFieldList(); + final Set containedNames = Sets.newHashSet(); + for (int groupKey : groupList) { + containedNames.add(fieldList.get(groupKey).getName()); + builder.add(fieldList.get(groupKey)); + } + if (indicator) { + for (int groupKey : groupList) { + final RelDataType booleanType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.BOOLEAN), false); + String name = "i$" + fieldList.get(groupKey).getName(); + int i = 0; + while (containedNames.contains(name)) { + name += "_" + i++; + } + containedNames.add(name); + builder.add(name, booleanType); + } + } + for (Ord aggCall : Ord.zip(aggCalls)) { + String name; + if (aggCall.e.name != null) { + name = aggCall.e.name; + } else { + name = "$f" + (groupList.size() + aggCall.i); + } + int i = 0; + while (containedNames.contains(name)) { + name += "_" + i++; + } + containedNames.add(name); + builder.add(name, aggCall.e.type); + } + return builder.build(); + } + private static class HiveAggRelFactory implements AggregateFactory { @Override