diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java index bde36e13cd852eca29bd0346f0088da72ea42dfb..70e0db13ac6f99edc86473b7ea34b70cdf06ca0b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMin.java @@ -28,7 +28,9 @@ import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax.MaxStreamingFixedWindow; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.FullMapEqualComparer; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.NullValueOption; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; @@ -109,7 +111,7 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { MinAgg myagg = (MinAgg) agg; - int r = ObjectInspectorUtils.compare(myagg.o, outputOI, partial, inputOI); + int r = ObjectInspectorUtils.compare(myagg.o, outputOI, partial, inputOI, new FullMapEqualComparer(), NullValueOption.MAXVALUE); if (myagg.o == null || r > 0) { myagg.o = ObjectInspectorUtils.copyToStandardObject(partial, inputOI, ObjectInspectorCopyOption.JAVA); diff --git a/ql/src/test/queries/clientpositive/min_structvalue.q b/ql/src/test/queries/clientpositive/min_structvalue.q new file mode 100644 index 0000000000000000000000000000000000000000..4431a0dcab9cda8695b918a1e679872e04af87c1 --- /dev/null +++ b/ql/src/test/queries/clientpositive/min_structvalue.q @@ -0,0 +1,10 @@ +select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp; + +select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp; + +select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp; + diff --git a/ql/src/test/results/clientpositive/min_structvalue.q.out b/ql/src/test/results/clientpositive/min_structvalue.q.out new file mode 100644 index 0000000000000000000000000000000000000000..35828373dabf94e43c3cdfdde8156e7064120657 --- /dev/null +++ b/ql/src/test/results/clientpositive/min_structvalue.q.out @@ -0,0 +1,45 @@ +PREHOOK: query: select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select max(a), min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(null as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":2} {"field":1} +PREHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",-2) as a union all select named_struct("field",cast(null as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":-2} +PREHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1) as a union all select named_struct("field",2) as a union all select named_struct("field",cast(5 as int)) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":1} +PREHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", cast(null as int) ) as a union all select named_struct("field",2, "secf", 3) as a union all select named_struct("field",cast(5 as int), "secf", 4) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":1,"secf":null} +PREHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +POSTHOOK: query: select min(a) FROM (select named_struct("field",1, "secf", 2) as a union all select named_struct("field",-2, "secf", 3) as a union all select named_struct("field",cast(null as int), "secf", 1) as a) tmp +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +#### A masked pattern was here #### +{"field":-2,"secf":3} diff --git a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java index 33e535728000b108cc409196d49d837cf652da13..c58e8ed05453c78cbe2e4daf0b7afa51adbc0ce9 100644 --- a/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java +++ b/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectInspectorUtils.java @@ -106,6 +106,17 @@ } /** + * This enum controls how we interpret null value when compare two objects. + * + * MINVALUE means treating null value as the minimum value. + * MAXVALUE means treating null value as the maximum value. + * + */ + public enum NullValueOption { + MINVALUE, MAXVALUE + } + + /** * Calculates the hash code for array of Objects that contains writables. This is used * to work around the buggy Hadoop DoubleWritable hashCode implementation. This should * only be used for process-local hash codes; don't replace stored hash codes like bucketing. @@ -762,17 +773,38 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, /** * Compare two objects with their respective ObjectInspectors. + * Treat null as minimum value. */ public static int compare(Object o1, ObjectInspector oi1, Object o2, ObjectInspector oi2, MapEqualComparer mapEqualComparer) { + return compare(o1, oi1, o2, oi2, mapEqualComparer, NullValueOption.MINVALUE); + } + + /** + * Compare two objects with their respective ObjectInspectors. + * if nullValueOpt is MAXVALUE, treat null as maximum value. + * if nullValueOpt is MINVALUE, treat null as minimum value. + */ + public static int compare(Object o1, ObjectInspector oi1, Object o2, + ObjectInspector oi2, MapEqualComparer mapEqualComparer, NullValueOption nullValueOpt) { if (oi1.getCategory() != oi2.getCategory()) { return oi1.getCategory().compareTo(oi2.getCategory()); } + int nullCmpRtn = -1; + switch (nullValueOpt) { + case MAXVALUE: + nullCmpRtn = 1; + break; + case MINVALUE: + nullCmpRtn = -1; + break; + } + if (o1 == null) { - return o2 == null ? 0 : -1; + return o2 == null ? 0 : nullCmpRtn; } else if (o2 == null) { - return 1; + return -nullCmpRtn; } switch (oi1.getCategory()) { @@ -915,7 +947,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, int r = compare(soi1.getStructFieldData(o1, fields1.get(i)), fields1 .get(i).getFieldObjectInspector(), soi2.getStructFieldData(o2, fields2.get(i)), fields2.get(i).getFieldObjectInspector(), - mapEqualComparer); + mapEqualComparer, nullValueOpt); if (r != 0) { return r; } @@ -930,7 +962,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, int r = compare(loi1.getListElement(o1, i), loi1 .getListElementObjectInspector(), loi2.getListElement(o2, i), loi2 .getListElementObjectInspector(), - mapEqualComparer); + mapEqualComparer, nullValueOpt); if (r != 0) { return r; } @@ -955,7 +987,7 @@ public static int compare(Object o1, ObjectInspector oi1, Object o2, return compare(uoi1.getField(o1), uoi1.getObjectInspectors().get(tag1), uoi2.getField(o2), uoi2.getObjectInspectors().get(tag2), - mapEqualComparer); + mapEqualComparer, nullValueOpt); } default: throw new RuntimeException("Compare on unknown type: "