diff --git a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/DBRecordWritable.java b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/DBRecordWritable.java new file mode 100644 index 0000000000..5af921d22c --- /dev/null +++ b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/DBRecordWritable.java @@ -0,0 +1,71 @@ +/* + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.storage.jdbc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import org.apache.hadoop.io.Writable; + +public class DBRecordWritable implements Writable, + org.apache.hadoop.mapreduce.lib.db.DBWritable { + + private Object[] columnValues; + + public DBRecordWritable() { + } + + public DBRecordWritable(int numColumns) { + this.columnValues = new Object[numColumns]; + } + + public void clear() { + Arrays.fill(columnValues, null); + } + + public void set(int i, Object columnObject) { + columnValues[i] = columnObject; + } + + @Override + public void readFields(ResultSet rs) throws SQLException { + // do nothing + } + + @Override + public void write(PreparedStatement statement) throws SQLException { + if (columnValues == null) { + throw new SQLException("No data available to be written"); + } + for (int i = 0; i < columnValues.length; i++) { + statement.setObject(i + 1, columnValues[i]); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + // do nothing + } + + @Override + public void write(DataOutput out) throws IOException { + // do nothing + } + +} diff --git a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcOutputFormat.java b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcOutputFormat.java index 26fb3cdd09..fe247503c2 100644 --- a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcOutputFormat.java +++ b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcOutputFormat.java @@ -18,12 +18,15 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.hive.ql.io.HiveOutputFormat; +import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.io.MapWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.OutputFormat; +import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.util.Progressable; +import org.apache.hive.storage.jdbc.dao.GenericJdbcDatabaseAccessor; import java.io.IOException; import java.util.Properties; @@ -31,6 +34,9 @@ public class JdbcOutputFormat implements OutputFormat, HiveOutputFormat { + private org.apache.hadoop.mapreduce.RecordWriter recordWriter; + private TaskAttemptContext taskContext; + /** * {@inheritDoc} */ @@ -41,7 +47,10 @@ public RecordWriter getHiveRecordWriter(JobConf jc, boolean isCompressed, Properties tableProperties, Progressable progress) throws IOException { - throw new UnsupportedOperationException("Write operations are not allowed."); + taskContext = ShimLoader.getHadoopShims().newTaskAttemptContext(jc, null); + recordWriter = new GenericJdbcDatabaseAccessor().getRecordWriter(taskContext); + // Wrapping DBRecordWriter in JdbcRecordWriter + return new JdbcRecordWriter((org.apache.hadoop.mapreduce.lib.db.DBOutputFormat.DBRecordWriter) recordWriter); } diff --git a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcRecordWriter.java b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcRecordWriter.java new file mode 100644 index 0000000000..15d6067b39 --- /dev/null +++ b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcRecordWriter.java @@ -0,0 +1,64 @@ +/* + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.storage.jdbc; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.SQLException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.lib.db.DBOutputFormat.DBRecordWriter; +import org.apache.hadoop.util.StringUtils; + +public class JdbcRecordWriter implements RecordWriter { + private static final Log LOG = LogFactory.getLog(JdbcRecordWriter.class); + + @SuppressWarnings("rawtypes") + private final DBRecordWriter dbRecordWriter; + + @SuppressWarnings("rawtypes") + public JdbcRecordWriter(DBRecordWriter writer) { + this.dbRecordWriter = writer; + } + + @SuppressWarnings("unchecked") + @Override + public void write(Writable w) throws IOException { + dbRecordWriter.write((DBRecordWritable) w, null); + } + + @Override + public void close(boolean abort) throws IOException { + if (abort) { + Connection conn = dbRecordWriter.getConnection(); + try { + conn.rollback(); + } catch (SQLException ex) { + LOG.warn(StringUtils.stringifyException(ex)); + } finally { + try { + conn.close(); + } catch (SQLException ex) { + throw new IOException(ex.getMessage()); + } + } + } else { + dbRecordWriter.close(null); + } + } + +} diff --git a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcSerDe.java b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcSerDe.java index add1a1919b..139f0115b1 100644 --- a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcSerDe.java +++ b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/JdbcSerDe.java @@ -27,6 +27,9 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; @@ -57,9 +60,11 @@ private static final Logger LOGGER = LoggerFactory.getLogger(JdbcSerDe.class); private String[] hiveColumnNames; + private int numColumns; private PrimitiveTypeInfo[] hiveColumnTypes; private ObjectInspector inspector; private List row; + private DBRecordWritable dbRecordWritable; /* @@ -109,6 +114,9 @@ public void initialize(Configuration conf, Properties properties) throws SerDeEx throw new SerDeException("Received an empty Hive column type definition"); } + numColumns = hiveColumnNames.length; + dbRecordWritable = new DBRecordWritable(numColumns); + // Populate column types and inspector hiveColumnTypes = new PrimitiveTypeInfo[hiveColumnTypesList.size()]; List fieldInspectors = new ArrayList<>(hiveColumnNames.length); @@ -138,6 +146,41 @@ public void initialize(Configuration conf, Properties properties) throws SerDeEx } } + /* + * This method takes an object representing a row of data from Hive, and + * uses the ObjectInspector to get the data for each column and serialize. + */ + @Override + public DBRecordWritable serialize(Object row, ObjectInspector inspector) + throws SerDeException { + LOGGER.trace("Serializing from SerDe"); + final StructObjectInspector structObjectInspector = (StructObjectInspector) inspector; + final List fields = structObjectInspector + .getAllStructFieldRefs(); + if (fields.size() != numColumns) { + throw new SerDeException(String.format( + "Required %d columns, received %d.", numColumns, + fields.size())); + } + + dbRecordWritable.clear(); + for (int i = 0; i < numColumns; i++) { + StructField structField = fields.get(i); + if (structField != null) { + Object field = structObjectInspector.getStructFieldData(row, structField); + ObjectInspector fieldObjectInspector = structField.getFieldObjectInspector(); + if (fieldObjectInspector.getCategory() == Category.PRIMITIVE) { + PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector) fieldObjectInspector; + dbRecordWritable.set(i, primitiveObjectInspector.getPrimitiveJavaObject(field)); + } else { + throw new SerDeException("Unsupported type " + fieldObjectInspector); + } + } + } + return dbRecordWritable; + } + + @Override public Object deserialize(Writable blob) throws SerDeException { LOGGER.trace("Deserializing from SerDe"); @@ -254,12 +297,6 @@ public ObjectInspector getObjectInspector() throws SerDeException { } - @Override - public Writable serialize(Object obj, ObjectInspector objInspector) throws SerDeException { - throw new UnsupportedOperationException("Writes are not allowed"); - } - - @Override public SerDeStats getSerDeStats() { return null; diff --git a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/dao/GenericJdbcDatabaseAccessor.java b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/dao/GenericJdbcDatabaseAccessor.java index c2e7473e40..59a363d2a8 100644 --- a/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/dao/GenericJdbcDatabaseAccessor.java +++ b/jdbc-handler/src/main/java/org/apache/hive/storage/jdbc/dao/GenericJdbcDatabaseAccessor.java @@ -21,7 +21,11 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.Constants; import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.mapreduce.RecordWriter; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hive.storage.jdbc.conf.DatabaseType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,6 +35,7 @@ import javax.sql.DataSource; +import java.io.IOException; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -172,6 +177,67 @@ public int getTotalNumberOfRecords(Configuration conf) throws HiveJdbcDatabaseAc } } + public RecordWriter getRecordWriter(TaskAttemptContext context) + throws IOException { + Configuration conf = context.getConfiguration(); + Connection conn = null ; + PreparedStatement ps = null; + String dbProductName = conf.get(JdbcStorageConfig.DATABASE_TYPE.getPropertyName()).toUpperCase(); + String tableName = conf.get(JdbcStorageConfig.TABLE.getPropertyName()); + + if (tableName == null || tableName.isEmpty()) { + throw new IllegalArgumentException("Table name should be defined"); + } + + String[] columnNames = conf.get(serdeConstants.LIST_COLUMNS).split(","); + + try { + initializeDatabaseConnection(conf); + conn = dbcpDataSource.getConnection(); + ps = conn.prepareStatement( + constructQuery(tableName, columnNames, dbProductName)); + return new org.apache.hadoop.mapreduce.lib.db.DBOutputFormat() + .new DBRecordWriter(conn, ps); + } catch (Exception e) { + cleanupResources(conn, ps, null); + throw new IOException(e.getMessage()); + } + } + + /** + * Constructs the query used as the prepared statement to insert data. + * + * @param table + * the table to insert into + * @param columnNames + * the columns to insert into. + * @param dbProductName + * type of database + * + */ + public String constructQuery(String table, String[] columnNames, String dbProductName) { + if(columnNames == null) { + throw new IllegalArgumentException("Column names may not be null"); + } + + StringBuilder query = new StringBuilder(); + query.append("INSERT INTO ").append(table).append(" VALUES ("); + + for (int i = 0; i < columnNames.length; i++) { + query.append("?"); + if(i != columnNames.length - 1) { + query.append(","); + } + } + + if (!dbProductName.equals(DatabaseType.DERBY.toString()) && !dbProductName.equals(DatabaseType.ORACLE.toString()) + && !dbProductName.equals(DatabaseType.DB2.toString())) { + query.append(");"); + } else { + query.append(")"); + } + return query.toString(); + } /** * Uses generic JDBC escape functions to add a limit and offset clause to a query string diff --git a/ql/src/test/queries/clientpositive/jdbc_handler.q b/ql/src/test/queries/clientpositive/jdbc_handler.q index d086735594..cb9f6e2692 100644 --- a/ql/src/test/queries/clientpositive/jdbc_handler.q +++ b/ql/src/test/queries/clientpositive/jdbc_handler.q @@ -40,6 +40,9 @@ explain select * from ext_simple_derby_table where 100 < ext_simple_derby_table. select * from ext_simple_derby_table where 100 < ext_simple_derby_table.kkey; +insert into ext_simple_derby_table values(100); +select * from ext_simple_derby_table; + CREATE EXTERNAL TABLE tables ( id bigint, diff --git a/ql/src/test/results/clientpositive/llap/jdbc_handler.q.out b/ql/src/test/results/clientpositive/llap/jdbc_handler.q.out index 3c043f544f..9bb591478e 100644 --- a/ql/src/test/results/clientpositive/llap/jdbc_handler.q.out +++ b/ql/src/test/results/clientpositive/llap/jdbc_handler.q.out @@ -118,6 +118,25 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@ext_simple_derby_table #### A masked pattern was here #### 200 +PREHOOK: query: insert into ext_simple_derby_table values(100) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@ext_simple_derby_table +POSTHOOK: query: insert into ext_simple_derby_table values(100) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@ext_simple_derby_table +PREHOOK: query: select * from ext_simple_derby_table +PREHOOK: type: QUERY +PREHOOK: Input: default@ext_simple_derby_table +#### A masked pattern was here #### +POSTHOOK: query: select * from ext_simple_derby_table +POSTHOOK: type: QUERY +POSTHOOK: Input: default@ext_simple_derby_table +#### A masked pattern was here #### +20 +200 +100 PREHOOK: query: CREATE EXTERNAL TABLE tables ( id bigint,