diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/HiveOptiqUtil.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/HiveOptiqUtil.java index fda53c1..e9b258e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/HiveOptiqUtil.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/HiveOptiqUtil.java @@ -15,6 +15,9 @@ import org.eigenbase.sql.validate.SqlValidatorUtil; import org.eigenbase.util.Pair; +import com.google.common.base.Function; +import com.google.common.collect.Lists; + /** * Generic utility functions needed for Optiq based Hive CBO. */ @@ -41,6 +44,16 @@ return vCols; } + public static List getProjsFromBelowAsInputRef(final RelNode rel) { + List projectList = Lists.transform(rel.getRowType().getFieldList(), + new Function() { + public RexNode apply(RelDataTypeField field) { + return rel.getCluster().getRexBuilder().makeInputRef(field.getType(), field.getIndex()); + } + }); + return projectList; + } + public static List translateBitSetToProjIndx(BitSet projBitSet) { List projIndxLst = new ArrayList(); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java index b51c0d9..dd2bf22 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/optiq/translator/DerivedTableInjector.java @@ -4,6 +4,7 @@ import java.util.List; import org.apache.hadoop.hive.metastore.api.FieldSchema; +import org.apache.hadoop.hive.ql.optimizer.optiq.HiveOptiqUtil; import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveAggregateRel; import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveProjectRel; import org.apache.hadoop.hive.ql.optimizer.optiq.reloperators.HiveSortRel; @@ -139,12 +140,7 @@ private static HiveProjectRel introduceTopLevelSelectInResultSchema(final RelNod } private static RelNode introduceDerivedTable(final RelNode rel) { - List projectList = Lists.transform(rel.getRowType().getFieldList(), - new Function() { - public RexNode apply(RelDataTypeField field) { - return rel.getCluster().getRexBuilder().makeInputRef(field.getType(), field.getIndex()); - } - }); + List projectList = HiveOptiqUtil.getProjsFromBelowAsInputRef(rel); HiveProjectRel select = HiveProjectRel.create(rel.getCluster(), rel, projectList, rel.getRowType(), rel.getCollationList()); diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index 6c293c8..b76a55c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -13091,80 +13091,153 @@ RexWindowBound getBound(BoundarySpec bs, RexNodeConverter converter) { return rwb; } - Pair genWindowingProj(QB qb, ASTNode windowProjAst, int wndSpecASTIndx, - int wndProjPos, RelNode srcRel) throws SemanticException { + int getWindowSpecIndx(ASTNode wndAST) { + int wndASTIndx = -1; + int wi = wndAST.getChildCount() - 1; + if (wi <= 0 || (wndAST.getChild(wi).getType() != HiveParser.TOK_WINDOWSPEC)) { + wi = -1; + } + + return wi; + } + + Pair genWindowingProj(QB qb, WindowExpressionSpec wExpSpec, RelNode srcRel) + throws SemanticException { RexNode w = null; TypeInfo wHiveRetType = null; + + if (wExpSpec instanceof WindowFunctionSpec) { + WindowFunctionSpec wFnSpec = (WindowFunctionSpec) wExpSpec; + ASTNode windowProjAst = wFnSpec.getExpression(); + // TODO: do we need to get to child? + int wndSpecASTIndx = getWindowSpecIndx((ASTNode) windowProjAst); + // 2. Get Hive Aggregate Info + AggInfo hiveAggInfo = getHiveAggInfo(windowProjAst, wndSpecASTIndx - 1, + this.m_relToHiveRR.get(srcRel)); + + // 3. Get Optiq Return type for Agg Fn + wHiveRetType = hiveAggInfo.m_returnType; + RelDataType optiqAggFnRetType = TypeConverter.convert(hiveAggInfo.m_returnType, + this.m_cluster.getTypeFactory()); + + // 4. Convert Agg Fn args to Optiq + ImmutableMap posMap = this.m_relToHiveColNameOptiqPosMap.get(srcRel); + RexNodeConverter converter = new RexNodeConverter(this.m_cluster, srcRel.getRowType(), + posMap, 0, false); + Builder optiqAggFnArgsBldr = ImmutableList. builder(); + Builder optiqAggFnArgsTypeBldr = ImmutableList. builder(); + RexNode rexNd = null; + for (int i = 0; i < hiveAggInfo.m_aggParams.size(); i++) { + optiqAggFnArgsBldr.add(converter.convert(hiveAggInfo.m_aggParams.get(i))); + optiqAggFnArgsTypeBldr.add(TypeConverter.convert(hiveAggInfo.m_aggParams.get(i) + .getTypeInfo(), this.m_cluster.getTypeFactory())); + } + ImmutableList optiqAggFnArgs = optiqAggFnArgsBldr.build(); + ImmutableList optiqAggFnArgsType = optiqAggFnArgsTypeBldr.build(); + + // 5. Get Optiq Agg Fn + final SqlAggFunction optiqAggFn = SqlFunctionConverter.getOptiqAggFn(hiveAggInfo.m_udfName, + optiqAggFnArgsType, optiqAggFnRetType); + + // 6. Translate Window spec + RowResolver inputRR = m_relToHiveRR.get(srcRel); + WindowSpec wndSpec = ((WindowFunctionSpec) wExpSpec).getWindowSpec(); + List partitionKeys = getPartitionKeys(wndSpec.getPartition(), converter, inputRR); + List orderKeys = getOrderKeys(wndSpec.getOrder(), converter, inputRR); + RexWindowBound upperBound = getBound(wndSpec.windowFrame.start, converter); + RexWindowBound lowerBound = getBound(wndSpec.windowFrame.end, converter); + boolean isRows = ((wndSpec.windowFrame.start instanceof RangeBoundarySpec) || (wndSpec.windowFrame.end instanceof RangeBoundarySpec)) ? true + : false; + + w = m_cluster.getRexBuilder().makeOver(optiqAggFnRetType, optiqAggFn, optiqAggFnArgs, + partitionKeys, ImmutableList. copyOf(orderKeys), lowerBound, + upperBound, isRows, true, false); + } else { + // TODO: Convert to Semantic Exception + throw new RuntimeException("Unsupported window Spec"); + } + + return new Pair(w, wHiveRetType); + } + + private RelNode genSelectForWindowing(QB qb, RelNode srcRel) throws SemanticException { + RelNode selOpForWindow = null; QBParseInfo qbp = getQBParseInfo(qb); - WindowingSpec wSpec = qb.getAllWindowingSpecs().values().iterator().next(); + WindowingSpec wSpec = (!qb.getAllWindowingSpecs().isEmpty()) ? qb.getAllWindowingSpecs() + .values().iterator().next() : null; if (wSpec != null) { // 1. Get valid Window Function Spec - // NOTE: WindowSpec uses alias "_wcol0","_wcol1"... for - // WindowFunctionSpec wSpec.validateAndMakeEffective(); - WindowExpressionSpec wExpSpec = wSpec.aliasToWdwExpr.get("_wcol" + wndProjPos); - // TODO: Throw exception if wExpSpec is not of type WindowFunctionSpec - if (wExpSpec instanceof WindowFunctionSpec) { - - // 2. Get Hive Aggregate Info - AggInfo hiveAggInfo = getHiveAggInfo(windowProjAst, wndSpecASTIndx - 1, - this.m_relToHiveRR.get(srcRel)); - - // 3. Get Optiq Return type for Agg Fn - wHiveRetType = hiveAggInfo.m_returnType; - RelDataType optiqAggFnRetType = TypeConverter.convert(hiveAggInfo.m_returnType, - this.m_cluster.getTypeFactory()); - - // 4. Convert Agg Fn args to Optiq - ImmutableMap posMap = this.m_relToHiveColNameOptiqPosMap.get(srcRel); - RexNodeConverter converter = new RexNodeConverter(this.m_cluster, srcRel.getRowType(), - posMap, 0, false); - Builder optiqAggFnArgsBldr = ImmutableList. builder(); - Builder optiqAggFnArgsTypeBldr = ImmutableList. builder(); - RexNode rexNd = null; - for (int i = 0; i < hiveAggInfo.m_aggParams.size(); i++) { - optiqAggFnArgsBldr.add(converter.convert(hiveAggInfo.m_aggParams.get(i))); - optiqAggFnArgsTypeBldr.add(TypeConverter.convert(hiveAggInfo.m_aggParams.get(i) - .getTypeInfo(), this.m_cluster.getTypeFactory())); + List windowExpressions = wSpec.getWindowExpressions(); + + if (windowExpressions != null && !windowExpressions.isEmpty()) { + RowResolver inputRR = this.m_relToHiveRR.get(srcRel); + // 2. Get RexNodes for original Projections from below + List projsForWindowSelOp = new ArrayList( + HiveOptiqUtil.getProjsFromBelowAsInputRef(srcRel)); + + // 3. Construct new Row Resolver with everything from below. + RowResolver out_rwsch = new RowResolver(); + RowResolver.add(out_rwsch, inputRR, 0); + + // 4. Walk through Window Expressions & Construct RexNodes for those, + // Update out_rwsch + for (WindowExpressionSpec wExprSpec : windowExpressions) { + if (out_rwsch.getExpression(wExprSpec.getExpression()) == null) { + Pair wtp = genWindowingProj(qb, wExprSpec, srcRel); + projsForWindowSelOp.add(wtp.getFirst()); + + // 6.2.2 Update Output Row Schema + ColumnInfo oColInfo = new ColumnInfo( + getColumnInternalName(projsForWindowSelOp.size()), wtp.getSecond(), null, false); + String colAlias = wExprSpec.getAlias(); + if (false) { + out_rwsch.checkColumn(null, wExprSpec.getAlias()); + out_rwsch.put(null, wExprSpec.getAlias(), oColInfo); + } else { + out_rwsch.putExpression(wExprSpec.getExpression(), oColInfo); + } + } } - ImmutableList optiqAggFnArgs = optiqAggFnArgsBldr.build(); - ImmutableList optiqAggFnArgsType = optiqAggFnArgsTypeBldr.build(); - - // 5. Get Optiq Agg Fn - final SqlAggFunction optiqAggFn = SqlFunctionConverter.getOptiqAggFn( - hiveAggInfo.m_udfName, optiqAggFnArgsType, optiqAggFnRetType); - - // 6. Translate Window spec - RowResolver inputRR = m_relToHiveRR.get(srcRel); - WindowSpec wndSpec = ((WindowFunctionSpec) wExpSpec).getWindowSpec(); - List partitionKeys = getPartitionKeys(wndSpec.getPartition(), converter, inputRR); - List orderKeys = getOrderKeys(wndSpec.getOrder(), converter, inputRR); - RexWindowBound upperBound = getBound(wndSpec.windowFrame.start, converter); - RexWindowBound lowerBound = getBound(wndSpec.windowFrame.end, converter); - boolean isRows = ((wndSpec.windowFrame.start instanceof RangeBoundarySpec) || (wndSpec.windowFrame.end instanceof RangeBoundarySpec)) ? true - : false; - - w = m_cluster.getRexBuilder().makeOver(optiqAggFnRetType, optiqAggFn, optiqAggFnArgs, - partitionKeys, ImmutableList. copyOf(orderKeys), lowerBound, - upperBound, isRows, true, false); - } else { - // TODO: Convert to Semantic Exception - throw new RuntimeException("Unsupported window Spec"); + + selOpForWindow = genSelectRelNode(projsForWindowSelOp, out_rwsch, srcRel); } } - return new Pair(w, wHiveRetType); + return selOpForWindow; } - int getWindowSpecIndx(ASTNode wndAST) { - int wndASTIndx = -1; - int wi = wndAST.getChildCount() - 1; - if (wi <= 0 || (wndAST.getChild(wi).getType() != HiveParser.TOK_WINDOWSPEC)) { - wi = -1; + private RelNode genSelectRelNode(List optiqColLst, RowResolver out_rwsch, + RelNode srcRel) { + // 1. Build Column Names + // TODO: Should this be external names + ArrayList columnNames = new ArrayList(); + for (int i = 0; i < optiqColLst.size(); i++) { + columnNames.add(getColumnInternalName(i)); } - return wi; + // 2. Prepend column names with '_o_' + /* + * Hive treats names that start with '_c' as internalNames; so change the + * names so we don't run into this issue when converting back to Hive AST. + */ + List oFieldNames = Lists.transform(columnNames, new Function() { + @Override + public String apply(String hName) { + return "_o_" + hName; + } + }); + + // 3 Build Optiq Rel Node for project using converted projections & col + // names + HiveRel selRel = HiveProjectRel.create(srcRel, optiqColLst, oFieldNames); + + // 4. Keep track of colname-to-posmap && RR for new select + this.m_relToHiveColNameOptiqPosMap.put(selRel, buildHiveToOptiqColumnMap(out_rwsch, selRel)); + this.m_relToHiveRR.put(selRel, out_rwsch); + + return selRel; } /** @@ -13174,6 +13247,11 @@ int getWindowSpecIndx(ASTNode wndAST) { * @throws SemanticException */ private RelNode genSelectLogicalPlan(QB qb, RelNode srcRel) throws SemanticException { + + // 0. Generate a Select Node for Windowing + RelNode selForWindow = genSelectForWindowing(qb, srcRel); + srcRel = (selForWindow == null) ? srcRel : selForWindow; + boolean subQuery; ArrayList col_list = new ArrayList(); ArrayList> windowingRexNodes = new ArrayList>(); @@ -13233,40 +13311,7 @@ private RelNode genSelectLogicalPlan(QB qb, RelNode srcRel) throws SemanticExcep ASTNode child = (ASTNode) exprList.getChild(i); boolean hasAsClause = (!isInTransform) && (child.getChildCount() == 2); - // 6.2 Handle windowing spec - int wndSpecASTIndx = -1; - // TODO: is the check ((child.getChildCount() == 1) || hasAsClause) - // needed? - boolean isWindowSpec = (((child.getChildCount() == 1) || hasAsClause) && child.getChild(0) - .getType() == HiveParser.TOK_FUNCTION) ? ((wndSpecASTIndx = getWindowSpecIndx((ASTNode) child - .getChild(0))) > 0) : false; - if (isWindowSpec) { - Pair wtp = genWindowingProj(qb, (ASTNode) child.getChild(0), - wndSpecASTIndx, wndProjPos, srcRel); - windowingRexNodes.add(new Pair(pos, wtp.getFirst())); - - // 6.2.1 Check if window expr has alias - String colAlias = null; - ASTNode tabOrColAst = (ASTNode) child.getChild(1); - if (tabOrColAst != null) - colAlias = BaseSemanticAnalyzer.getUnescapedName(tabOrColAst); - - // 6.2.2 Update Output Row Schema - ColumnInfo oColInfo = new ColumnInfo(getColumnInternalName(pos), wtp.getSecond(), null, - false); - if (colAlias != null) { - out_rwsch.checkColumn(null, colAlias); - out_rwsch.put(null, colAlias, oColInfo); - } else { - out_rwsch.putExpression(child, oColInfo); - } - - pos = Integer.valueOf(pos.intValue() + 1); - wndProjPos++; - continue; - } - - // 6.3 EXPR AS (ALIAS,...) parses, but is only allowed for UDTF's + // 6.2 EXPR AS (ALIAS,...) parses, but is only allowed for UDTF's // This check is not needed and invalid when there is a transform b/c // the // AST's are slightly different. @@ -13279,14 +13324,14 @@ private RelNode genSelectLogicalPlan(QB qb, RelNode srcRel) throws SemanticExcep String tabAlias; String colAlias; - // 6.4 Get rid of TOK_SELEXPR + // 6.3 Get rid of TOK_SELEXPR expr = (ASTNode) child.getChild(0); String[] colRef = getColAlias(child, autogenColAliasPrfxLbl, inputRR, autogenColAliasPrfxIncludeFuncName, i); tabAlias = colRef[0]; colAlias = colRef[1]; - // 6.5 Build ExprNode corresponding to colums + // 6.4 Build ExprNode corresponding to colums if (expr.getType() == HiveParser.TOK_ALLCOLREF) { pos = genColListRegex(".*", expr.getChildCount() == 0 ? null : getUnescapedName((ASTNode) expr.getChild(0)) @@ -13346,12 +13391,7 @@ private RelNode genSelectLogicalPlan(QB qb, RelNode srcRel) throws SemanticExcep } selectStar = selectStar && exprList.getChildCount() == posn + 1; - ArrayList columnNames = new ArrayList(); - for (int i = 0; i < col_list.size(); i++) { - columnNames.add(getColumnInternalName(i)); - } - - // 8. Convert Hive projections to Optiq + // 7. Convert Hive projections to Optiq List optiqColLst = new ArrayList(); RexNodeConverter rexNodeConv = new RexNodeConverter(m_cluster, srcRel.getRowType(), buildHiveColNameToInputPosMap(col_list, inputRR), 0, false); @@ -13359,31 +13399,8 @@ private RelNode genSelectLogicalPlan(QB qb, RelNode srcRel) throws SemanticExcep optiqColLst.add(rexNodeConv.convert(colExpr)); } - // 9. Add windowing Proj Names - for (Pair wndPair : windowingRexNodes) { - optiqColLst.add(wndPair.getFirst(), wndPair.getSecond()); - columnNames.add(getColumnInternalName(wndPair.getFirst())); - } - - // 10. Construct Hive Project Rel - // 10.1. Prepend column names with '_o_' - /* - * Hive treats names that start with '_c' as internalNames; so change the - * names so we don't run into this issue when converting back to Hive AST. - */ - List oFieldNames = Lists.transform(columnNames, new Function() { - @Override - public String apply(String hName) { - return "_o_" + hName; - } - }); - // 10.2 Build Optiq Rel Node for project using converted projections & col - // names - HiveRel selRel = HiveProjectRel.create(srcRel, optiqColLst, oFieldNames); - - // 11. Keep track of colname-to-posmap && RR for new select - this.m_relToHiveColNameOptiqPosMap.put(selRel, buildHiveToOptiqColumnMap(out_rwsch, selRel)); - this.m_relToHiveRR.put(selRel, out_rwsch); + // 8. Build Optiq Rel + RelNode selRel = genSelectRelNode(optiqColLst, out_rwsch, srcRel); return selRel; } diff --git a/ql/src/test/queries/clientpositive/cbo_correctness.q b/ql/src/test/queries/clientpositive/cbo_correctness.q index f80cbfd..ddba40a 100644 --- a/ql/src/test/queries/clientpositive/cbo_correctness.q +++ b/ql/src/test/queries/clientpositive/cbo_correctness.q @@ -205,6 +205,8 @@ select count(c_int) over() from t1; select count(c_int) over(), sum(c_float) over(), max(c_int) over(), min(c_int) over(), row_number() over(), rank() over(), dense_rank() over(), percent_rank() over(), lead(c_int, 2, c_int) over(), lag(c_float, 2, c_float) over() from t1; select * from (select count(c_int) over(), sum(c_float) over(), max(c_int) over(), min(c_int) over(), row_number() over(), rank() over(), dense_rank() over(), percent_rank() over(), lead(c_int, 2, c_int) over(), lag(c_float, 2, c_float) over() from t1) t1; select x from (select count(c_int) over() as x, sum(c_float) over() from t1) t1; +select 1+sum(c_int) over() from t1; +select sum(c_int)+sum(sum(c_int)) over() from t1; select * from (select max(c_int) over (partition by key order by value Rows UNBOUNDED PRECEDING), min(c_int) over (partition by key order by value rows current row), count(c_int) over(partition by key order by value ROWS 1 PRECEDING), avg(value) over (partition by key order by value Rows between unbounded preceding and unbounded following), sum(value) over (partition by key order by value rows between unbounded preceding and current row), avg(c_float) over (partition by key order by value Rows between 1 preceding and unbounded following), sum(c_float) over (partition by key order by value rows between 1 preceding and current row), max(c_float) over (partition by key order by value rows between 1 preceding and unbounded following), min(c_float) over (partition by key order by value rows between 1 preceding and 1 following) from t1) t1; select i, a, h, b, c, d, e, f, g, a as x, a +1 as y from (select max(c_int) over (partition by key order by value range UNBOUNDED PRECEDING) a, min(c_int) over (partition by key order by value range current row) b, count(c_int) over(partition by key order by value range 1 PRECEDING) c, avg(value) over (partition by key order by value range between unbounded preceding and unbounded following) d, sum(value) over (partition by key order by value range between unbounded preceding and current row) e, avg(c_float) over (partition by key order by value range between 1 preceding and unbounded following) f, sum(c_float) over (partition by key order by value range between 1 preceding and current row) g, max(c_float) over (partition by key order by value range between 1 preceding and unbounded following) h, min(c_float) over (partition by key order by value range between 1 preceding and 1 following) i from t1) t1; diff --git a/ql/src/test/results/clientpositive/cbo_correctness.q.out b/ql/src/test/results/clientpositive/cbo_correctness.q.out index 8145cc0..1c302e9 100644 --- a/ql/src/test/results/clientpositive/cbo_correctness.q.out +++ b/ql/src/test/results/clientpositive/cbo_correctness.q.out @@ -16380,6 +16380,47 @@ POSTHOOK: Input: default@t1@dt=2014 18 18 18 +PREHOOK: query: select 1+sum(c_int) over() from t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +PREHOOK: Input: default@t1@dt=2014 +#### A masked pattern was here #### +POSTHOOK: query: select 1+sum(c_int) over() from t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +POSTHOOK: Input: default@t1@dt=2014 +#### A masked pattern was here #### +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +19 +PREHOOK: query: select sum(c_int)+sum(sum(c_int)) over() from t1 +PREHOOK: type: QUERY +PREHOOK: Input: default@t1 +PREHOOK: Input: default@t1@dt=2014 +#### A masked pattern was here #### +POSTHOOK: query: select sum(c_int)+sum(sum(c_int)) over() from t1 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@t1 +POSTHOOK: Input: default@t1@dt=2014 +#### A masked pattern was here #### +36 PREHOOK: query: select * from (select max(c_int) over (partition by key order by value Rows UNBOUNDED PRECEDING), min(c_int) over (partition by key order by value rows current row), count(c_int) over(partition by key order by value ROWS 1 PRECEDING), avg(value) over (partition by key order by value Rows between unbounded preceding and unbounded following), sum(value) over (partition by key order by value rows between unbounded preceding and current row), avg(c_float) over (partition by key order by value Rows between 1 preceding and unbounded following), sum(c_float) over (partition by key order by value rows between 1 preceding and current row), max(c_float) over (partition by key order by value rows between 1 preceding and unbounded following), min(c_float) over (partition by key order by value rows between 1 preceding and 1 following) from t1) t1 PREHOOK: type: QUERY PREHOOK: Input: default@t1