diff --git jdbc/pom.xml jdbc/pom.xml index f87ab59..decf8b2 100644 --- jdbc/pom.xml +++ jdbc/pom.xml @@ -116,6 +116,11 @@ ${junit.version} test + + org.mockito + mockito-all + test + diff --git jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java index 705a32a..69c1587 100644 --- jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java +++ jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java @@ -39,8 +39,10 @@ import java.sql.Timestamp; import java.sql.Types; import java.text.MessageFormat; +import java.util.ArrayList; import java.util.Calendar; import java.util.HashMap; +import java.util.List; import java.util.Scanner; import org.apache.hive.service.rpc.thrift.TCLIService; @@ -126,60 +128,67 @@ public int executeUpdate() throws SQLException { * @param sql * @param parameters * @return updated SQL string - */ - private String updateSql(final String sql, HashMap parameters) { - if (!sql.contains("?")) { - return sql; - } - - StringBuilder newSql = new StringBuilder(sql); - - int paramLoc = 1; - while (getCharIndexFromSqlByParamLocation(sql, '?', paramLoc) > 0) { - // check the user has set the needs parameters - if (parameters.containsKey(paramLoc)) { - int tt = getCharIndexFromSqlByParamLocation(newSql.toString(), '?', 1); - newSql.deleteCharAt(tt); - newSql.insert(tt, parameters.get(paramLoc)); + * @throws SQLException + */ + private String updateSql(final String sql, HashMap parameters) throws SQLException { + List parts=splitSqlStatement(sql); + + StringBuilder newSql = new StringBuilder(parts.get(0)); + for(int i=1;i The -1 will be return, if nothing found - * + * Splits the parametered sql statement at parameter boundaries. + * + * taking into account ' and \ escaping. + * + * output for: 'select 1 from ? where a = ?' + * ['select 1 from ',' where a = ',''] + * * @param sql - * @param cchar - * @param paramLoc * @return */ - private int getCharIndexFromSqlByParamLocation(final String sql, final char cchar, final int paramLoc) { - int signalCount = 0; - int charIndex = -1; - int num = 0; + private List splitSqlStatement(String sql) { + List parts=new ArrayList<>(); + int apCount=0; + int off=0; + boolean skip=false; + for (int i = 0; i < sql.length(); i++) { char c = sql.charAt(i); - if (c == '\'' || c == '\\')// record the count of char "'" and char "\" - { - signalCount++; - } else if (c == cchar && signalCount % 2 == 0) {// check if the ? is really the parameter - num++; - if (num == paramLoc) { - charIndex = i; - break; + if(skip){ + skip=false; + continue; + } + switch (c) { + case '\'': + apCount++; + break; + case '\\': + skip = true; + break; + case '?': + if ((apCount & 1) == 0) { + parts.add(sql.substring(off,i)); + off=i+1; } + break; + default: + break; } } - return charIndex; + parts.add(sql.substring(off,sql.length())); + return parts; } - - /* * (non-Javadoc) * diff --git jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java new file mode 100644 index 0000000..bc49aeb --- /dev/null +++ jdbc/src/test/org/apache/hive/jdbc/TestHivePreparedStatement.java @@ -0,0 +1,138 @@ +package org.apache.hive.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.SQLException; + +import org.apache.hive.service.rpc.thrift.TCLIService.Iface; +import org.apache.hive.service.rpc.thrift.TExecuteStatementReq; +import org.apache.hive.service.rpc.thrift.TExecuteStatementResp; +import org.apache.hive.service.rpc.thrift.TGetOperationStatusReq; +import org.apache.hive.service.rpc.thrift.TGetOperationStatusResp; +import org.apache.hive.service.rpc.thrift.TOperationHandle; +import org.apache.hive.service.rpc.thrift.TOperationState; +import org.apache.hive.service.rpc.thrift.TSessionHandle; +import org.apache.hive.service.rpc.thrift.TStatus; +import org.apache.hive.service.rpc.thrift.TStatusCode; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +public class TestHivePreparedStatement { + + @Mock + private HiveConnection connection; + @Mock + private Iface client; + @Mock + private TSessionHandle sessHandle; + @Mock + TExecuteStatementResp tExecStatementResp; + @Mock + TGetOperationStatusResp tGetOperationStatusResp; + private TStatus tStatus_SUCCESS = new TStatus(TStatusCode.SUCCESS_STATUS); + @Mock + private TOperationHandle tOperationHandle; + + @Before + public void before() throws Exception { + MockitoAnnotations.initMocks(this); + when(tExecStatementResp.getStatus()).thenReturn(tStatus_SUCCESS); + when(tExecStatementResp.getOperationHandle()).thenReturn(tOperationHandle); + + when(tGetOperationStatusResp.getStatus()).thenReturn(tStatus_SUCCESS); + when(tGetOperationStatusResp.getOperationState()).thenReturn(TOperationState.FINISHED_STATE); + when(tGetOperationStatusResp.isSetOperationState()).thenReturn(true); + when(tGetOperationStatusResp.isSetOperationCompleted()).thenReturn(true); + + when(client.GetOperationStatus(any(TGetOperationStatusReq.class))).thenReturn(tGetOperationStatusResp); + when(client.ExecuteStatement(any(TExecuteStatementReq.class))).thenReturn(tExecStatementResp); + } + + @SuppressWarnings("resource") + @Test + public void testNonParameterized() throws Exception { + String sql = "select 1"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.execute(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(client).ExecuteStatement(argument.capture()); + assertEquals("select 1", argument.getValue().getStatement()); + } + + @SuppressWarnings("resource") + @Test + public void unusedArgument() throws Exception { + String sql = "select 1"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.setString(1, "asd"); + ps.execute(); + } + + @SuppressWarnings("resource") + @Test(expected=SQLException.class) + public void unsetArgument() throws Exception { + String sql = "select 1 from x where a=?"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.execute(); + } + + @SuppressWarnings("resource") + @Test + public void oneArgument() throws Exception { + String sql = "select 1 from x where a=?"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.setString(1, "asd"); + ps.execute(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(client).ExecuteStatement(argument.capture()); + assertEquals("select 1 from x where a='asd'", argument.getValue().getStatement()); + } + + @SuppressWarnings("resource") + @Test + public void escapingOfStringArgument() throws Exception { + String sql = "select 1 from x where a=?"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.setString(1, "a'\"d"); + ps.execute(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(client).ExecuteStatement(argument.capture()); + assertEquals("select 1 from x where a='a\\'\"d'", argument.getValue().getStatement()); + } + + @SuppressWarnings("resource") + @Test + public void pastingIntoQuery() throws Exception { + String sql = "select 1 from x where a='e' || ?"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.setString(1, "v"); + ps.execute(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(client).ExecuteStatement(argument.capture()); + assertEquals("select 1 from x where a='e' || 'v'", argument.getValue().getStatement()); + } + + // HIVE-13625 + @SuppressWarnings("resource") + @Test + public void pastingIntoEscapedQuery() throws Exception { + String sql = "select 1 from x where a='\\044e' || ?"; + HivePreparedStatement ps = new HivePreparedStatement(connection, client, sessHandle, sql); + ps.setString(1, "v"); + ps.execute(); + + ArgumentCaptor argument = ArgumentCaptor.forClass(TExecuteStatementReq.class); + verify(client).ExecuteStatement(argument.capture()); + assertEquals("select 1 from x where a='\\044e' || 'v'", argument.getValue().getStatement()); + } +}