diff --git itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java index 3c0532c9d0..4c46db9e3b 100644 --- itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java +++ itests/hive-unit/src/test/java/org/apache/hive/jdbc/BaseJdbcWithMiniLlap.java @@ -223,6 +223,116 @@ public void testLlapInputFormatEndToEnd() throws Exception { assertEquals(0, rowCount); } + @Test(timeout = 300000) + public void testMultipleBatchesOfComplexTypes() throws Exception { + final String tableName = "testMultipleBatchesOfComplexTypes"; + try (Statement stmt = hs2Conn.createStatement()) { + String createQuery = + "create table " + tableName + "(c1 array>, " + + "c2 int, " + + "c3 array>, " + + "c4 array>>) STORED AS ORC"; + + // create table + stmt.execute("DROP TABLE IF EXISTS " + tableName); + stmt.execute(createQuery); + // load data + stmt.execute("INSERT INTO " + tableName + " VALUES " + // value 1 + + "(ARRAY(NAMED_STRUCT('f1','a1', 'f2','a2'), NAMED_STRUCT('f1','a3', 'f2','a4')), " + + "1, ARRAY(ARRAY(1)), ARRAY(NAMED_STRUCT('f1',ARRAY('aa1')))), " + // value 2 + + "(ARRAY(NAMED_STRUCT('f1','b1', 'f2','b2'), NAMED_STRUCT('f1','b3', 'f2','b4')), 2, " + + "ARRAY(ARRAY(2,2), ARRAY(2,2)), " + + "ARRAY(NAMED_STRUCT('f1',ARRAY('aa2','aa2')), NAMED_STRUCT('f1',ARRAY('aa2','aa2')))), " + // value 3 + + "(ARRAY(NAMED_STRUCT('f1','c1', 'f2','c2'), NAMED_STRUCT('f1','c3', 'f2','c4'), " + + "NAMED_STRUCT('f1','c5', 'f2','c6')), 3, " + "ARRAY(ARRAY(3,3,3), ARRAY(3,3,3), ARRAY(3,3,3)), " + + "ARRAY(NAMED_STRUCT('f1',ARRAY('aa3','aa3','aa3')), " + + "NAMED_STRUCT('f1',ARRAY('aa3','aa3', 'aa3')), NAMED_STRUCT('f1',ARRAY('aa3','aa3', 'aa3')))), " + // value 4 + + "(ARRAY(NAMED_STRUCT('f1','d1', 'f2','d2'), NAMED_STRUCT('f1','d3', 'f2','d4')," + + " NAMED_STRUCT('f1','d5', 'f2','d6'), NAMED_STRUCT('f1','d7', 'f2','d8')), 4, " + + "ARRAY(ARRAY(4,4,4,4),ARRAY(4,4,4,4),ARRAY(4,4,4,4),ARRAY(4,4,4,4)), " + + "ARRAY(NAMED_STRUCT('f1',ARRAY('aa4','aa4','aa4', 'aa4')), " + + "NAMED_STRUCT('f1',ARRAY('aa4','aa4','aa4', 'aa4')), NAMED_STRUCT('f1',ARRAY('aa4','aa4','aa4', 'aa4'))," + + " NAMED_STRUCT('f1',ARRAY('aa4','aa4','aa4', 'aa4'))))"); + + // generate 4096 rows from above records + for (int i = 0; i < 10; i++) { + stmt.execute(String.format("insert into %s select * from %s", tableName, tableName)); + } + // validate test table + ResultSet res = stmt.executeQuery("SELECT count(*) FROM " + tableName); + assertTrue(res.next()); + assertEquals(4096, res.getInt(1)); + res.close(); + } + + RowCollector rowCollector = new RowCollector(); + String query = "select * from " + tableName; + int rowCount = processQuery(query, 1, rowCollector); + assertEquals(4096, rowCount); + + /* + * + * validate different rows + * [[[a1, a2], [a3, a4]], 1, [[1]], [[[aa1]]]] + * [[[b1, b2], [b3, b4]], 2, [[2, 2], [2, 2]], [[[aa2, aa2]], [[aa2, aa2]]]] + * [[[c1, c2], [c3, c4], [c5, c6]], 3, [[3, 3, 3], [3, 3, 3], [3, 3, 3]], [[[aa3, aa3, aa3]], [[aa3, aa3, aa3]], [[aa3, aa3, aa3]]]] + * [[[d1, d2], [d3, d4], [d5, d6], [d7, d8]], 4, [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]]]] + * + */ + rowCollector.rows.clear(); + query = "select * from " + tableName + " where c2=1 limit 1"; + rowCount = processQuery(query, 1, rowCollector); + assertEquals(1, rowCount); + final String[] expected1 = + { "[[a1, a2], [a3, a4]]", + "1", + "[[1]]", + "[[[aa1]]]" + }; + assertArrayEquals(expected1, rowCollector.rows.get(0)); + + rowCollector.rows.clear(); + query = "select * from " + tableName + " where c2=2 limit 1"; + rowCount = processQuery(query, 1, rowCollector); + assertEquals(1, rowCount); + final String[] expected2 = + { "[[b1, b2], [b3, b4]]", + "2", + "[[2, 2], [2, 2]]", + "[[[aa2, aa2]], [[aa2, aa2]]]" + }; + assertArrayEquals(expected2, rowCollector.rows.get(0)); + + rowCollector.rows.clear(); + query = "select * from " + tableName + " where c2=3 limit 1"; + rowCount = processQuery(query, 1, rowCollector); + assertEquals(1, rowCount); + final String[] expected3 = + { "[[c1, c2], [c3, c4], [c5, c6]]", + "3", + "[[3, 3, 3], [3, 3, 3], [3, 3, 3]]", + "[[[aa3, aa3, aa3]], [[aa3, aa3, aa3]], [[aa3, aa3, aa3]]]" + }; + assertArrayEquals(expected3, rowCollector.rows.get(0)); + + rowCollector.rows.clear(); + query = "select * from " + tableName + " where c2=4 limit 1"; + rowCount = processQuery(query, 1, rowCollector); + assertEquals(1, rowCount); + final String[] expected4 = + { "[[d1, d2], [d3, d4], [d5, d6], [d7, d8]]", + "4", + "[[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]", + "[[[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]], [[aa4, aa4, aa4, aa4]]]" + }; + assertArrayEquals(expected4, rowCollector.rows.get(0)); + + } + @Test(timeout = 300000) public void testLlapInputFormatEndToEndWithMultipleBatches() throws Exception { String tableName = "over10k_table"; diff --git ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java index edc4b39922..ac4d237de6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/arrow/Deserializer.java @@ -104,7 +104,7 @@ public Object deserialize(Writable writable) { final VectorSchemaRoot vectorSchemaRoot = arrowWrapperWritable.getVectorSchemaRoot(); final List fieldVectors = vectorSchemaRoot.getFieldVectors(); final int fieldCount = fieldVectors.size(); - final int rowCount = vectorSchemaRoot.getRowCount(); + final int rowCount = vectorSchemaRoot.getFieldVectors().get(0).getValueCount(); vectorizedRowBatch.ensureSize(rowCount); if (rows == null || rows.length < rowCount ) { @@ -129,6 +129,10 @@ public Object deserialize(Writable writable) { } private void read(FieldVector arrowVector, ColumnVector hiveVector, TypeInfo typeInfo) { + // make sure that hiveVector is as big as arrowVector + final int size = arrowVector.getValueCount(); + hiveVector.ensureSize(size, false); + switch (typeInfo.getCategory()) { case PRIMITIVE: readPrimitive(arrowVector, hiveVector); @@ -154,7 +158,6 @@ private void readPrimitive(FieldVector arrowVector, ColumnVector hiveVector) { final Types.MinorType minorType = arrowVector.getMinorType(); final int size = arrowVector.getValueCount(); - hiveVector.ensureSize(size, false); switch (minorType) { case BIT: