diff --git a/jdbc/src/java/org/apache/hive/jdbc/HiveBaseResultSet.java b/jdbc/src/java/org/apache/hive/jdbc/HiveBaseResultSet.java index a69ea9524a..45de932740 100644 --- a/jdbc/src/java/org/apache/hive/jdbc/HiveBaseResultSet.java +++ b/jdbc/src/java/org/apache/hive/jdbc/HiveBaseResultSet.java @@ -42,6 +42,7 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.Calendar; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -65,6 +66,7 @@ protected List normalizedColumnNames; protected List columnTypes; protected List columnAttributes; + private final Map columnNameIndexCache = new HashMap<>(); private TableSchema schema; @@ -95,19 +97,26 @@ public void deleteRow() throws SQLException { @Override public int findColumn(final String columnName) throws SQLException { - int columnIndex = 0; - if (columnName != null) { - final String lcColumnName = columnName.toLowerCase(); + if (columnName == null) { + throw new SQLException("null column name not supported"); + } + final String lcColumnName = columnName.toLowerCase(); + final Integer result = this.columnNameIndexCache.computeIfAbsent(lcColumnName, cn -> { + int columnIndex = 0; for (final String normalizedColumnName : normalizedColumnNames) { ++columnIndex; final int idx = normalizedColumnName.lastIndexOf('.'); final String name = (idx == -1) ? normalizedColumnName : normalizedColumnName.substring(1 + idx); - if (name.equals(lcColumnName) || normalizedColumnName.equals(lcColumnName)) { + if (name.equals(cn) || normalizedColumnName.equals(cn)) { return columnIndex; } } + return null; + }); + if (result == null) { + throw new SQLException("Could not find " + columnName + " in " + normalizedColumnNames); } - throw new SQLException("Could not find " + columnName + " in " + normalizedColumnNames); + return result.intValue(); } @Override diff --git a/jdbc/src/test/org/apache/hive/jdbc/TestHiveBaseResultSet.java b/jdbc/src/test/org/apache/hive/jdbc/TestHiveBaseResultSet.java index 9d423179ec..bca26f336f 100644 --- a/jdbc/src/test/org/apache/hive/jdbc/TestHiveBaseResultSet.java +++ b/jdbc/src/test/org/apache/hive/jdbc/TestHiveBaseResultSet.java @@ -20,9 +20,11 @@ import static org.mockito.Mockito.when; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import org.apache.hadoop.hive.metastore.api.FieldSchema; @@ -30,6 +32,7 @@ import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; +import org.mockito.internal.util.reflection.FieldSetter; /** * Test suite for {@link HiveBaseResultSet} class. @@ -237,4 +240,94 @@ public void testGetBooleanStringTrue() throws SQLException { Assert.assertFalse(resultSet.wasNull()); } + @Test + public void testFindColumnUnqualified() throws Exception { + FieldSchema fieldSchema1 = new FieldSchema(); + fieldSchema1.setType("int"); + + FieldSchema fieldSchema2 = new FieldSchema(); + fieldSchema2.setType("int"); + + FieldSchema fieldSchema3 = new FieldSchema(); + fieldSchema3.setType("int"); + + List fieldSchemas = Arrays.asList(fieldSchema1, fieldSchema2, fieldSchema3); + TableSchema schema = new TableSchema(fieldSchemas); + + HiveBaseResultSet resultSet = Mockito.mock(HiveBaseResultSet.class); + resultSet.row = new Object[] { new Integer(1), new Integer(2), new Integer(3) }; + resultSet.normalizedColumnNames = Arrays.asList("one", "two", "three"); + + Field executorField = HiveBaseResultSet.class.getDeclaredField("columnNameIndexCache"); + FieldSetter.setField(resultSet, executorField, new HashMap<>()); + + when(resultSet.getSchema()).thenReturn(schema); + when(resultSet.findColumn("one")).thenCallRealMethod(); + when(resultSet.findColumn("Two")).thenCallRealMethod(); + when(resultSet.findColumn("THREE")).thenCallRealMethod(); + + Assert.assertEquals(1, resultSet.findColumn("one")); + Assert.assertEquals(2, resultSet.findColumn("Two")); + Assert.assertEquals(3, resultSet.findColumn("THREE")); + } + + @Test + public void testFindColumnQualified() throws Exception { + FieldSchema fieldSchema1 = new FieldSchema(); + fieldSchema1.setType("int"); + + FieldSchema fieldSchema2 = new FieldSchema(); + fieldSchema2.setType("int"); + + FieldSchema fieldSchema3 = new FieldSchema(); + fieldSchema3.setType("int"); + + List fieldSchemas = Arrays.asList(fieldSchema1, fieldSchema2, fieldSchema3); + TableSchema schema = new TableSchema(fieldSchemas); + + HiveBaseResultSet resultSet = Mockito.mock(HiveBaseResultSet.class); + resultSet.row = new Object[] { new Integer(1), new Integer(2), new Integer(3) }; + resultSet.normalizedColumnNames = Arrays.asList("table.one", "table.two", "table.three"); + + Field executorField = HiveBaseResultSet.class.getDeclaredField("columnNameIndexCache"); + FieldSetter.setField(resultSet, executorField, new HashMap<>()); + + when(resultSet.getSchema()).thenReturn(schema); + when(resultSet.findColumn("one")).thenCallRealMethod(); + when(resultSet.findColumn("Two")).thenCallRealMethod(); + when(resultSet.findColumn("THREE")).thenCallRealMethod(); + + Assert.assertEquals(1, resultSet.findColumn("one")); + Assert.assertEquals(2, resultSet.findColumn("Two")); + Assert.assertEquals(3, resultSet.findColumn("THREE")); + } + + @Test(expected = SQLException.class) + public void testFindColumnNull() throws Exception { + HiveBaseResultSet resultSet = Mockito.mock(HiveBaseResultSet.class); + when(resultSet.findColumn(null)).thenCallRealMethod(); + Assert.assertEquals(0, resultSet.findColumn(null)); + } + + @Test(expected = SQLException.class) + public void testFindColumnUnknownColumn() throws Exception { + FieldSchema fieldSchema1 = new FieldSchema(); + fieldSchema1.setType("int"); + + List fieldSchemas = Arrays.asList(fieldSchema1); + TableSchema schema = new TableSchema(fieldSchemas); + + HiveBaseResultSet resultSet = Mockito.mock(HiveBaseResultSet.class); + resultSet.row = new Object[] { new Integer(1) }; + resultSet.normalizedColumnNames = Arrays.asList("table.one"); + + Field executorField = HiveBaseResultSet.class.getDeclaredField("columnNameIndexCache"); + FieldSetter.setField(resultSet, executorField, new HashMap<>()); + + when(resultSet.getSchema()).thenReturn(schema); + when(resultSet.findColumn("zero")).thenCallRealMethod(); + + Assert.assertEquals(1, resultSet.findColumn("zero")); + } + }