diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java index cf0b9f0..b3eac2d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.exec.tez.TezContext; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @@ -78,6 +79,8 @@ transient List otherKey = null; transient List values = null; transient RecordSource[] sources; + transient WritableComparator[][] keyComparators; + transient List> originalParents = new ArrayList>(); @@ -105,6 +108,11 @@ public CommonMergeJoinOperator() { nextKeyWritables = new ArrayList[maxAlias]; fetchDone = new boolean[maxAlias]; foundNextKeyGroup = new boolean[maxAlias]; + keyComparators = new WritableComparator[maxAlias][]; + + for (Entry> entry : conf.getKeys().entrySet()) { + keyComparators[entry.getKey().intValue()] = new WritableComparator[entry.getValue().size()]; + } int bucketSize; @@ -279,7 +287,7 @@ private void putDummyOrEmpty(Byte i) { result[pos] = -1; continue; } - result[pos] = compareKeys(key, smallestOne); + result[pos] = compareKeys(pos, key, smallestOne); if (result[pos] < 0) { smallestOne = key; } @@ -411,14 +419,16 @@ private void promoteNextGroupToCandidate(Byte t) throws HiveException { this.nextGroupStorage[t] = oldRowContainer; } + @SuppressWarnings("rawtypes") private boolean processKey(byte alias, List key) throws HiveException { List keyWritable = keyWritables[alias]; if (keyWritable == null) { // the first group. keyWritables[alias] = key; + keyComparators[alias] = new WritableComparator[key.size()]; return false; } else { - int cmp = compareKeys(key, keyWritable); + int cmp = compareKeys(alias, key, keyWritable); if (cmp != 0) { nextKeyWritables[alias] = key; return true; @@ -428,30 +438,42 @@ private boolean processKey(byte alias, List key) throws HiveException { } @SuppressWarnings("rawtypes") - private int compareKeys(List k1, List k2) { - int ret = 0; + private int compareKeys(byte alias, List k1, List k2) { + final WritableComparator[] comparators = keyComparators[alias]; // join keys have difference sizes? - ret = k1.size() - k2.size(); - if (ret != 0) { - return ret; + if (k1.size() != k2.size()) { + return k1.size() - k2.size(); + } + + if (comparators.length == 0) { + // cross-product - no keys really + return 0; } - for (int i = 0; i < k1.size(); i++) { + if (comparators.length > 1) { + // rare case + return compareKeysMany(comparators, k1, k2); + } else { + return compareKey(comparators, 0, + (WritableComparable) k1.get(0), + (WritableComparable) k2.get(0), + nullsafes != null ? nullsafes[0]: false); + } + } + + @SuppressWarnings("rawtypes") + private int compareKeysMany(WritableComparator[] comparators, + final List k1, + final List k2) { + // invariant: k1.size == k2.size + int ret = 0; + final int size = k1.size(); + for (int i = 0; i < size; i++) { WritableComparable key_1 = (WritableComparable) k1.get(i); WritableComparable key_2 = (WritableComparable) k2.get(i); - if (key_1 == null && key_2 == null) { - if (nullsafes != null && nullsafes[i]) { - continue; - } else { - return -1; - } - } else if (key_1 == null) { - return -1; - } else if (key_2 == null) { - return 1; - } - ret = WritableComparator.get(key_1.getClass()).compare(key_1, key_2); + ret = compareKey(comparators, i, key_1, key_2, + nullsafes != null ? nullsafes[i] : false); if (ret != 0) { return ret; } @@ -459,6 +481,30 @@ private int compareKeys(List k1, List k2) { return ret; } + @SuppressWarnings("rawtypes") + private int compareKey(final WritableComparator comparators[], final int pos, + final WritableComparable key_1, + final WritableComparable key_2, + final boolean nullsafe) { + + if (key_1 == null && key_2 == null) { + if (nullsafe) { + return 0; + } else { + return -1; + } + } else if (key_1 == null) { + return -1; + } else if (key_2 == null) { + return 1; + } + + if (comparators[pos] == null) { + comparators[pos] = WritableComparator.get(key_1.getClass()); + } + return comparators[pos].compare(key_1, key_2); + } + @SuppressWarnings("unchecked") private List mergeJoinComputeKeys(Object row, Byte alias) throws HiveException { if ((joinKeysObjectInspectors != null) && (joinKeysObjectInspectors[alias] != null)) {