diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java index 55a3735..1c742b2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/UpdateDeleteSemanticAnalyzer.java @@ -27,14 +27,13 @@ import java.util.Map; import java.util.Set; +import org.apache.commons.lang.StringEscapeUtils; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.conf.HiveConfUtil; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.ErrorMsg; import org.apache.hadoop.hive.ql.QueryState; -import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.hooks.Entity; import org.apache.hadoop.hive.ql.hooks.ReadEntity; import org.apache.hadoop.hive.ql.hooks.WriteEntity; @@ -979,14 +978,34 @@ private void handleInsert(ASTNode whenNotMatchedClause, StringBuilder rewrittenQ assert whenNotMatchedClause.getType() == HiveParser.TOK_NOT_MATCHED; assert getWhenClauseOperation(whenNotMatchedClause).getType() == HiveParser.TOK_INSERT; List partCols = targetTable.getPartCols(); - + //getMatchedText() looses `` but retains ''. Any \\t turns to \t String valuesClause = ((ASTNode)getWhenClauseOperation(whenNotMatchedClause).getChild(0)) .getMatchedText(); valuesClause = valuesClause.substring(1, valuesClause.length() - 1); + //valuesClause = escapeSQLString(valuesClause); + valuesClause = StringEscapeUtils.escapeJava(valuesClause); + //StringEscapeUtils.escapeSql(valuesClause); + rewrittenQueryStr.append("INSERT INTO ").append(getFullTableNameForSQL(target)); addPartitionColsToInsert(partCols, rewrittenQueryStr); - OnClauseAnalyzer oca = new OnClauseAnalyzer(onClause, targetTable, targetTableNameInSourceQuery); + /* + TODO: either scan original command (ctx.getCmd() for `` and make a dictionary and then + process the getMatchedText to add `` based on dictionary or walk the ASTNode lookig for + HiveParser.Identifier and build the dictionary and then insert `` + HiveParser.StringLiteral HiveParser.QuotedIdentifier + BaseSemanticAnalyzer.escapeSQLString()/unescapeSQLString()/unescapeIdentifier() + HiveUtils.unparseIdentifier() + SemanticAnalyzer.unparseExprForValuesClause() + UnparseTranslator.addIdentifierTranslation()/SubQueryDiagnostic.java + + If we build a dictionary of StringLiteral and QuotedIdentifier objects and replace + every QuotedIdentifier with `QuotedIdentifier` unless the position of the matched + string is inside a StringLiteral. Who's to say that a StringLiteral cannot be inside a + QuotedIdentifier? Presumably ASTNode has some markers... + */ + OnClauseAnalyzer oca = + new OnClauseAnalyzer(onClause, targetTable, targetTableNameInSourceQuery, conf); oca.analyze(); rewrittenQueryStr.append("\n select ") .append(valuesClause).append("\n WHERE ").append(oca.getPredicate()); @@ -1008,7 +1027,7 @@ private void handleInsert(ASTNode whenNotMatchedClause, StringBuilder rewrittenQ * we know that target is always a table (as opposed to some derived table). * The job of this class is to generate this predicate. * - * Note that is thi predicate cannot simply be NOT(on-clause-expr). IF on-clause-expr evaluates + * Note that is this predicate cannot simply be NOT(on-clause-expr). IF on-clause-expr evaluates * to Unknown, it will be treated as False in the WHEN MATCHED Inserts but NOT(Unknown) = Unknown, * and so it will be False for WHEN NOT MATCHED Insert... */ @@ -1019,14 +1038,16 @@ private void handleInsert(ASTNode whenNotMatchedClause, StringBuilder rewrittenQ private final List allTargetTableColumns = new ArrayList<>(); private final Set tableNamesFound = new HashSet<>(); private final String targetTableNameInSourceQuery; + private final HiveConf conf; /** * @param targetTableNameInSourceQuery alias or simple name */ - OnClauseAnalyzer(ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery) { + OnClauseAnalyzer(ASTNode onClause, Table targetTable, String targetTableNameInSourceQuery, HiveConf conf) { this.onClause = onClause; allTargetTableColumns.addAll(targetTable.getCols()); allTargetTableColumns.addAll(targetTable.getPartCols()); this.targetTableNameInSourceQuery = unescapeIdentifier(targetTableNameInSourceQuery); + this.conf = conf; } /** * finds all columns and groups by table ref (if there is one) @@ -1056,7 +1077,6 @@ private void visit(ASTNode n) { } private void analyze() { visit(onClause); - int numTableRefs = tableNamesFound.size(); if(tableNamesFound.size() > 2) { throw new IllegalArgumentException("Found > 2 table refs in ON clause. Found " + tableNamesFound + " in " + onClause.getMatchedText()); @@ -1107,7 +1127,8 @@ private String getPredicate() { sb.append(" AND "); } //but preserve table name in SQL - sb.append(targetTableNameInSourceQuery).append(".").append(col).append(" IS NULL"); + sb.append(HiveUtils.unparseIdentifier(targetTableNameInSourceQuery, conf)).append(".") + .append(HiveUtils.unparseIdentifier(col, conf)).append(" IS NULL"); } return sb.toString(); } diff --git ql/src/test/org/apache/hadoop/hive/ql/TestTxnCommands.java ql/src/test/org/apache/hadoop/hive/ql/TestTxnCommands.java index 68af15a..cb8aba9 100644 --- ql/src/test/org/apache/hadoop/hive/ql/TestTxnCommands.java +++ ql/src/test/org/apache/hadoop/hive/ql/TestTxnCommands.java @@ -599,7 +599,14 @@ public void testMergeNegative2() throws Exception { "\nWHEN MATCHED THEN UPDATE set b=a"); Assert.assertEquals(ErrorMsg.MERGE_TOO_MANY_UPDATE, ((HiveException)cpr.getException()).getCanonicalErrorMsg()); } - @Ignore + + /** + * ok, so `1` means 1 is a column name and '1' means 1 is a string literal + * HiveConf.HIVE_QUOTEDID_SUPPORT + * HiveConf.HIVE_SUPPORT_SPECICAL_CHARACTERS_IN_TABLE_NAMES + * @throws Exception + * todo: try Unicode chars + */ @Test public void testSpecialChar() throws Exception { String target = "`aci/d_u/ami`"; @@ -609,10 +616,33 @@ public void testSpecialChar() throws Exception { runStatementOnDriver("create table " + target + "(i int," + "`d?*de e` decimal(5,2)," + "vc varchar(128)) clustered by (i) into 2 buckets stored as orc TBLPROPERTIES ('transactional'='true')"); + runStatementOnDriver("create table " + src + "(gh int, j decimal(5,2), k varchar(128))"); + runStatementOnDriver("merge into " + target + " as `d/8` using " + src + " as `a/b` on i=gh " + + "\nwhen matched and i > 5 then delete " + + "\nwhen matched then update set vc='blah' " + + "\nwhen not matched then insert values(1,2.1,'baz')"); + runStatementOnDriver("merge into " + target + " as `d/8` using " + src + " as `a/b` on i=gh " + + "\nwhen matched and i > 5 then delete " + + "\nwhen matched then update set vc='blah' " + + "\nwhen not matched then insert values(1,2.1,'a\\b')"); + runStatementOnDriver("merge into " + target + " as `d/8` using " + src + " as `a/b` on i=gh " + + "\nwhen matched and i > 5 then delete " + + "\nwhen matched then update set vc='∆∋'" + + "\nwhen not matched then insert values(`a/b`.gh,`a/b`.j,'c\\t')"); + } + @Test + public void testSpecialChar2() throws Exception { + String target = "`aci/d_u/ami`"; + String src = "`src/name`"; + runStatementOnDriver("drop table if exists " + target); + runStatementOnDriver("drop table if exists " + src); + runStatementOnDriver("create table " + target + "(i int," + + "`d?*de e` decimal(5,2)," + + "vc varchar(128)) clustered by (i) into 2 buckets stored as orc TBLPROPERTIES ('transactional'='true')"); runStatementOnDriver("create table " + src + "(`g/h` int, j decimal(5,2), k varchar(128))"); runStatementOnDriver("merge into " + target + " as `d/8` using " + src + " as `a/b` on i=`g/h` " + "\nwhen matched and i > 5 then delete " + - "\nwhen matched then update set vc=`∆∋` " + + "\nwhen matched then update set vc='∆∋' " +// the quoted literal looses quotes "\nwhen not matched then insert values(`a/b`.`g/h`,`a/b`.j,`a/b`.k)"); } }