diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFSortArray.java ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFSortArray.java index 2d6d58c..edc75ec 100644 --- ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFSortArray.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFSortArray.java @@ -25,13 +25,11 @@ import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -60,22 +58,19 @@ public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver; returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(true); - if (arguments.length != 1) { - throw new UDFArgumentLengthException( - "The function SORT_ARRAY(array(obj1, obj2,...)) needs one argument."); - } + checkArgsSize(arguments, 1, 1); switch(arguments[0].getCategory()) { case LIST: - if(((ListObjectInspector)(arguments[0])).getListElementObjectInspector() - .getCategory().equals(Category.PRIMITIVE)) { + if(!((ListObjectInspector)(arguments[0])).getListElementObjectInspector() + .getCategory().equals(ObjectInspector.Category.UNION)) { break; } default: throw new UDFArgumentTypeException(0, "Argument 1" - + " of function SORT_ARRAY must be " + serdeConstants.LIST_TYPE_NAME - + "<" + Category.PRIMITIVE + ">, but " + arguments[0].getTypeName() - + " was found."); + + " of function SORT_ARRAY must be " + serdeConstants.LIST_TYPE_NAME + + ", and element type should be either primitive, list, struct, or map, " + + "but " + arguments[0].getTypeName() + " was found."); } ObjectInspector elementObjectInspector = diff --git ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFSortArray.java ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFSortArray.java new file mode 100644 index 0000000..4a0eb7e --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/udf/generic/TestGenericUDFSortArray.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.udf.generic; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DateWritable; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.junit.Assert; +import org.junit.Test; + +import java.sql.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Arrays.asList; + +public class TestGenericUDFSortArray { + private final GenericUDFSortArray udf = new GenericUDFSortArray(); + + @Test + public void testSortPrimitive() throws HiveException { + ObjectInspector[] inputOIs = { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableIntObjectInspector) + }; + udf.initialize(inputOIs); + + Object i1 = new IntWritable(3); + Object i2 = new IntWritable(4); + Object i3 = new IntWritable(2); + Object i4 = new IntWritable(1); + + runAndVerify(asList(i1,i2,i3,i4), asList(i4,i3,i1,i2)); + } + + @Test + public void testSortList() throws HiveException { + ObjectInspector[] inputOIs = { + ObjectInspectorFactory.getStandardListObjectInspector( + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector + ) + ) + }; + udf.initialize(inputOIs); + + Object i1 = asList(new Text("aa"),new Text("dd"),new Text("cc"),new Text("bb")); + Object i2 = asList(new Text("aa"),new Text("cc"),new Text("ba"),new Text("dd")); + Object i3 = asList(new Text("aa"),new Text("cc"),new Text("dd"),new Text("ee"), new Text("bb")); + Object i4 = asList(new Text("aa"),new Text("cc"),new Text("ddd"),new Text("bb")); + + runAndVerify(asList(i1,i2,i3,i4), asList(i2,i3,i4,i1)); + } + + @Test + public void testSortStruct() throws HiveException { + ObjectInspector[] inputOIs = { + ObjectInspectorFactory.getStandardListObjectInspector( + ObjectInspectorFactory.getStandardStructObjectInspector( + asList("f1", "f2", "f3", "f4"), + asList( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableDateObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableIntObjectInspector + ) + ) + ) + ) + }; + udf.initialize(inputOIs); + + Object i1 = asList(new Text("a"), new DoubleWritable(3.1415), + new DateWritable(new Date(2015, 5, 26)), + asList(new IntWritable(1), new IntWritable(3), + new IntWritable(2), new IntWritable(4))); + + Object i2 = asList(new Text("b"), new DoubleWritable(3.14), + new DateWritable(new Date(2015, 5, 26)), + asList(new IntWritable(1), new IntWritable(3), + new IntWritable(2), new IntWritable(4))); + + Object i3 = asList(new Text("a"), new DoubleWritable(3.1415), + new DateWritable(new Date(2015, 5, 25)), + asList(new IntWritable(1), new IntWritable(3), + new IntWritable(2), new IntWritable(5))); + + Object i4 = asList(new Text("a"), new DoubleWritable(3.1415), + new DateWritable(new Date(2015, 5, 25)), + asList(new IntWritable(1), new IntWritable(3), + new IntWritable(2), new IntWritable(4))); + + runAndVerify(asList(i1,i2,i3,i4), asList(i4,i3,i1,i2)); + } + + @Test + public void testSortMap() throws HiveException { + ObjectInspector[] inputOIs = { + ObjectInspectorFactory.getStandardListObjectInspector( + ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, + PrimitiveObjectInspectorFactory.writableIntObjectInspector + ) + ) + }; + udf.initialize(inputOIs); + + Map m1 = new HashMap(); + m1.put(new Text("a"), new IntWritable(4)); + m1.put(new Text("b"), new IntWritable(3)); + m1.put(new Text("c"), new IntWritable(1)); + m1.put(new Text("d"), new IntWritable(2)); + + Map m2 = new HashMap(); + m2.put(new Text("d"), new IntWritable(4)); + m2.put(new Text("b"), new IntWritable(3)); + m2.put(new Text("a"), new IntWritable(1)); + m2.put(new Text("c"), new IntWritable(2)); + + Map m3 = new HashMap(); + m3.put(new Text("d"), new IntWritable(4)); + m3.put(new Text("b"), new IntWritable(3)); + m3.put(new Text("a"), new IntWritable(1)); + + runAndVerify(asList((Object)m1, m2, m3), asList((Object)m3, m2, m1)); + } + + private void runAndVerify(List actual, List expected) + throws HiveException { + GenericUDF.DeferredJavaObject[] args = { new GenericUDF.DeferredJavaObject(actual) }; + List result = (List) udf.evaluate(args); + + Assert.assertEquals("Check size", expected.size(), result.size()); + Assert.assertArrayEquals("Check content", expected.toArray(), result.toArray()); + } + + +} diff --git ql/src/test/queries/clientnegative/udf_sort_array_wrong3.q ql/src/test/queries/clientnegative/udf_sort_array_wrong3.q index 034de06..49856ae 100644 --- ql/src/test/queries/clientnegative/udf_sort_array_wrong3.q +++ ql/src/test/queries/clientnegative/udf_sort_array_wrong3.q @@ -1,2 +1,2 @@ -- invalid argument type -SELECT sort_array(array(array(10, 20), array(5, 15), array(3, 13))) FROM src LIMIT 1; +SELECT sort_array(array(create_union(0,"a"))) FROM src LIMIT 1; diff --git ql/src/test/queries/clientpositive/udf_sort_array.q ql/src/test/queries/clientpositive/udf_sort_array.q index 313bcf8..997d0c8 100644 --- ql/src/test/queries/clientpositive/udf_sort_array.q +++ ql/src/test/queries/clientpositive/udf_sort_array.q @@ -19,6 +19,16 @@ SELECT sort_array(array(2, 9, 7, 3, 5, 4, 1, 6, 8)) FROM src tablesample (1 rows -- Evaluate function against FLOAT valued keys SELECT sort_array(sort_array(array(2.333, 9, 1.325, 2.003, 0.777, -3.445, 1))) FROM src tablesample (1 rows); +-- Evaluate function against LIST valued keys +SELECT sort_array(array(array(2, 9, 7), array(3, 5, 4), array(1, 6, 8))) FROM src tablesample (1 rows); + +-- Evaluate function against STRUCT valued keys +SELECT sort_array(array(struct(2, 9, 7), struct(3, 5, 4), struct(1, 6, 8))) FROM src tablesample (1 rows); + +-- Evaluate function against MAP valued keys +SELECT sort_array(array(map("f", 2, "a", 9, "g", 7), map("c", 3, "b", 5, "d", 4), map("e", 1, "k", 6, "i", 8))) FROM src tablesample (1 rows); + + -- Test it against data in a table. CREATE TABLE dest1 ( tinyints ARRAY, diff --git ql/src/test/results/clientnegative/udf_sort_array_wrong2.q.out ql/src/test/results/clientnegative/udf_sort_array_wrong2.q.out index c068ecd..2123e2e 100644 --- ql/src/test/results/clientnegative/udf_sort_array_wrong2.q.out +++ ql/src/test/results/clientnegative/udf_sort_array_wrong2.q.out @@ -1 +1 @@ -FAILED: SemanticException [Error 10016]: Line 2:18 Argument type mismatch '"Invalid"': Argument 1 of function SORT_ARRAY must be array, but string was found. +FAILED: SemanticException [Error 10016]: Line 2:18 Argument type mismatch '"Invalid"': Argument 1 of function SORT_ARRAY must be array, and element type should be either primitive, list, struct, or map, but string was found. diff --git ql/src/test/results/clientnegative/udf_sort_array_wrong3.q.out ql/src/test/results/clientnegative/udf_sort_array_wrong3.q.out index abf7124..6745f4f 100644 --- ql/src/test/results/clientnegative/udf_sort_array_wrong3.q.out +++ ql/src/test/results/clientnegative/udf_sort_array_wrong3.q.out @@ -1 +1 @@ -FAILED: SemanticException [Error 10016]: Line 2:18 Argument type mismatch '13': Argument 1 of function SORT_ARRAY must be array, but array> was found. +FAILED: SemanticException [Error 10016]: Line 2:18 Argument type mismatch '"a"': Argument 1 of function SORT_ARRAY must be array, and element type should be either primitive, list, struct, or map, but array> was found. diff --git ql/src/test/results/clientpositive/udf_sort_array.q.out ql/src/test/results/clientpositive/udf_sort_array.q.out index 9631c2d..dee8500 100644 --- ql/src/test/results/clientpositive/udf_sort_array.q.out +++ ql/src/test/results/clientpositive/udf_sort_array.q.out @@ -87,6 +87,39 @@ POSTHOOK: type: QUERY POSTHOOK: Input: default@src #### A masked pattern was here #### [-3.445,0.777,1.0,1.325,2.003,2.333,9.0] +PREHOOK: query: -- Evaluate function against LIST valued keys +SELECT sort_array(array(array(2, 9, 7), array(3, 5, 4), array(1, 6, 8))) FROM src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- Evaluate function against LIST valued keys +SELECT sort_array(array(array(2, 9, 7), array(3, 5, 4), array(1, 6, 8))) FROM src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +[[1,6,8],[2,9,7],[3,5,4]] +PREHOOK: query: -- Evaluate function against STRUCT valued keys +SELECT sort_array(array(struct(2, 9, 7), struct(3, 5, 4), struct(1, 6, 8))) FROM src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- Evaluate function against STRUCT valued keys +SELECT sort_array(array(struct(2, 9, 7), struct(3, 5, 4), struct(1, 6, 8))) FROM src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +[{"col1":1,"col2":6,"col3":8},{"col1":2,"col2":9,"col3":7},{"col1":3,"col2":5,"col3":4}] +PREHOOK: query: -- Evaluate function against MAP valued keys +SELECT sort_array(array(map("f", 2, "a", 9, "g", 7), map("c", 3, "b", 5, "d", 4), map("e", 1, "k", 6, "i", 8))) FROM src tablesample (1 rows) +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: -- Evaluate function against MAP valued keys +SELECT sort_array(array(map("f", 2, "a", 9, "g", 7), map("c", 3, "b", 5, "d", 4), map("e", 1, "k", 6, "i", 8))) FROM src tablesample (1 rows) +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +[{"a":9,"f":2,"g":7},{"b":5,"d":4,"c":3},{"i":8,"k":6,"e":1}] PREHOOK: query: -- Test it against data in a table. CREATE TABLE dest1 ( tinyints ARRAY,