diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java index 345b64a..e3cdb48 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java @@ -1,25 +1,25 @@ package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators; -import org.apache.calcite.sql.SqlInternalOperator; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.InferTypes; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.hadoop.hive.ql.metadata.VirtualColumn; -public class HiveGroupingID extends SqlInternalOperator { +public class HiveGroupingID extends SqlAggFunction { - public static final SqlInternalOperator GROUPING__ID = + public static final SqlAggFunction INSTANCE = new HiveGroupingID(); private HiveGroupingID() { - super("$GROUPING__ID", + super(VirtualColumn.GROUPINGID.getName(), SqlKind.OTHER, - 0, - false, - ReturnTypes.BIGINT, + ReturnTypes.INTEGER, InferTypes.BOOLEAN, - OperandTypes.ONE_OR_MORE); + OperandTypes.NILADIC, + SqlFunctionCategory.USER_DEFINED_FUNCTION); } } - diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java index ea59181..ae74e55 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java @@ -189,31 +189,7 @@ else if (aggregateType == Group.CUBE) { int i = 0; for (RexNode r : select.getChildExps()) { - // If it is a GroupBy with grouping sets and grouping__id column - // is selected, we reformulate to project that column from - // the output of the GroupBy operator - boolean reformulate = false; - if (groupBy != null && groupBy.indicator) { - RexNode expr = select.getChildExps().get(i); - if (expr instanceof RexCall) { - if ( ((RexCall) expr).getOperator(). - equals(HiveGroupingID.GROUPING__ID)) { - reformulate = true; - } - } - } - ASTNode expr; - if(reformulate) { - RexInputRef iRef = new RexInputRef( - groupBy.getGroupCount() * 2 + groupBy.getAggCallList().size(), - TypeConverter.convert( - VirtualColumn.GROUPINGID.getTypeInfo(), - groupBy.getCluster().getTypeFactory())); - expr = iRef.accept(new RexVisitor(schema)); - } - else { - expr = r.accept(new RexVisitor(schema, r instanceof RexLiteral)); - } + ASTNode expr = r.accept(new RexVisitor(schema, r instanceof RexLiteral)); String alias = select.getRowType().getFieldNames().get(i++); ASTNode selectExpr = ASTBuilder.selectExpr(expr, alias); b.add(selectExpr); @@ -631,6 +607,10 @@ public QueryBlockInfo(Schema schema, ASTNode ast) { } List aggs = gBy.getAggCallList(); for (AggregateCall agg : aggs) { + if (agg.getAggregation() == HiveGroupingID.INSTANCE) { + add(new ColumnInfo(null,VirtualColumn.GROUPINGID.getName())); + continue; + } int argCount = agg.getArgList().size(); ASTBuilder b = agg.isDistinct() ? ASTBuilder.construct(HiveParser.TOK_FUNCTIONDI, "TOK_FUNCTIONDI") : argCount == 0 ? ASTBuilder.construct(HiveParser.TOK_FUNCTIONSTAR, @@ -643,9 +623,6 @@ public QueryBlockInfo(Schema schema, ASTNode ast) { } add(new ColumnInfo(null, b.node())); } - if(gBy.indicator) { - add(new ColumnInfo(null,VirtualColumn.GROUPINGID.getName())); - } } /** diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java index d00a988..9596269 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/CalcitePlanner.java @@ -1520,6 +1520,7 @@ private RelNode genGBRelNode(List gbExprs, List aggInfoLs RexNodeConverter converter = new RexNodeConverter(this.cluster, srcRel.getRowType(), posMap, 0, false); + final boolean hasGroupSets = groupSets != null && !groupSets.isEmpty(); final List gbChildProjLst = Lists.newArrayList(); final HashMap rexNodeToPosMap = new HashMap(); final List groupSetPositions = Lists.newArrayList(); @@ -1534,23 +1535,10 @@ private RelNode genGBRelNode(List gbExprs, List aggInfoLs } final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions); - List aggregateCalls = Lists.newArrayList(); - for (AggInfo agg : aggInfoLst) { - aggregateCalls.add(convertGBAgg(agg, srcRel, gbChildProjLst, converter, rexNodeToPosMap, - gbChildProjLst.size())); - } - - if (gbChildProjLst.isEmpty()) { - // This will happen for count(*), in such cases we arbitarily pick - // first element from srcRel - gbChildProjLst.add(this.cluster.getRexBuilder().makeInputRef(srcRel, 0)); - } - RelNode gbInputRel = HiveProject.create(srcRel, gbChildProjLst, null); - // Grouping sets: we need to transform them into ImmutableBitSet // objects for Calcite List transformedGroupSets = null; - if(groupSets != null && !groupSets.isEmpty()) { + if(hasGroupSets) { Set setTransformedGroupSets = new HashSet(groupSets.size()); for(int val: groupSets) { @@ -1561,6 +1549,27 @@ private RelNode genGBRelNode(List gbExprs, List aggInfoLs Collections.sort(transformedGroupSets, ImmutableBitSet.COMPARATOR); } + List aggregateCalls = Lists.newArrayList(); + for (AggInfo agg : aggInfoLst) { + aggregateCalls.add(convertGBAgg(agg, srcRel, gbChildProjLst, converter, rexNodeToPosMap, + gbChildProjLst.size())); + } + if (hasGroupSets) { + // Create GroupingID column + AggregateCall aggCall = new AggregateCall(HiveGroupingID.INSTANCE, + false, new ImmutableList.Builder().build(), + this.cluster.getTypeFactory().createSqlType(SqlTypeName.INTEGER), + HiveGroupingID.INSTANCE.getName()); + aggregateCalls.add(aggCall); + } + + if (gbChildProjLst.isEmpty()) { + // This will happen for count(*), in such cases we arbitarily pick + // first element from srcRel + gbChildProjLst.add(this.cluster.getRexBuilder().makeInputRef(srcRel, 0)); + } + RelNode gbInputRel = HiveProject.create(srcRel, gbChildProjLst, null); + HiveRelNode aggregateRel = null; try { aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), @@ -1783,7 +1792,19 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException } else if (qbp.getDestGroupingSets().contains(detsClauseName)) { groupingSets = getGroupingSets(grpByAstExprs, qbp, detsClauseName); } - groupingColsSize = groupingColsSize * 2; + + final int limit = groupingColsSize * 2; + while (groupingColsSize < limit) { + String field = getColumnInternalName(groupingColsSize); + outputColumnNames.add(field); + groupByOutputRowResolver.put(null, field, + new ColumnInfo( + field, + TypeInfoFactory.booleanTypeInfo, + null, + false)); + groupingColsSize++; + } } // 5. Construct aggregation function Info @@ -1821,55 +1842,23 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException } } - gbRel = genGBRelNode(gbExprNDescLst, aggregations, groupingSets, srcRel); - relToHiveColNameCalcitePosMap.put(gbRel, - buildHiveToCalciteColumnMap(groupByOutputRowResolver, gbRel)); - this.relToHiveRR.put(gbRel, groupByOutputRowResolver); - - // 6. If GroupingSets, Cube, Rollup were used, we account grouping__id. - // Further, we insert a project operator on top to remove the grouping - // boolean associated to each column in Calcite; this will avoid - // recalculating all column positions when we go back from Calcite to Hive + // 6. If GroupingSets, Cube, Rollup were used, we account grouping__id if(groupingSets != null && !groupingSets.isEmpty()) { - RowResolver selectOutputRowResolver = new RowResolver(); - selectOutputRowResolver.setIsExprResolver(true); - RowResolver.add(selectOutputRowResolver, groupByOutputRowResolver); - outputColumnNames = new ArrayList(outputColumnNames); - - // 6.1 List of columns to keep from groupBy operator - List gbOutput = gbRel.getRowType().getFieldList(); - List calciteColLst = new ArrayList(); - for(RelDataTypeField gbOut: gbOutput) { - if(gbOut.getIndex() < gbExprNDescLst.size() || - gbOut.getIndex() >= gbExprNDescLst.size() * 2) { - calciteColLst.add(new RexInputRef(gbOut.getIndex(), gbOut.getType())); - } - } - - // 6.2 Add column for grouping_id function String field = getColumnInternalName(groupingColsSize + aggregations.size()); outputColumnNames.add(field); - selectOutputRowResolver.put(null, VirtualColumn.GROUPINGID.getName(), + groupByOutputRowResolver.put(null, VirtualColumn.GROUPINGID.getName(), new ColumnInfo( field, - TypeInfoFactory.stringTypeInfo, + TypeInfoFactory.intTypeInfo, null, true)); - - // 6.3 Compute column for grouping_id function in Calcite - List identifierCols = new ArrayList(); - for(int i = gbExprNDescLst.size(); i < gbExprNDescLst.size() * 2; i++) { - identifierCols.add(new RexInputRef( - i, gbOutput.get(i).getType())); - } - final RexBuilder rexBuilder = cluster.getRexBuilder(); - RexNode groupingID = rexBuilder.makeCall(HiveGroupingID.GROUPING__ID, - identifierCols); - calciteColLst.add(groupingID); - - // Create select - gbRel = this.genSelectRelNode(calciteColLst, selectOutputRowResolver, gbRel); } + + // 7. We create the group_by operator + gbRel = genGBRelNode(gbExprNDescLst, aggregations, groupingSets, srcRel); + relToHiveColNameCalcitePosMap.put(gbRel, + buildHiveToCalciteColumnMap(groupByOutputRowResolver, gbRel)); + this.relToHiveRR.put(gbRel, groupByOutputRowResolver); } return gbRel; diff --git ql/src/test/results/clientpositive/groupby_cube1.q.out ql/src/test/results/clientpositive/groupby_cube1.q.out index 0dc0159..659ac52 100644 --- ql/src/test/results/clientpositive/groupby_cube1.q.out +++ ql/src/test/results/clientpositive/groupby_cube1.q.out @@ -206,11 +206,11 @@ STAGE PLANS: Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Select Operator expressions: key (type: string), val (type: string) - outputColumnNames: key, val + outputColumnNames: _col0, _col1 Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Group By Operator - aggregations: count(DISTINCT val) - keys: key (type: string), '0' (type: string), val (type: string) + aggregations: count(DISTINCT _col1) + keys: _col0 (type: string), '0' (type: string), _col1 (type: string) mode: hash outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE @@ -388,11 +388,11 @@ STAGE PLANS: Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Select Operator expressions: key (type: string), val (type: string) - outputColumnNames: key, val + outputColumnNames: _col0, _col1 Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Group By Operator - aggregations: count(DISTINCT val) - keys: key (type: string), '0' (type: string), val (type: string) + aggregations: count(DISTINCT _col1) + keys: _col0 (type: string), '0' (type: string), _col1 (type: string) mode: hash outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE diff --git ql/src/test/results/clientpositive/groupby_rollup1.q.out ql/src/test/results/clientpositive/groupby_rollup1.q.out index cb39bc1..faa6583 100644 --- ql/src/test/results/clientpositive/groupby_rollup1.q.out +++ ql/src/test/results/clientpositive/groupby_rollup1.q.out @@ -116,11 +116,11 @@ STAGE PLANS: Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Select Operator expressions: key (type: string), val (type: string) - outputColumnNames: key, val + outputColumnNames: _col0, _col1 Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Group By Operator - aggregations: count(DISTINCT val) - keys: key (type: string), '0' (type: string), val (type: string) + aggregations: count(DISTINCT _col1) + keys: _col0 (type: string), '0' (type: string), _col1 (type: string) mode: hash outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE @@ -292,11 +292,11 @@ STAGE PLANS: Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Select Operator expressions: key (type: string), val (type: string) - outputColumnNames: key, val + outputColumnNames: _col0, _col1 Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE Group By Operator - aggregations: count(DISTINCT val) - keys: key (type: string), '0' (type: string), val (type: string) + aggregations: count(DISTINCT _col1) + keys: _col0 (type: string), '0' (type: string), _col1 (type: string) mode: hash outputColumnNames: _col0, _col1, _col2, _col3 Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE