diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java index 4968d16876..8ab29ee5f0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java @@ -30,7 +30,6 @@ import org.apache.calcite.rel.RelNode; import org.apache.commons.lang.StringUtils; -import org.apache.commons.lang3.math.NumberUtils; import org.apache.hadoop.hive.common.type.Date; import org.apache.hadoop.hive.common.type.HiveChar; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -105,6 +104,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; @@ -1258,37 +1258,36 @@ protected ExprNodeDesc getXpathOrFuncExprNodeDesc(ASTNode expr, return desc; } - private ExprNodeDesc interpretNodeAs(PrimitiveTypeInfo colTypeInfo, ExprNodeDesc constChild) { + @VisibleForTesting + protected ExprNodeDesc interpretNodeAs(PrimitiveTypeInfo colTypeInfo, ExprNodeDesc constChild) { if (constChild instanceof ExprNodeConstantDesc) { // Try to narrow type of constant Object constVal = ((ExprNodeConstantDesc) constChild).getValue(); - String constType = constChild.getTypeString().toLowerCase(); if (constVal instanceof Number || constVal instanceof String) { try { PrimitiveTypeEntry primitiveTypeEntry = colTypeInfo.getPrimitiveTypeEntry(); if (PrimitiveObjectInspectorUtils.intTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Integer(constVal.toString())); + return new ExprNodeConstantDesc(new BigDecimal(constVal.toString()).intValueExact()); } else if (PrimitiveObjectInspectorUtils.longTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Long(constVal.toString())); + return new ExprNodeConstantDesc(new BigDecimal(constVal.toString()).longValueExact()); } else if (PrimitiveObjectInspectorUtils.doubleTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Double(constVal.toString())); + return new ExprNodeConstantDesc(Double.valueOf(constVal.toString())); } else if (PrimitiveObjectInspectorUtils.floatTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Float(constVal.toString())); + return new ExprNodeConstantDesc(Float.valueOf(constVal.toString())); } else if (PrimitiveObjectInspectorUtils.byteTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Byte(constVal.toString())); + return new ExprNodeConstantDesc(new BigDecimal(constVal.toString()).byteValueExact()); } else if (PrimitiveObjectInspectorUtils.shortTypeEntry.equals(primitiveTypeEntry)) { - return new ExprNodeConstantDesc(new Short(constVal.toString())); + return new ExprNodeConstantDesc(new BigDecimal(constVal.toString()).shortValueExact()); } else if (PrimitiveObjectInspectorUtils.decimalTypeEntry.equals(primitiveTypeEntry)) { return NumExprProcessor.createDecimal(constVal.toString(), false); } - } catch (NumberFormatException nfe) { + } catch (NumberFormatException | ArithmeticException nfe) { LOG.trace("Failed to narrow type of constant", nfe); - if (!NumberUtils.isNumber(constVal.toString())) { - return null; - } + return null; } } + String constType = constChild.getTypeString().toLowerCase(); // if column type is char and constant type is string, then convert the constant to char // type with padded spaces. if (constType.equalsIgnoreCase(serdeConstants.STRING_TYPE_NAME) && colTypeInfo instanceof CharTypeInfo) { diff --git ql/src/test/org/apache/hadoop/hive/ql/parse/TestTypeCheckProcFactory.java ql/src/test/org/apache/hadoop/hive/ql/parse/TestTypeCheckProcFactory.java new file mode 100644 index 0000000000..f1ff616078 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/parse/TestTypeCheckProcFactory.java @@ -0,0 +1,144 @@ +package org.apache.hadoop.hive.ql.parse; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.when; + +import org.apache.hadoop.hive.ql.parse.TypeCheckProcFactory.DefaultExprProcessor; +import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveTypeEntry; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + + +public class TestTypeCheckProcFactory { + private static final String NON_ZERO_FRACTION = "100.1"; + @Mock + private PrimitiveTypeInfo typeInfo; + @Mock + private ExprNodeConstantDesc nodeDesc; + + private DefaultExprProcessor testSubject; + + @Before + public void init() { + MockitoAnnotations.initMocks(this); + testSubject = new DefaultExprProcessor(); + } + + public void testOneCase(Object constValue, PrimitiveTypeEntry constType, Object expectedValue, String expectedType) { + when(nodeDesc.getValue()).thenReturn(constValue); + when(typeInfo.getPrimitiveTypeEntry()).thenReturn(constType); + + ExprNodeConstantDesc result = (ExprNodeConstantDesc) testSubject.interpretNodeAs(typeInfo, nodeDesc); + + assertNotNull(result); + assertEquals(expectedType, result.getTypeString()); + assertEquals(expectedValue, result.getValue()); + } + + public void testNullCase(Object constValue, PrimitiveTypeEntry constType) { + when(nodeDesc.getValue()).thenReturn(constValue); + when(typeInfo.getPrimitiveTypeEntry()).thenReturn(constType); + + ExprNodeConstantDesc result = (ExprNodeConstantDesc) testSubject.interpretNodeAs(typeInfo, nodeDesc); + + assertNull(result); + } + + @Test + public void testLongHappyPath() { + testOneCase("9223372036854775807", PrimitiveObjectInspectorUtils.longTypeEntry, 9223372036854775807L, "bigint"); + } + + @Test + public void testLongWithZeroFraction() { + testOneCase("100.0", PrimitiveObjectInspectorUtils.longTypeEntry, 100L, "bigint"); + } + + @Test + public void testLongOverflow() { + testNullCase("9223372036854775808", PrimitiveObjectInspectorUtils.longTypeEntry); + } + + @Test + public void testLongWithNonZeroFraction() { + testNullCase(NON_ZERO_FRACTION, PrimitiveObjectInspectorUtils.longTypeEntry); + } + + @Test + public void testIntHappyPath() { + testOneCase("2147483647", PrimitiveObjectInspectorUtils.intTypeEntry, 2147483647, "int"); + } + + @Test + public void testIntWithZeroFraction() { + testOneCase("123.0", PrimitiveObjectInspectorUtils.intTypeEntry, 123, "int"); + } + + @Test + public void testIntOverflow() { + testNullCase("2147483648", PrimitiveObjectInspectorUtils.intTypeEntry); + } + + @Test + public void testIntWithNonZeroFraction() { + testNullCase(NON_ZERO_FRACTION, PrimitiveObjectInspectorUtils.intTypeEntry); + } + + @Test + public void testShortHappyPath() { + testOneCase("32767", PrimitiveObjectInspectorUtils.shortTypeEntry, (short) 32767, "smallint"); + } + + @Test + public void testShortWithZeroFraction() { + testOneCase("32767.0", PrimitiveObjectInspectorUtils.shortTypeEntry, (short) 32767, "smallint"); + } + + @Test + public void testShortOverflow() { + testNullCase("32768", PrimitiveObjectInspectorUtils.shortTypeEntry); + } + + @Test + public void testShortWithNonZeroFraction() { + testNullCase(NON_ZERO_FRACTION, PrimitiveObjectInspectorUtils.shortTypeEntry); + } + + @Test + public void testByteHappyPath() { + testOneCase("100", PrimitiveObjectInspectorUtils.byteTypeEntry, (byte) 100, "tinyint"); + } + + @Test + public void testByteWithZeroFraction() { + testOneCase("100.0", PrimitiveObjectInspectorUtils.byteTypeEntry, (byte) 100, "tinyint"); + } + + @Test + public void testByteOverflow() { + testNullCase("128", PrimitiveObjectInspectorUtils.byteTypeEntry); + } + + @Test + public void testByteWithNonZeroFraction() { + testNullCase(NON_ZERO_FRACTION, PrimitiveObjectInspectorUtils.byteTypeEntry); + } + + @Test + public void testFloatHappyPath() { + testOneCase("111.1", PrimitiveObjectInspectorUtils.floatTypeEntry, 111.1f, "float"); + } + + @Test + public void testDoubleHappyPath() { + testOneCase("222.2", PrimitiveObjectInspectorUtils.doubleTypeEntry, 222.2, "double"); + } + +} diff --git ql/src/test/results/clientpositive/infer_const_type.q.out ql/src/test/results/clientpositive/infer_const_type.q.out index 4129bd0c71..c7c5a104ba 100644 --- ql/src/test/results/clientpositive/infer_const_type.q.out +++ ql/src/test/results/clientpositive/infer_const_type.q.out @@ -104,7 +104,6 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@infertypes #### A masked pattern was here #### 127 32767 12345 -12345 906.0 -307.0 1234 -WARNING: Comparing a bigint and a string may result in a loss of precision. PREHOOK: query: EXPLAIN SELECT * FROM infertypes WHERE ti = '128' OR si = 32768 OR @@ -131,10 +130,9 @@ STAGE PLANS: Map Operator Tree: TableScan alias: infertypes - filterExpr: ((UDFToDouble(ti) = 128.0D) or (UDFToInteger(si) = 32768) or (UDFToDouble(i) = 2.147483648E9D) or (UDFToDouble(bi) = 9.223372036854776E18D)) (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Filter Operator - predicate: ((UDFToDouble(bi) = 9.223372036854776E18D) or (UDFToDouble(i) = 2.147483648E9D) or (UDFToDouble(ti) = 128.0D) or (UDFToInteger(si) = 32768)) (type: boolean) + predicate: false (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: ti (type: tinyint), si (type: smallint), i (type: int), bi (type: bigint), fl (type: float), db (type: double), str (type: string) @@ -155,7 +153,6 @@ STAGE PLANS: Processor Tree: ListSink -WARNING: Comparing a bigint and a string may result in a loss of precision. PREHOOK: query: SELECT * FROM infertypes WHERE ti = '128' OR si = 32768 OR @@ -196,10 +193,10 @@ STAGE PLANS: Map Operator Tree: TableScan alias: infertypes - filterExpr: ((UDFToDouble(ti) = 127.0D) or (CAST( si AS decimal(5,0)) = 327) or (UDFToDouble(i) = -100.0D)) (type: boolean) + filterExpr: ((ti = 127Y) or (CAST( si AS decimal(5,0)) = 327) or (i = -100)) (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Filter Operator - predicate: ((CAST( si AS decimal(5,0)) = 327) or (UDFToDouble(i) = -100.0D) or (UDFToDouble(ti) = 127.0D)) (type: boolean) + predicate: ((CAST( si AS decimal(5,0)) = 327) or (i = -100) or (ti = 127Y)) (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: ti (type: tinyint), si (type: smallint), i (type: int), bi (type: bigint), fl (type: float), db (type: double), str (type: string) @@ -255,10 +252,10 @@ STAGE PLANS: Map Operator Tree: TableScan alias: infertypes - filterExpr: ((UDFToDouble(ti) < 127.0D) and (UDFToDouble(i) > 100.0D) and (UDFToDouble(str) = 1.57D)) (type: boolean) + filterExpr: ((ti < 127Y) and (i > 100) and (UDFToDouble(str) = 1.57D)) (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Filter Operator - predicate: ((UDFToDouble(i) > 100.0D) and (UDFToDouble(str) = 1.57D) and (UDFToDouble(ti) < 127.0D)) (type: boolean) + predicate: ((UDFToDouble(str) = 1.57D) and (i > 100) and (ti < 127Y)) (type: boolean) Statistics: Num rows: 1 Data size: 1170 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: ti (type: tinyint), si (type: smallint), i (type: int), bi (type: bigint), fl (type: float), db (type: double), str (type: string) diff --git ql/src/test/results/clientpositive/parquet_vectorization_0.q.out ql/src/test/results/clientpositive/parquet_vectorization_0.q.out index 4156c5d921..a3ff8a9eee 100644 --- ql/src/test/results/clientpositive/parquet_vectorization_0.q.out +++ ql/src/test/results/clientpositive/parquet_vectorization_0.q.out @@ -1531,7 +1531,7 @@ STAGE PLANS: Map Operator Tree: TableScan alias: alltypesparquet - filterExpr: ((cstring2 like '%b%') or (CAST( cint AS decimal(13,3)) <> 79.553) or (UDFToDouble(cbigint) < cdouble) or ((UDFToShort(ctinyint) >= csmallint) and (cboolean2 = 1) and (UDFToInteger(ctinyint) = 3569))) (type: boolean) + filterExpr: ((cstring2 like '%b%') or (CAST( cint AS decimal(13,3)) <> 79.553) or (UDFToDouble(cbigint) < cdouble)) (type: boolean) Statistics: Num rows: 12288 Data size: 147456 Basic stats: COMPLETE Column stats: NONE TableScan Vectorization: native: true @@ -1540,8 +1540,8 @@ STAGE PLANS: Filter Vectorization: className: VectorFilterOperator native: true - predicateExpression: FilterExprOrExpr(children: FilterStringColLikeStringScalar(col 7:string, pattern %b%), FilterDecimalColNotEqualDecimalScalar(col 13:decimal(13,3), val 79.553)(children: CastLongToDecimal(col 2:int) -> 13:decimal(13,3)), FilterDoubleColLessDoubleColumn(col 14:double, col 5:double)(children: CastLongToDouble(col 3:bigint) -> 14:double), FilterExprAndExpr(children: FilterLongColGreaterEqualLongColumn(col 0:smallint, col 1:smallint)(children: col 0:tinyint), FilterLongColEqualLongScalar(col 11:boolean, val 1), FilterLongColEqualLongScalar(col 0:int, val 3569)(children: col 0:tinyint))) - predicate: (((UDFToShort(ctinyint) >= csmallint) and (cboolean2 = 1) and (UDFToInteger(ctinyint) = 3569)) or (CAST( cint AS decimal(13,3)) <> 79.553) or (UDFToDouble(cbigint) < cdouble) or (cstring2 like '%b%')) (type: boolean) + predicateExpression: FilterExprOrExpr(children: FilterStringColLikeStringScalar(col 7:string, pattern %b%), FilterDecimalColNotEqualDecimalScalar(col 13:decimal(13,3), val 79.553)(children: CastLongToDecimal(col 2:int) -> 13:decimal(13,3)), FilterDoubleColLessDoubleColumn(col 14:double, col 5:double)(children: CastLongToDouble(col 3:bigint) -> 14:double)) + predicate: ((CAST( cint AS decimal(13,3)) <> 79.553) or (UDFToDouble(cbigint) < cdouble) or (cstring2 like '%b%')) (type: boolean) Statistics: Num rows: 12288 Data size: 147456 Basic stats: COMPLETE Column stats: NONE Select Operator expressions: cbigint (type: bigint), cfloat (type: float), ctinyint (type: tinyint), UDFToDouble(cbigint) (type: double), (UDFToDouble(cbigint) * UDFToDouble(cbigint)) (type: double) @@ -1585,7 +1585,7 @@ STAGE PLANS: vectorized: true rowBatchContext: dataColumnCount: 12 - includeColumns: [0, 1, 2, 3, 4, 5, 7, 11] + includeColumns: [0, 2, 3, 4, 5, 7] dataColumns: ctinyint:tinyint, csmallint:smallint, cint:int, cbigint:bigint, cfloat:float, cdouble:double, cstring1:string, cstring2:string, ctimestamp1:timestamp, ctimestamp2:timestamp, cboolean1:boolean, cboolean2:boolean partitionColumnCount: 0 scratchColumnTypeNames: [decimal(13,3), double, double, double, double]