diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java index 14217e3..1a7e7e3 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java @@ -79,9 +79,13 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; @@ -703,13 +707,31 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, static ExprNodeDesc toExprNodeDesc(ColumnInfo colInfo) { ObjectInspector inspector = colInfo.getObjectInspector(); - if (inspector instanceof ConstantObjectInspector && - inspector instanceof PrimitiveObjectInspector) { - PrimitiveObjectInspector poi = (PrimitiveObjectInspector) inspector; - Object constant = ((ConstantObjectInspector) inspector).getWritableConstantValue(); - ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), poi.getPrimitiveJavaObject(constant)); - constantExpr.setFoldedFromCol(colInfo.getInternalName()); - return constantExpr; + if (inspector instanceof ConstantObjectInspector && inspector instanceof PrimitiveObjectInspector) { + return toPrimitiveConstDesc(colInfo, inspector); + } + if (inspector instanceof ConstantObjectInspector && inspector instanceof ListObjectInspector) { + ObjectInspector listElementOI = ((ListObjectInspector)inspector).getListElementObjectInspector(); + if (listElementOI instanceof PrimitiveObjectInspector) { + return toListConstDesc(colInfo, inspector, listElementOI); + } + } + if (inspector instanceof ConstantObjectInspector && inspector instanceof MapObjectInspector) { + ObjectInspector keyOI = ((MapObjectInspector)inspector).getMapKeyObjectInspector(); + ObjectInspector valueOI = ((MapObjectInspector)inspector).getMapValueObjectInspector(); + if (keyOI instanceof PrimitiveObjectInspector && valueOI instanceof PrimitiveObjectInspector) { + return toMapConstDesc(colInfo, inspector, keyOI, valueOI); + } + } + if (inspector instanceof ConstantObjectInspector && inspector instanceof StructObjectInspector) { + boolean allPrimitive = true; + List fields = ((StructObjectInspector)inspector).getAllStructFieldRefs(); + for (StructField field : fields) { + allPrimitive &= field.getFieldObjectInspector() instanceof PrimitiveObjectInspector; + } + if (allPrimitive) { + return toStructConstDesc(colInfo, inspector, fields); + } } // non-constant or non-primitive constants ExprNodeColumnDesc column = new ExprNodeColumnDesc(colInfo); @@ -717,6 +739,59 @@ static ExprNodeDesc toExprNodeDesc(ColumnInfo colInfo) { return column; } + private static ExprNodeConstantDesc toPrimitiveConstDesc(ColumnInfo colInfo, ObjectInspector inspector) { + PrimitiveObjectInspector poi = (PrimitiveObjectInspector) inspector; + Object constant = ((ConstantObjectInspector) inspector).getWritableConstantValue(); + ExprNodeConstantDesc constantExpr = + new ExprNodeConstantDesc(colInfo.getType(), poi.getPrimitiveJavaObject(constant)); + constantExpr.setFoldedFromCol(colInfo.getInternalName()); + return constantExpr; + } + + private static ExprNodeConstantDesc toListConstDesc(ColumnInfo colInfo, ObjectInspector inspector, + ObjectInspector listElementOI) { + PrimitiveObjectInspector poi = (PrimitiveObjectInspector)listElementOI; + List values = (List)((ConstantObjectInspector) inspector).getWritableConstantValue(); + List constant = new ArrayList(); + for (Object o : values) { + constant.add(poi.getPrimitiveJavaObject(o)); + } + + ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant); + constantExpr.setFoldedFromCol(colInfo.getInternalName()); + return constantExpr; + } + + private static ExprNodeConstantDesc toMapConstDesc(ColumnInfo colInfo, ObjectInspector inspector, + ObjectInspector keyOI, ObjectInspector valueOI) { + PrimitiveObjectInspector keyPoi = (PrimitiveObjectInspector)keyOI; + PrimitiveObjectInspector valuePoi = (PrimitiveObjectInspector)valueOI; + Map values = (Map)((ConstantObjectInspector) inspector).getWritableConstantValue(); + Map constant = new HashMap(); + for (Map.Entry e : values.entrySet()) { + constant.put(keyPoi.getPrimitiveJavaObject(e.getKey()), valuePoi.getPrimitiveJavaObject(e.getValue())); + } + + ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant); + constantExpr.setFoldedFromCol(colInfo.getInternalName()); + return constantExpr; + } + + private static ExprNodeConstantDesc toStructConstDesc(ColumnInfo colInfo, ObjectInspector inspector, + List fields) { + List values = (List)((ConstantObjectInspector) inspector).getWritableConstantValue(); + List constant = new ArrayList(); + for (int i = 0; i < values.size(); i++) { + Object value = values.get(i); + PrimitiveObjectInspector fieldPoi = (PrimitiveObjectInspector) fields.get(i).getFieldObjectInspector(); + constant.add(fieldPoi.getPrimitiveJavaObject(value)); + } + + ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant); + constantExpr.setFoldedFromCol(colInfo.getInternalName()); + return constantExpr; + } + /** * Factory method to get ColumnExprProcessor. * diff --git a/ql/src/test/queries/clientpositive/udaf_percentile_approx_23.q b/ql/src/test/queries/clientpositive/udaf_percentile_approx_23.q index 38cf927..70974ba 100644 --- a/ql/src/test/queries/clientpositive/udaf_percentile_approx_23.q +++ b/ql/src/test/queries/clientpositive/udaf_percentile_approx_23.q @@ -99,5 +99,4 @@ select percentile_approx(key, 0.5) from bucket; select percentile_approx(key, 0.5) between 255.0 and 257.0 from bucket; -- test where number of elements is zero -set hive.cbo.enable=false; select percentile_approx(key, array(0.50, 0.70, 0.90, 0.95, 0.99)) from bucket where key > 10000;