diff --git a/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcDriver2.java b/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcDriver2.java index f6f64ee..4e8c5bf 100644 --- a/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcDriver2.java +++ b/itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcDriver2.java @@ -46,6 +46,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; import java.io.InputStream; import java.lang.Exception; import java.lang.Object; @@ -492,6 +493,25 @@ public void testPrepareStatement() { expectedException); } + @Test + public void testPrepareStatementWithSetBinaryStream() throws SQLException { + PreparedStatement stmt = con.prepareStatement("select under_col from " + tableName + " where value=?"); + stmt.setBinaryStream(1, new ByteArrayInputStream("'val_238' or under_col <> 0".getBytes())); + ResultSet res = stmt.executeQuery(); + assertFalse(res.next()); + } + + @Test + public void testPrepareStatementWithSetString() throws SQLException { + PreparedStatement stmt = con.prepareStatement("select under_col from " + tableName + " where value=?"); + stmt.setString(1, "val_238\\' or under_col <> 0 --"); + ResultSet res = stmt.executeQuery(); + assertFalse(res.next()); + stmt.setString(1, "anyStringHere\\' or 1=1 --"); + res = stmt.executeQuery(); + assertFalse(res.next()); + } + private PreparedStatement createPreapredStatementUsingSetObject(String sql) throws SQLException { PreparedStatement ps = con.prepareStatement(sql); diff --git a/jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java b/jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java index 4bb7398..77a1797 100644 --- a/jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java +++ b/jdbc/src/java/org/apache/hive/jdbc/HivePreparedStatement.java @@ -276,7 +276,7 @@ public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { String str = new Scanner(x, "UTF-8").useDelimiter("\\A").next(); - this.parameters.put(parameterIndex, str); + setString(parameterIndex, str); } /* @@ -696,6 +696,27 @@ public void setShort(int parameterIndex, short x) throws SQLException { this.parameters.put(parameterIndex,""+x); } + private String replaceBackSlashSingleQuote(String x) { + // scrutinize escape pair, specifically, replace \' to ' + StringBuffer newX = new StringBuffer(); + for (int i = 0; i < x.length(); i++) { + char c = x.charAt(i); + if (c == '\\' && i < x.length()-1) { + char c1 = x.charAt(i+1); + if (c1 == '\'') { + newX.append(c1); + } else { + newX.append(c); + newX.append(c1); + } + i++; + } else { + newX.append(c); + } + } + return newX.toString(); + } + /* * (non-Javadoc) * @@ -703,8 +724,9 @@ public void setShort(int parameterIndex, short x) throws SQLException { */ public void setString(int parameterIndex, String x) throws SQLException { - x=x.replace("'", "\\'"); - this.parameters.put(parameterIndex,"'"+x+"'"); + x = replaceBackSlashSingleQuote(x); + x=x.replace("'", "\\'"); + this.parameters.put(parameterIndex, "'"+x+"'"); } /*