diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java index bde36e13cd852eca29bd0346f0088da72ea42dfb..85f7ffe466f8ab882d8d1b09e2fa7c9b2aba44bf 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java @@ -109,7 +109,7 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { MinAgg myagg = (MinAgg) agg; - int r = ObjectInspectorUtils.compare(myagg.o, outputOI, partial, inputOI); + int r = ObjectInspectorUtils.compareForMin(myagg.o, outputOI, partial, inputOI); if (myagg.o == null || r > 0) { myagg.o = ObjectInspectorUtils.copyToStandardObject(partial, inputOI, ObjectInspectorCopyOption.JAVA); diff --git a/ql/src/test/queries/clientpositive/min_structvalue.q b/ql/src/test/queries/clientpositive/min_structvalue.q new file mode 100644 index 0000000000000000000000000000000000000000..4431a0dcab9cda8695b918a1e679872e04af87c1 --- /dev/null +++ b/ql/src/test/queries/clientpositive/min_structvalue.q @@ -0,0 +1,10 @@ +select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp; + +select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp; + diff --git a/ql/src/test/results/clientpositive/min_structvalue.q.out b/ql/src/test/results/clientpositive/min_structvalue.q.out new file mode 100644 index 0000000000000000000000000000000000000000..35828373dabf94e43c3cdfdde8156e7064120657 --- /dev/null +++ b/ql/src/test/results/clientpositive/min_structvalue.q.out @@ -0,0 +1,45 @@ +PREHOOK: query: select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":2} {"field":1} +PREHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":-2} +PREHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":1} +PREHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":1,"secf":null} +PREHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":-2,"secf":3} diff --git a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java index 33e535728000b108cc409196d49d837cf652da13..eb89ac8bb1fcde0273603076b0223b673adc8403 100644 --- a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java +++ b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java @@ -762,17 +762,38 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, /** * Compare two objects with their respective ObjectInspectors. + * Treat null as minimum value. */ public static int compare(Object o1, ObjectInspector oi1, Object o2, ObjectInspector oi2, MapEqualComparer mapEqualComparer) { + return compare(o1, oi1, o2, oi2, mapEqualComparer, false); + } + + /** + * Compare two objects with their respective ObjectInspectors. + * This compare method is used for get the value for function min + */ + public static int compareForMin(Object o1, ObjectInspector oi1, Object o2, + ObjectInspector oi2) { + return compare(o1, oi1, o2, oi2, new FullMapEqualComparer(), true); + } + + /** + * Compare two objects with their respective ObjectInspectors. + * if nullAsMax true, treat null as maximum value. + */ + public static int compare(Object o1, ObjectInspector oi1, Object o2, + ObjectInspector oi2, MapEqualComparer mapEqualComparer, boolean nullAsMax) { if (oi1.getCategory() != oi2.getCategory()) { return oi1.getCategory().compareTo(oi2.getCategory()); } + int nullCmpRtn = nullAsMax ? 1 : -1; + if (o1 == null) { - return o2 == null ? 0 : -1; + return o2 == null ? 0 : nullCmpRtn; } else if (o2 == null) { - return 1; + return -nullCmpRtn; } switch (oi1.getCategory()) { @@ -915,7 +936,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, int r = compare(soi1.getStructFieldData(o1, fields1.get(i)), fields1 .get(i).getFieldObjectInspector(), soi2.getStructFieldData(o2, fields2.get(i)), fields2.get(i).getFieldObjectInspector(), - mapEqualComparer); + mapEqualComparer, nullAsMax); if (r != 0) { return r; } @@ -930,7 +951,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, int r = compare(loi1.getListElement(o1, i), loi1 .getListElementObjectInspector(), loi2.getListElement(o2, i), loi2 .getListElementObjectInspector(), - mapEqualComparer); + mapEqualComparer, nullAsMax); if (r != 0) { return r; } @@ -955,7 +976,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, return compare(uoi1.getField(o1), uoi1.getObjectInspectors().get(tag1), uoi2.getField(o2), uoi2.getObjectInspectors().get(tag2), - mapEqualComparer); + mapEqualComparer, nullAsMax); } default: throw new RuntimeException("Compare on unknown type: "