diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java index 6961d7f..c7a22fa 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/StatsOptimizer.java @@ -34,7 +34,6 @@ import org.apache.hadoop.hive.metastore.api.DoubleColumnStatsData; import org.apache.hadoop.hive.metastore.api.LongColumnStatsData; import org.apache.hadoop.hive.ql.exec.ColumnInfo; -import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.FetchTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; @@ -74,8 +73,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hive.common.util.AnnotationUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.thrift.TException; import com.google.common.collect.Lists; @@ -143,6 +141,22 @@ public MetaDataProcessor (ParseContext pctx) { Unsupported } + enum LongSubType { + BIGINT { Object cast(long longValue) { return longValue; } }, + INT { Object cast(long longValue) { return (int)longValue; } }, + SMALLINT { Object cast(long longValue) { return (short)longValue; } }, + TINYINT { Object cast(long longValue) { return (byte)longValue; } }; + + abstract Object cast(long longValue); + } + + enum DoubleSubType { + DOUBLE { Object cast(double doubleValue) { return doubleValue; } }, + FLOAT { Object cast(double doubleValue) { return (float) doubleValue; } }; + + abstract Object cast(double doubleValue); + } + private StatType getType(String origType) { if (serdeConstants.IntegralTypes.contains(origType)) { return StatType.Integeral; @@ -236,7 +250,6 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, Table tbl = tsOp.getConf().getTableMetadata(); List oneRow = new ArrayList(); - List ois = new ArrayList(); Hive hive = Hive.get(pctx.getConf()); @@ -249,7 +262,12 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, GenericUDAFResolver udaf = FunctionRegistry.getGenericUDAFResolver(aggr.getGenericUDAFName()); if (udaf instanceof GenericUDAFSum) { + // long/double/decimal ExprNodeDesc desc = aggr.getParameters().get(0); + PrimitiveCategory category = GenericUDAFSum.getReturnType(desc.getTypeInfo()); + if (category == null) { + return null; + } String constant; if (desc instanceof ExprNodeConstantDesc) { constant = ((ExprNodeConstantDesc) desc).getValue().toString(); @@ -262,11 +280,22 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, if(rowCnt == null) { return null; } - oneRow.add(HiveDecimal.create(constant).multiply(HiveDecimal.create(rowCnt))); - ois.add(PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( - PrimitiveCategory.DECIMAL)); + switch (category) { + case LONG: + oneRow.add(Long.valueOf(constant) * rowCnt); + break; + case DOUBLE: + oneRow.add(Double.valueOf(constant) * rowCnt); + break; + case DECIMAL: + oneRow.add(HiveDecimal.create(constant).multiply(HiveDecimal.create(rowCnt))); + break; + default: + throw new IllegalStateException("never"); + } } else if (udaf instanceof GenericUDAFCount) { + // always long Long rowCnt = 0L; if (aggr.getParameters().isEmpty() || aggr.getParameters().get(0) instanceof ExprNodeConstantDesc || ((aggr.getParameters().get(0) instanceof ExprNodeColumnDesc) && @@ -341,8 +370,6 @@ else if (udaf instanceof GenericUDAFCount) { } } oneRow.add(rowCnt); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); } else if (udaf instanceof GenericUDAFMax) { ExprNodeColumnDesc colDesc = (ExprNodeColumnDesc)exprMap.get(((ExprNodeColumnDesc)aggr.getParameters().get(0)).getColumn()); String colName = colDesc.getColumn(); @@ -359,19 +386,28 @@ else if (udaf instanceof GenericUDAFCount) { return null; } ColumnStatisticsData statData = stats.get(0).getStatsData(); + String name = colDesc.getTypeString().toUpperCase(); switch (type) { - case Integeral: + case Integeral: { + LongSubType subType = LongSubType.valueOf(name); LongColumnStatsData lstats = statData.getLongStats(); - oneRow.add(lstats.isSetHighValue() ? lstats.getHighValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (lstats.isSetHighValue()) { + oneRow.add(subType.cast(lstats.getHighValue())); + } else { + oneRow.add(null); + } break; - case Double: + } + case Double: { + DoubleSubType subType = DoubleSubType.valueOf(name); DoubleColumnStatsData dstats = statData.getDoubleStats(); - oneRow.add(dstats.isSetHighValue() ? dstats.getHighValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.DOUBLE)); + if (dstats.isSetHighValue()) { + oneRow.add(subType.cast(dstats.getHighValue())); + } else { + oneRow.add(null); + } break; + } default: // unsupported type Log.debug("Unsupported type: " + colDesc.getTypeString() + " encountered in " + @@ -381,8 +417,11 @@ else if (udaf instanceof GenericUDAFCount) { } else { Set parts = pctx.getPrunedPartitions( tsOp.getConf().getAlias(), tsOp).getPartitions(); + String name = colDesc.getTypeString().toUpperCase(); switch (type) { case Integeral: { + LongSubType subType = LongSubType.valueOf(name); + Long maxVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -399,12 +438,16 @@ else if (udaf instanceof GenericUDAFCount) { long curVal = lstats.getHighValue(); maxVal = maxVal == null ? curVal : Math.max(maxVal, curVal); } - oneRow.add(maxVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (maxVal != null) { + oneRow.add(subType.cast(maxVal)); + } else { + oneRow.add(maxVal); + } break; } case Double: { + DoubleSubType subType = DoubleSubType.valueOf(name); + Double maxVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -421,9 +464,11 @@ else if (udaf instanceof GenericUDAFCount) { double curVal = statData.getDoubleStats().getHighValue(); maxVal = maxVal == null ? curVal : Math.max(maxVal, curVal); } - oneRow.add(maxVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.DOUBLE)); + if (maxVal != null) { + oneRow.add(subType.cast(maxVal)); + } else { + oneRow.add(null); + } break; } default: @@ -444,19 +489,28 @@ else if (udaf instanceof GenericUDAFCount) { ColumnStatisticsData statData = hive.getMSC().getTableColumnStatistics( tbl.getDbName(), tbl.getTableName(), Lists.newArrayList(colName)) .get(0).getStatsData(); + String name = colDesc.getTypeString().toUpperCase(); switch (type) { - case Integeral: + case Integeral: { + LongSubType subType = LongSubType.valueOf(name); LongColumnStatsData lstats = statData.getLongStats(); - oneRow.add(lstats.isSetLowValue() ? lstats.getLowValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (lstats.isSetLowValue()) { + oneRow.add(subType.cast(lstats.getLowValue())); + } else { + oneRow.add(null); + } break; - case Double: + } + case Double: { + DoubleSubType subType = DoubleSubType.valueOf(name); DoubleColumnStatsData dstats = statData.getDoubleStats(); - oneRow.add(dstats.isSetLowValue() ? dstats.getLowValue() : null); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.DOUBLE)); + if (dstats.isSetLowValue()) { + oneRow.add(subType.cast(dstats.getLowValue())); + } else { + oneRow.add(null); + } break; + } default: // unsupported type Log.debug("Unsupported type: " + colDesc.getTypeString() + " encountered in " + "metadata optimizer for column : " + colName); @@ -464,8 +518,11 @@ else if (udaf instanceof GenericUDAFCount) { } } else { Set parts = pctx.getPrunedPartitions(tsOp.getConf().getAlias(), tsOp).getPartitions(); + String name = colDesc.getTypeString().toUpperCase(); switch(type) { case Integeral: { + LongSubType subType = LongSubType.valueOf(name); + Long minVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -482,12 +539,16 @@ else if (udaf instanceof GenericUDAFCount) { long curVal = lstats.getLowValue(); minVal = minVal == null ? curVal : Math.min(minVal, curVal); } - oneRow.add(minVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.LONG)); + if (minVal != null) { + oneRow.add(subType.cast(minVal)); + } else { + oneRow.add(minVal); + } break; } case Double: { + DoubleSubType subType = DoubleSubType.valueOf(name); + Double minVal = null; Collection> result = verifyAndGetPartStats(hive, tbl, colName, parts); @@ -504,9 +565,11 @@ else if (udaf instanceof GenericUDAFCount) { double curVal = statData.getDoubleStats().getLowValue(); minVal = minVal == null ? curVal : Math.min(minVal, curVal); } - oneRow.add(minVal); - ois.add(PrimitiveObjectInspectorFactory. - getPrimitiveJavaObjectInspector(PrimitiveCategory.DOUBLE)); + if (minVal != null) { + oneRow.add(subType.cast(minVal)); + } else { + oneRow.add(minVal); + } break; } default: // unsupported type @@ -528,8 +591,10 @@ else if (udaf instanceof GenericUDAFCount) { allRows.add(oneRow); List colNames = new ArrayList(); + List ois = new ArrayList(); for (ColumnInfo colInfo: gbyOp.getSchema().getSignature()) { colNames.add(colInfo.getInternalName()); + ois.add(TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(colInfo.getType())); } StandardStructObjectInspector sOI = ObjectInspectorFactory. getStandardStructObjectInspector(colNames, ois); diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index d1118f1..ffb7093 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -87,6 +87,29 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) } } + public static PrimitiveObjectInspector.PrimitiveCategory getReturnType(TypeInfo type) { + if (type.getCategory() != ObjectInspector.Category.PRIMITIVE) { + return null; + } + switch (((PrimitiveTypeInfo) type).getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + return PrimitiveObjectInspector.PrimitiveCategory.LONG; + case TIMESTAMP: + case FLOAT: + case DOUBLE: + case STRING: + case VARCHAR: + case CHAR: + return PrimitiveObjectInspector.PrimitiveCategory.DOUBLE; + case DECIMAL: + return PrimitiveObjectInspector.PrimitiveCategory.DECIMAL; + } + return null; + } + /** * GenericUDAFSumHiveDecimal. * diff --git ql/src/test/results/clientpositive/metadata_only_queries.q.out ql/src/test/results/clientpositive/metadata_only_queries.q.out index 90c76ed..5907f4a 100644 --- ql/src/test/results/clientpositive/metadata_only_queries.q.out +++ ql/src/test/results/clientpositive/metadata_only_queries.q.out @@ -338,7 +338,7 @@ POSTHOOK: query: select count(*), sum(1), sum(0.2), count(1), count(s), count(bo POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl #### A masked pattern was here #### -9999 9999 1999.8 9999 9999 9999 9999 9999 +9999 9999 1999.8000000000002 9999 9999 9999 9999 9999 PREHOOK: query: explain select min(i), max(i), min(b), max(b), min(f), max(f), min(d), max(d) from stats_tbl PREHOOK: type: QUERY @@ -363,7 +363,7 @@ POSTHOOK: query: select min(i), max(i), min(b), max(b), min(f), max(f), min(d), POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl #### A masked pattern was here #### -65536 65791 4294967296 4294967551 0.009999999776482582 99.9800033569336 0.01 50.0 +65536 65791 4294967296 4294967551 0.01 99.98 0.01 50.0 PREHOOK: query: explain select count(*), sum(1), sum(0.2), count(1), count(s), count(bo), count(bin), count(si) from stats_tbl_part PREHOOK: type: QUERY @@ -388,7 +388,7 @@ POSTHOOK: query: select count(*), sum(1), sum(0.2), count(1), count(s), count(bo POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl_part #### A masked pattern was here #### -9489 9489 1897.8 9489 9489 9489 9489 9489 +9489 9489 1897.8000000000002 9489 9489 9489 9489 9489 PREHOOK: query: explain select min(i), max(i), min(b), max(b), min(f), max(f), min(d), max(d) from stats_tbl_part PREHOOK: type: QUERY @@ -413,7 +413,7 @@ POSTHOOK: query: select min(i), max(i), min(b), max(b), min(f), max(f), min(d), POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl_part #### A masked pattern was here #### -65536 65791 4294967296 4294967551 0.009999999776482582 99.9800033569336 0.01 50.0 +65536 65791 4294967296 4294967551 0.01 99.98 0.01 50.0 PREHOOK: query: explain select count(ts) from stats_tbl_part PREHOOK: type: QUERY POSTHOOK: query: explain select count(ts) from stats_tbl_part diff --git ql/src/test/results/clientpositive/metadata_only_queries_with_filters.q.out ql/src/test/results/clientpositive/metadata_only_queries_with_filters.q.out index 5be958f..6dea3e0 100644 --- ql/src/test/results/clientpositive/metadata_only_queries_with_filters.q.out +++ ql/src/test/results/clientpositive/metadata_only_queries_with_filters.q.out @@ -166,7 +166,7 @@ POSTHOOK: query: select count(*), count(1), sum(1), count(s), count(bo), count(b POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl_part #### A masked pattern was here #### -2322 2322 2322 2322 2322 2322 2322 65791 4294967296 99.9800033569336 0.03 +2322 2322 2322 2322 2322 2322 2322 65791 4294967296 99.98 0.03 PREHOOK: query: explain select count(*), count(1), sum(1), sum(2), count(s), count(bo), count(bin), count(si), max(i), min(b), max(f), min(d) from stats_tbl_part where dt > 2010 PREHOOK: type: QUERY @@ -191,7 +191,7 @@ POSTHOOK: query: select count(*), count(1), sum(1), sum(2), count(s), count(bo), POSTHOOK: type: QUERY POSTHOOK: Input: default@stats_tbl_part #### A masked pattern was here #### -2219 2219 2219 4438 2219 2219 2219 2219 65791 4294967296 99.95999908447266 0.04 +2219 2219 2219 4438 2219 2219 2219 2219 65791 4294967296 99.96 0.04 PREHOOK: query: select count(*) from stats_tbl_part PREHOOK: type: QUERY PREHOOK: Input: default@stats_tbl_part