diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 72336abfb0..b97b01daab 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -397,6 +397,7 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal llapDaemonVarsSetLocal.add(ConfVars.LLAP_VALIDATE_ACLS.varname); llapDaemonVarsSetLocal.add(ConfVars.LLAP_DAEMON_LOGGER.varname); llapDaemonVarsSetLocal.add(ConfVars.LLAP_DAEMON_AM_USE_FQDN.varname); + llapDaemonVarsSetLocal.add(ConfVars.LLAP_OUTPUT_FORMAT_ARROW.varname); } /** @@ -2623,6 +2624,11 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "Set to true to ensure that each SQL Merge statement ensures that for each row in the target\n" + "table there is at most 1 matching row in the source table per SQL Specification."), + // For Arrow SerDe + HIVE_ARROW_ROOT_ALLOCATOR_LIMIT("hive.arrow.root.allocator.limit", Long.MAX_VALUE, + "Arrow root allocator memory size limitation in bytes."), + HIVE_ARROW_BATCH_SIZE("hive.arrow.batch.size", 1000, "The number of rows sent in one Arrow batch."), + // For Druid storage handler HIVE_DRUID_INDEXING_GRANULARITY("hive.druid.indexer.segments.granularity", "DAY", new PatternSet("YEAR", "MONTH", "WEEK", "DAY", "HOUR", "MINUTE", "SECOND"), @@ -4151,6 +4157,8 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal Constants.LLAP_LOGGER_NAME_RFA, Constants.LLAP_LOGGER_NAME_CONSOLE), "logger used for llap-daemons."), + LLAP_OUTPUT_FORMAT_ARROW("hive.llap.output.format.arrow", true, + "Whether LLapOutputFormatService should output arrow batches"), HIVE_TRIGGER_VALIDATION_INTERVAL("hive.trigger.validation.interval", "500ms", new TimeValidator(TimeUnit.MILLISECONDS), diff --git itests/hive-jmh/pom.xml itests/hive-jmh/pom.xml index 5eb30267dc..e045ace6f6 100644 --- itests/hive-jmh/pom.xml +++ itests/hive-jmh/pom.xml @@ -66,7 +66,7 @@ org.apache.hive hive-storage-api - 2.7.0-SNAPSHOT + ${storage-api.version} org.apache.hive diff --git itests/hive-unit/src/test/java/org/apache/hive/jdbc/AbstractJdbcTriggersTest.java itests/hive-unit/src/test/java/org/apache/hive/jdbc/AbstractJdbcTriggersTest.java index 17e44bb37f..7d5172b421 100644 --- itests/hive-unit/src/test/java/org/apache/hive/jdbc/AbstractJdbcTriggersTest.java +++ itests/hive-unit/src/test/java/org/apache/hive/jdbc/AbstractJdbcTriggersTest.java @@ -90,7 +90,7 @@ public static void beforeTest() throws Exception { @Before public void setUp() throws Exception { - hs2Conn = TestJdbcWithMiniLlap.getConnection(miniHS2.getJdbcURL(), System.getProperty("user.name"), "bar"); + hs2Conn = BaseJdbcWithMiniLlap.getConnection(miniHS2.getJdbcURL(), System.getProperty("user.name"), "bar"); } @After @@ -124,7 +124,7 @@ void runQueryWithTrigger(final String query, final List setCmds, throws Exception { Connection con = hs2Conn; - TestJdbcWithMiniLlap.createTestTable(con, null, tableName, kvDataFilePath.toString()); + BaseJdbcWithMiniLlap.createTestTable(con, null, tableName, kvDataFilePath.toString()); createSleepUDF(); final ByteArrayOutputStream baos = new ByteArrayOutputStream(); diff --git itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java similarity index 92% rename from itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java rename to itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java index 68a8e21307..7a891ef360 100644 --- itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlap.java +++ itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java @@ -77,7 +77,6 @@ import org.apache.hive.jdbc.miniHS2.MiniHS2; import org.apache.hive.jdbc.miniHS2.MiniHS2.MiniClusterType; import org.apache.hadoop.hive.llap.LlapBaseInputFormat; -import org.apache.hadoop.hive.llap.LlapRowInputFormat; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; @@ -93,20 +92,25 @@ import org.junit.AfterClass; import org.junit.Assert; import org.junit.Before; -import org.junit.BeforeClass; import org.junit.Test; +import org.apache.hadoop.mapred.InputFormat; -public class TestJdbcWithMiniLlap { +/** + * Specialize this base class for different serde's/formats + * {@link #beforeTest(boolean) beforeTest} should be called + * by sub-classes in a {@link org.junit.BeforeClass} initializer + */ +public abstract class BaseJdbcWithMiniLlap { private static MiniHS2 miniHS2 = null; private static String dataFileDir; private static Path kvDataFilePath; private static Path dataTypesFilePath; private static HiveConf conf = null; - private Connection hs2Conn = null; + private static Connection hs2Conn = null; - @BeforeClass - public static void beforeTest() throws Exception { + // This method should be called by sub-classes in a @BeforeClass initializer + public static void beforeTest(boolean useArrow) throws Exception { Class.forName(MiniHS2.getJdbcDriverName()); String confDir = "../../data/conf/llap/"; @@ -118,6 +122,11 @@ public static void beforeTest() throws Exception { conf = new HiveConf(); conf.setBoolVar(ConfVars.HIVE_SUPPORT_CONCURRENCY, false); conf.setBoolVar(ConfVars.HIVE_SERVER2_ENABLE_DOAS, false); + if(useArrow) { + conf.setBoolVar(ConfVars.LLAP_OUTPUT_FORMAT_ARROW, true); + } else { + conf.setBoolVar(ConfVars.LLAP_OUTPUT_FORMAT_ARROW, false); + } conf.addResource(new URL("file://" + new File(confDir).toURI().getPath() + "/tez-site.xml")); @@ -184,7 +193,7 @@ public static void createTestTable(Connection connection, String database, Strin stmt.close(); } - private void createDataTypesTable(String tableName) throws Exception { + protected void createDataTypesTable(String tableName) throws Exception { Statement stmt = hs2Conn.createStatement(); // create table @@ -235,12 +244,12 @@ public void testLlapInputFormatEndToEnd() throws Exception { @Test(timeout = 60000) public void testNonAsciiStrings() throws Exception { - createTestTable(hs2Conn, "nonascii", "testtab_nonascii", kvDataFilePath.toString()); + createTestTable("testtab_nonascii"); RowCollector rowCollector = new RowCollector(); String nonAscii = "À côté du garçon"; String query = "select value, '" + nonAscii + "' from testtab_nonascii where under_col=0"; - int rowCount = processQuery("nonascii", query, 1, rowCollector); + int rowCount = processQuery(query, 1, rowCollector); assertEquals(3, rowCount); assertArrayEquals(new String[] {"val_0", nonAscii}, rowCollector.rows.get(0)); @@ -439,11 +448,24 @@ public void testDataTypes() throws Exception { assertArrayEquals("X'01FF'".getBytes("UTF-8"), (byte[]) rowValues[22]); } + + @Test(timeout = 60000) + public void testComplexQuery() throws Exception { + createTestTable("testtab1"); + + RowCollector rowCollector = new RowCollector(); + String query = "select value, count(*) from testtab1 where under_col=0 group by value"; + int rowCount = processQuery(query, 1, rowCollector); + assertEquals(1, rowCount); + + assertArrayEquals(new String[] {"val_0", "3"}, rowCollector.rows.get(0)); + } + private interface RowProcessor { void process(Row row); } - private static class RowCollector implements RowProcessor { + protected static class RowCollector implements RowProcessor { ArrayList rows = new ArrayList(); Schema schema = null; int numColumns = 0; @@ -464,7 +486,7 @@ public void process(Row row) { } // Save the actual values from each row as opposed to the String representation. - private static class RowCollector2 implements RowProcessor { + protected static class RowCollector2 implements RowProcessor { ArrayList rows = new ArrayList(); Schema schema = null; int numColumns = 0; @@ -483,17 +505,19 @@ public void process(Row row) { } } - private int processQuery(String query, int numSplits, RowProcessor rowProcessor) throws Exception { + protected int processQuery(String query, int numSplits, RowProcessor rowProcessor) throws Exception { return processQuery(null, query, numSplits, rowProcessor); } + protected abstract InputFormat getInputFormat(); + private int processQuery(String currentDatabase, String query, int numSplits, RowProcessor rowProcessor) throws Exception { String url = miniHS2.getJdbcURL(); String user = System.getProperty("user.name"); String pwd = user; String handleId = UUID.randomUUID().toString(); - LlapRowInputFormat inputFormat = new LlapRowInputFormat(); + InputFormat inputFormat = getInputFormat(); // Get splits JobConf job = new JobConf(conf); @@ -600,4 +624,5 @@ public void run() { private static class ExceptionHolder { Throwable throwable; } -} \ No newline at end of file +} + diff --git itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapArrow.java itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapArrow.java new file mode 100644 index 0000000000..afb9837ce4 --- /dev/null +++ itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapArrow.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import org.apache.hadoop.hive.llap.FieldDesc; +import org.apache.hadoop.hive.llap.Row; +import org.apache.hadoop.io.NullWritable; +import org.junit.BeforeClass; + +import org.apache.hadoop.mapred.InputFormat; +import org.apache.hadoop.hive.llap.LlapArrowRowInputFormat; + +/** + * TestJdbcWithMiniLlap for Arrow format + */ +public class TestJdbcWithMiniLlapArrow extends BaseJdbcWithMiniLlap { + + + @BeforeClass + public static void beforeTest() throws Exception { + BaseJdbcWithMiniLlap.beforeTest(true); + } + + @Override + protected InputFormat getInputFormat() { + //For unit testing, no harm in hard-coding allocator ceiling to LONG.MAX_VALUE + return new LlapArrowRowInputFormat(Long.MAX_VALUE); + } + + // Currently MAP type is not supported. Add it back when Arrow 1.0 is released. + // See: SPARK-21187 + @Override + public void testDataTypes() throws Exception { + createDataTypesTable("datatypes"); + RowCollector2 rowCollector = new RowCollector2(); + String query = "select * from datatypes"; + int rowCount = processQuery(query, 1, rowCollector); + assertEquals(3, rowCount); + + // Verify schema + String[][] colNameTypes = new String[][] { + {"datatypes.c1", "int"}, + {"datatypes.c2", "boolean"}, + {"datatypes.c3", "double"}, + {"datatypes.c4", "string"}, + {"datatypes.c5", "array"}, + {"datatypes.c6", "map"}, + {"datatypes.c7", "map"}, + {"datatypes.c8", "struct"}, + {"datatypes.c9", "tinyint"}, + {"datatypes.c10", "smallint"}, + {"datatypes.c11", "float"}, + {"datatypes.c12", "bigint"}, + {"datatypes.c13", "array>"}, + {"datatypes.c14", "map>"}, + {"datatypes.c15", "struct>"}, + {"datatypes.c16", "array,n:int>>"}, + {"datatypes.c17", "timestamp"}, + {"datatypes.c18", "decimal(16,7)"}, + {"datatypes.c19", "binary"}, + {"datatypes.c20", "date"}, + {"datatypes.c21", "varchar(20)"}, + {"datatypes.c22", "char(15)"}, + {"datatypes.c23", "binary"}, + }; + FieldDesc fieldDesc; + assertEquals(23, rowCollector.numColumns); + for (int idx = 0; idx < rowCollector.numColumns; ++idx) { + fieldDesc = rowCollector.schema.getColumns().get(idx); + assertEquals("ColName idx=" + idx, colNameTypes[idx][0], fieldDesc.getName()); + assertEquals("ColType idx=" + idx, colNameTypes[idx][1], fieldDesc.getTypeInfo().getTypeName()); + } + + // First row is all nulls + Object[] rowValues = rowCollector.rows.get(0); + for (int idx = 0; idx < rowCollector.numColumns; ++idx) { + assertEquals("idx=" + idx, null, rowValues[idx]); + } + + // Second Row + rowValues = rowCollector.rows.get(1); + assertEquals(Integer.valueOf(-1), rowValues[0]); + assertEquals(Boolean.FALSE, rowValues[1]); + assertEquals(Double.valueOf(-1.1d), rowValues[2]); + assertEquals("", rowValues[3]); + + List c5Value = (List) rowValues[4]; + assertEquals(0, c5Value.size()); + + //Map c6Value = (Map) rowValues[5]; + //assertEquals(0, c6Value.size()); + + //Map c7Value = (Map) rowValues[6]; + //assertEquals(0, c7Value.size()); + + List c8Value = (List) rowValues[7]; + assertEquals(null, c8Value.get(0)); + assertEquals(null, c8Value.get(1)); + assertEquals(null, c8Value.get(2)); + + assertEquals(Byte.valueOf((byte) -1), rowValues[8]); + assertEquals(Short.valueOf((short) -1), rowValues[9]); + assertEquals(Float.valueOf(-1.0f), rowValues[10]); + assertEquals(Long.valueOf(-1l), rowValues[11]); + + List c13Value = (List) rowValues[12]; + assertEquals(0, c13Value.size()); + + //Map c14Value = (Map) rowValues[13]; + //assertEquals(0, c14Value.size()); + + List c15Value = (List) rowValues[14]; + assertEquals(null, c15Value.get(0)); + assertEquals(null, c15Value.get(1)); + + //List c16Value = (List) rowValues[15]; + //assertEquals(0, c16Value.size()); + + assertEquals(null, rowValues[16]); + assertEquals(null, rowValues[17]); + assertEquals(null, rowValues[18]); + assertEquals(null, rowValues[19]); + assertEquals(null, rowValues[20]); + assertEquals(null, rowValues[21]); + assertEquals(null, rowValues[22]); + + // Third row + rowValues = rowCollector.rows.get(2); + assertEquals(Integer.valueOf(1), rowValues[0]); + assertEquals(Boolean.TRUE, rowValues[1]); + assertEquals(Double.valueOf(1.1d), rowValues[2]); + assertEquals("1", rowValues[3]); + + c5Value = (List) rowValues[4]; + assertEquals(2, c5Value.size()); + assertEquals(Integer.valueOf(1), c5Value.get(0)); + assertEquals(Integer.valueOf(2), c5Value.get(1)); + + //c6Value = (Map) rowValues[5]; + //assertEquals(2, c6Value.size()); + //assertEquals("x", c6Value.get(Integer.valueOf(1))); + //assertEquals("y", c6Value.get(Integer.valueOf(2))); + + //c7Value = (Map) rowValues[6]; + //assertEquals(1, c7Value.size()); + //assertEquals("v", c7Value.get("k")); + + c8Value = (List) rowValues[7]; + assertEquals("a", c8Value.get(0)); + assertEquals(Integer.valueOf(9), c8Value.get(1)); + assertEquals(Double.valueOf(2.2d), c8Value.get(2)); + + assertEquals(Byte.valueOf((byte) 1), rowValues[8]); + assertEquals(Short.valueOf((short) 1), rowValues[9]); + assertEquals(Float.valueOf(1.0f), rowValues[10]); + assertEquals(Long.valueOf(1l), rowValues[11]); + + c13Value = (List) rowValues[12]; + assertEquals(2, c13Value.size()); + List listVal = (List) c13Value.get(0); + assertEquals("a", listVal.get(0)); + assertEquals("b", listVal.get(1)); + listVal = (List) c13Value.get(1); + assertEquals("c", listVal.get(0)); + assertEquals("d", listVal.get(1)); + + //c14Value = (Map) rowValues[13]; + //assertEquals(2, c14Value.size()); + //Map mapVal = (Map) c14Value.get(Integer.valueOf(1)); + //assertEquals(2, mapVal.size()); + //assertEquals(Integer.valueOf(12), mapVal.get(Integer.valueOf(11))); + //assertEquals(Integer.valueOf(14), mapVal.get(Integer.valueOf(13))); + //mapVal = (Map) c14Value.get(Integer.valueOf(2)); + //assertEquals(1, mapVal.size()); + //assertEquals(Integer.valueOf(22), mapVal.get(Integer.valueOf(21))); + + c15Value = (List) rowValues[14]; + assertEquals(Integer.valueOf(1), c15Value.get(0)); + listVal = (List) c15Value.get(1); + assertEquals(2, listVal.size()); + assertEquals(Integer.valueOf(2), listVal.get(0)); + assertEquals("x", listVal.get(1)); + + //c16Value = (List) rowValues[15]; + //assertEquals(2, c16Value.size()); + //listVal = (List) c16Value.get(0); + //assertEquals(2, listVal.size()); + //mapVal = (Map) listVal.get(0); + //assertEquals(0, mapVal.size()); + //assertEquals(Integer.valueOf(1), listVal.get(1)); + //listVal = (List) c16Value.get(1); + //mapVal = (Map) listVal.get(0); + //assertEquals(2, mapVal.size()); + //assertEquals("b", mapVal.get("a")); + //assertEquals("d", mapVal.get("c")); + //assertEquals(Integer.valueOf(2), listVal.get(1)); + + assertEquals(Timestamp.valueOf("2012-04-22 09:00:00.123456789"), rowValues[16]); + assertEquals(new BigDecimal("123456789.123456"), rowValues[17]); + assertArrayEquals("abcd".getBytes("UTF-8"), (byte[]) rowValues[18]); + assertEquals(Date.valueOf("2013-01-01"), rowValues[19]); + assertEquals("abc123", rowValues[20]); + assertEquals("abc123 ", rowValues[21]); + assertArrayEquals("X'01FF'".getBytes("UTF-8"), (byte[]) rowValues[22]); + } + +} + diff --git itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapRow.java itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapRow.java new file mode 100644 index 0000000000..809068fe3e --- /dev/null +++ itests/hive-unit/src/test/java/org/apache/hive/jdbc/TestJdbcWithMiniLlapRow.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.jdbc; + +import org.apache.hadoop.hive.llap.Row; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.hive.llap.LlapRowInputFormat; +import org.junit.BeforeClass; +import org.junit.Before; +import org.junit.After; +import org.apache.hadoop.mapred.InputFormat; + +/** + * TestJdbcWithMiniLlap for llap Row format. + */ +public class TestJdbcWithMiniLlapRow extends BaseJdbcWithMiniLlap { + + @BeforeClass + public static void beforeTest() throws Exception { + BaseJdbcWithMiniLlap.beforeTest(false); + } + + @Override + protected InputFormat getInputFormat() { + return new LlapRowInputFormat(); + } + +} + diff --git llap-client/src/java/org/apache/hadoop/hive/llap/LlapBaseRecordReader.java llap-client/src/java/org/apache/hadoop/hive/llap/LlapBaseRecordReader.java index a9ed3d200f..5316aa7489 100644 --- llap-client/src/java/org/apache/hadoop/hive/llap/LlapBaseRecordReader.java +++ llap-client/src/java/org/apache/hadoop/hive/llap/LlapBaseRecordReader.java @@ -22,25 +22,15 @@ import java.io.BufferedInputStream; import java.io.Closeable; -import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.DataInputStream; import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.llap.Schema; import org.apache.hadoop.hive.llap.io.ChunkedInputStream; -import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.io.NullWritable; -import org.apache.hadoop.util.ReflectionUtils; import org.apache.hadoop.mapred.RecordReader; -import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.mapred.JobConf; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -149,52 +139,61 @@ public boolean next(NullWritable key, V value) throws IOException { throw new IOException("Hit end of input, but did not find expected end of data indicator"); } - // There should be a reader event available, or coming soon, so okay to be blocking call. - ReaderEvent event = getReaderEvent(); - switch (event.getEventType()) { - case DONE: - break; - default: - throw new IOException("Expected reader event with done status, but got " - + event.getEventType() + " with message " + event.getMessage()); - } + processReaderEvent(); return false; } } catch (IOException io) { - try { - if (Thread.interrupted()) { - // Either we were interrupted by one of: - // 1. handleEvent(), in which case there is a reader (error) event waiting for us in the queue - // 2. Some other unrelated cause which interrupted us, in which case there may not be a reader event coming. - // Either way we should not try to block trying to read the reader events queue. - if (readerEvents.isEmpty()) { - // Case 2. - throw io; - } else { - // Case 1. Fail the reader, sending back the error we received from the reader event. - ReaderEvent event = getReaderEvent(); - switch (event.getEventType()) { - case ERROR: - throw new IOException("Received reader event error: " + event.getMessage(), io); - default: - throw new IOException("Got reader event type " + event.getEventType() - + ", expected error event", io); - } - } - } else { - // If we weren't interrupted, just propagate the error + failOnInterruption(io); + return false; + } + } + + protected void processReaderEvent() throws IOException { + // There should be a reader event available, or coming soon, so okay to be blocking call. + ReaderEvent event = getReaderEvent(); + switch (event.getEventType()) { + case DONE: + break; + default: + throw new IOException("Expected reader event with done status, but got " + + event.getEventType() + " with message " + event.getMessage()); + } + } + + protected void failOnInterruption(IOException io) throws IOException { + try { + if (Thread.interrupted()) { + // Either we were interrupted by one of: + // 1. handleEvent(), in which case there is a reader (error) event waiting for us in the queue + // 2. Some other unrelated cause which interrupted us, in which case there may not be a reader event coming. + // Either way we should not try to block trying to read the reader events queue. + if (readerEvents.isEmpty()) { + // Case 2. throw io; + } else { + // Case 1. Fail the reader, sending back the error we received from the reader event. + ReaderEvent event = getReaderEvent(); + switch (event.getEventType()) { + case ERROR: + throw new IOException("Received reader event error: " + event.getMessage(), io); + default: + throw new IOException("Got reader event type " + event.getEventType() + + ", expected error event", io); + } } - } finally { - // The external client handling umbilical responses and the connection to read the incoming - // data are not coupled. Calling close() here to make sure an error in one will cause the - // other to be closed as well. - try { - close(); - } catch (Exception err) { - // Don't propagate errors from close() since this will lose the original error above. - LOG.error("Closing RecordReader due to error and hit another error during close()", err); - } + } else { + // If we weren't interrupted, just propagate the error + throw io; + } + } finally { + // The external client handling umbilical responses and the connection to read the incoming + // data are not coupled. Calling close() here to make sure an error in one will cause the + // other to be closed as well. + try { + close(); + } catch (Exception err) { + // Don't propagate errors from close() since this will lose the original error above. + LOG.error("Closing RecordReader due to error and hit another error during close()", err); } } } diff --git llap-client/src/java/org/apache/hadoop/hive/llap/LlapRowRecordReader.java llap-client/src/java/org/apache/hadoop/hive/llap/LlapRowRecordReader.java index 1cfbf3a86e..6cc1d1792b 100644 --- llap-client/src/java/org/apache/hadoop/hive/llap/LlapRowRecordReader.java +++ llap-client/src/java/org/apache/hadoop/hive/llap/LlapRowRecordReader.java @@ -29,7 +29,6 @@ import java.util.Properties; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; @@ -70,20 +69,20 @@ private static final Logger LOG = LoggerFactory.getLogger(LlapRowRecordReader.class); protected final Configuration conf; - protected final RecordReader reader; + protected final RecordReader reader; protected final Schema schema; protected final AbstractSerDe serde; - protected final BytesWritable data; + protected final Writable data; public LlapRowRecordReader(Configuration conf, Schema schema, - RecordReader reader) throws IOException { + RecordReader reader) throws IOException { this.conf = conf; this.schema = schema; this.reader = reader; - this.data = new BytesWritable(); + this.data = reader.createValue(); try { - serde = initSerDe(conf); + this.serde = initSerDe(conf); } catch (SerDeException err) { throw new IOException(err); } @@ -118,7 +117,7 @@ public float getProgress() throws IOException { public boolean next(NullWritable key, Row value) throws IOException { Preconditions.checkArgument(value != null); - boolean hasNext = reader.next(key, data); + boolean hasNext = reader.next(key, data); if (hasNext) { // Deserialize data to column values, and populate the row record Object rowObj; @@ -216,7 +215,7 @@ static Object convertValue(Object val, ObjectInspector oi) { return convertedVal; } - static void setRowFromStruct(Row row, Object structVal, StructObjectInspector soi) { + protected static void setRowFromStruct(Row row, Object structVal, StructObjectInspector soi) { Schema structSchema = row.getSchema(); // Add struct field data to the Row List structFields = soi.getAllStructFieldRefs(); @@ -230,6 +229,11 @@ static void setRowFromStruct(Row row, Object structVal, StructObjectInspector so } } + //Factory method for serDe + protected AbstractSerDe createSerDe() throws SerDeException { + return new LazyBinarySerDe(); + } + protected AbstractSerDe initSerDe(Configuration conf) throws SerDeException { Properties props = new Properties(); StringBuilder columnsBuffer = new StringBuilder(); @@ -249,9 +253,9 @@ protected AbstractSerDe initSerDe(Configuration conf) throws SerDeException { props.put(serdeConstants.LIST_COLUMNS, columns); props.put(serdeConstants.LIST_COLUMN_TYPES, types); props.put(serdeConstants.ESCAPE_CHAR, "\\"); - AbstractSerDe serde = new LazyBinarySerDe(); - serde.initialize(conf, props); + AbstractSerDe createdSerDe = createSerDe(); + createdSerDe.initialize(conf, props); - return serde; + return createdSerDe; } } diff --git llap-ext-client/pom.xml llap-ext-client/pom.xml index ed4704b4cd..295d3e6319 100644 --- llap-ext-client/pom.xml +++ llap-ext-client/pom.xml @@ -41,6 +41,11 @@ org.apache.hive + hive-exec + ${project.version} + + + org.apache.hive hive-llap-client ${project.version} diff --git llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowBatchRecordReader.java llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowBatchRecordReader.java new file mode 100644 index 0000000000..d9c5666bc4 --- /dev/null +++ llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowBatchRecordReader.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap; + +import com.google.common.base.Preconditions; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.hadoop.hive.ql.io.arrow.ArrowWrapperWritable; +import org.apache.hadoop.hive.ql.io.arrow.RootAllocatorFactory; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.JobConf; +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.net.Socket; + +/* + * Read from Arrow stream batch-by-batch + */ +public class LlapArrowBatchRecordReader extends LlapBaseRecordReader { + + private BufferAllocator allocator; + private ArrowStreamReader arrowStreamReader; + + public LlapArrowBatchRecordReader(InputStream in, Schema schema, Class clazz, + JobConf job, Closeable client, Socket socket, long arrowAllocatorLimit) throws IOException { + super(in, schema, clazz, job, client, socket); + allocator = RootAllocatorFactory.INSTANCE.getOrCreateRootAllocator(arrowAllocatorLimit); + this.arrowStreamReader = new ArrowStreamReader(socket.getInputStream(), allocator); + } + + @Override + public boolean next(NullWritable key, ArrowWrapperWritable value) throws IOException { + try { + // Need a way to know what thread to interrupt, since this is a blocking thread. + setReaderThread(Thread.currentThread()); + + boolean hasInput = arrowStreamReader.loadNextBatch(); + if (hasInput) { + VectorSchemaRoot vectorSchemaRoot = arrowStreamReader.getVectorSchemaRoot(); + //There must be at least one column vector + Preconditions.checkState(vectorSchemaRoot.getFieldVectors().size() > 0); + if(vectorSchemaRoot.getFieldVectors().get(0).getValueCount() == 0) { + //An empty batch will appear at the end of the stream + return false; + } + value.setVectorSchemaRoot(arrowStreamReader.getVectorSchemaRoot()); + return true; + } else { + processReaderEvent(); + return false; + } + } catch (IOException io) { + failOnInterruption(io); + return false; + } + } + + @Override + public void close() throws IOException { + arrowStreamReader.close(); + } + +} + diff --git llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowInputFormat.java llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowInputFormat.java new file mode 100644 index 0000000000..fafbdee210 --- /dev/null +++ llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowInputFormat.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap; + +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapred.InputFormat; +import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.RecordReader; +import org.apache.hadoop.mapred.Reporter; +import java.io.IOException; + +/* + * Adapts an Arrow batch reader to a row reader + */ +public class LlapArrowRowInputFormat implements InputFormat { + + private LlapBaseInputFormat baseInputFormat; + + public LlapArrowRowInputFormat(long arrowAllocatorLimit) { + baseInputFormat = new LlapBaseInputFormat(true, arrowAllocatorLimit); + } + + @Override + public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { + return baseInputFormat.getSplits(job, numSplits); + } + + @Override + public RecordReader getRecordReader(InputSplit split, JobConf job, Reporter reporter) + throws IOException { + LlapInputSplit llapSplit = (LlapInputSplit) split; + LlapArrowBatchRecordReader reader = + (LlapArrowBatchRecordReader) baseInputFormat.getRecordReader(llapSplit, job, reporter); + return new LlapArrowRowRecordReader(job, reader.getSchema(), reader); + } +} diff --git llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowRecordReader.java llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowRecordReader.java new file mode 100644 index 0000000000..d4179d5202 --- /dev/null +++ llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapArrowRowRecordReader.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap; + +import com.google.common.base.Preconditions; +import org.apache.arrow.vector.FieldVector; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe; +import org.apache.hadoop.hive.ql.io.arrow.ArrowWrapperWritable; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapred.RecordReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +/** + * Buffers a batch for reading one row at a time. + */ +public class LlapArrowRowRecordReader extends LlapRowRecordReader { + + private static final Logger LOG = LoggerFactory.getLogger(LlapArrowRowRecordReader.class); + private int rowIndex = 0; + private int batchSize = 0; + + //Buffer one batch at a time, for row retrieval + private Object[][] currentBatch; + + public LlapArrowRowRecordReader(Configuration conf, Schema schema, + RecordReader reader) throws IOException { + super(conf, schema, reader); + } + + @Override + public boolean next(NullWritable key, Row value) throws IOException { + Preconditions.checkArgument(value != null); + boolean hasNext = false; + ArrowWrapperWritable batchData = (ArrowWrapperWritable) data; + if((batchSize == 0) || (rowIndex == batchSize)) { + //This is either the first batch or we've used up the current batch buffer + batchSize = 0; + rowIndex = 0; + hasNext = reader.next(key, data); + if(hasNext) { + //There is another batch to buffer + try { + List vectors = batchData.getVectorSchemaRoot().getFieldVectors(); + //hasNext implies there is some column in the batch + Preconditions.checkState(vectors.size() > 0); + //All the vectors have the same length, + //we can get the number of rows from the first vector + batchSize = vectors.get(0).getValueCount(); + ArrowWrapperWritable wrapper = new ArrowWrapperWritable(batchData.getVectorSchemaRoot()); + currentBatch = (Object[][]) serde.deserialize(wrapper); + StructObjectInspector rowOI = (StructObjectInspector) serde.getObjectInspector(); + setRowFromStruct(value, currentBatch[rowIndex], rowOI); + } catch (Exception e) { + LOG.error("Failed to fetch Arrow batch", e); + throw new RuntimeException(e); + } + } + //There were no more batches AND + //this is either the first batch or we've used up the current batch buffer. + //goto return false + } else if(rowIndex < batchSize) { + //Take a row from the current buffered batch + hasNext = true; + StructObjectInspector rowOI = null; + try { + rowOI = (StructObjectInspector) serde.getObjectInspector(); + } catch (SerDeException e) { + throw new RuntimeException(e); + } + setRowFromStruct(value, currentBatch[rowIndex], rowOI); + } + //Always inc the batch buffer index + //If we return false, it is just a noop + rowIndex++; + return hasNext; + } + + protected AbstractSerDe createSerDe() throws SerDeException { + return new ArrowColumnarBatchSerDe(); + } + +} diff --git llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java index f4c7fa4b30..ef03be660e 100644 --- llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java +++ llap-ext-client/src/java/org/apache/hadoop/hive/llap/LlapBaseInputFormat.java @@ -49,15 +49,15 @@ import org.apache.hadoop.hive.llap.ext.LlapTaskUmbilicalExternalClient; import org.apache.hadoop.hive.llap.ext.LlapTaskUmbilicalExternalClient.LlapTaskUmbilicalExternalResponder; import org.apache.hadoop.hive.llap.registry.LlapServiceInstance; -import org.apache.hadoop.hive.llap.registry.LlapServiceInstanceSet; import org.apache.hadoop.hive.llap.registry.impl.LlapRegistryService; import org.apache.hadoop.hive.llap.security.LlapTokenIdentifier; import org.apache.hadoop.hive.llap.tez.Converters; +import org.apache.hadoop.hive.ql.io.arrow.ArrowWrapperWritable; import org.apache.hadoop.hive.registry.ServiceInstanceSet; +import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.DataInputBuffer; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.NullWritable; -import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.mapred.InputFormat; import org.apache.hadoop.mapred.InputSplit; @@ -104,6 +104,8 @@ private String user; // "hive", private String pwd; // "" private String query; + private boolean useArrow; + private long arrowAllocatorLimit; private final Random rand = new Random(); public static final String URL_KEY = "llap.if.hs2.connection"; @@ -123,7 +125,14 @@ public LlapBaseInputFormat(String url, String user, String pwd, String query) { this.query = query; } - public LlapBaseInputFormat() {} + public LlapBaseInputFormat(boolean useArrow, long arrowAllocatorLimit) { + this.useArrow = useArrow; + this.arrowAllocatorLimit = arrowAllocatorLimit; + } + + public LlapBaseInputFormat() { + this.useArrow = false; + } @SuppressWarnings("unchecked") @@ -195,8 +204,16 @@ public LlapBaseInputFormat() {} LOG.info("Registered id: " + fragmentId); @SuppressWarnings("rawtypes") - LlapBaseRecordReader recordReader = new LlapBaseRecordReader(socket.getInputStream(), - llapSplit.getSchema(), Text.class, job, llapClient, (java.io.Closeable)socket); + LlapBaseRecordReader recordReader; + if(useArrow) { + recordReader = new LlapArrowBatchRecordReader( + socket.getInputStream(), llapSplit.getSchema(), + ArrowWrapperWritable.class, job, llapClient, socket, + arrowAllocatorLimit); + } else { + recordReader = new LlapBaseRecordReader(socket.getInputStream(), + llapSplit.getSchema(), BytesWritable.class, job, llapClient, (java.io.Closeable)socket); + } umbilicalResponder.setRecordReader(recordReader); return recordReader; } diff --git pom.xml pom.xml index f318e190d2..83f609a5f5 100644 --- pom.xml +++ pom.xml @@ -119,6 +119,7 @@ 3.5.2 1.5.6 0.1 + 0.8.0 1.11.0 1.7.7 diff --git ql/src/java/org/apache/hadoop/hive/llap/LlapArrowRecordWriter.java ql/src/java/org/apache/hadoop/hive/llap/LlapArrowRecordWriter.java new file mode 100644 index 0000000000..1b3a3ebb26 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/llap/LlapArrowRecordWriter.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap; + +import java.io.IOException; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.hadoop.hive.ql.io.arrow.ArrowWrapperWritable; +import org.apache.hadoop.io.Writable; +import java.nio.channels.WritableByteChannel; +import org.apache.hadoop.mapred.RecordWriter; +import org.apache.hadoop.mapred.Reporter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Writes Arrow batches to an {@link org.apache.arrow.vector.ipc.ArrowStreamWriter}. + * The byte stream will be formatted according to the Arrow Streaming format. + * Because ArrowStreamWriter is bound to a {@link org.apache.arrow.vector.VectorSchemaRoot} + * when it is created, + * calls to the {@link #write(Writable, Writable)} method only serve as a signal that + * a new batch has been loaded to the associated VectorSchemaRoot. + * Payload data for writing is indirectly made available by reference: + * ArrowStreamWriter -> VectorSchemaRoot -> List + * i.e. both they key and value are ignored once a reference to the VectorSchemaRoot + * is obtained. + */ +public class LlapArrowRecordWriter + implements RecordWriter { + public static final Logger LOG = LoggerFactory.getLogger(LlapArrowRecordWriter.class); + + ArrowStreamWriter arrowStreamWriter; + WritableByteChannel out; + + public LlapArrowRecordWriter(WritableByteChannel out) { + this.out = out; + } + + @Override + public void close(Reporter reporter) throws IOException { + arrowStreamWriter.close(); + } + + @Override + public void write(K key, V value) throws IOException { + ArrowWrapperWritable arrowWrapperWritable = (ArrowWrapperWritable) value; + if (arrowStreamWriter == null) { + VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); + arrowStreamWriter = new ArrowStreamWriter(vectorSchemaRoot, null, out); + } + arrowStreamWriter.writeBatch(); + } +} diff --git ql/src/java/org/apache/hadoop/hive/llap/LlapOutputFormatService.java ql/src/java/org/apache/hadoop/hive/llap/LlapOutputFormatService.java index 30d5eb5eab..c71c637c71 100644 --- ql/src/java/org/apache/hadoop/hive/llap/LlapOutputFormatService.java +++ ql/src/java/org/apache/hadoop/hive/llap/LlapOutputFormatService.java @@ -198,11 +198,16 @@ private void registerReader(ChannelHandlerContext ctx, String id, byte[] tokenBy LOG.debug("registering socket for: " + id); int maxPendingWrites = HiveConf.getIntVar(conf, HiveConf.ConfVars.LLAP_DAEMON_OUTPUT_SERVICE_MAX_PENDING_WRITES); + boolean useArrow = HiveConf.getBoolVar(conf, HiveConf.ConfVars.LLAP_OUTPUT_FORMAT_ARROW); @SuppressWarnings("rawtypes") - LlapRecordWriter writer = new LlapRecordWriter(id, + RecordWriter writer = null; + if(useArrow) { + writer = new LlapArrowRecordWriter(new WritableByteChannelAdapter(ctx, maxPendingWrites, id)); + } else { + writer = new LlapRecordWriter(id, new ChunkedOutputStream( - new ChannelOutputStream(ctx, id, sendBufferSize, maxPendingWrites), - sendBufferSize, id)); + new ChannelOutputStream(ctx, id, sendBufferSize, maxPendingWrites), sendBufferSize, id)); + } boolean isFailed = true; synchronized (lock) { if (!writers.containsKey(id)) { diff --git ql/src/java/org/apache/hadoop/hive/llap/WritableByteChannelAdapter.java ql/src/java/org/apache/hadoop/hive/llap/WritableByteChannelAdapter.java new file mode 100644 index 0000000000..57da1d9f6d --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/llap/WritableByteChannelAdapter.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.concurrent.Semaphore; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; + +/** + * Provides an adapter between {@link java.nio.channels.WritableByteChannel} + * and {@link io.netty.channel.ChannelHandlerContext}. + * Additionally provides a form of flow-control by limiting the number of + * queued async writes. + */ +public class WritableByteChannelAdapter implements WritableByteChannel { + + private static final Logger LOG = LoggerFactory.getLogger(WritableByteChannelAdapter.class); + private ChannelHandlerContext chc; + private final int maxPendingWrites; + // This semaphore provides two functions: + // 1. Forces a cap on the number of outstanding async writes to channel + // 2. Ensures that channel isn't closed if there are any outstanding async writes + private final Semaphore writeResources; + private boolean closed = false; + private final String id; + + private ChannelFutureListener writeListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + //Asynch write completed + //Up the semaphore + writeResources.release(); + + if (future.isCancelled()) { + LOG.error("Write cancelled on ID " + id); + } else if (!future.isSuccess()) { + LOG.error("Write error on ID " + id, future.cause()); + } + } + }; + + private ChannelFutureListener closeListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isCancelled()) { + LOG.error("Close cancelled on ID " + id); + } else if (!future.isSuccess()) { + LOG.error("Close failed on ID " + id, future.cause()); + } + } + }; + + public WritableByteChannelAdapter(ChannelHandlerContext chc, int maxPendingWrites, String id) { + this.chc = chc; + this.maxPendingWrites = maxPendingWrites; + this.writeResources = new Semaphore(maxPendingWrites); + this.id = id; + } + + @Override + public int write(ByteBuffer src) throws IOException { + int size = src.remaining(); + //Down the semaphore or block until available + takeWriteResources(1); + chc.writeAndFlush(Unpooled.wrappedBuffer(src)).addListener(writeListener); + return size; + } + + @Override + public boolean isOpen() { + return chc.channel().isOpen(); + } + + @Override + public void close() throws IOException { + if (closed) { + throw new IOException("Already closed: " + id); + } + + closed = true; + //Block until all semaphore resources are released + //by outstanding async writes + takeWriteResources(maxPendingWrites); + + try { + chc.close().addListener(closeListener); + } finally { + chc = null; + closed = true; + } + } + + private void takeWriteResources(int numResources) throws IOException { + try { + writeResources.acquire(numResources); + } catch (InterruptedException ie) { + throw new IOException("Interrupted while waiting for write resources for " + id); + } + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/FileSinkOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/FileSinkOperator.java index 01a5b4c9c3..9c57eff2e8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/FileSinkOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/FileSinkOperator.java @@ -56,6 +56,7 @@ import org.apache.hadoop.hive.ql.io.RecordUpdater; import org.apache.hadoop.hive.ql.io.StatsProvidingRecordWriter; import org.apache.hadoop.hive.ql.io.StreamingOutputFormat; +import org.apache.hadoop.hive.ql.io.arrow.ArrowWrapperWritable; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.metadata.HiveFatalException; import org.apache.hadoop.hive.ql.plan.DynamicPartitionCtx; @@ -1251,16 +1252,25 @@ public void closeOp(boolean abort) throws HiveException { // If serializer is ThriftJDBCBinarySerDe, then it buffers rows to a certain limit (hive.server2.thrift.resultset.max.fetch.size) // and serializes the whole batch when the buffer is full. The serialize returns null if the buffer is not full // (the size of buffer is kept track of in the ThriftJDBCBinarySerDe). - if (conf.isUsingThriftJDBCBinarySerDe()) { - try { - recordValue = serializer.serialize(null, inputObjInspectors[0]); - if ( null != fpaths ) { - rowOutWriters = fpaths.outWriters; - rowOutWriters[0].write(recordValue); + if (conf.isUsingBatchingSerDe()) { + try { + recordValue = serializer.serialize(null, inputObjInspectors[0]); + if (null != fpaths) { + rowOutWriters = fpaths.outWriters; + rowOutWriters[0].write(recordValue); + } else if(recordValue instanceof ArrowWrapperWritable) { + //Because LLAP arrow output depends on the ThriftJDBCBinarySerDe code path + //this is required for 0 row outputs + //i.e. we need to write a 0 size batch to signal EOS to the consumer + for (FSPaths fsPaths : valToPaths.values()) { + for(RecordWriter writer : fsPaths.outWriters) { + writer.write(recordValue); + } } - } catch (SerDeException | IOException e) { - throw new HiveException(e); } + } catch (SerDeException | IOException e) { + throw new HiveException(e); + } } List commitPaths = new ArrayList<>(); for (FSPaths fsp : valToPaths.values()) { diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java new file mode 100644 index 0000000000..b093ebbd27 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowColumnarBatchSerDe.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import com.google.common.collect.Lists; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.writer.BaseWriter; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeStats; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TimestampLocalTZTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.io.Writable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.DataInput; +import java.io.DataOutput; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; + +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo; + +/** + * ArrowColumnarBatchSerDe converts Apache Hive rows to Apache Arrow columns. Its serialized + * class is {@link ArrowWrapperWritable}, which doesn't support {@link + * Writable#readFields(DataInput)} and {@link Writable#write(DataOutput)}. + * + * Followings are known issues of current implementation. + * + * A list column cannot have a decimal column. {@link UnionListWriter} doesn't have an + * implementation for {@link BaseWriter.ListWriter#decimal()}. + * + * A union column can have only one of string, char, varchar fields at a same time. Apache Arrow + * doesn't have string and char, so {@link ArrowColumnarBatchSerDe} uses varchar to simulate + * string and char. They will be considered as a same data type in + * {@link org.apache.arrow.vector.complex.UnionVector}. + * + * Timestamp with local timezone is not supported. {@link VectorAssignRow} doesn't support it. + */ +public class ArrowColumnarBatchSerDe extends AbstractSerDe { + public static final Logger LOG = LoggerFactory.getLogger(ArrowColumnarBatchSerDe.class.getName()); + private static final String DEFAULT_ARROW_FIELD_NAME = "[DEFAULT]"; + + static final int MS_PER_SECOND = 1_000; + static final int NS_PER_SECOND = 1_000_000_000; + static final int NS_PER_MS = 1_000_000; + static final int SECOND_PER_DAY = 24 * 60 * 60; + + BufferAllocator rootAllocator; + StructTypeInfo rowTypeInfo; + StructObjectInspector rowObjectInspector; + Configuration conf; + + private Serializer serializer; + private Deserializer deserializer; + + @Override + public void initialize(Configuration conf, Properties tbl) throws SerDeException { + this.conf = conf; + + rootAllocator = RootAllocatorFactory.INSTANCE.getRootAllocator(conf); + + final String columnNameProperty = tbl.getProperty(serdeConstants.LIST_COLUMNS); + final String columnTypeProperty = tbl.getProperty(serdeConstants.LIST_COLUMN_TYPES); + final String columnNameDelimiter = tbl.containsKey(serdeConstants.COLUMN_NAME_DELIMITER) ? tbl + .getProperty(serdeConstants.COLUMN_NAME_DELIMITER) : String.valueOf(SerDeUtils.COMMA); + + // Create an object inspector + final List columnNames; + if (columnNameProperty.length() == 0) { + columnNames = new ArrayList<>(); + } else { + columnNames = Arrays.asList(columnNameProperty.split(columnNameDelimiter)); + } + final List columnTypes; + if (columnTypeProperty.length() == 0) { + columnTypes = new ArrayList<>(); + } else { + columnTypes = TypeInfoUtils.getTypeInfosFromTypeString(columnTypeProperty); + } + rowTypeInfo = (StructTypeInfo) TypeInfoFactory.getStructTypeInfo(columnNames, columnTypes); + rowObjectInspector = + (StructObjectInspector) getStandardWritableObjectInspectorFromTypeInfo(rowTypeInfo); + + final List fields = new ArrayList<>(); + final int size = columnNames.size(); + for (int i = 0; i < size; i++) { + fields.add(toField(columnNames.get(i), columnTypes.get(i))); + } + + serializer = new Serializer(this); + deserializer = new Deserializer(this); + } + + private static Field toField(String name, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + final PrimitiveTypeInfo primitiveTypeInfo = (PrimitiveTypeInfo) typeInfo; + switch (primitiveTypeInfo.getPrimitiveCategory()) { + case BOOLEAN: + return Field.nullable(name, MinorType.BIT.getType()); + case BYTE: + return Field.nullable(name, MinorType.TINYINT.getType()); + case SHORT: + return Field.nullable(name, MinorType.SMALLINT.getType()); + case INT: + return Field.nullable(name, MinorType.INT.getType()); + case LONG: + return Field.nullable(name, MinorType.BIGINT.getType()); + case FLOAT: + return Field.nullable(name, MinorType.FLOAT4.getType()); + case DOUBLE: + return Field.nullable(name, MinorType.FLOAT8.getType()); + case STRING: + case VARCHAR: + case CHAR: + return Field.nullable(name, MinorType.VARCHAR.getType()); + case DATE: + return Field.nullable(name, MinorType.DATEDAY.getType()); + case TIMESTAMP: + return Field.nullable(name, MinorType.TIMESTAMPMILLI.getType()); + case TIMESTAMPLOCALTZ: + final TimestampLocalTZTypeInfo timestampLocalTZTypeInfo = + (TimestampLocalTZTypeInfo) typeInfo; + final String timeZone = timestampLocalTZTypeInfo.getTimeZone().toString(); + return Field.nullable(name, new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZone)); + case BINARY: + return Field.nullable(name, MinorType.VARBINARY.getType()); + case DECIMAL: + final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; + final int precision = decimalTypeInfo.precision(); + final int scale = decimalTypeInfo.scale(); + return Field.nullable(name, new ArrowType.Decimal(precision, scale)); + case INTERVAL_YEAR_MONTH: + return Field.nullable(name, MinorType.INTERVALYEAR.getType()); + case INTERVAL_DAY_TIME: + return Field.nullable(name, MinorType.INTERVALDAY.getType()); + default: + throw new IllegalArgumentException(); + } + case LIST: + final ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo; + final TypeInfo elementTypeInfo = listTypeInfo.getListElementTypeInfo(); + return new Field(name, FieldType.nullable(MinorType.LIST.getType()), + Lists.newArrayList(toField(DEFAULT_ARROW_FIELD_NAME, elementTypeInfo))); + case STRUCT: + final StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; + final List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + final List fieldNames = structTypeInfo.getAllStructFieldNames(); + final List structFields = Lists.newArrayList(); + final int structSize = fieldNames.size(); + for (int i = 0; i < structSize; i++) { + structFields.add(toField(fieldNames.get(i), fieldTypeInfos.get(i))); + } + return new Field(name, FieldType.nullable(MinorType.MAP.getType()), structFields); + case UNION: + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final List unionFields = Lists.newArrayList(); + final int unionSize = unionFields.size(); + for (int i = 0; i < unionSize; i++) { + unionFields.add(toField(DEFAULT_ARROW_FIELD_NAME, objectTypeInfos.get(i))); + } + return new Field(name, FieldType.nullable(MinorType.UNION.getType()), unionFields); + case MAP: + final MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; + final TypeInfo keyTypeInfo = mapTypeInfo.getMapKeyTypeInfo(); + final TypeInfo valueTypeInfo = mapTypeInfo.getMapValueTypeInfo(); + final StructTypeInfo mapStructTypeInfo = new StructTypeInfo(); + mapStructTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); + mapStructTypeInfo.setAllStructFieldTypeInfos( + Lists.newArrayList(keyTypeInfo, valueTypeInfo)); + final ListTypeInfo mapListStructTypeInfo = new ListTypeInfo(); + mapListStructTypeInfo.setListElementTypeInfo(mapStructTypeInfo); + + return toField(name, mapListStructTypeInfo); + default: + throw new IllegalArgumentException(); + } + } + + static ListTypeInfo toStructListTypeInfo(MapTypeInfo mapTypeInfo) { + final StructTypeInfo structTypeInfo = new StructTypeInfo(); + structTypeInfo.setAllStructFieldNames(Lists.newArrayList("keys", "values")); + structTypeInfo.setAllStructFieldTypeInfos(Lists.newArrayList( + mapTypeInfo.getMapKeyTypeInfo(), mapTypeInfo.getMapValueTypeInfo())); + final ListTypeInfo structListTypeInfo = new ListTypeInfo(); + structListTypeInfo.setListElementTypeInfo(structTypeInfo); + return structListTypeInfo; + } + + static ListColumnVector toStructListVector(MapColumnVector mapVector) { + final StructColumnVector structVector; + final ListColumnVector structListVector; + structVector = new StructColumnVector(); + structVector.fields = new ColumnVector[] {mapVector.keys, mapVector.values}; + structListVector = new ListColumnVector(); + structListVector.child = structVector; + structListVector.childCount = mapVector.childCount; + structListVector.isRepeating = mapVector.isRepeating; + structListVector.noNulls = mapVector.noNulls; + System.arraycopy(mapVector.offsets, 0, structListVector.offsets, 0, mapVector.childCount); + System.arraycopy(mapVector.lengths, 0, structListVector.lengths, 0, mapVector.childCount); + return structListVector; + } + + @Override + public Class getSerializedClass() { + return ArrowWrapperWritable.class; + } + + @Override + public ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { + return serializer.serialize(obj, objInspector); + } + + @Override + public SerDeStats getSerDeStats() { + return null; + } + + @Override + public Object deserialize(Writable writable) { + return deserializer.deserialize(writable); + } + + @Override + public ObjectInspector getObjectInspector() { + return rowObjectInspector; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java new file mode 100644 index 0000000000..32bcbbcca2 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/ArrowWrapperWritable.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.hadoop.io.WritableComparable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Allows a VectorSchemaRoot to be passed to/from a RecordWriter/RecordReader + */ +public class ArrowWrapperWritable implements WritableComparable { + private VectorSchemaRoot vectorSchemaRoot; + + public ArrowWrapperWritable(VectorSchemaRoot vectorSchemaRoot) { + this.vectorSchemaRoot = vectorSchemaRoot; + } + public ArrowWrapperWritable() {} + + public VectorSchemaRoot getVectorSchemaRoot() { + return vectorSchemaRoot; + } + + public void setVectorSchemaRoot(VectorSchemaRoot vectorSchemaRoot) { + this.vectorSchemaRoot = vectorSchemaRoot; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override public int compareTo(Object o) { + return 0; + } + + @Override public boolean equals(Object o) { + if (o instanceof ArrowWrapperWritable) { + ArrowWrapperWritable other = (ArrowWrapperWritable) o; + return other.vectorSchemaRoot.equals(vectorSchemaRoot); + } else { + return false; + } + } + + @Override + public int hashCode() { + return vectorSchemaRoot.hashCode(); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java new file mode 100644 index 0000000000..fb5800b140 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; +import org.apache.hadoop.io.Writable; + +import java.util.List; + +import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MS; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; + +class Deserializer { + private final ArrowColumnarBatchSerDe serDe; + private final VectorExtractRow vectorExtractRow; + private final VectorizedRowBatch vectorizedRowBatch; + private Object[][] rows; + + Deserializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { + this.serDe = serDe; + vectorExtractRow = new VectorExtractRow(); + final List fieldTypeInfoList = serDe.rowTypeInfo.getAllStructFieldTypeInfos(); + final int fieldCount = fieldTypeInfoList.size(); + final TypeInfo[] typeInfos = fieldTypeInfoList.toArray(new TypeInfo[fieldCount]); + try { + vectorExtractRow.init(typeInfos); + } catch (HiveException e) { + throw new SerDeException(e); + } + + vectorizedRowBatch = new VectorizedRowBatch(fieldCount); + for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) { + final ColumnVector columnVector = createColumnVector(typeInfos[fieldIndex]); + columnVector.init(); + vectorizedRowBatch.cols[fieldIndex] = columnVector; + } + } + + public Object deserialize(Writable writable) { + final ArrowWrapperWritable arrowWrapperWritable = (ArrowWrapperWritable) writable; + final VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); + final List fieldVectors = vectorSchemaRoot.getFieldVectors(); + final int fieldCount = fieldVectors.size(); + final int rowCount = vectorSchemaRoot.getRowCount(); + vectorizedRowBatch.ensureSize(rowCount); + + if (rows == null || rows.length < rowCount ) { + rows = new Object[rowCount][]; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + rows[rowIndex] = new Object[fieldCount]; + } + } + + for (int fieldIndex = 0; fieldIndex < fieldCount; fieldIndex++) { + final FieldVector fieldVector = fieldVectors.get(fieldIndex); + final int projectedCol = vectorizedRowBatch.projectedColumns[fieldIndex]; + final ColumnVector columnVector = vectorizedRowBatch.cols[projectedCol]; + final TypeInfo typeInfo = serDe.rowTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + read(fieldVector, columnVector, typeInfo); + } + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { + vectorExtractRow.extractRow(vectorizedRowBatch, rowIndex, rows[rowIndex]); + } + vectorizedRowBatch.reset(); + return rows; + } + + private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + readPrimitive(arrowVector, hiveVector, typeInfo); + break; + case LIST: + readList(arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo); + break; + case MAP: + readMap(arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo); + break; + case STRUCT: + readStruct(arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo); + break; + case UNION: + readUnion(arrowVector, (UnionColumnVector) hiveVector, (UnionTypeInfo) typeInfo); + break; + default: + throw new IllegalArgumentException(); + } + } + + private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + + final int size = arrowVector.getValueCount(); + hiveVector.ensureSize(size, false); + + switch (primitiveCategory) { + case BOOLEAN: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((BitVector) arrowVector).get(i); + } + } + } + break; + case BYTE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((TinyIntVector) arrowVector).get(i); + } + } + } + break; + case SHORT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((SmallIntVector) arrowVector).get(i); + } + } + } + break; + case INT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((IntVector) arrowVector).get(i); + } + } + } + break; + case LONG: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((BigIntVector) arrowVector).get(i); + } + } + } + break; + case FLOAT: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DoubleColumnVector) hiveVector).vector[i] = ((Float4Vector) arrowVector).get(i); + } + } + } + break; + case DOUBLE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DoubleColumnVector) hiveVector).vector[i] = ((Float8Vector) arrowVector).get(i); + } + } + } + break; + case STRING: + case VARCHAR: + case CHAR: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((BytesColumnVector) hiveVector).setVal(i, ((VarCharVector) arrowVector).get(i)); + } + } + } + break; + case DATE: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((DateDayVector) arrowVector).get(i); + } + } + } + break; + case TIMESTAMP: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + + // Time = second + sub-second + final long timeInNanos = ((TimeStampNanoVector) arrowVector).get(i); + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + int subSecondInNanos = (int) (timeInNanos % NS_PER_SECOND); + long second = timeInNanos / NS_PER_SECOND; + + // A nanosecond value should not be negative + if (subSecondInNanos < 0) { + + // So add one second to the negative nanosecond value to make it positive + subSecondInNanos += NS_PER_SECOND; + + // Subtract one second from the second value because we added one second, + // then subtract one more second because of the ceiling in the division. + second -= 2; + } + timestampColumnVector.time[i] = second * MS_PER_SECOND; + timestampColumnVector.nanos[i] = subSecondInNanos; + } + } + } + break; + case BINARY: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((BytesColumnVector) hiveVector).setVal(i, ((VarBinaryVector) arrowVector).get(i)); + } + } + } + break; + case DECIMAL: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((DecimalColumnVector) hiveVector).set(i, + HiveDecimal.create(((DecimalVector) arrowVector).getObject(i))); + } + } + } + break; + case INTERVAL_YEAR_MONTH: + { + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + ((LongColumnVector) hiveVector).vector[i] = ((IntervalYearVector) arrowVector).get(i); + } + } + } + break; + case INTERVAL_DAY_TIME: + { + final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; + final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); + final HiveIntervalDayTime intervalDayTime = new HiveIntervalDayTime(); + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + intervalDayVector.get(i, intervalDayHolder); + final long seconds = intervalDayHolder.days * SECOND_PER_DAY + + intervalDayHolder.milliseconds / MS_PER_SECOND; + final int nanos = (intervalDayHolder.milliseconds % 1_000) * NS_PER_MS; + intervalDayTime.set(seconds, nanos); + ((IntervalDayTimeColumnVector) hiveVector).set(i, intervalDayTime); + } + } + } + break; + case VOID: + case TIMESTAMPLOCALTZ: + case UNKNOWN: + default: + break; + } + } + + private void readList(FieldVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final ArrowBuf offsets = arrowVector.getOffsetBuffer(); + final int OFFSET_WIDTH = 4; + + read(arrowVector.getChildrenFromFields().get(0), + hiveVector.child, + typeInfo.getListElementTypeInfo()); + + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + final int offset = offsets.getInt(i * OFFSET_WIDTH); + hiveVector.offsets[i] = offset; + hiveVector.lengths[i] = offsets.getInt((i + 1) * OFFSET_WIDTH) - offset; + } + } + } + + private void readMap(FieldVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final ListTypeInfo mapStructListTypeInfo = toStructListTypeInfo(typeInfo); + final ListColumnVector mapStructListVector = toStructListVector(hiveVector); + final StructColumnVector mapStructVector = (StructColumnVector) mapStructListVector.child; + + read(arrowVector, mapStructListVector, mapStructListTypeInfo); + + hiveVector.isRepeating = mapStructListVector.isRepeating; + hiveVector.childCount = mapStructListVector.childCount; + hiveVector.noNulls = mapStructListVector.noNulls; + hiveVector.keys = mapStructVector.fields[0]; + hiveVector.values = mapStructVector.fields[1]; + System.arraycopy(mapStructListVector.offsets, 0, hiveVector.offsets, 0, size); + System.arraycopy(mapStructListVector.lengths, 0, hiveVector.lengths, 0, size); + System.arraycopy(mapStructListVector.isNull, 0, hiveVector.isNull, 0, size); + } + + private void readStruct(FieldVector arrowVector, StructColumnVector hiveVector, StructTypeInfo typeInfo) { + final int size = arrowVector.getValueCount(); + final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); + final int fieldSize = arrowVector.getChildrenFromFields().size(); + for (int i = 0; i < fieldSize; i++) { + read(arrowVector.getChildrenFromFields().get(i), hiveVector.fields[i], fieldTypeInfos.get(i)); + } + + for (int i = 0; i < size; i++) { + if (arrowVector.isNull(i)) { + VectorizedBatchUtil.setNullColIsNullValue(hiveVector, i); + } else { + hiveVector.isNull[i] = false; + } + } + } + + private void readUnion(FieldVector arrowVector, UnionColumnVector hiveVector, UnionTypeInfo typeInfo) { + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java new file mode 100644 index 0000000000..7aa732bd5c --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/RootAllocatorFactory.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.io.arrow; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_ROOT_ALLOCATOR_LIMIT; + +/** + * Thread-safe singleton factory for RootAllocator + */ +public enum RootAllocatorFactory { + INSTANCE; + + private RootAllocator rootAllocator; + + RootAllocatorFactory() { + } + + public synchronized RootAllocator getRootAllocator(Configuration conf) { + if (rootAllocator == null) { + final long limit = HiveConf.getLongVar(conf, HIVE_ARROW_ROOT_ALLOCATOR_LIMIT); + rootAllocator = new RootAllocator(limit); + } + return rootAllocator; + } + + //arrowAllocatorLimit is ignored if an allocator was previously created + public synchronized RootAllocator getOrCreateRootAllocator(long arrowAllocatorLimit) { + if (rootAllocator == null) { + rootAllocator = new RootAllocator(arrowAllocatorLimit); + } + return rootAllocator; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java new file mode 100644 index 0000000000..bd23011c93 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Serializer.java @@ -0,0 +1,537 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.NullableMapVector; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.IntervalDayTimeColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorAssignRow; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.UnionTypeInfo; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.HIVE_ARROW_BATCH_SIZE; +import static org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil.createColumnVector; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.MS_PER_SECOND; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.NS_PER_MS; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.SECOND_PER_DAY; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListTypeInfo; +import static org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe.toStructListVector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE; +import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils.getTypeInfoFromObjectInspector; + +class Serializer { + private final int MAX_BUFFERED_ROWS; + + // Schema + private final StructTypeInfo structTypeInfo; + private final int fieldSize; + + // Hive columns + private final VectorizedRowBatch vectorizedRowBatch; + private final VectorAssignRow vectorAssignRow; + private int batchSize; + + private final NullableMapVector rootVector; + + Serializer(ArrowColumnarBatchSerDe serDe) throws SerDeException { + MAX_BUFFERED_ROWS = HiveConf.getIntVar(serDe.conf, HIVE_ARROW_BATCH_SIZE); + ArrowColumnarBatchSerDe.LOG.info("ArrowColumnarBatchSerDe max number of buffered columns: " + MAX_BUFFERED_ROWS); + + // Schema + structTypeInfo = (StructTypeInfo) getTypeInfoFromObjectInspector(serDe.rowObjectInspector); + List fieldTypeInfos = structTypeInfo.getAllStructFieldTypeInfos(); + fieldSize = fieldTypeInfos.size(); + + // Init Arrow stuffs + rootVector = NullableMapVector.empty(null, serDe.rootAllocator); + + // Init Hive stuffs + vectorizedRowBatch = new VectorizedRowBatch(fieldSize); + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final ColumnVector columnVector = createColumnVector(fieldTypeInfos.get(fieldIndex)); + vectorizedRowBatch.cols[fieldIndex] = columnVector; + columnVector.init(); + } + vectorizedRowBatch.ensureSize(MAX_BUFFERED_ROWS); + vectorAssignRow = new VectorAssignRow(); + try { + vectorAssignRow.init(serDe.rowObjectInspector); + } catch (HiveException e) { + throw new SerDeException(e); + } + } + + private ArrowWrapperWritable serializeBatch() { + rootVector.setValueCount(0); + + for (int fieldIndex = 0; fieldIndex < vectorizedRowBatch.projectionSize; fieldIndex++) { + final int projectedColumn = vectorizedRowBatch.projectedColumns[fieldIndex]; + final ColumnVector hiveVector = vectorizedRowBatch.cols[projectedColumn]; + final TypeInfo fieldTypeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + final String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + final FieldType fieldType = toFieldType(fieldTypeInfo); + final FieldVector arrowVector = rootVector.addOrGet(fieldName, fieldType, FieldVector.class); + arrowVector.setInitialCapacity(batchSize); + arrowVector.allocateNew(); + write(arrowVector, hiveVector, fieldTypeInfo, batchSize); + } + vectorizedRowBatch.reset(); + rootVector.setValueCount(batchSize); + + batchSize = 0; + VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(rootVector); + return new ArrowWrapperWritable(vectorSchemaRoot); + } + + private FieldType toFieldType(TypeInfo typeInfo) { + return new FieldType(true, toArrowType(typeInfo), null); + } + + private ArrowType toArrowType(TypeInfo typeInfo) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { + case BOOLEAN: + return Types.MinorType.BIT.getType(); + case BYTE: + return Types.MinorType.TINYINT.getType(); + case SHORT: + return Types.MinorType.SMALLINT.getType(); + case INT: + return Types.MinorType.INT.getType(); + case LONG: + return Types.MinorType.BIGINT.getType(); + case FLOAT: + return Types.MinorType.FLOAT4.getType(); + case DOUBLE: + return Types.MinorType.FLOAT8.getType(); + case STRING: + case VARCHAR: + case CHAR: + return Types.MinorType.VARCHAR.getType(); + case DATE: + return Types.MinorType.DATEDAY.getType(); + case TIMESTAMP: + return Types.MinorType.TIMESTAMPNANO.getType(); + case BINARY: + return Types.MinorType.VARBINARY.getType(); + case DECIMAL: + final DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo; + return new ArrowType.Decimal(decimalTypeInfo.precision(), decimalTypeInfo.scale()); + case INTERVAL_YEAR_MONTH: + return Types.MinorType.INTERVALYEAR.getType(); + case INTERVAL_DAY_TIME: + return Types.MinorType.INTERVALDAY.getType(); + case VOID: + case TIMESTAMPLOCALTZ: + case UNKNOWN: + default: + throw new IllegalArgumentException(); + } + case LIST: + return ArrowType.List.INSTANCE; + case STRUCT: + return ArrowType.Struct.INSTANCE; + case MAP: + return ArrowType.List.INSTANCE; + case UNION: + default: + throw new IllegalArgumentException(); + } + } + + private void write(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, int size) { + switch (typeInfo.getCategory()) { + case PRIMITIVE: + writePrimitive(arrowVector, hiveVector, typeInfo, size); + break; + case LIST: + writeList((ListVector) arrowVector, (ListColumnVector) hiveVector, (ListTypeInfo) typeInfo, size); + break; + case STRUCT: + writeStruct((MapVector) arrowVector, (StructColumnVector) hiveVector, (StructTypeInfo) typeInfo, size); + break; + case UNION: + writeUnion(arrowVector, hiveVector, typeInfo, size); + break; + case MAP: + writeMap((ListVector) arrowVector, (MapColumnVector) hiveVector, (MapTypeInfo) typeInfo, size); + break; + default: + throw new IllegalArgumentException(); + } + } + + private void writeMap(ListVector arrowVector, MapColumnVector hiveVector, MapTypeInfo typeInfo, + int size) { + final ListTypeInfo structListTypeInfo = toStructListTypeInfo(typeInfo); + final ListColumnVector structListVector = toStructListVector(hiveVector); + + write(arrowVector, structListVector, structListTypeInfo, size); + + final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + BitVectorHelper.setValidityBit(validityBuffer, rowIndex, 0); + } else { + BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); + } + } + } + + private void writeUnion(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size) { + final UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; + final List objectTypeInfos = unionTypeInfo.getAllUnionObjectTypeInfos(); + final UnionColumnVector hiveUnionVector = (UnionColumnVector) hiveVector; + final ColumnVector[] hiveObjectVectors = hiveUnionVector.fields; + + final int tag = hiveUnionVector.tags[0]; + final ColumnVector hiveObjectVector = hiveObjectVectors[tag]; + final TypeInfo objectTypeInfo = objectTypeInfos.get(tag); + + write(arrowVector, hiveObjectVector, objectTypeInfo, size); + } + + private void writeStruct(MapVector arrowVector, StructColumnVector hiveVector, + StructTypeInfo typeInfo, int size) { + final List fieldNames = typeInfo.getAllStructFieldNames(); + final List fieldTypeInfos = typeInfo.getAllStructFieldTypeInfos(); + final ColumnVector[] hiveFieldVectors = hiveVector.fields; + final int fieldSize = fieldTypeInfos.size(); + + for (int fieldIndex = 0; fieldIndex < fieldSize; fieldIndex++) { + final TypeInfo fieldTypeInfo = fieldTypeInfos.get(fieldIndex); + final ColumnVector hiveFieldVector = hiveFieldVectors[fieldIndex]; + final String fieldName = fieldNames.get(fieldIndex); + final FieldVector arrowFieldVector = + arrowVector.addOrGet(fieldName, + toFieldType(fieldTypeInfos.get(fieldIndex)), FieldVector.class); + arrowFieldVector.setInitialCapacity(size); + arrowFieldVector.allocateNew(); + write(arrowFieldVector, hiveFieldVector, fieldTypeInfo, size); + } + + final ArrowBuf validityBuffer = arrowVector.getValidityBuffer(); + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + BitVectorHelper.setValidityBit(validityBuffer, rowIndex, 0); + } else { + BitVectorHelper.setValidityBitToOne(validityBuffer, rowIndex); + } + } + } + + private void writeList(ListVector arrowVector, ListColumnVector hiveVector, ListTypeInfo typeInfo, + int size) { + final int OFFSET_WIDTH = 4; + final TypeInfo elementTypeInfo = typeInfo.getListElementTypeInfo(); + final ColumnVector hiveElementVector = hiveVector.child; + final FieldVector arrowElementVector = + (FieldVector) arrowVector.addOrGetVector(toFieldType(elementTypeInfo)).getVector(); + arrowElementVector.setInitialCapacity(hiveVector.childCount); + arrowElementVector.allocateNew(); + + write(arrowElementVector, hiveElementVector, elementTypeInfo, hiveVector.childCount); + + final ArrowBuf offsetBuffer = arrowVector.getOffsetBuffer(); + int nextOffset = 0; + + for (int rowIndex = 0; rowIndex < size; rowIndex++) { + if (hiveVector.isNull[rowIndex]) { + offsetBuffer.setInt(rowIndex * OFFSET_WIDTH, nextOffset); + } else { + offsetBuffer.setInt(rowIndex * OFFSET_WIDTH, nextOffset); + nextOffset += (int) hiveVector.lengths[rowIndex]; + arrowVector.setNotNull(rowIndex); + } + } + offsetBuffer.setInt(size * OFFSET_WIDTH, nextOffset); + } + + private void writePrimitive(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo, + int size) { + final PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = + ((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory(); + switch (primitiveCategory) { + case BOOLEAN: + { + final BitVector bitVector = (BitVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + bitVector.setNull(i); + } else { + bitVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case BYTE: + { + final TinyIntVector tinyIntVector = (TinyIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + tinyIntVector.setNull(i); + } else { + tinyIntVector.set(i, (byte) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case SHORT: + { + final SmallIntVector smallIntVector = (SmallIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + smallIntVector.setNull(i); + } else { + smallIntVector.set(i, (short) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case INT: + { + final IntVector intVector = (IntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intVector.setNull(i); + } else { + intVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case LONG: + { + final BigIntVector bigIntVector = (BigIntVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + bigIntVector.setNull(i); + } else { + bigIntVector.set(i, ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case FLOAT: + { + final Float4Vector float4Vector = (Float4Vector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + float4Vector.setNull(i); + } else { + float4Vector.set(i, (float) ((DoubleColumnVector) hiveVector).vector[i]); + } + } + } + break; + case DOUBLE: + { + final Float8Vector float8Vector = (Float8Vector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + float8Vector.setNull(i); + } else { + float8Vector.set(i, ((DoubleColumnVector) hiveVector).vector[i]); + } + } + } + break; + case STRING: + case VARCHAR: + case CHAR: + { + final VarCharVector varCharVector = (VarCharVector) arrowVector; + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varCharVector.setNull(i); + } else { + varCharVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + } + } + } + break; + case DATE: + { + final DateDayVector dateDayVector = (DateDayVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + dateDayVector.setNull(i); + } else { + dateDayVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case TIMESTAMP: + { + final TimeStampNanoVector timeStampNanoVector = (TimeStampNanoVector) arrowVector; + final TimestampColumnVector timestampColumnVector = (TimestampColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + timeStampNanoVector.setNull(i); + } else { + // Time = second + sub-second + final long secondInMillis = timestampColumnVector.getTime(i); + final long secondInNanos = (secondInMillis - secondInMillis % 1000) * NS_PER_MS; // second + final long subSecondInNanos = timestampColumnVector.getNanos(i); // sub-second + + if ((secondInMillis > 0 && secondInNanos < 0) || (secondInMillis < 0 && secondInNanos > 0)) { + // If the timestamp cannot be represented in long nanosecond, set it as a null value + timeStampNanoVector.setNull(i); + } else { + timeStampNanoVector.set(i, secondInNanos + subSecondInNanos); + } + } + } + } + break; + case BINARY: + { + final VarBinaryVector varBinaryVector = (VarBinaryVector) arrowVector; + final BytesColumnVector bytesVector = (BytesColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + varBinaryVector.setNull(i); + } else { + varBinaryVector.setSafe(i, bytesVector.vector[i], bytesVector.start[i], bytesVector.length[i]); + } + } + } + break; + case DECIMAL: + { + final DecimalVector decimalVector = (DecimalVector) arrowVector; + final int scale = decimalVector.getScale(); + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + decimalVector.setNull(i); + } else { + decimalVector.set(i, + ((DecimalColumnVector) hiveVector).vector[i].getHiveDecimal().bigDecimalValue().setScale(scale)); + } + } + } + break; + case INTERVAL_YEAR_MONTH: + { + final IntervalYearVector intervalYearVector = (IntervalYearVector) arrowVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intervalYearVector.setNull(i); + } else { + intervalYearVector.set(i, (int) ((LongColumnVector) hiveVector).vector[i]); + } + } + } + break; + case INTERVAL_DAY_TIME: + { + final IntervalDayVector intervalDayVector = (IntervalDayVector) arrowVector; + final IntervalDayTimeColumnVector intervalDayTimeColumnVector = + (IntervalDayTimeColumnVector) hiveVector; + for (int i = 0; i < size; i++) { + if (hiveVector.isNull[i]) { + intervalDayVector.setNull(i); + } else { + final long totalSeconds = intervalDayTimeColumnVector.getTotalSeconds(i); + final long days = totalSeconds / SECOND_PER_DAY; + final long millis = + (totalSeconds - days * SECOND_PER_DAY) * MS_PER_SECOND + + intervalDayTimeColumnVector.getNanos(i) / NS_PER_MS; + intervalDayVector.set(i, (int) days, (int) millis); + } + } + } + break; + case VOID: + case UNKNOWN: + case TIMESTAMPLOCALTZ: + default: + throw new IllegalArgumentException(); + } + } + + ArrowWrapperWritable serialize(Object obj, ObjectInspector objInspector) { + // if row is null, it means there are no more rows (closeOp()). + // another case can be that the buffer is full. + if (obj == null) { + return serializeBatch(); + } + List standardObjects = new ArrayList(); + ObjectInspectorUtils.copyToStandardObject(standardObjects, obj, + ((StructObjectInspector) objInspector), WRITABLE); + + vectorAssignRow.assignRow(vectorizedRowBatch, batchSize, standardObjects, fieldSize); + batchSize++; + if (batchSize == MAX_BUFFERED_ROWS) { + return serializeBatch(); + } + return null; + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java index 263de3a5b1..e23e4033ff 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/SemanticAnalyzer.java @@ -110,6 +110,7 @@ import org.apache.hadoop.hive.ql.io.AcidOutputFormat; import org.apache.hadoop.hive.ql.io.AcidUtils; import org.apache.hadoop.hive.ql.io.AcidUtils.Operation; +import org.apache.hadoop.hive.ql.io.arrow.ArrowColumnarBatchSerDe; import org.apache.hadoop.hive.ql.io.CombineHiveInputFormat; import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; import org.apache.hadoop.hive.ql.io.HiveOutputFormat; @@ -7498,7 +7499,12 @@ protected Operator genFileSinkPlan(String dest, QB qb, Operator input) fileFormat = HiveConf.getVar(conf, HiveConf.ConfVars.HIVEQUERYRESULTFILEFORMAT); Class serdeClass = LazySimpleSerDe.class; if (fileFormat.equals(PlanUtils.LLAP_OUTPUT_FORMAT_KEY)) { - serdeClass = LazyBinarySerDe2.class; + boolean useArrow = HiveConf.getBoolVar(conf, HiveConf.ConfVars.LLAP_OUTPUT_FORMAT_ARROW); + if(useArrow) { + serdeClass = ArrowColumnarBatchSerDe.class; + } else { + serdeClass = LazyBinarySerDe2.class; + } } table_desc = PlanUtils.getDefaultQueryOutputTableDesc(cols, colTypes, fileFormat, @@ -7579,13 +7585,10 @@ protected Operator genFileSinkPlan(String dest, QB qb, Operator input) ltd.setInsertOverwrite(true); } } - if (SessionState.get().isHiveServerQuery() && - null != table_desc && - table_desc.getSerdeClassName().equalsIgnoreCase(ThriftJDBCBinarySerDe.class.getName()) && - HiveConf.getBoolVar(conf,HiveConf.ConfVars.HIVE_SERVER2_THRIFT_RESULTSET_SERIALIZE_IN_TASKS)) { - fileSinkDesc.setIsUsingThriftJDBCBinarySerDe(true); + if (null != table_desc && useBatchingSerializer(table_desc.getSerdeClassName())) { + fileSinkDesc.setIsUsingBatchingSerDe(true); } else { - fileSinkDesc.setIsUsingThriftJDBCBinarySerDe(false); + fileSinkDesc.setIsUsingBatchingSerDe(false); } Operator output = putOpInsertMap(OperatorFactory.getAndMakeChild( @@ -7620,6 +7623,17 @@ protected Operator genFileSinkPlan(String dest, QB qb, Operator input) return output; } + private boolean useBatchingSerializer(String serdeClassName) { + return SessionState.get().isHiveServerQuery() && + hasSetBatchSerializer(serdeClassName); + } + + private boolean hasSetBatchSerializer(String serdeClassName) { + return (serdeClassName.equalsIgnoreCase(ThriftJDBCBinarySerDe.class.getName()) && + HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_SERVER2_THRIFT_RESULTSET_SERIALIZE_IN_TASKS)) || + serdeClassName.equalsIgnoreCase(ArrowColumnarBatchSerDe.class.getName()); + } + private ColsAndTypes deriveFileSinkColTypes( RowResolver inputRR, List field_schemas) throws SemanticException { ColsAndTypes result = new ColsAndTypes("", ""); diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/FileSinkDesc.java ql/src/java/org/apache/hadoop/hive/ql/plan/FileSinkDesc.java index fcb6de7d08..1d054688ee 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/FileSinkDesc.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/FileSinkDesc.java @@ -103,9 +103,9 @@ /** * Whether is a HiveServer query, and the destination table is - * indeed written using ThriftJDBCBinarySerDe + * indeed written using a row batching SerDe */ - private boolean isUsingThriftJDBCBinarySerDe = false; + private boolean isUsingBatchingSerDe = false; private boolean isInsertOverwrite = false; @@ -183,12 +183,12 @@ public void setHiveServerQuery(boolean isHiveServerQuery) { this.isHiveServerQuery = isHiveServerQuery; } - public boolean isUsingThriftJDBCBinarySerDe() { - return this.isUsingThriftJDBCBinarySerDe; + public boolean isUsingBatchingSerDe() { + return this.isUsingBatchingSerDe; } - public void setIsUsingThriftJDBCBinarySerDe(boolean isUsingThriftJDBCBinarySerDe) { - this.isUsingThriftJDBCBinarySerDe = isUsingThriftJDBCBinarySerDe; + public void setIsUsingBatchingSerDe(boolean isUsingBatchingSerDe) { + this.isUsingBatchingSerDe = isUsingBatchingSerDe; } @Explain(displayName = "directory", explainLevels = { Level.EXTENDED }) diff --git ql/src/test/org/apache/hadoop/hive/llap/TestLlapOutputFormat.java ql/src/test/org/apache/hadoop/hive/llap/TestLlapOutputFormat.java index 13a3070ef6..f27cdf4969 100644 --- ql/src/test/org/apache/hadoop/hive/llap/TestLlapOutputFormat.java +++ ql/src/test/org/apache/hadoop/hive/llap/TestLlapOutputFormat.java @@ -54,6 +54,7 @@ public static void setUp() throws Exception { Configuration conf = new Configuration(); // Pick random avail port HiveConf.setIntVar(conf, HiveConf.ConfVars.LLAP_DAEMON_OUTPUT_SERVICE_PORT, 0); + HiveConf.setBoolVar(conf, HiveConf.ConfVars.LLAP_OUTPUT_FORMAT_ARROW, false); LlapOutputFormatService.initializeAndStart(conf, null); service = LlapOutputFormatService.get(); LlapProxy.setDaemon(true); diff --git ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java new file mode 100644 index 0000000000..74f6624597 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/io/arrow/TestArrowColumnarBatchSerDe.java @@ -0,0 +1,777 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.io.arrow; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveIntervalDayTime; +import org.apache.hadoop.hive.common.type.HiveIntervalYearMonth; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.serde.serdeConstants; +import org.apache.hadoop.hive.serde2.AbstractSerDe; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.io.ByteWritable; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveCharWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalDayTimeWritable; +import org.apache.hadoop.hive.serde2.io.HiveIntervalYearMonthWritable; +import org.apache.hadoop.hive.serde2.io.HiveVarcharWritable; +import org.apache.hadoop.hive.serde2.io.ShortWritable; +import org.apache.hadoop.hive.serde2.io.TimestampWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.junit.Before; +import org.junit.Test; + +import java.sql.Timestamp; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class TestArrowColumnarBatchSerDe { + private Configuration conf; + + private final static Object[][] INTEGER_ROWS = { + {byteW(0), shortW(0), intW(0), longW(0)}, + {byteW(1), shortW(1), intW(1), longW(1)}, + {byteW(-1), shortW(-1), intW(-1), longW(-1)}, + {byteW(Byte.MIN_VALUE), shortW(Short.MIN_VALUE), intW(Integer.MIN_VALUE), + longW(Long.MIN_VALUE)}, + {byteW(Byte.MAX_VALUE), shortW(Short.MAX_VALUE), intW(Integer.MAX_VALUE), + longW(Long.MAX_VALUE)}, + {null, null, null, null}, + }; + + private final static Object[][] FLOAT_ROWS = { + {floatW(0f), doubleW(0d)}, + {floatW(1f), doubleW(1d)}, + {floatW(-1f), doubleW(-1d)}, + {floatW(Float.MIN_VALUE), doubleW(Double.MIN_VALUE)}, + {floatW(-Float.MIN_VALUE), doubleW(-Double.MIN_VALUE)}, + {floatW(Float.MAX_VALUE), doubleW(Double.MAX_VALUE)}, + {floatW(-Float.MAX_VALUE), doubleW(-Double.MAX_VALUE)}, + {floatW(Float.POSITIVE_INFINITY), doubleW(Double.POSITIVE_INFINITY)}, + {floatW(Float.NEGATIVE_INFINITY), doubleW(Double.NEGATIVE_INFINITY)}, + {null, null}, + }; + + private final static Object[][] STRING_ROWS = { + {text(""), charW("", 10), varcharW("", 10)}, + {text("Hello"), charW("Hello", 10), varcharW("Hello", 10)}, + {text("world!"), charW("world!", 10), varcharW("world!", 10)}, + {null, null, null}, + }; + + private final static long TIME_IN_MS = TimeUnit.DAYS.toMillis(365 + 31 + 3); + private final static long NEGATIVE_TIME_IN_MS = TimeUnit.DAYS.toMillis(-9 * 365 + 31 + 3); + private final static Timestamp TIMESTAMP; + private final static Timestamp NEGATIVE_TIMESTAMP_WITHOUT_NANOS; + private final static Timestamp NEGATIVE_TIMESTAMP_WITH_NANOS; + + static { + TIMESTAMP = new Timestamp(TIME_IN_MS); + TIMESTAMP.setNanos(123456789); + NEGATIVE_TIMESTAMP_WITHOUT_NANOS = new Timestamp(NEGATIVE_TIME_IN_MS); + NEGATIVE_TIMESTAMP_WITH_NANOS = new Timestamp(NEGATIVE_TIME_IN_MS); + NEGATIVE_TIMESTAMP_WITH_NANOS.setNanos(123456789); + } + + private final static Object[][] DTI_ROWS = { + { + new DateWritable(DateWritable.millisToDays(TIME_IN_MS)), + new TimestampWritable(TIMESTAMP), + new HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(1, 2)), + new HiveIntervalDayTimeWritable(new HiveIntervalDayTime(1, 2, 3, 4, 5_000_000)) + }, + { + new DateWritable(DateWritable.millisToDays(NEGATIVE_TIME_IN_MS)), + new TimestampWritable(NEGATIVE_TIMESTAMP_WITHOUT_NANOS), + null, + null + }, + { + null, + new TimestampWritable(NEGATIVE_TIMESTAMP_WITH_NANOS), + null, + null + }, + {null, null, null, null}, + }; + + private final static Object[][] DECIMAL_ROWS = { + {decimalW(HiveDecimal.ZERO)}, + {decimalW(HiveDecimal.ONE)}, + {decimalW(HiveDecimal.ONE.negate())}, + {decimalW(HiveDecimal.create("0.000001"))}, + {decimalW(HiveDecimal.create("100000"))}, + {null}, + }; + + private final static Object[][] BOOLEAN_ROWS = { + {new BooleanWritable(true)}, + {new BooleanWritable(false)}, + {null}, + }; + + private final static Object[][] BINARY_ROWS = { + {new BytesWritable("".getBytes())}, + {new BytesWritable("Hello".getBytes())}, + {new BytesWritable("world!".getBytes())}, + {null}, + }; + + @Before + public void setUp() { + conf = new Configuration(); + } + + private static ByteWritable byteW(int value) { + return new ByteWritable((byte) value); + } + + private static ShortWritable shortW(int value) { + return new ShortWritable((short) value); + } + + private static IntWritable intW(int value) { + return new IntWritable(value); + } + + private static LongWritable longW(long value) { + return new LongWritable(value); + } + + private static FloatWritable floatW(float value) { + return new FloatWritable(value); + } + + private static DoubleWritable doubleW(double value) { + return new DoubleWritable(value); + } + + private static Text text(String value) { + return new Text(value); + } + + private static HiveCharWritable charW(String value, int length) { + return new HiveCharWritable(new HiveChar(value, length)); + } + + private static HiveVarcharWritable varcharW(String value, int length) { + return new HiveVarcharWritable(new HiveVarchar(value, length)); + } + + private static HiveDecimalWritable decimalW(HiveDecimal value) { + return new HiveDecimalWritable(value); + } + + private void initAndSerializeAndDeserialize(String[][] schema, Object[][] rows) throws SerDeException { + ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); + StructObjectInspector rowOI = initSerDe(serDe, schema); + serializeAndDeserialize(serDe, rows, rowOI); + } + + private StructObjectInspector initSerDe(AbstractSerDe serDe, String[][] schema) + throws SerDeException { + List fieldNameList = newArrayList(); + List fieldTypeList = newArrayList(); + List typeInfoList = newArrayList(); + + for (String[] nameAndType : schema) { + String name = nameAndType[0]; + String type = nameAndType[1]; + fieldNameList.add(name); + fieldTypeList.add(type); + typeInfoList.add(TypeInfoUtils.getTypeInfoFromTypeString(type)); + } + + String fieldNames = Joiner.on(',').join(fieldNameList); + String fieldTypes = Joiner.on(',').join(fieldTypeList); + + Properties schemaProperties = new Properties(); + schemaProperties.setProperty(serdeConstants.LIST_COLUMNS, fieldNames); + schemaProperties.setProperty(serdeConstants.LIST_COLUMN_TYPES, fieldTypes); + SerDeUtils.initializeSerDe(serDe, conf, schemaProperties, null); + return (StructObjectInspector) TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo( + TypeInfoFactory.getStructTypeInfo(fieldNameList, typeInfoList)); + } + + private void serializeAndDeserialize(ArrowColumnarBatchSerDe serDe, Object[][] rows, + StructObjectInspector rowOI) { + ArrowWrapperWritable serialized = null; + for (Object[] row : rows) { + serialized = serDe.serialize(row, rowOI); + } + // Pass null to complete a batch + if (serialized == null) { + serialized = serDe.serialize(null, rowOI); + } + String s = serialized.getVectorSchemaRoot().contentToTSVString(); + final Object[][] deserializedRows = (Object[][]) serDe.deserialize(serialized); + + for (int rowIndex = 0; rowIndex < Math.min(deserializedRows.length, rows.length); rowIndex++) { + final Object[] row = rows[rowIndex]; + final Object[] deserializedRow = deserializedRows[rowIndex]; + assertEquals(row.length, deserializedRow.length); + + final List fields = rowOI.getAllStructFieldRefs(); + for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { + final StructField field = fields.get(fieldIndex); + final ObjectInspector fieldObjInspector = field.getFieldObjectInspector(); + switch (fieldObjInspector.getCategory()) { + case PRIMITIVE: + final PrimitiveObjectInspector primitiveObjInspector = + (PrimitiveObjectInspector) fieldObjInspector; + switch (primitiveObjInspector.getPrimitiveCategory()) { + case STRING: + case VARCHAR: + case CHAR: + assertEquals(Objects.toString(row[fieldIndex]), + Objects.toString(deserializedRow[fieldIndex])); + break; + default: + assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); + break; + } + break; + case STRUCT: + final Object[] rowStruct = (Object[]) row[fieldIndex]; + final List deserializedRowStruct = (List) deserializedRow[fieldIndex]; + if (rowStruct == null) { + assertNull(deserializedRowStruct); + } else { + assertArrayEquals(rowStruct, deserializedRowStruct.toArray()); + } + break; + case LIST: + case UNION: + assertEquals(row[fieldIndex], deserializedRow[fieldIndex]); + break; + case MAP: + final Map rowMap = (Map) row[fieldIndex]; + final Map deserializedRowMap = (Map) deserializedRow[fieldIndex]; + if (rowMap == null) { + assertNull(deserializedRowMap); + } else { + final Set rowMapKeySet = rowMap.keySet(); + final Set deserializedRowMapKeySet = deserializedRowMap.keySet(); + assertEquals(rowMapKeySet, deserializedRowMapKeySet); + for (Object key : rowMapKeySet) { + assertEquals(rowMap.get(key), deserializedRowMap.get(key)); + } + } + break; + } + } + } + } + + @Test + public void testComprehensive() throws SerDeException { + String[][] schema = { + {"datatypes.c1", "int"}, + {"datatypes.c2", "boolean"}, + {"datatypes.c3", "double"}, + {"datatypes.c4", "string"}, + {"datatypes.c5", "array"}, + {"datatypes.c6", "map"}, + {"datatypes.c7", "map"}, + {"datatypes.c8", "struct"}, + {"datatypes.c9", "tinyint"}, + {"datatypes.c10", "smallint"}, + {"datatypes.c11", "float"}, + {"datatypes.c12", "bigint"}, + {"datatypes.c13", "array>"}, + {"datatypes.c14", "map>"}, + {"datatypes.c15", "struct>"}, + {"datatypes.c16", "array,n:int>>"}, + {"datatypes.c17", "timestamp"}, + {"datatypes.c18", "decimal(16,7)"}, + {"datatypes.c19", "binary"}, + {"datatypes.c20", "date"}, + {"datatypes.c21", "varchar(20)"}, + {"datatypes.c22", "char(15)"}, + {"datatypes.c23", "binary"}, + }; + + Object[][] comprehensiveRows = { + { + intW(0), // c1:int + new BooleanWritable(false), // c2:boolean + doubleW(0), // c3:double + text("Hello"), // c4:string + newArrayList(intW(0), intW(1), intW(2)), // c5:array + Maps.toMap( + newArrayList(intW(0), intW(1), intW(2)), + input -> text("Number " + input)), // c6:map + Maps.toMap( + newArrayList(text("apple"), text("banana"), text("carrot")), + input -> text(input.toString().toUpperCase())), // c7:map + new Object[] {text("0"), intW(1), doubleW(2)}, // c8:struct + byteW(0), // c9:tinyint + shortW(0), // c10:smallint + floatW(0), // c11:float + longW(0), // c12:bigint + newArrayList( + newArrayList(text("a"), text("b"), text("c")), + newArrayList(text("A"), text("B"), text("C"))), // c13:array> + Maps.toMap( + newArrayList(intW(0), intW(1), intW(2)), + x -> Maps.toMap( + newArrayList(x, intW(x.get() * 2)), + y -> y)), // c14:map> + new Object[] { + intW(0), + newArrayList( + intW(1), + text("Hello"))}, // c15:struct> + Collections.singletonList( + newArrayList( + Maps.toMap( + newArrayList(text("hello")), + input -> text(input.toString().toUpperCase())), + intW(0))), // c16:array,n:int>> + new TimestampWritable(TIMESTAMP), // c17:timestamp + decimalW(HiveDecimal.create(0, 0)), // c18:decimal(16,7) + new BytesWritable("Hello".getBytes()), // c19:binary + new DateWritable(123), // c20:date + varcharW("x", 20), // c21:varchar(20) + charW("y", 15), // c22:char(15) + new BytesWritable("world!".getBytes()), // c23:binary + }, { + null, null, null, null, null, null, null, null, null, null, // c1-c10 + null, null, null, null, null, null, null, null, null, null, // c11-c20 + null, null, null, // c21-c23 + } + }; + + initAndSerializeAndDeserialize(schema, comprehensiveRows); + } + + private List newArrayList(E ... elements) { + return Lists.newArrayList(elements); + } + + @Test + public void testPrimitiveInteger() throws SerDeException { + String[][] schema = { + {"tinyint1", "tinyint"}, + {"smallint1", "smallint"}, + {"int1", "int"}, + {"bigint1", "bigint"} + }; + + initAndSerializeAndDeserialize(schema, INTEGER_ROWS); + } + + @Test + public void testPrimitiveBigInt10000() throws SerDeException { + String[][] schema = { + {"bigint1", "bigint"} + }; + + final int batchSize = 1000; + final Object[][] integerRows = new Object[batchSize][]; + final ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); + StructObjectInspector rowOI = initSerDe(serDe, schema); + + for (int j = 0; j < 10; j++) { + for (int i = 0; i < batchSize; i++) { + integerRows[i] = new Object[] {longW(i + j * batchSize)}; + } + + serializeAndDeserialize(serDe, integerRows, rowOI); + } + } + + @Test + public void testPrimitiveBigIntRandom() { + try { + String[][] schema = { + {"bigint1", "bigint"} + }; + + final ArrowColumnarBatchSerDe serDe = new ArrowColumnarBatchSerDe(); + StructObjectInspector rowOI = initSerDe(serDe, schema); + + final Random random = new Random(); + for (int j = 0; j < 1000; j++) { + final int batchSize = random.nextInt(1000); + final Object[][] integerRows = new Object[batchSize][]; + for (int i = 0; i < batchSize; i++) { + integerRows[i] = new Object[] {longW(random.nextLong())}; + } + + serializeAndDeserialize(serDe, integerRows, rowOI); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testPrimitiveFloat() throws SerDeException { + String[][] schema = { + {"float1", "float"}, + {"double1", "double"}, + }; + + initAndSerializeAndDeserialize(schema, FLOAT_ROWS); + } + + @Test(expected = AssertionError.class) + public void testPrimitiveFloatNaN() throws SerDeException { + String[][] schema = { + {"float1", "float"}, + }; + + Object[][] rows = {{new FloatWritable(Float.NaN)}}; + + initAndSerializeAndDeserialize(schema, rows); + } + + @Test(expected = AssertionError.class) + public void testPrimitiveDoubleNaN() throws SerDeException { + String[][] schema = { + {"double1", "double"}, + }; + + Object[][] rows = {{new DoubleWritable(Double.NaN)}}; + + initAndSerializeAndDeserialize(schema, rows); + } + + @Test + public void testPrimitiveString() throws SerDeException { + String[][] schema = { + {"string1", "string"}, + {"char1", "char(10)"}, + {"varchar1", "varchar(10)"}, + }; + + initAndSerializeAndDeserialize(schema, STRING_ROWS); + } + + @Test + public void testPrimitiveDTI() throws SerDeException { + String[][] schema = { + {"date1", "date"}, + {"timestamp1", "timestamp"}, + {"interval_year_month1", "interval_year_month"}, + {"interval_day_time1", "interval_day_time"}, + }; + + initAndSerializeAndDeserialize(schema, DTI_ROWS); + } + + @Test + public void testPrimitiveDecimal() throws SerDeException { + String[][] schema = { + {"decimal1", "decimal(38,10)"}, + }; + + initAndSerializeAndDeserialize(schema, DECIMAL_ROWS); + } + + @Test + public void testPrimitiveBoolean() throws SerDeException { + String[][] schema = { + {"boolean1", "boolean"}, + }; + + initAndSerializeAndDeserialize(schema, BOOLEAN_ROWS); + } + + @Test + public void testPrimitiveBinary() throws SerDeException { + String[][] schema = { + {"binary1", "binary"}, + }; + + initAndSerializeAndDeserialize(schema, BINARY_ROWS); + } + + private List[][] toList(Object[][] rows) { + List[][] array = new List[rows.length][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + array[rowIndex] = new List[row.length]; + for (int fieldIndex = 0; fieldIndex < row.length; fieldIndex++) { + array[rowIndex][fieldIndex] = newArrayList(row[fieldIndex]); + } + } + return array; + } + + @Test + public void testListInteger() throws SerDeException { + String[][] schema = { + {"tinyint_list", "array"}, + {"smallint_list", "array"}, + {"int_list", "array"}, + {"bigint_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(INTEGER_ROWS)); + } + + @Test + public void testListFloat() throws SerDeException { + String[][] schema = { + {"float_list", "array"}, + {"double_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(FLOAT_ROWS)); + } + + @Test + public void testListString() throws SerDeException { + String[][] schema = { + {"string_list", "array"}, + {"char_list", "array"}, + {"varchar_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(STRING_ROWS)); + } + + @Test + public void testListDTI() throws SerDeException { + String[][] schema = { + {"date_list", "array"}, + {"timestamp_list", "array"}, + {"interval_year_month_list", "array"}, + {"interval_day_time_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(DTI_ROWS)); + } + + @Test + public void testListBoolean() throws SerDeException { + String[][] schema = { + {"boolean_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(BOOLEAN_ROWS)); + } + + @Test + public void testListBinary() throws SerDeException { + String[][] schema = { + {"binary_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(BINARY_ROWS)); + } + + private Object[][][] toStruct(Object[][] rows) { + Object[][][] struct = new Object[rows.length][][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + struct[rowIndex] = new Object[][] {row}; + } + return struct; + } + + @Test + public void testStructInteger() throws SerDeException { + String[][] schema = { + {"int_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(INTEGER_ROWS)); + } + + @Test + public void testStructFloat() throws SerDeException { + String[][] schema = { + {"float_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(FLOAT_ROWS)); + } + + @Test + public void testStructString() throws SerDeException { + String[][] schema = { + {"string_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(STRING_ROWS)); + } + + @Test + public void testStructDTI() throws SerDeException { + String[][] schema = { + {"date_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(DTI_ROWS)); + } + + @Test + public void testStructDecimal() throws SerDeException { + String[][] schema = { + {"decimal_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(DECIMAL_ROWS)); + } + + @Test + public void testStructBoolean() throws SerDeException { + String[][] schema = { + {"boolean_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(BOOLEAN_ROWS)); + } + + @Test + public void testStructBinary() throws SerDeException { + String[][] schema = { + {"binary_struct", "struct"}, + }; + + initAndSerializeAndDeserialize(schema, toStruct(BINARY_ROWS)); + } + + private Object[][] toMap(Object[][] rows) { + Map[][] array = new Map[rows.length][]; + for (int rowIndex = 0; rowIndex < rows.length; rowIndex++) { + Object[] row = rows[rowIndex]; + array[rowIndex] = new Map[row.length]; + for (int fieldIndex = 0; fieldIndex < row.length; fieldIndex++) { + Map map = Maps.newHashMap(); + map.put(new Text(String.valueOf(row[fieldIndex])), row[fieldIndex]); + array[rowIndex][fieldIndex] = map; + } + } + return array; + } + + @Test + public void testMapInteger() throws SerDeException { + String[][] schema = { + {"tinyint_map", "map"}, + {"smallint_map", "map"}, + {"int_map", "map"}, + {"bigint_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(INTEGER_ROWS)); + } + + @Test + public void testMapFloat() throws SerDeException { + String[][] schema = { + {"float_map", "map"}, + {"double_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(FLOAT_ROWS)); + } + + @Test + public void testMapString() throws SerDeException { + String[][] schema = { + {"string_map", "map"}, + {"char_map", "map"}, + {"varchar_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(STRING_ROWS)); + } + + @Test + public void testMapDTI() throws SerDeException { + String[][] schema = { + {"date_map", "map"}, + {"timestamp_map", "map"}, + {"interval_year_month_map", "map"}, + {"interval_day_time_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(DTI_ROWS)); + } + + @Test + public void testMapBoolean() throws SerDeException { + String[][] schema = { + {"boolean_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(BOOLEAN_ROWS)); + } + + @Test + public void testMapBinary() throws SerDeException { + String[][] schema = { + {"binary_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(BINARY_ROWS)); + } + + public void testMapDecimal() throws SerDeException { + String[][] schema = { + {"decimal_map", "map"}, + }; + + initAndSerializeAndDeserialize(schema, toMap(DECIMAL_ROWS)); + } + + public void testListDecimal() throws SerDeException { + String[][] schema = { + {"decimal_list", "array"}, + }; + + initAndSerializeAndDeserialize(schema, toList(DECIMAL_ROWS)); + } + +} diff --git serde/pom.xml serde/pom.xml index e005585e4b..eca34af32d 100644 --- serde/pom.xml +++ serde/pom.xml @@ -71,6 +71,11 @@ ${arrow.version} + org.apache.arrow + arrow-vector + ${arrow.version} + + com.carrotsearch hppc ${hppc.version}