diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveTypeSystemImpl.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveTypeSystemImpl.java index e83ffe1217..511d19b971 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveTypeSystemImpl.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveTypeSystemImpl.java @@ -17,6 +17,8 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystemImpl; import org.apache.calcite.sql.type.SqlTypeName; @@ -37,6 +39,14 @@ private static final int MAX_BINARY_PRECISION = Integer.MAX_VALUE; private static final int MAX_TIMESTAMP_PRECISION = 9; private static final int MAX_TIMESTAMP_WITH_LOCAL_TIME_ZONE_PRECISION = 15; // Up to nanos + private static final int DEFAULT_BOOLEAN_PRECISION = 1; + private static final int DEFAULT_TINYINT_PRECISION = 3; + private static final int DEFAULT_SMALLINT_PRECISION = 5; + private static final int DEFAULT_INTEGER_PRECISION = 10; + private static final int DEFAULT_BIGINT_PRECISION = 19; + private static final int DEFAULT_FLOAT_PRECISION = 7; + private static final int DEFAULT_DOUBLE_PRECISION = 15; + @Override public int getMaxScale(SqlTypeName typeName) { @@ -93,6 +103,20 @@ public int getDefaultPrecision(SqlTypeName typeName) { case INTERVAL_MINUTE_SECOND: case INTERVAL_SECOND: return SqlTypeName.DEFAULT_INTERVAL_START_PRECISION; + case BOOLEAN: + return DEFAULT_BOOLEAN_PRECISION; + case TINYINT: + return DEFAULT_TINYINT_PRECISION; + case SMALLINT: + return DEFAULT_SMALLINT_PRECISION; + case INTEGER: + return DEFAULT_INTEGER_PRECISION; + case BIGINT: + return DEFAULT_BIGINT_PRECISION; + case FLOAT: + return DEFAULT_FLOAT_PRECISION; + case DOUBLE: + return DEFAULT_DOUBLE_PRECISION; default: return -1; } @@ -129,7 +153,7 @@ public int getMaxPrecision(SqlTypeName typeName) { case INTERVAL_SECOND: return SqlTypeName.MAX_INTERVAL_START_PRECISION; default: - return -1; + return getDefaultPrecision(typeName); } } @@ -148,4 +172,17 @@ public boolean isSchemaCaseSensitive() { return false; } + @Override + public RelDataType deriveSumType(RelDataTypeFactory typeFactory, + RelDataType argumentType) { + switch (argumentType.getSqlTypeName()) { + case DECIMAL: + return typeFactory.createSqlType( + SqlTypeName.DECIMAL, + Math.min(MAX_DECIMAL_PRECISION, argumentType.getPrecision() + 10), + argumentType.getScale()); + } + return argumentType; + } + } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.java index 4b7139a8f7..759a43232a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.java @@ -292,7 +292,7 @@ private RexNode reduceSum0( final RelDataType sum0InputType = typeFactory.createTypeWithNullability( getFieldType(oldAggRel.getInput(), iAvgInput), true); final RelDataType sumReturnType = getSumReturnType( - rexBuilder.getTypeFactory(), sum0InputType, oldCall.getType()); + rexBuilder.getTypeFactory(), sum0InputType); final AggregateCall sumCall = AggregateCall.create( new HiveSqlSumAggFunction( @@ -336,7 +336,7 @@ private RexNode reduceAvg( final RelDataType avgInputType = typeFactory.createTypeWithNullability( getFieldType(oldAggRel.getInput(), iAvgInput), true); final RelDataType sumReturnType = getSumReturnType( - rexBuilder.getTypeFactory(), avgInputType, oldCall.getType()); + rexBuilder.getTypeFactory(), avgInputType); final AggregateCall sumCall = AggregateCall.create( new HiveSqlSumAggFunction( @@ -427,13 +427,13 @@ private RexNode reduceStddev( rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false); final int argRefOrdinal = lookupOrAdd(inputExprs, argRef); final RelDataType sumReturnType = getSumReturnType( - rexBuilder.getTypeFactory(), argRef.getType(), oldCall.getType()); + rexBuilder.getTypeFactory(), argRef.getType()); final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef); final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared); final RelDataType sumSquaredReturnType = getSumReturnType( - rexBuilder.getTypeFactory(), argSquared.getType(), oldCall.getType()); + rexBuilder.getTypeFactory(), argSquared.getType()); final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, @@ -593,7 +593,7 @@ private RelDataType getFieldType(RelNode relNode, int i) { } private RelDataType getSumReturnType(RelDataTypeFactory typeFactory, - RelDataType inputType, RelDataType originalReturnType) { + RelDataType inputType) { switch (inputType.getSqlTypeName()) { case TINYINT: case SMALLINT: @@ -607,8 +607,7 @@ private RelDataType getSumReturnType(RelDataTypeFactory typeFactory, case CHAR: return TypeConverter.convert(TypeInfoFactory.doubleTypeInfo, typeFactory); case DECIMAL: - // We keep precision and scale - return originalReturnType; + return typeFactory.getTypeSystem().deriveSumType(typeFactory, inputType); } return null; }