Index: jdbc/src/java/org/apache/hadoop/hive/jdbc/HivePreparedStatement.java =================================================================== --- jdbc/src/java/org/apache/hadoop/hive/jdbc/HivePreparedStatement.java (revision 1101773) +++ jdbc/src/java/org/apache/hadoop/hive/jdbc/HivePreparedStatement.java (working copy) @@ -40,6 +40,7 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.Calendar; +import java.util.HashMap; import org.apache.hadoop.hive.service.HiveInterface; import org.apache.hadoop.hive.service.HiveServerException; @@ -49,9 +50,14 @@ * */ public class HivePreparedStatement implements PreparedStatement { - private String sql; + private final String sql; private HiveInterface client; /** + * save the SQL parameters {paramLoc:paramValue} + */ + private final HashMap parameters=new HashMap(); + + /** * We need to keep a reference to the result set to support the following: * * statement.execute(String sql); @@ -62,12 +68,12 @@ /** * The maximum number of rows this statement should return (0 => all rows). */ - private final int maxRows = 0; + private int maxRows = 0; /** * Add SQLWarnings to the warningChain if needed. */ - private final SQLWarning warningChain = null; + private SQLWarning warningChain = null; /** * Keep state so we can fail certain calls made after close(). @@ -75,6 +81,11 @@ private boolean isClosed = false; /** + * keep the current ResultRet update count + */ + private final int updateCount=0; + + /** * */ public HivePreparedStatement(HiveInterface client, @@ -101,8 +112,7 @@ */ public void clearParameters() throws SQLException { - // TODO Auto-generated method stub - // throw new SQLException("Method not supported"); + this.parameters.clear(); } /** @@ -137,8 +147,8 @@ */ public int executeUpdate() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + executeImmediate(sql); + return updateCount; } /** @@ -159,6 +169,9 @@ try { clearWarnings(); resultSet = null; + if (sql.contains("?")) { + sql = updateSql(sql, parameters); + } client.execute(sql); } catch (HiveServerException e) { throw new SQLException(e.getMessage(), e.getSQLState(), e.getErrorCode()); @@ -169,8 +182,63 @@ return resultSet; } + /** + * update the SQL string with parameters set by setXXX methods of {@link PreparedStatement} + * + * @param sql + * @param parameters + * @return updated SQL string + */ + private String updateSql(final String sql, HashMap parameters) { + StringBuffer newSql = new StringBuffer(sql); + int paramLoc = 1; + while (getCharIndexFromSqlByParamLocation(sql, '?', paramLoc) > 0) { + // check the user has set the needs parameters + if (parameters.containsKey(paramLoc)) { + int tt = getCharIndexFromSqlByParamLocation(newSql.toString(), '?', 1); + newSql.deleteCharAt(tt); + newSql.insert(tt, parameters.get(paramLoc)); + } + paramLoc++; + } + + return newSql.toString(); + + } + + /** + * Get the index of given char from the SQL string by parameter location + *
The -1 will be return, if nothing found + * + * @param sql + * @param cchar + * @param paramLoc + * @return + */ + private int getCharIndexFromSqlByParamLocation(final String sql, final char cchar, final int paramLoc) { + int signalCount = 0; + int charIndex = -1; + int num = 0; + for (int i = 0; i < sql.length(); i++) { + char c = sql.charAt(i); + if (c == '\'' || c == '\\')// record the count of char "'" and char "\" + { + signalCount++; + } else if (c == cchar && signalCount % 2 == 0) {// check if the ? is really the parameter + num++; + if (num == paramLoc) { + charIndex = i; + break; + } + } + } + return charIndex; + } + + + /* * (non-Javadoc) * @@ -326,8 +394,7 @@ */ public void setBoolean(int parameterIndex, boolean x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex, ""+x); } /* @@ -337,8 +404,7 @@ */ public void setByte(int parameterIndex, byte x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex, ""+x); } /* @@ -452,8 +518,7 @@ */ public void setDouble(int parameterIndex, double x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex,""+x); } /* @@ -463,8 +528,7 @@ */ public void setFloat(int parameterIndex, float x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex,""+x); } /* @@ -474,8 +538,7 @@ */ public void setInt(int parameterIndex, int x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex,""+x); } /* @@ -485,8 +548,7 @@ */ public void setLong(int parameterIndex, long x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex,""+x); } /* @@ -654,8 +716,7 @@ */ public void setShort(int parameterIndex, short x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + this.parameters.put(parameterIndex,""+x); } /* @@ -665,8 +726,8 @@ */ public void setString(int parameterIndex, String x) throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + x=x.replace("'", "\\'"); + this.parameters.put(parameterIndex,"'"+x+"'"); } /* @@ -780,8 +841,7 @@ */ public void clearWarnings() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + warningChain=null; } /** @@ -971,8 +1031,7 @@ */ public int getMaxRows() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + return this.maxRows; } /* @@ -1015,8 +1074,7 @@ */ public ResultSet getResultSet() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + return this.resultSet; } /* @@ -1059,8 +1117,7 @@ */ public int getUpdateCount() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + return updateCount; } /* @@ -1070,8 +1127,7 @@ */ public SQLWarning getWarnings() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + return warningChain; } /* @@ -1081,8 +1137,7 @@ */ public boolean isClosed() throws SQLException { - // TODO Auto-generated method stub - throw new SQLException("Method not supported"); + return isClosed; } /* @@ -1158,8 +1213,10 @@ */ public void setMaxRows(int max) throws SQLException { - // TODO Auto-generated method stub - // throw new SQLException("Method not supported"); + if (max < 0) { + throw new SQLException("max must be >= 0"); + } + this.maxRows = max; } /* Index: jdbc/src/java/org/apache/hadoop/hive/jdbc/HiveStatement.java =================================================================== --- jdbc/src/java/org/apache/hadoop/hive/jdbc/HiveStatement.java (revision 1101773) +++ jdbc/src/java/org/apache/hadoop/hive/jdbc/HiveStatement.java (working copy) @@ -55,6 +55,11 @@ * Keep state so we can fail certain calls made after close(). */ private boolean isClosed = false; + + /** + * keep the current ResultRet update count + */ + private final int updateCount=0; /** * @@ -204,7 +209,7 @@ } catch (Exception ex) { throw new SQLException(ex.toString()); } - throw new SQLException("Method not supported"); + return updateCount; } /* @@ -374,7 +379,7 @@ */ public int getUpdateCount() throws SQLException { - return 0; + return updateCount; } /* Index: jdbc/src/test/org/apache/hadoop/hive/jdbc/TestJdbcDriver.java =================================================================== --- jdbc/src/test/org/apache/hadoop/hive/jdbc/TestJdbcDriver.java (revision 1101774) +++ jdbc/src/test/org/apache/hadoop/hive/jdbc/TestJdbcDriver.java (working copy) @@ -18,14 +18,11 @@ package org.apache.hadoop.hive.jdbc; -import junit.framework.TestCase; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.conf.HiveConf; - import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.DriverManager; import java.sql.DriverPropertyInfo; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; @@ -37,6 +34,11 @@ import java.util.Map; import java.util.Set; +import junit.framework.TestCase; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.conf.HiveConf; + /** * TestJdbcDriver. * @@ -68,6 +70,7 @@ .getProperty("test.service.standalone.server")); } + @Override protected void setUp() throws Exception { super.setUp(); Class.forName(driverName); @@ -160,6 +163,7 @@ assertFalse(res.next()); } + @Override protected void tearDown() throws Exception { super.tearDown(); @@ -188,6 +192,115 @@ expectedException); } + public void testPrepareStatement() { + + String sql = "from (select count(1) from " + + tableName + + " where 'not?param?not?param' <> 'not_param??not_param' and ?=? " + + " and 1=? and 2=? and 3.0=? and 4.0=? and 'test\\'string\"'=? and 5=? and ?=? " + + " ) t select '2011-03-25' ddate,'China',true bv, 10 num limit 10"; + + /////////////////////////////////////////////// + //////////////////// correct testcase + ////////////////////////////////////////////// + try { + PreparedStatement ps = con.prepareStatement(sql); + + ps.setBoolean(1, true); + ps.setBoolean(2, true); + + ps.setShort(3, Short.valueOf("1")); + ps.setInt(4, 2); + ps.setFloat(5, 3f); + ps.setDouble(6, Double.valueOf(4)); + ps.setString(7, "test'string\""); + ps.setLong(8, 5L); + ps.setByte(9, (byte) 1); + ps.setByte(10, (byte) 1); + + ps.setMaxRows(2); + + assertTrue(true); + + ResultSet res = ps.executeQuery(); + assertNotNull(res); + + while (res.next()) { + assertEquals("2011-03-25", res.getString("ddate")); + assertEquals("10", res.getString("num")); + assertEquals((byte) 10, res.getByte("num")); + assertEquals("2011-03-25", res.getDate("ddate").toString()); + assertEquals(Double.valueOf(10).doubleValue(), res.getDouble("num"), 0.1); + assertEquals(10, res.getInt("num")); + assertEquals(Short.valueOf("10").shortValue(), res.getShort("num")); + assertEquals(10L, res.getLong("num")); + assertEquals(true, res.getBoolean("bv")); + Object o = res.getObject("ddate"); + assertNotNull(o); + o = res.getObject("num"); + assertNotNull(o); + } + res.close(); + assertTrue(true); + + ps.close(); + assertTrue(true); + + } catch (Exception e) { + e.printStackTrace(); + fail(e.toString()); + } + + /////////////////////////////////////////////// + //////////////////// other failure testcases + ////////////////////////////////////////////// + // set nothing for prepared sql + Exception expectedException = null; + try { + PreparedStatement ps = con.prepareStatement(sql); + ps.executeQuery(); + } catch (Exception e) { + expectedException = e; + } + assertNotNull( + "Execute the un-setted sql statement should throw exception", + expectedException); + + // set some of parameters for prepared sql, not all of them. + expectedException = null; + try { + PreparedStatement ps = con.prepareStatement(sql); + ps.setBoolean(1, true); + ps.setBoolean(2, true); + ps.executeQuery(); + } catch (Exception e) { + expectedException = e; + } + assertNotNull( + "Execute the invalid setted sql statement should throw exception", + expectedException); + + // set the wrong type parameters for prepared sql. + expectedException = null; + try { + PreparedStatement ps = con.prepareStatement(sql); + + // wrong type here + ps.setString(1, "wrong"); + + assertTrue(true); + ResultSet res = ps.executeQuery(); + if (!res.next()) { + throw new Exception("there must be a empty result set"); + } + } catch (Exception e) { + expectedException = e; + } + assertNotNull( + "Execute the invalid setted sql statement should throw exception", + expectedException); + } + public final void testSelectAll() throws Exception { doTestSelectAll(tableName, -1); // tests not setting maxRows (return all) doTestSelectAll(tableName, 0); // tests setting maxRows to 0 (return all) @@ -702,4 +815,6 @@ assertEquals("Invalid DriverPropertyInfo value", value, dpi.value); assertEquals("Invalid DriverPropertyInfo required", false, dpi.required); } + + }