diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/HiveParser.g b/ql/src/java/org/apache/hadoop/hive/ql/parse/HiveParser.g index 78bc87c..bc95c46 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/HiveParser.g +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/HiveParser.g @@ -2999,8 +2999,8 @@ whenNotMatchedClause @init { pushMsg("WHEN NOT MATCHED clause", state); } @after { popMsg(state); } : - KW_WHEN KW_NOT KW_MATCHED (KW_AND expression)? KW_THEN KW_INSERT KW_VALUES valueRowConstructor -> - ^(TOK_NOT_MATCHED ^(TOK_INSERT valueRowConstructor) expression?) + KW_WHEN KW_NOT KW_MATCHED (KW_AND expression)? KW_THEN KW_INSERT (targetCols=columnParenthesesList)? KW_VALUES valueRowConstructor -> + ^(TOK_NOT_MATCHED ^(TOK_INSERT $targetCols? valueRowConstructor) expression?) ; whenMatchedAndClause @init { pushMsg("WHEN MATCHED AND clause", state); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java index e8823e1..911b7d5 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java @@ -993,7 +993,6 @@ WHEN NOT MATCHED THEN INSERT VALUES(source.a2, source.b2) insClauseIdx < rewrittenTree.getChildCount() - (validating ? 1 : 0/*skip cardinality violation clause*/); insClauseIdx++, whenClauseIdx++) { //we've added Insert clauses in order or WHEN items in whenClauses - ASTNode insertClause = (ASTNode) rewrittenTree.getChild(insClauseIdx); switch (getWhenClauseOperation(whenClauses.get(whenClauseIdx)).getType()) { case HiveParser.TOK_INSERT: rewrittenCtx.addDestNamePrefix(insClauseIdx, Context.DestClausePrefix.INSERT); @@ -1185,7 +1184,7 @@ private String handleUpdate(ASTNode whenMatchedUpdateClause, StringBuilder rewri String targetName = getSimpleTableName(target); rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target)); addPartitionColsToInsert(targetTable.getPartCols(), rewrittenQueryStr); - rewrittenQueryStr.append(" -- update clause\n select "); + rewrittenQueryStr.append(" -- update clause\n SELECT "); if (hintStr != null) { rewrittenQueryStr.append(hintStr); } @@ -1226,7 +1225,7 @@ private String handleUpdate(ASTNode whenMatchedUpdateClause, StringBuilder rewri if(deleteExtraPredicate != null) { rewrittenQueryStr.append(" AND NOT(").append(deleteExtraPredicate).append(")"); } - rewrittenQueryStr.append("\n sort by "); + rewrittenQueryStr.append("\n SORT BY "); rewrittenQueryStr.append(targetName).append(".ROW__ID \n"); setUpAccessControlInfoForUpdate(targetTable, setColsExprs); @@ -1249,7 +1248,7 @@ private String handleDelete(ASTNode whenMatchedDeleteClause, StringBuilder rewri rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target)); addPartitionColsToInsert(partCols, rewrittenQueryStr); - rewrittenQueryStr.append(" -- delete clause\n select "); + rewrittenQueryStr.append(" -- delete clause\n SELECT "); if (hintStr != null) { rewrittenQueryStr.append(hintStr); } @@ -1264,7 +1263,7 @@ private String handleDelete(ASTNode whenMatchedDeleteClause, StringBuilder rewri if(updateExtraPredicate != null) { rewrittenQueryStr.append(" AND NOT(").append(updateExtraPredicate).append(")"); } - rewrittenQueryStr.append("\n sort by "); + rewrittenQueryStr.append("\n SORT BY "); rewrittenQueryStr.append(targetName).append(".ROW__ID \n"); return extraPredicate; } @@ -1353,7 +1352,7 @@ private ASTNode getWhenClauseOperation(ASTNode whenClause) { */ private String getWhenClausePredicate(ASTNode whenClause) { if(!(whenClause.getType() == HiveParser.TOK_MATCHED || whenClause.getType() == HiveParser.TOK_NOT_MATCHED)) { - throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause); + throw raiseWrongType("Expected TOK_MATCHED|TOK_NOT_MATCHED", whenClause); } if(whenClause.getChildCount() == 2) { return getMatchedText((ASTNode)whenClause.getChild(1)); @@ -1366,33 +1365,58 @@ private String getWhenClausePredicate(ASTNode whenClause) { * @throws SemanticException */ private void handleInsert(ASTNode whenNotMatchedClause, StringBuilder rewrittenQueryStr, ASTNode target, - ASTNode onClause, Table targetTable, - String targetTableNameInSourceQuery, String onClauseAsString, String hintStr) throws SemanticException { + ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery, + String onClauseAsString, String hintStr) throws SemanticException { + ASTNode whenClauseOperation = getWhenClauseOperation(whenNotMatchedClause); assert whenNotMatchedClause.getType() == HiveParser.TOK_NOT_MATCHED; - assert getWhenClauseOperation(whenNotMatchedClause).getType() == HiveParser.TOK_INSERT; - List partCols = targetTable.getPartCols(); - String valuesClause = getMatchedText((ASTNode)getWhenClauseOperation(whenNotMatchedClause).getChild(0)); - valuesClause = valuesClause.substring(1, valuesClause.length() - 1);//strip '(' and ')' - valuesClause = SemanticAnalyzer.replaceDefaultKeywordForMerge(valuesClause, targetTable); + assert whenClauseOperation.getType() == HiveParser.TOK_INSERT; + + ASTNode valuesNode = null; + ASTNode columnListNode = null; + for (Node n : whenClauseOperation.getChildren()) { + if (n instanceof ASTNode) { + ASTNode an = (ASTNode)n; + switch (((ASTNode)n).getType()) { + case HiveParser.TOK_FUNCTION: + valuesNode = an; + break; + case HiveParser.TOK_TABCOLNAME: + columnListNode = an; + break; + default: + break; + } + } + } rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target)); - addPartitionColsToInsert(partCols, rewrittenQueryStr); + if (columnListNode != null) { + rewrittenQueryStr.append(' ').append(getMatchedText(columnListNode)); + } + addPartitionColsToInsert(targetTable.getPartCols(), rewrittenQueryStr); - OnClauseAnalyzer oca = new OnClauseAnalyzer(onClause, targetTable, targetTableNameInSourceQuery, - conf, onClauseAsString); - oca.analyze(); - rewrittenQueryStr.append(" -- insert clause\n select "); + rewrittenQueryStr.append(" -- insert clause\n SELECT "); if (hintStr != null) { rewrittenQueryStr.append(hintStr); } + + OnClauseAnalyzer oca = new OnClauseAnalyzer(onClause, targetTable, targetTableNameInSourceQuery, + conf, onClauseAsString); + oca.analyze(); + + String valuesClause = getMatchedText(valuesNode); + valuesClause = valuesClause.substring(1, valuesClause.length() - 1);//strip '(' and ')' + valuesClause = SemanticAnalyzer.replaceDefaultKeywordForMerge(valuesClause, targetTable); rewrittenQueryStr.append(valuesClause).append("\n WHERE ").append(oca.getPredicate()); + String extraPredicate = getWhenClausePredicate(whenNotMatchedClause); - if(extraPredicate != null) { + if (extraPredicate != null) { //we have WHEN NOT MATCHED AND THEN INSERT rewrittenQueryStr.append(" AND ") .append(getMatchedText(((ASTNode)whenNotMatchedClause.getChild(1)))).append('\n'); } } + /** * Suppose the input Merge statement has ON target.a = source.b and c = d. Assume, that 'c' is from * target table and 'd' is from source expression. In order to properly @@ -1503,7 +1527,7 @@ private String getPredicate() { List targetCols = table2column.get(targetTableNameInSourceQuery.toLowerCase()); if(targetCols == null) { /*e.g. ON source.t=1 - * this is not strictly speaking invlaid but it does ensure that all columns from target + * this is not strictly speaking invalid but it does ensure that all columns from target * table are all NULL for every row. This would make any WHEN MATCHED clause invalid since * we don't have a ROW__ID. The WHEN NOT MATCHED could be meaningful but it's just data from * source satisfying source.t=1... not worth the effort to support this*/ diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFCardinalityViolation.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFCardinalityViolation.java index b688447..7bb0f0b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFCardinalityViolation.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFCardinalityViolation.java @@ -18,29 +18,15 @@ package org.apache.hadoop.hive.ql.udf.generic; -import java.util.ArrayList; - import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.VoidObjectInspector; -import org.apache.logging.log4j.core.layout.StringBuilderEncoder; -/** - * GenericUDFArray. - * - */ @Description(name = "cardinality_violation", value = "_FUNC_(n0, n1...) - raises Cardinality Violation") public class GenericUDFCardinalityViolation extends GenericUDF { - private transient Converter[] converters; - private transient ArrayList ret = new ArrayList(); @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { @@ -50,8 +36,8 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { StringBuilder nonUniqueKey = new StringBuilder(); - for(DeferredObject t : arguments) { - if(nonUniqueKey.length() > 0) {nonUniqueKey.append(','); } + for (DeferredObject t : arguments) { + if (nonUniqueKey.length() > 0) {nonUniqueKey.append(',');} nonUniqueKey.append(t.get()); } throw new RuntimeException("Cardinality Violation in Merge statement: " + nonUniqueKey); diff --git a/ql/src/test/queries/clientpositive/sqlmerge_stats.q b/ql/src/test/queries/clientpositive/sqlmerge_stats.q index c480eb6..f167a70 100644 --- a/ql/src/test/queries/clientpositive/sqlmerge_stats.q +++ b/ql/src/test/queries/clientpositive/sqlmerge_stats.q @@ -29,7 +29,7 @@ desc formatted t; merge into t as t using upd_t as u ON t.a = u.a WHEN MATCHED THEN DELETE -WHEN NOT MATCHED THEN INSERT VALUES(u.a, u.b); +WHEN NOT MATCHED THEN INSERT (a, b) VALUES(u.a, u.b); select assert_true(count(1) = 0) from t group by a>-1; diff --git a/ql/src/test/results/clientpositive/llap/sqlmerge_stats.q.out b/ql/src/test/results/clientpositive/llap/sqlmerge_stats.q.out index 02aa87a..89edd9a 100644 --- a/ql/src/test/results/clientpositive/llap/sqlmerge_stats.q.out +++ b/ql/src/test/results/clientpositive/llap/sqlmerge_stats.q.out @@ -458,7 +458,7 @@ Storage Desc Params: serialization.format 1 PREHOOK: query: merge into t as t using upd_t as u ON t.a = u.a WHEN MATCHED THEN DELETE -WHEN NOT MATCHED THEN INSERT VALUES(u.a, u.b) +WHEN NOT MATCHED THEN INSERT (a, b) VALUES(u.a, u.b) PREHOOK: type: QUERY PREHOOK: Input: default@t PREHOOK: Input: default@upd_t @@ -467,7 +467,7 @@ PREHOOK: Output: default@t PREHOOK: Output: default@t POSTHOOK: query: merge into t as t using upd_t as u ON t.a = u.a WHEN MATCHED THEN DELETE -WHEN NOT MATCHED THEN INSERT VALUES(u.a, u.b) +WHEN NOT MATCHED THEN INSERT (a, b) VALUES(u.a, u.b) POSTHOOK: type: QUERY POSTHOOK: Input: default@t POSTHOOK: Input: default@upd_t