diff --git a/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java b/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java index 68d2ddc..28fa7a5 100644 --- a/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java +++ b/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java @@ -156,8 +156,17 @@ public static void afterTest() throws Exception { } private void createTestTable(String tableName) throws Exception { + createTestTable(null, tableName); + } + + private void createTestTable(String database, String tableName) throws Exception { Statement stmt = hs2Conn.createStatement(); + if (database != null) { + stmt.execute("CREATE DATABASE IF NOT EXISTS " + database); + stmt.execute("USE " + database); + } + // create table stmt.execute("DROP TABLE IF EXISTS " + tableName); stmt.execute("CREATE TABLE " + tableName @@ -225,12 +234,12 @@ public void testLlapInputFormatEndToEnd() throws Exception { @Test(timeout = 60000) public void testNonAsciiStrings() throws Exception { - createTestTable("testtab1"); + createTestTable("nonascii", "testtab_nonascii"); RowCollector rowCollector = new RowCollector(); String nonAscii = "À côté du garçon"; - String query = "select value, '" + nonAscii + "' from testtab1 where under_col=0"; - int rowCount = processQuery(query, 1, rowCollector); + String query = "select value, '" + nonAscii + "' from testtab_nonascii where under_col=0"; + int rowCount = processQuery("nonascii", query, 1, rowCollector); assertEquals(3, rowCount); assertArrayEquals(new String[] {"val_0", nonAscii}, rowCollector.rows.get(0)); @@ -474,6 +483,10 @@ public void process(Row row) { } private int processQuery(String query, int numSplits, RowProcessor rowProcessor) throws Exception { + return processQuery(null, query, numSplits, rowProcessor); + } + + private int processQuery(String currentDatabase, String query, int numSplits, RowProcessor rowProcessor) throws Exception { String url = miniHS2.getJdbcURL(); String user = System.getProperty("user.name"); String pwd = user; @@ -488,6 +501,9 @@ private int processQuery(String query, int numSplits, RowProcessor rowProcessor) job.set(LlapBaseInputFormat.PWD_KEY, pwd); job.set(LlapBaseInputFormat.QUERY_KEY, query); job.set(LlapBaseInputFormat.HANDLE_ID, handleId); + if (currentDatabase != null) { + job.set(LlapBaseInputFormat.DB_KEY, currentDatabase); + } InputSplit[] splits = inputFormat.getSplits(job, numSplits); assertTrue(splits.length > 0); diff --git a/llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java b/llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java index de9a031..c53e6cb 100644 --- a/llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java +++ b/llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java @@ -109,6 +109,7 @@ public static final String USER_KEY = "llap.if.user"; public static final String PWD_KEY = "llap.if.pwd"; public static final String HANDLE_ID = "llap.if.handleid"; + public static final String DB_KEY = "llap.if.database"; public final String SPLIT_QUERY = "select get_splits(\"%s\",%d)"; public static final LlapServiceInstance[] serviceInstanceArray = new LlapServiceInstance[0]; @@ -206,6 +207,7 @@ public LlapBaseInputFormat() {} if (query == null) query = job.get(QUERY_KEY); if (user == null) user = job.get(USER_KEY); if (pwd == null) pwd = job.get(PWD_KEY); + String database = job.get(DB_KEY); if (url == null || query == null) { throw new IllegalStateException(); @@ -235,6 +237,9 @@ public LlapBaseInputFormat() {} try ( Statement stmt = conn.createStatement(); ) { + if (database != null && !database.isEmpty()) { + stmt.execute("USE " + database); + } ResultSet res = stmt.executeQuery(sql); while (res.next()) { // deserialize split