diff --git jdbc/pom.xml jdbc/pom.xml
index f87ab59..decf8b2 100644
--- jdbc/pom.xml
+++ jdbc/pom.xml
@@ -116,6 +116,11 @@
${junit.version}
test
+
+ org.mockito
+ mockito-all
+ test
+
diff --git jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java
index 705a32a..69c1587 100644
--- jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java
+++ jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java
@@ -39,8 +39,10 @@
import java.sql.Timestamp;
import java.sql.Types;
import java.text.MessageFormat;
+import java.util.ArrayList;
import java.util.Calendar;
import java.util.HashMap;
+import java.util.List;
import java.util.Scanner;
import org.apache.hive.service.rpc.thrift.TCLIService;
@@ -126,60 +128,67 @@ public int executeUpdate() throws SQLException {
* @param sql
* @param parameters
* @return updated SQL string
- */
- private String updateSql(final String sql, HashMap parameters) {
- if (!sql.contains("?")) {
- return sql;
- }
-
- StringBuilder newSql = new StringBuilder(sql);
-
- int paramLoc = 1;
- while (getCharIndexFromSqlByParamLocation(sql, '?', paramLoc) > 0) {
- // check the user has set the needs parameters
- if (parameters.containsKey(paramLoc)) {
- int tt = getCharIndexFromSqlByParamLocation(newSql.toString(), '?', 1);
- newSql.deleteCharAt(tt);
- newSql.insert(tt, parameters.get(paramLoc));
+ * @throws SQLException
+ */
+ private String updateSql(final String sql, HashMap parameters) throws SQLException {
+ List parts=splitSqlStatement(sql);
+
+ StringBuilder newSql = new StringBuilder(parts.get(0));
+ for(int i=1;i The -1 will be return, if nothing found
- *
+ * Splits the parametered sql statement at parameter boundaries.
+ *
+ * taking into account ' and \ escaping.
+ *
+ * output for: 'select 1 from ? where a = ?'
+ * ['select 1 from ',' where a = ','']
+ *
* @param sql
- * @param cchar
- * @param paramLoc
* @return
*/
- private int getCharIndexFromSqlByParamLocation(final String sql, final char cchar, final int paramLoc) {
- int signalCount = 0;
- int charIndex = -1;
- int num = 0;
+ private List splitSqlStatement(String sql) {
+ List parts=new ArrayList<>();
+ int apCount=0;
+ int off=0;
+ boolean skip=false;
+
for (int i = 0; i < sql.length(); i++) {
char c = sql.charAt(i);
- if (c == '\'' || c == '\\')// record the count of char "'" and char "\"
- {
- signalCount++;
- } else if (c == cchar && signalCount % 2 == 0) {// check if the ? is really the parameter
- num++;
- if (num == paramLoc) {
- charIndex = i;
- break;
+ if(skip){
+ skip=false;
+ continue;
+ }
+ switch (c) {
+ case '\'':
+ apCount++;
+ break;
+ case '\\':
+ skip = true;
+ break;
+ case '?':
+ if ((apCount & 1) == 0) {
+ parts.add(sql.substring(off,i));
+ off=i+1;
}
+ break;
+ default:
+ break;
}
}
- return charIndex;
+ parts.add(sql.substring(off,sql.length()));
+ return parts;
}
-
-
/*
* (non-Javadoc)
*
diff --git jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java
new file mode 100644
index 0000000..bc49aeb
--- /dev/null
+++ jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java
@@ -0,0 +1,138 @@
+package org.apache.hive.jdbc;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.sql.SQLException;
+
+import org.apache.hive.service.rpc.thrift.TCLIService.Iface;
+import org.apache.hive.service.rpc.thrift.TExecuteStatementReq;
+import org.apache.hive.service.rpc.thrift.TExecuteStatementResp;
+import org.apache.hive.service.rpc.thrift.TGetOperationStatusReq;
+import org.apache.hive.service.rpc.thrift.TGetOperationStatusResp;
+import org.apache.hive.service.rpc.thrift.TOperationHandle;
+import org.apache.hive.service.rpc.thrift.TOperationState;
+import org.apache.hive.service.rpc.thrift.TSessionHandle;
+import org.apache.hive.service.rpc.thrift.TStatus;
+import org.apache.hive.service.rpc.thrift.TStatusCode;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+public class TestHivePreparedStatement {
+
+ @Mock
+ private HiveConnection connection;
+ @Mock
+ private Iface client;
+ @Mock
+ private TSessionHandle sessHandle;
+ @Mock
+ TExecuteStatementResp tExecStatementResp;
+ @Mock
+ TGetOperationStatusResp tGetOperationStatusResp;
+ private TStatus tStatus_SUCCESS = new TStatus(TStatusCode.SUCCESS_STATUS);
+ @Mock
+ private TOperationHandle tOperationHandle;
+
+ @Before
+ public void before() throws Exception {
+ MockitoAnnotations.initMocks(this);
+ when(tExecStatementResp.getStatus()).thenReturn(tStatus_SUCCESS);
+ when(tExecStatementResp.getOperationHandle()).thenReturn(tOperationHandle);
+
+ when(tGetOperationStatusResp.getStatus()).thenReturn(tStatus_SUCCESS);
+ when(tGetOperationStatusResp.getOperationState()).thenReturn(TOperationState.FINISHED_STATE);
+ when(tGetOperationStatusResp.isSetOperationState()).thenReturn(true);
+ when(tGetOperationStatusResp.isSetOperationCompleted()).thenReturn(true);
+
+ when(client.GetOperationStatus(any(TGetOperationStatusReq.class))).thenReturn(tGetOperationStatusResp);
+ when(client.ExecuteStatement(any(TExecuteStatementReq.class))).thenReturn(tExecStatementResp);
+ }
+
+ @SuppressWarnings("resource")
+ @Test
+ public void testNonParameterized() throws Exception {
+ String sql = "select 1";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.execute();
+
+ ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class);
+ verify(client).ExecuteStatement(argument.capture());
+ assertEquals("select 1", argument.getValue().getStatement());
+ }
+
+ @SuppressWarnings("resource")
+ @Test
+ public void unusedArgument() throws Exception {
+ String sql = "select 1";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.setString(1, "asd");
+ ps.execute();
+ }
+
+ @SuppressWarnings("resource")
+ @Test(expected=SQLException.class)
+ public void unsetArgument() throws Exception {
+ String sql = "select 1 from x where a=?";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.execute();
+ }
+
+ @SuppressWarnings("resource")
+ @Test
+ public void oneArgument() throws Exception {
+ String sql = "select 1 from x where a=?";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.setString(1, "asd");
+ ps.execute();
+
+ ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class);
+ verify(client).ExecuteStatement(argument.capture());
+ assertEquals("select 1 from x where a='asd'", argument.getValue().getStatement());
+ }
+
+ @SuppressWarnings("resource")
+ @Test
+ public void escapingOfStringArgument() throws Exception {
+ String sql = "select 1 from x where a=?";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.setString(1, "a'\"d");
+ ps.execute();
+
+ ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class);
+ verify(client).ExecuteStatement(argument.capture());
+ assertEquals("select 1 from x where a='a\\'\"d'", argument.getValue().getStatement());
+ }
+
+ @SuppressWarnings("resource")
+ @Test
+ public void pastingIntoQuery() throws Exception {
+ String sql = "select 1 from x where a='e' || ?";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.setString(1, "v");
+ ps.execute();
+
+ ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class);
+ verify(client).ExecuteStatement(argument.capture());
+ assertEquals("select 1 from x where a='e' || 'v'", argument.getValue().getStatement());
+ }
+
+ // HIVE-13625
+ @SuppressWarnings("resource")
+ @Test
+ public void pastingIntoEscapedQuery() throws Exception {
+ String sql = "select 1 from x where a='\\044e' || ?";
+ HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql);
+ ps.setString(1, "v");
+ ps.execute();
+
+ ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class);
+ verify(client).ExecuteStatement(argument.capture());
+ assertEquals("select 1 from x where a='\\044e' || 'v'", argument.getValue().getStatement());
+ }
+}