diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java index cf8534a..ae2c0cc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java @@ -34,7 +34,6 @@ import org.apache.hadoop.hive.metastore.api.DoubleColumnStatsData; import org.apache.hadoop.hive.metastore.api.LongColumnStatsData; import org.apache.hadoop.hive.ql.exec.ColumnInfo; -import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.FetchTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; @@ -74,7 +73,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hive.common.util.AnnotationUtils; import org.apache.thrift.TException; import com.google.common.collect.Lists; @@ -140,6 +138,24 @@ public MetaDataProcessor (ParseContext pctx) { Unsupported } + enum IntegerSubType { + BIGINT, + INT { + Object cast(long longValue) { return (int)longValue; } + ObjectInspector getTypeOI() { return PrimitiveObjectInspectorFactory.javaIntObjectInspector;} + }, + SMALLINT { + Object cast(long longValue) { return (short)longValue; } + ObjectInspector getTypeOI() { return PrimitiveObjectInspectorFactory.javaShortObjectInspector;} + }, + TINYINT { + Object cast(long longValue) { return (byte)longValue; } + ObjectInspector getTypeOI() { return PrimitiveObjectInspectorFactory.javaByteObjectInspector;} + }; + Object cast(long longValue) { return longValue; } + ObjectInspector getTypeOI() { return PrimitiveObjectInspectorFactory.javaLongObjectInspector;} + } + private StatType getType(String origType) { if (serdeConstants.IntegralTypes.contains(origType)) { return StatType.Integeral; @@ -243,7 +259,12 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, GenericUDAFResolver udaf = FunctionRegistry.getGenericUDAFResolver(aggr.getGenericUDAFName()); if (udaf instanceof GenericUDAFSum) { + // long/double/decimal ExprNodeDesc desc = aggr.getParameters().get(0); + PrimitiveCategory category = GenericUDAFSum.getReturnType(desc.getTypeInfo()); + if (category == null) { + return null; + } String constant; if (desc instanceof ExprNodeConstantDesc) { constant = ((ExprNodeConstantDesc) desc).getValue().toString(); @@ -256,11 +277,25 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, if(rowCnt == null) { return null; } - oneRow.add(HiveDecimal.create(constant).multiply(HiveDecimal.create(rowCnt))); - ois.add(PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveCategory.DECIMAL)); + switch (category) { + case LONG: + oneRow.add(Long.valueOf(constant) * rowCnt); + ois.add(PrimitiveObjectInspectorFactory.javaLongObjectInspector); + break; + case DOUBLE: + oneRow.add(Double.valueOf(constant) * rowCnt); + ois.add(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); + break; + case DECIMAL: + oneRow.add(HiveDecimal.create(constant).multiply(HiveDecimal.create(rowCnt))); + ois.add(PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector); + break; + default: + throw new IllegalStateException("never"); + } } else if (udaf instanceof GenericUDAFCount) { + // always long Long rowCnt = 0L; if (aggr.getParameters().isEmpty() || aggr.getParameters().get(0) instanceof ExprNodeConstantDesc || ((aggr.getParameters().get(0) instanceof ExprNodeColumnDesc) && @@ -355,10 +390,15 @@ else if (udaf instanceof GenericUDAFCount) { ColumnStatisticsData statData = stats.get(0).getStatsData(); switch (type) { case Integeral: + String name = colDesc.getTypeString().toUpperCase(); + IntegerSubType subType = IntegerSubType.valueOf(name); LongColumnStatsData lstats = statData.getLongStats(); - oneRow.add(lstats.isSetHighValue() ? lstats.getHighValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (lstats.isSetHighValue()) { + oneRow.add(subType.valueOf(name).cast(lstats.getHighValue())); + } else { + oneRow.add(null); + } + ois.add(subType.getTypeOI()); break; case Double: DoubleColumnStatsData dstats = statData.getDoubleStats(); @@ -377,6 +417,9 @@ else if (udaf instanceof GenericUDAFCount) { tsOp.getConf().getAlias(), tsOp).getPartitions(); switch (type) { case Integeral: { + String name = colDesc.getTypeString().toUpperCase(); + IntegerSubType subType = IntegerSubType.valueOf(name); + Long maxVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -393,9 +436,12 @@ else if (udaf instanceof GenericUDAFCount) { long curVal = lstats.getHighValue(); maxVal = maxVal == null ? curVal : Math.max(maxVal, curVal); } - oneRow.add(maxVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (maxVal != null) { + oneRow.add(subType.cast(maxVal)); + } else { + oneRow.add(maxVal); + } + ois.add(subType.getTypeOI()); break; } case Double: { @@ -440,10 +486,15 @@ else if (udaf instanceof GenericUDAFCount) { .get(0).getStatsData(); switch (type) { case Integeral: + String name = colDesc.getTypeString().toUpperCase(); + IntegerSubType subType = IntegerSubType.valueOf(name); LongColumnStatsData lstats = statData.getLongStats(); - oneRow.add(lstats.isSetLowValue() ? lstats.getLowValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (lstats.isSetLowValue()) { + oneRow.add(subType.valueOf(name).cast(lstats.getLowValue())); + } else { + oneRow.add(null); + } + ois.add(subType.getTypeOI()); break; case Double: DoubleColumnStatsData dstats = statData.getDoubleStats(); @@ -460,6 +511,9 @@ else if (udaf instanceof GenericUDAFCount) { Set parts = pctx.getPrunedPartitions(tsOp.getConf().getAlias(), tsOp).getPartitions(); switch(type) { case Integeral: { + String name = colDesc.getTypeString().toUpperCase(); + IntegerSubType subType = IntegerSubType.valueOf(name); + Long minVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -476,9 +530,12 @@ else if (udaf instanceof GenericUDAFCount) { long curVal = lstats.getLowValue(); minVal = minVal == null ? curVal : Math.min(minVal, curVal); } - oneRow.add(minVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (minVal != null) { + oneRow.add(subType.cast(minVal)); + } else { + oneRow.add(minVal); + } + ois.add(subType.getTypeOI()); break; } case Double: { diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index d1118f1..ffb7093 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -87,6 +87,29 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) } } + public static PrimitiveObjectInspector.PrimitiveCategory getReturnType(TypeInfo type) { + if (type.getCategory() != ObjectInspector.Category.PRIMITIVE) { + return null; + } + switch (((PrimitiveTypeInfo) type).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + return PrimitiveObjectInspector.PrimitiveCategory.LONG; + case TIMESTAMP: + case FLOAT: + case DOUBLE: + case STRING: + case VARCHAR: + case CHAR: + return PrimitiveObjectInspector.PrimitiveCategory.DOUBLE; + case DECIMAL: + return PrimitiveObjectInspector.PrimitiveCategory.DECIMAL; + } + return null; + } + /** * GenericUDAFSumHiveDecimal. *