diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/QBSubQuery.java ql/src/java/org/apache/hadoop/hive/ql/parse/QBSubQuery.java index b9c7e6f..92ccbea 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/QBSubQuery.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/QBSubQuery.java @@ -2,6 +2,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Set; import java.util.Stack; import org.apache.hadoop.hive.common.ObjectPair; @@ -375,12 +376,14 @@ void addCorrExpr(ASTNode corrExpr) { } public ASTNode getSubQueryAST() { - return SubQueryUtils.buildNotInNullCheckQuery( + ASTNode ast = SubQueryUtils.buildNotInNullCheckQuery( QBSubQuery.this.getSubQueryAST(), QBSubQuery.this.getAlias(), CNT_ALIAS, subQryCorrExprs, sqRR); + SubQueryUtils.setOriginDeep(ast, QBSubQuery.this.originalSQASTOrigin); + return ast; } public String getAlias() { @@ -392,8 +395,10 @@ public JoinType getJoinType() { } public ASTNode getJoinConditionAST() { - return + ASTNode ast = SubQueryUtils.buildNotInNullJoinCond(getAlias(), CNT_ALIAS); + SubQueryUtils.setOriginDeep(ast, QBSubQuery.this.originalSQASTOrigin); + return ast; } public QBSubQuery getSubQuery() { @@ -475,7 +480,8 @@ public SubQueryTypeDef getOperator() { void validateAndRewriteAST(RowResolver outerQueryRR, boolean forHavingClause, - String outerQueryAlias) throws SemanticException { + String outerQueryAlias, + Set outerQryAliases) throws SemanticException { ASTNode selectClause = (ASTNode) subQueryAST.getChild(1).getChild(1); @@ -483,6 +489,44 @@ void validateAndRewriteAST(RowResolver outerQueryRR, if ( selectClause.getChild(0).getType() == HiveParser.TOK_HINTLIST ) { selectExprStart = 1; } + + /* + * Restriction.16.s :: Correlated Expression in Outer Query must not contain + * unqualified column references. + */ + if ( parentQueryExpression != null && !forHavingClause ) { + ASTNode u = SubQueryUtils.hasUnQualifiedColumnReferences(parentQueryExpression); + if ( u != null ) { + subQueryAST.setOrigin(originalSQASTOrigin); + throw new SemanticException(ErrorMsg.UNSUPPORTED_SUBQUERY_EXPRESSION.getMsg( + u, "Correlating expression cannot contain unqualified column references.")); + } + } + + /* + * Restriction 17.s :: SubQuery cannot use the same table alias as one used in + * the Outer Query. + */ + List sqAliases = SubQueryUtils.getTableAliasesInSubQuery(this); + String sharedAlias = null; + for(String s : sqAliases ) { + if ( outerQryAliases.contains(s) ) { + sharedAlias = s; + } + } + if ( sharedAlias != null) { + ASTNode whereClause = SubQueryUtils.subQueryWhere(subQueryAST); + + if ( whereClause != null ) { + ASTNode u = SubQueryUtils.hasUnQualifiedColumnReferences(whereClause); + if ( u != null ) { + subQueryAST.setOrigin(originalSQASTOrigin); + throw new SemanticException(ErrorMsg.UNSUPPORTED_SUBQUERY_EXPRESSION.getMsg( + u, "SubQuery cannot use the table alias: " + sharedAlias + "; " + + "this is also an alias in the Outer Query and SubQuery contains a unqualified column reference")); + } + } + } /* * Check.5.h :: For In and Not In the SubQuery must implicitly or @@ -491,6 +535,7 @@ void validateAndRewriteAST(RowResolver outerQueryRR, if ( operator.getType() != SubQueryType.EXISTS && operator.getType() != SubQueryType.NOT_EXISTS && selectClause.getChildCount() - selectExprStart > 1 ) { + subQueryAST.setOrigin(originalSQASTOrigin); throw new SemanticException(ErrorMsg.INVALID_SUBQUERY_EXPRESSION.getMsg( subQueryAST, "SubQuery can contain only 1 item in Select List.")); } @@ -675,11 +720,7 @@ private void rewrite(RowResolver parentQueryRR, boolean forHavingClause, String outerQueryAlias) throws SemanticException { ASTNode selectClause = (ASTNode) subQueryAST.getChild(1).getChild(1); - ASTNode whereClause = null; - if ( subQueryAST.getChild(1).getChildCount() > 2 && - subQueryAST.getChild(1).getChild(2).getType() == HiveParser.TOK_WHERE ) { - whereClause = (ASTNode) subQueryAST.getChild(1).getChild(2); - } + ASTNode whereClause = SubQueryUtils.subQueryWhere(subQueryAST); if ( whereClause == null ) { return; diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index b9cd65c..bccacdc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -2021,7 +2021,7 @@ private Operator genFilterPlan(ASTNode searchCond, QB qb, Operator input, aliasToOpInfo.put(havingInputAlias, input); } - subQuery.validateAndRewriteAST(inputRR, forHavingClause, havingInputAlias); + subQuery.validateAndRewriteAST(inputRR, forHavingClause, havingInputAlias, aliasToOpInfo.keySet()); QB qbSQ = new QB(subQuery.getOuterQueryId(), subQuery.getAlias(), true); Operator sqPlanTopOp = genPlanForSubQueryPredicate(qbSQ, subQuery); diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/SubQueryUtils.java ql/src/java/org/apache/hadoop/hive/ql/parse/SubQueryUtils.java index 8c03c7d..8ffbe07 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/SubQueryUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/SubQueryUtils.java @@ -280,6 +280,32 @@ private static void getTableAliasesInSubQuery(ASTNode joinNode, List ali getTableAliasesInSubQuery(right, aliases); } } + + static ASTNode hasUnQualifiedColumnReferences(ASTNode ast) { + int type = ast.getType(); + if ( type == HiveParser.DOT ) { + return null; + } + else if ( type == HiveParser.TOK_TABLE_OR_COL ) { + return ast; + } + + for(int i=0; i < ast.getChildCount(); i++ ) { + ASTNode c = hasUnQualifiedColumnReferences((ASTNode) ast.getChild(i)); + if ( c != null ) { + return c; + } + } + return null; + } + + static ASTNode subQueryWhere(ASTNode subQueryAST) { + if ( subQueryAST.getChild(1).getChildCount() > 2 && + subQueryAST.getChild(1).getChild(2).getType() == HiveParser.TOK_WHERE ) { + return (ASTNode) subQueryAST.getChild(1).getChild(2); + } + return null; + } /* * construct the ASTNode for the SQ column that will join with the OuterQuery Expression. diff --git ql/src/test/org/apache/hadoop/hive/ql/parse/TestQBSubQuery.java ql/src/test/org/apache/hadoop/hive/ql/parse/TestQBSubQuery.java new file mode 100644 index 0000000..7e57471 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/parse/TestQBSubQuery.java @@ -0,0 +1,117 @@ +package org.apache.hadoop.hive.ql.parse; + +import java.util.ArrayList; +import java.util.List; + +import junit.framework.Assert; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.session.SessionState; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestQBSubQuery { + static HiveConf conf; + + private static String IN_QUERY = " select * " + + "from src " + + "where src.key in (select key from src s1 where s1.key > '9' and s1.value > '9') "; + + private static String IN_QUERY2 = " select * " + + "from src " + + "where src.key in (select key from src s1 where s1.key > '9' and s1.value > '9') and value > '9'"; + + private static String QUERY3 = "select p_mfgr, min(p_size), rank() over(partition by p_mfgr) as r from part group by p_mfgr"; + + ParseDriver pd; + SemanticAnalyzer sA; + + @BeforeClass + public static void initialize() { + conf = new HiveConf(SemanticAnalyzer.class); + SessionState.start(conf); + } + + @Before + public void setup() throws SemanticException { + pd = new ParseDriver(); + sA = new SemanticAnalyzer(conf); + } + + ASTNode parse(String query) throws ParseException { + ASTNode nd = pd.parse(query); + return (ASTNode) nd.getChild(0); + } + + @Test + public void testExtractSubQueries() throws Exception { + ASTNode ast = parse(IN_QUERY); + ASTNode where = where(ast); + List sqs = SubQueryUtils.findSubQueries((ASTNode) where.getChild(0)); + Assert.assertEquals(sqs.size(), 1); + + ASTNode sq = sqs.get(0); + Assert.assertEquals(sq.toStringTree(), + "(TOK_SUBQUERY_EXPR (TOK_SUBQUERY_OP in) (TOK_QUERY (TOK_FROM (TOK_TABREF (TOK_TABNAME src) s1)) (TOK_INSERT (TOK_DESTINATION (TOK_DIR TOK_TMP_FILE)) (TOK_SELECT (TOK_SELEXPR (TOK_TABLE_OR_COL key))) (TOK_WHERE (and (> (. (TOK_TABLE_OR_COL s1) key) '9') (> (. (TOK_TABLE_OR_COL s1) value) '9'))))) (. (TOK_TABLE_OR_COL src) key))" + ); + } + + @Test + public void testExtractConjuncts() throws Exception { + ASTNode ast = parse(IN_QUERY); + ASTNode where = where(ast); + List sqs = SubQueryUtils.findSubQueries((ASTNode) where.getChild(0)); + ASTNode sq = sqs.get(0); + + ASTNode sqWhere = where((ASTNode) sq.getChild(1)); + + List conjuncts = new ArrayList(); + SubQueryUtils.extractConjuncts((ASTNode) sqWhere.getChild(0), conjuncts); + Assert.assertEquals(conjuncts.size(), 2); + + Assert.assertEquals(conjuncts.get(0).toStringTree(), "(> (. (TOK_TABLE_OR_COL s1) key) '9')"); + Assert.assertEquals(conjuncts.get(1).toStringTree(), "(> (. (TOK_TABLE_OR_COL s1) value) '9')"); + } + + @Test + public void testRewriteOuterQueryWhere() throws Exception { + ASTNode ast = parse(IN_QUERY); + ASTNode where = where(ast); + List sqs = SubQueryUtils.findSubQueries((ASTNode) where.getChild(0)); + ASTNode sq = sqs.get(0); + + ASTNode newWhere = SubQueryUtils.rewriteParentQueryWhere((ASTNode) where.getChild(0), sq); + Assert.assertEquals(newWhere.toStringTree(), "(= 1 1)"); + } + + @Test + public void testRewriteOuterQueryWhere2() throws Exception { + ASTNode ast = parse(IN_QUERY2); + ASTNode where = where(ast); + List sqs = SubQueryUtils.findSubQueries((ASTNode) where.getChild(0)); + ASTNode sq = sqs.get(0); + + ASTNode newWhere = SubQueryUtils.rewriteParentQueryWhere((ASTNode) where.getChild(0), sq); + Assert.assertEquals(newWhere.toStringTree(), "(> (TOK_TABLE_OR_COL value) '9')"); + } + + @Test + public void testCheckAggOrWindowing() throws Exception { + ASTNode ast = parse(QUERY3); + ASTNode select = select(ast); + + Assert.assertEquals(SubQueryUtils.checkAggOrWindowing((ASTNode) select.getChild(0)), 0); + Assert.assertEquals(SubQueryUtils.checkAggOrWindowing((ASTNode) select.getChild(1)), 1); + Assert.assertEquals(SubQueryUtils.checkAggOrWindowing((ASTNode) select.getChild(2)), 2); + } + + private ASTNode where(ASTNode qry) { + return (ASTNode) qry.getChild(1).getChild(2); + } + + private ASTNode select(ASTNode qry) { + return (ASTNode) qry.getChild(1).getChild(1); + } + +} diff --git ql/src/test/queries/clientnegative/subquery_shared_alias.q ql/src/test/queries/clientnegative/subquery_shared_alias.q new file mode 100644 index 0000000..d442f07 --- /dev/null +++ ql/src/test/queries/clientnegative/subquery_shared_alias.q @@ -0,0 +1,6 @@ + + +select * +from src +where src.key in (select key from src where key > '9') +; \ No newline at end of file diff --git ql/src/test/queries/clientnegative/subquery_unqual_corr_expr.q ql/src/test/queries/clientnegative/subquery_unqual_corr_expr.q new file mode 100644 index 0000000..99ff9ca --- /dev/null +++ ql/src/test/queries/clientnegative/subquery_unqual_corr_expr.q @@ -0,0 +1,6 @@ + + +select * +from src +where key in (select key from src) +; \ No newline at end of file diff --git ql/src/test/queries/clientpositive/subquery_views.q ql/src/test/queries/clientpositive/subquery_views.q new file mode 100644 index 0000000..e1e5373 --- /dev/null +++ ql/src/test/queries/clientpositive/subquery_views.q @@ -0,0 +1,48 @@ + + +-- exists test +create view cv1 as +select * +from src b +where exists + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_9') +; + +select * +from cv1 where cv1.key in (select key from cv1 c where c.key > '95'); +; + + +-- not in test +create view cv2 as +select * +from src b +where b.key not in + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_11' + ) +; + +select * +from cv2 where cv2.key in (select key from cv2 c where c.key < '11'); +; + +-- in where + having +create view cv3 as +select key, value, count(*) +from src b +where b.key in (select key from src where src.key > '8') +group by key, value +having count(*) in (select count(*) from src s1 where s1.key > '9' group by s1.key ) +; + +select * from cv3; + + +-- join of subquery views +select * +from cv3 +where cv3.key in (select key from cv1); \ No newline at end of file diff --git ql/src/test/results/clientnegative/subquery_shared_alias.q.out ql/src/test/results/clientnegative/subquery_shared_alias.q.out new file mode 100644 index 0000000..2d94bef --- /dev/null +++ ql/src/test/results/clientnegative/subquery_shared_alias.q.out @@ -0,0 +1 @@ +FAILED: SemanticException [Error 10249]: Line 5:44 Unsupported SubQuery Expression 'key': SubQuery cannot use the table alias: src; this is also an alias in the Outer Query and SubQuery contains a unqualified column reference diff --git ql/src/test/results/clientnegative/subquery_unqual_corr_expr.q.out ql/src/test/results/clientnegative/subquery_unqual_corr_expr.q.out new file mode 100644 index 0000000..f69a538 --- /dev/null +++ ql/src/test/results/clientnegative/subquery_unqual_corr_expr.q.out @@ -0,0 +1 @@ +FAILED: SemanticException [Error 10249]: Line 5:6 Unsupported SubQuery Expression 'key': Correlating expression cannot contain unqualified column references. diff --git ql/src/test/results/clientpositive/subquery_views.q.out ql/src/test/results/clientpositive/subquery_views.q.out new file mode 100644 index 0000000..37d29a7 --- /dev/null +++ ql/src/test/results/clientpositive/subquery_views.q.out @@ -0,0 +1,145 @@ +PREHOOK: query: -- exists test +create view cv1 as +select * +from src b +where exists + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_9') +PREHOOK: type: CREATEVIEW +POSTHOOK: query: -- exists test +create view cv1 as +select * +from src b +where exists + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_9') +POSTHOOK: type: CREATEVIEW +POSTHOOK: Output: default@cv1 +PREHOOK: query: select * +from cv1 where cv1.key in (select key from cv1 c where c.key > '95') +PREHOOK: type: QUERY +PREHOOK: Input: default@cv1 +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select * +from cv1 where cv1.key in (select key from cv1 c where c.key > '95') +POSTHOOK: type: QUERY +POSTHOOK: Input: default@cv1 +POSTHOOK: Input: default@src +#### A masked pattern was here #### +96 val_96 +97 val_97 +97 val_97 +98 val_98 +98 val_98 +PREHOOK: query: -- not in test +create view cv2 as +select * +from src b +where b.key not in + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_11' + ) +PREHOOK: type: CREATEVIEW +POSTHOOK: query: -- not in test +create view cv2 as +select * +from src b +where b.key not in + (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_11' + ) +POSTHOOK: type: CREATEVIEW +POSTHOOK: Output: default@cv2 +PREHOOK: query: select * +from cv2 where cv2.key in (select key from cv2 c where c.key < '11') +PREHOOK: type: QUERY +PREHOOK: Input: default@cv2 +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select * +from cv2 where cv2.key in (select key from cv2 c where c.key < '11') +POSTHOOK: type: QUERY +POSTHOOK: Input: default@cv2 +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 val_0 +0 val_0 +0 val_0 +10 val_10 +100 val_100 +100 val_100 +103 val_103 +103 val_103 +104 val_104 +104 val_104 +105 val_105 +PREHOOK: query: -- in where + having +create view cv3 as +select key, value, count(*) +from src b +where b.key in (select key from src where src.key > '8') +group by key, value +having count(*) in (select count(*) from src s1 where s1.key > '9' group by s1.key ) +PREHOOK: type: CREATEVIEW +POSTHOOK: query: -- in where + having +create view cv3 as +select key, value, count(*) +from src b +where b.key in (select key from src where src.key > '8') +group by key, value +having count(*) in (select count(*) from src s1 where s1.key > '9' group by s1.key ) +POSTHOOK: type: CREATEVIEW +POSTHOOK: Output: default@cv3 +PREHOOK: query: select * from cv3 +PREHOOK: type: QUERY +PREHOOK: Input: default@cv3 +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select * from cv3 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@cv3 +POSTHOOK: Input: default@src +#### A masked pattern was here #### +80 val_80 1 +82 val_82 1 +85 val_85 1 +86 val_86 1 +87 val_87 1 +9 val_9 1 +92 val_92 1 +96 val_96 1 +83 val_83 2 +84 val_84 2 +97 val_97 2 +95 val_95 2 +98 val_98 2 +90 val_90 3 +PREHOOK: query: -- join of subquery views +select * +from cv3 +where cv3.key in (select key from cv1) +PREHOOK: type: QUERY +PREHOOK: Input: default@cv1 +PREHOOK: Input: default@cv3 +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- join of subquery views +select * +from cv3 +where cv3.key in (select key from cv1) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@cv1 +POSTHOOK: Input: default@cv3 +POSTHOOK: Input: default@src +#### A masked pattern was here #### +90 val_90 3 +92 val_92 1 +95 val_95 2 +96 val_96 1 +97 val_97 2 +98 val_98 2