diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectList.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectList.java index 156d19b..9521051 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectList.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectList.java @@ -22,9 +22,10 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator.BufferType; +import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -@Description(name = "collect_list", value = "_FUNC_(x) - Returns a list of objects with duplicates") +@Description(name = "collect_list", value = "_FUNC_(x, y) - Returns a list of objects with duplicates") public class GenericUDAFCollectList extends AbstractGenericUDAFResolver { public GenericUDAFCollectList() { @@ -33,9 +34,9 @@ public GenericUDAFCollectList() { @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { - if (parameters.length != 1) { + if (parameters.length < 1 || parameters.length > 2) { throw new UDFArgumentTypeException(parameters.length - 1, - "Exactly one argument is expected."); + "Expecting 1 or 2 parameters"); } switch (parameters[0].getCategory()) { @@ -49,6 +50,14 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) "Only primitive, struct, list or map type arguments are accepted but " + parameters[0].getTypeName() + " was passed as parameter 1."); } + + + if(parameters.length == 2 && !parameters[1].getTypeName().equals(serdeConstants.BOOLEAN_TYPE_NAME)) { + throw new UDFArgumentTypeException(1, + "Only boolean type argument is accepted but " + + parameters[1].getTypeName() + " was passed as parameter 2."); + } + return new GenericUDAFMkCollectionEvaluator(BufferType.LIST); } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectSet.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectSet.java index 0c2cf90..8171f5e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectSet.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCollectSet.java @@ -21,12 +21,13 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator.BufferType; +import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; /** * GenericUDAFCollectSet */ -@Description(name = "collect_set", value = "_FUNC_(x) - Returns a set of objects with duplicate elements eliminated") +@Description(name = "collect_set", value = "_FUNC_(x, y) - Returns a set of objects with duplicate elements eliminated") public class GenericUDAFCollectSet extends AbstractGenericUDAFResolver { public GenericUDAFCollectSet() { @@ -35,9 +36,9 @@ public GenericUDAFCollectSet() { @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { - if (parameters.length != 1) { + if (parameters.length < 1 || parameters.length > 2) { throw new UDFArgumentTypeException(parameters.length - 1, - "Exactly one argument is expected."); + "Expecting 1 or 2 parameters"); } switch (parameters[0].getCategory()) { case PRIMITIVE: @@ -50,6 +51,13 @@ public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) "Only primitive, struct, list or map type arguments are accepted but " + parameters[0].getTypeName() + " was passed as parameter 1."); } + + if(parameters.length == 2 && !parameters[1].getTypeName().equals(serdeConstants.BOOLEAN_TYPE_NAME)) { + throw new UDFArgumentTypeException(1, + "Only boolean type argument is accepted but " + + parameters[1].getTypeName() + " was passed as parameter 2."); + } + return new GenericUDAFMkCollectionEvaluator(BufferType.SET); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMkCollectionEvaluator.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMkCollectionEvaluator.java index 2b5e6dd..9dceee6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMkCollectionEvaluator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFMkCollectionEvaluator.java @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.io.BooleanWritable; public class GenericUDAFMkCollectionEvaluator extends GenericUDAFEvaluator implements Serializable { @@ -112,10 +113,15 @@ public AggregationBuffer getNewAggregationBuffer() throws HiveException { @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { - assert (parameters.length == 1); + assert (parameters.length <= 2); Object p = parameters[0]; + boolean includeNull = false; - if (p != null) { + if(parameters.length == 2) { + includeNull = ((BooleanWritable) parameters[1]).get(); + } + + if (includeNull || p != null) { MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg; putIntoCollection(p, myagg); }