diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorTopNKeyOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorTopNKeyOperator.java index c80bc804a2..6390f6348c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorTopNKeyOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/VectorTopNKeyOperator.java @@ -17,6 +17,14 @@ */ package org.apache.hadoop.hive.ql.exec.vector; +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.NullValueOption.MAXVALUE; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; + import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.CompilationOpContext; @@ -28,15 +36,75 @@ import org.apache.hadoop.hive.ql.plan.TopNKeyDesc; import org.apache.hadoop.hive.ql.plan.VectorDesc; import org.apache.hadoop.hive.ql.plan.VectorTopNKeyDesc; +import org.apache.hadoop.hive.ql.plan.api.OperatorType; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectComparator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; /** * VectorTopNKeyOperator passes rows that contains top N keys only. */ -public class VectorTopNKeyOperator extends TopNKeyOperator implements VectorizationOperator { +public class VectorTopNKeyOperator extends Operator implements VectorizationOperator { private static final long serialVersionUID = 1L; + private static class KeyWrapper { + private final Object[] keys; + + public KeyWrapper(Object[] keys) { + this.keys = keys; + } + + public KeyWrapper(KeyWrapper other) { + keys = Arrays.copyOf(other.keys, other.keys.length); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + KeyWrapper that = (KeyWrapper) o; + return Arrays.equals(keys, that.keys); + } + + @Override + public int hashCode() { + return Arrays.hashCode(keys); + } + + public Object[] getKeyArray() { + return keys; + } + } + + public static class RowWrapperComparator implements Comparator { + private final List> comparators; + + RowWrapperComparator(List> comparators) { + this.comparators = comparators; + } + + @Override + public int compare(KeyWrapper key1, KeyWrapper key2) { + Object[] keyArray1 = key1.getKeyArray(); + Object[] keyArray2 = key2.getKeyArray(); + for (int i = 0; i < comparators.size(); ++i) { + int c = comparators.get(i).compare(keyArray1[i], keyArray2[i]); + if (c != 0) { + return c; + } + } + return 0; + } + } + private VectorTopNKeyDesc vectorDesc; private VectorizationContext vContext; @@ -46,6 +114,8 @@ // Batch processing private transient int[] temporarySelected; + private transient int topN; + private transient PriorityQueue priorityQueue; public VectorTopNKeyOperator(CompilationOpContext ctx, OperatorDesc conf, VectorizationContext vContext, VectorDesc vectorDesc) { @@ -70,17 +140,42 @@ public VectorTopNKeyOperator(CompilationOpContext ctx) { protected void initializeOp(Configuration hconf) throws HiveException { super.initializeOp(hconf); + this.topN = conf.getTopN(); + VectorExpression.doTransientInit(vectorDesc.getKeyExpressions(), hconf); for (VectorExpression keyExpression : vectorDesc.getKeyExpressions()) { keyExpression.init(hconf); } + StructObjectInspector inputStructObjInspector = (StructObjectInspector) inputObjInspectors[0]; + temporarySelected = new int [VectorizedRowBatch.DEFAULT_SIZE]; + + String columnSortOrder = conf.getColumnSortOrder(); + int numKeys = conf.getKeyColumns().size(); + List> comparators = new ArrayList<>(numKeys); + TypeInfo[] typeInfoArray = new TypeInfo[numKeys]; + int[] projectedColumns = new int[numKeys]; + for (int i = 0; i < numKeys; i++) { + StructField field = inputStructObjInspector.getStructFieldRef(conf.getKeyColumnNames().get(i)); + ObjectInspector fieldObjectInspector = field.getFieldObjectInspector(); + typeInfoArray[i] = TypeInfoUtils.getTypeInfoFromTypeString(fieldObjectInspector.getTypeName()); + projectedColumns[i] = vContext.getProjectedColumns().get(field.getFieldID()); + ObjectComparator comparator = new ObjectComparator( + fieldObjectInspector, + fieldObjectInspector, + MAXVALUE); + if (columnSortOrder.charAt(i) == '-') { + comparators.add(comparator); + } else { + comparators.add(comparator.reversed()); + } + } + + this.priorityQueue = new PriorityQueue<>(this.topN + 1, new RowWrapperComparator(comparators)); + vectorExtractRow = new VectorExtractRow(); - vectorExtractRow.init((StructObjectInspector) inputObjInspectors[0], - vContext.getProjectedColumns()); + vectorExtractRow.init(typeInfoArray, projectedColumns); extractedRow = new Object[vectorExtractRow.getCount()]; - - temporarySelected = new int [VectorizedRowBatch.DEFAULT_SIZE]; } @Override @@ -137,6 +232,19 @@ public void process(Object data, int tag) throws HiveException { batch.selectedInUse = selectedInUseBackup; } + private boolean canProcess(Object row, int tag) { + KeyWrapper keyWrapper = new KeyWrapper((Object[])row); + + if (!priorityQueue.contains(keyWrapper)) { + priorityQueue.offer(new KeyWrapper(keyWrapper)); + } + if (priorityQueue.size() > topN) { + priorityQueue.poll(); + } + + return priorityQueue.contains(keyWrapper); + } + @Override public VectorizationContext getInputVectorizationContext() { return vContext; @@ -154,4 +262,14 @@ public void setNextVectorBatchGroupStatus(boolean isLastGroupBatch) throws HiveE op.setNextVectorBatchGroupStatus(isLastGroupBatch); } } + + @Override + public String getName() { + return TopNKeyOperator.getOperatorName(); + } + + @Override + public OperatorType getType() { + return OperatorType.TOPNKEY; + } } diff --git ql/src/test/queries/clientpositive/vector_topnkey.q ql/src/test/queries/clientpositive/vector_topnkey.q index e1b7d26afe..017130e9da 100644 --- ql/src/test/queries/clientpositive/vector_topnkey.q +++ ql/src/test/queries/clientpositive/vector_topnkey.q @@ -23,6 +23,9 @@ explain vectorization detail SELECT key FROM src GROUP BY key ORDER BY key LIMIT 5; SELECT key FROM src GROUP BY key ORDER BY key LIMIT 5; +SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) LIMIT 5; +SELECT key FROM src GROUP BY key ORDER BY key desc LIMIT 5; +SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) desc LIMIT 5; explain vectorization detail SELECT src1.key, src2.value FROM src src1 JOIN src src2 ON (src1.key = src2.key) ORDER BY src1.key LIMIT 5; diff --git ql/src/test/results/clientpositive/llap/vector_topnkey.q.out ql/src/test/results/clientpositive/llap/vector_topnkey.q.out index d859270ff0..f3d9e751cb 100644 --- ql/src/test/results/clientpositive/llap/vector_topnkey.q.out +++ ql/src/test/results/clientpositive/llap/vector_topnkey.q.out @@ -389,6 +389,45 @@ POSTHOOK: Input: default@src 100 103 104 +PREHOOK: query: SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) LIMIT 5 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) LIMIT 5 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 +2 +4 +5 +8 +PREHOOK: query: SELECT key FROM src GROUP BY key ORDER BY key desc LIMIT 5 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT key FROM src GROUP BY key ORDER BY key desc LIMIT 5 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +98 +97 +96 +95 +92 +PREHOOK: query: SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) desc LIMIT 5 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: SELECT cast(key as int) FROM src GROUP BY cast(key as int) ORDER BY cast(key as int) desc LIMIT 5 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +498 +497 +496 +495 +494 PREHOOK: query: explain vectorization detail SELECT src1.key, src2.value FROM src src1 JOIN src src2 ON (src1.key = src2.key) ORDER BY src1.key LIMIT 5 PREHOOK: type: QUERY diff --git serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectComparator.java serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectComparator.java new file mode 100644 index 0000000000..9fb7787118 --- /dev/null +++ serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/ObjectComparator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.serde2.objectinspector; + +import java.util.Comparator; + +import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.NullValueOption; + +/** + * This class wraps the ObjectInspectorUtils.compare method and implements java.util.Comparator. + */ +public class ObjectComparator implements Comparator { + + private final ObjectInspector objectInspector1; + private final ObjectInspector objectInspector2; + private final NullValueOption nullSortOrder; + private final MapEqualComparer mapEqualComparer = new FullMapEqualComparer(); + + public ObjectComparator(ObjectInspector objectInspector1, ObjectInspector objectInspector2, + NullValueOption nullSortOrder) { + this.objectInspector1 = objectInspector1; + this.objectInspector2 = objectInspector2; + this.nullSortOrder = nullSortOrder; + } + + @Override + public int compare(Object o1, Object o2) { + return ObjectInspectorUtils.compare(o1, objectInspector1, o2, objectInspector2, mapEqualComparer, nullSortOrder); + } +}