diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java new file mode 100644 index 0000000..345b64a --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/reloperators/HiveGroupingID.java @@ -0,0 +1,25 @@ +package org.apache.hadoop.hive.ql.optimizer.calcite.reloperators; + +import org.apache.calcite.sql.SqlInternalOperator; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.InferTypes; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; + +public class HiveGroupingID extends SqlInternalOperator { + + public static final SqlInternalOperator GROUPING__ID = + new HiveGroupingID(); + + private HiveGroupingID() { + super("$GROUPING__ID", + SqlKind.OTHER, + 0, + false, + ReturnTypes.BIGINT, + InferTypes.BOOLEAN, + OperandTypes.ONE_OR_MORE); + } + +} + diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTBuilder.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTBuilder.java index e6e6fe3..4b1f5c1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTBuilder.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTBuilder.java @@ -21,7 +21,7 @@ import java.text.SimpleDateFormat; import java.util.Calendar; -import org.apache.calcite.avatica.ByteString; +import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rex.RexLiteral; diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java index c02a65e..1b754a5 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java @@ -49,12 +49,14 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.util.BitSets; +import org.apache.calcite.util.ImmutableBitSet; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.metadata.VirtualColumn; import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException; import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveSort; import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter.HiveToken; import org.apache.hadoop.hive.ql.parse.ASTNode; @@ -93,7 +95,7 @@ public static ASTNode convert(final RelNode relNode, List resultSch return c.convert(); } - private ASTNode convert() { + private ASTNode convert() throws CalciteSemanticException { /* * 1. Walk RelNode Graph; note from, where, gBy.. nodes. */ @@ -118,15 +120,34 @@ private ASTNode convert() { * 4. GBy */ if (groupBy != null) { - ASTBuilder b = ASTBuilder.construct(HiveParser.TOK_GROUPBY, "TOK_GROUPBY"); - for (int i : BitSets.toIter(groupBy.getGroupSet())) { + ASTBuilder b = groupBy.indicator ? + ASTBuilder.construct(HiveParser.TOK_GROUPING_SETS, "TOK_GROUPING_SETS") : + ASTBuilder.construct(HiveParser.TOK_GROUPBY, "TOK_GROUPBY"); + + for (int i : groupBy.getGroupSet()) { RexInputRef iRef = new RexInputRef(i, groupBy.getCluster().getTypeFactory() .createSqlType(SqlTypeName.ANY)); b.add(iRef.accept(new RexVisitor(schema))); } - if (!groupBy.getGroupSet().isEmpty()) + //Grouping sets expressions + if(groupBy.indicator) { + for(ImmutableBitSet groupSet: groupBy.getGroupSets()) { + ASTBuilder expression = ASTBuilder.construct( + HiveParser.TOK_GROUPING_SETS_EXPRESSION, "TOK_GROUPING_SETS_EXPRESSION"); + for (int i : groupSet) { + RexInputRef iRef = new RexInputRef(i, groupBy.getCluster().getTypeFactory() + .createSqlType(SqlTypeName.ANY)); + expression.add(iRef.accept(new RexVisitor(schema))); + } + b.add(expression); + } + } + + if (!groupBy.getGroupSet().isEmpty()) { hiveAST.groupBy = b.node(); + } + schema = new Schema(schema, groupBy); } @@ -151,9 +172,33 @@ private ASTNode convert() { int i = 0; for (RexNode r : select.getChildExps()) { - ASTNode selectExpr = ASTBuilder.selectExpr(r.accept( - new RexVisitor(schema, r instanceof RexLiteral)), - select.getRowType().getFieldNames().get(i++)); + // If it is a GroupBy with grouping sets and grouping__id column + // is selected, we reformulate to project that column from + // the output of the GroupBy operator + boolean reformulate = false; + if (groupBy != null && groupBy.indicator) { + RexNode expr = select.getChildExps().get(i); + if (expr instanceof RexCall) { + if ( ((RexCall) expr).getOperator(). + equals(HiveGroupingID.GROUPING__ID)) { + reformulate = true; + } + } + } + ASTNode expr; + if(reformulate) { + RexInputRef iRef = new RexInputRef( + groupBy.getGroupCount() * 2 + groupBy.getAggCallList().size(), + TypeConverter.convert( + VirtualColumn.GROUPINGID.getTypeInfo(), + groupBy.getCluster().getTypeFactory())); + expr = iRef.accept(new RexVisitor(schema)); + } + else { + expr = r.accept(new RexVisitor(schema, r instanceof RexLiteral)); + } + String alias = select.getRowType().getFieldNames().get(i++); + ASTNode selectExpr = ASTBuilder.selectExpr(expr, alias); b.add(selectExpr); } } @@ -232,7 +277,7 @@ private Schema getRowSchema(String tblAlias) { return new Schema(select, tblAlias); } - private QueryBlockInfo convertSource(RelNode r) { + private QueryBlockInfo convertSource(RelNode r) throws CalciteSemanticException { Schema s; ASTNode ast; @@ -554,10 +599,19 @@ public QueryBlockInfo(Schema schema, ASTNode ast) { } Schema(Schema src, Aggregate gBy) { - for (int i : BitSets.toIter(gBy.getGroupSet())) { + for (int i : gBy.getGroupSet()) { ColumnInfo cI = src.get(i); add(cI); } + // If we are using grouping sets, we add the + // fields again, these correspond to the boolean + // grouping in Calcite. They are not used by Hive. + if(gBy.indicator) { + for (int i : gBy.getGroupSet()) { + ColumnInfo cI = src.get(i); + add(cI); + } + } List aggs = gBy.getAggCallList(); for (AggregateCall agg : aggs) { int argCount = agg.getArgList().size(); @@ -572,6 +626,9 @@ public QueryBlockInfo(Schema schema, ASTNode ast) { } add(new ColumnInfo(null, b.node())); } + if(gBy.indicator) { + add(new ColumnInfo(null,VirtualColumn.GROUPINGID.getName())); + } } /** @@ -665,4 +722,5 @@ public static boolean isFlat(RexCall call) { return flat; } + } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java index 56cb4e8..29bb48c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/RexNodeConverter.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.Map; -import org.apache.calcite.avatica.ByteString; +import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index 29be691..6edc163 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -28,6 +28,8 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -187,6 +189,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.cost.HiveVolcanoPlanner; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter; +import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject; import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode; @@ -13416,7 +13419,7 @@ private AggregateCall convertGBAgg(AggInfo agg, RelNode input, List gbC } private RelNode genGBRelNode(List gbExprs, List aggInfoLst, - RelNode srcRel) throws SemanticException { + List groupSets, RelNode srcRel) throws SemanticException { RowResolver gbInputRR = this.relToHiveRR.get(srcRel); ImmutableMap posMap = this.relToHiveColNameCalcitePosMap.get(srcRel); RexNodeConverter converter = new RexNodeConverter(this.cluster, srcRel.getRowType(), @@ -13450,10 +13453,23 @@ private RelNode genGBRelNode(List gbExprs, List aggInfoLs } RelNode gbInputRel = HiveProject.create(srcRel, gbChildProjLst, null); + // Grouping sets: we need to transform them into ImmutableBitSet + // objects for Calcite + List transformedGroupSets = null; + if(groupSets != null && !groupSets.isEmpty()) { + transformedGroupSets = new ArrayList(groupSets.size()); + for(int val: groupSets) { + transformedGroupSets.add(convert(val)); + } + // Calcite expects the grouping sets sorted + Collections.sort(transformedGroupSets, ImmutableBitSet.COMPARATOR); + } + HiveRelNode aggregateRel = null; try { aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), - gbInputRel, false, groupSet, null, aggregateCalls); + gbInputRel, (transformedGroupSets!=null ? true:false), groupSet, + transformedGroupSets, aggregateCalls); } catch (InvalidRelException e) { throw new SemanticException(e); } @@ -13461,6 +13477,19 @@ private RelNode genGBRelNode(List gbExprs, List aggInfoLs return aggregateRel; } + private ImmutableBitSet convert(int value) { + BitSet bits = new BitSet(); + int index = 0; + while (value != 0L) { + if (value % 2 != 0) { + bits.set(index); + } + ++index; + value = value >>> 1; + } + return ImmutableBitSet.FROM_BIT_SET.apply(bits); + } + private void addAlternateGByKeyMappings(ASTNode gByExpr, ColumnInfo colInfo, RowResolver gByInputRR, RowResolver gByRR) { if (gByExpr.getType() == HiveParser.DOT @@ -13593,29 +13622,6 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException RelNode gbRel = null; QBParseInfo qbp = getQBParseInfo(qb); - // 0. for GSets, Cube, Rollup, bail from Calcite path. - if (!qbp.getDestRollups().isEmpty() - || !qbp.getDestGroupingSets().isEmpty() - || !qbp.getDestCubes().isEmpty()) { - String gbyClause = null; - HashMap gbysMap = qbp.getDestToGroupBy(); - if (gbysMap.size() == 1) { - ASTNode gbyAST = gbysMap.entrySet().iterator().next().getValue(); - gbyClause = SemanticAnalyzer.this.ctx.getTokenRewriteStream() - .toString(gbyAST.getTokenStartIndex(), - gbyAST.getTokenStopIndex()); - gbyClause = "in '" + gbyClause + "'."; - } else { - gbyClause = "."; - } - String msg = String.format("Encountered Grouping Set/Cube/Rollup%s" - + " Currently we don't support Grouping Set/Cube/Rollup" - + " clauses in CBO," + " turn off cbo for these queries.", - gbyClause); - LOG.debug(msg); - throw new CalciteSemanticException(msg); - } - // 1. Gather GB Expressions (AST) (GB + Aggregations) // NOTE: Multi Insert is not supported String detsClauseName = qbp.getClauseNames().iterator().next(); @@ -13649,18 +13655,35 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException } } - // 4. Construct aggregation function Info + // 4. GroupingSets, Cube, Rollup + int groupingColsSize = gbExprNDescLst.size(); + List groupingSets = null; + if (!qbp.getDestRollups().isEmpty() + || !qbp.getDestGroupingSets().isEmpty() + || !qbp.getDestCubes().isEmpty()) { + if (qbp.getDestRollups().contains(detsClauseName)) { + groupingSets = getGroupingSetsForRollup(grpByAstExprs.size()); + } else if (qbp.getDestCubes().contains(detsClauseName)) { + groupingSets = getGroupingSetsForCube(grpByAstExprs.size()); + } else if (qbp.getDestGroupingSets().contains(detsClauseName)) { + groupingSets = getGroupingSets(grpByAstExprs, qbp, detsClauseName); + } + + groupingColsSize = groupingColsSize * 2; + } + + // 5. Construct aggregation function Info ArrayList aggregations = new ArrayList(); if (hasAggregationTrees) { assert (aggregationTrees != null); for (ASTNode value : aggregationTrees.values()) { - // 4.1 Determine type of UDAF + // 5.1 Determine type of UDAF // This is the GenericUDAF name String aggName = unescapeIdentifier(value.getChild(0).getText()); boolean isDistinct = value.getType() == HiveParser.TOK_FUNCTIONDI; boolean isAllColumns = value.getType() == HiveParser.TOK_FUNCTIONSTAR; - // 4.2 Convert UDAF Params to ExprNodeDesc + // 5.2 Convert UDAF Params to ExprNodeDesc ArrayList aggParameters = new ArrayList(); for (int i = 1; i < value.getChildCount(); i++) { ASTNode paraExpr = (ASTNode) value.getChild(i); @@ -13675,17 +13698,65 @@ private RelNode genGBLogicalPlan(QB qb, RelNode srcRel) throws SemanticException GenericUDAFInfo udaf = getGenericUDAFInfo(genericUDAFEvaluator, amode, aggParameters); AggInfo aInfo = new AggInfo(aggParameters, udaf.returnType, aggName, isDistinct); aggregations.add(aInfo); - String field = getColumnInternalName(gbExprNDescLst.size() + aggregations.size() - 1); + String field = getColumnInternalName(groupingColsSize + aggregations.size() - 1); outputColumnNames.add(field); groupByOutputRowResolver.putExpression(value, new ColumnInfo(field, aInfo.m_returnType, "", false)); } } - gbRel = genGBRelNode(gbExprNDescLst, aggregations, srcRel); + gbRel = genGBRelNode(gbExprNDescLst, aggregations, groupingSets, srcRel); relToHiveColNameCalcitePosMap.put(gbRel, buildHiveToCalciteColumnMap(groupByOutputRowResolver, gbRel)); this.relToHiveRR.put(gbRel, groupByOutputRowResolver); + + // 6. If GroupingSets, Cube, Rollup were used, we account grouping__id. + // Further, we insert a project operator on top to remove the grouping + // boolean associated to each column in Calcite; this will avoid + // recalculating all column positions when we go back from Calcite to Hive + if(groupingSets != null) { + + RowResolver selectOutputRowResolver = new RowResolver(); + selectOutputRowResolver.setIsExprResolver(true); + RowResolver.add(selectOutputRowResolver, groupByOutputRowResolver); + outputColumnNames = new ArrayList(outputColumnNames); + + // 6.1 List of columns to keep from groupBy operator + List gbOutput = gbRel.getRowType().getFieldList(); + List calciteColLst = new ArrayList(); + for(RelDataTypeField gbOut: gbOutput) { + if(gbOut.getIndex() < gbExprNDescLst.size() || + gbOut.getIndex() >= gbExprNDescLst.size() * 2) { + calciteColLst.add(new RexInputRef(gbOut.getIndex(), gbOut.getType())); + } + } + + // 6.2 Add column for grouping_id function + String field = getColumnInternalName(groupingColsSize++); + ExprNodeDesc inputExpr = new ExprNodeColumnDesc(TypeInfoFactory.stringTypeInfo, + field, null, false); + outputColumnNames.add(field); + selectOutputRowResolver.put(null, VirtualColumn.GROUPINGID.getName(), + new ColumnInfo( + field, + TypeInfoFactory.stringTypeInfo, + null, + true)); + + // 6.3 Compute column for grouping_id function in Calcite + List identifierCols = new ArrayList(); + for(int i = gbExprNDescLst.size(); i < gbExprNDescLst.size() * 2; i++) { + identifierCols.add(new RexInputRef( + i, gbOutput.get(i).getType())); + } + final RexBuilder rexBuilder = cluster.getRexBuilder(); + RexNode groupingID = rexBuilder.makeCall(HiveGroupingID.GROUPING__ID, + identifierCols); + calciteColLst.add(groupingID); + + // Create select + gbRel = this.genSelectRelNode(calciteColLst, selectOutputRowResolver, gbRel); + } } return gbRel; diff --git ql/src/test/queries/clientpositive/groupby_cube1.q ql/src/test/queries/clientpositive/groupby_cube1.q index c12720b..02b41b9 100644 --- ql/src/test/queries/clientpositive/groupby_cube1.q +++ ql/src/test/queries/clientpositive/groupby_cube1.q @@ -13,6 +13,11 @@ SELECT key, val, count(1) FROM T1 GROUP BY key, val with cube; SELECT key, val, count(1) FROM T1 GROUP BY key, val with cube; EXPLAIN +SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube; + +SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube; + +EXPLAIN SELECT key, count(distinct val) FROM T1 GROUP BY key with cube; SELECT key, count(distinct val) FROM T1 GROUP BY key with cube; diff --git ql/src/test/results/clientpositive/groupby_cube1.q.out ql/src/test/results/clientpositive/groupby_cube1.q.out index 7b5d70a..ec6c010 100644 --- ql/src/test/results/clientpositive/groupby_cube1.q.out +++ ql/src/test/results/clientpositive/groupby_cube1.q.out @@ -103,6 +103,90 @@ NULL 18 1 NULL 28 1 NULL NULL 6 PREHOOK: query: EXPLAIN +SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube +PREHOOK: type: QUERY +POSTHOOK: query: EXPLAIN +SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Map Reduce + Map Operator Tree: + TableScan + alias: t1 + Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE + Select Operator + expressions: key (type: string), val (type: string) + outputColumnNames: key, val + Statistics: Num rows: 0 Data size: 30 Basic stats: PARTIAL Column stats: NONE + Group By Operator + aggregations: count(1) + keys: key (type: string), val (type: string), '0' (type: string) + mode: hash + outputColumnNames: _col0, _col1, _col2, _col3 + Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE + Reduce Output Operator + key expressions: _col0 (type: string), _col1 (type: string), _col2 (type: string) + sort order: +++ + Map-reduce partition columns: _col0 (type: string), _col1 (type: string), _col2 (type: string) + Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE + value expressions: _col3 (type: bigint) + Reduce Operator Tree: + Group By Operator + aggregations: count(VALUE._col0) + keys: KEY._col0 (type: string), KEY._col1 (type: string), KEY._col2 (type: string) + mode: mergepartial + outputColumnNames: _col0, _col1, _col2, _col3 + Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE + Select Operator + expressions: _col0 (type: string), _col1 (type: string), _col2 (type: string), _col3 (type: bigint) + outputColumnNames: _col0, _col1, _col2, _col3 + Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 0 Data size: 0 Basic stats: NONE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.TextInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: -1 + Processor Tree: + ListSink + +PREHOOK: query: SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +#### A masked pattern was here #### +POSTHOOK: query: SELECT key, val, GROUPING__ID, count(1) FROM T1 GROUP BY key, val with cube +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +#### A masked pattern was here #### +1 11 3 1 +1 NULL 1 1 +2 12 3 1 +2 NULL 1 1 +3 13 3 1 +3 NULL 1 1 +7 17 3 1 +7 NULL 1 1 +8 18 3 1 +8 28 3 1 +8 NULL 1 2 +NULL 11 2 1 +NULL 12 2 1 +NULL 13 2 1 +NULL 17 2 1 +NULL 18 2 1 +NULL 28 2 1 +NULL NULL 0 6 +PREHOOK: query: EXPLAIN SELECT key, count(distinct val) FROM T1 GROUP BY key with cube PREHOOK: type: QUERY POSTHOOK: query: EXPLAIN