diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java index cf0b9f0..5384c28 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -35,6 +36,7 @@ import org.apache.hadoop.hive.ql.exec.tez.TezContext; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @@ -78,6 +80,8 @@ transient List otherKey = null; transient List values = null; transient RecordSource[] sources; + transient WritableComparator[][] keyComparators; + transient List> originalParents = new ArrayList>(); @@ -105,6 +109,11 @@ public CommonMergeJoinOperator() { nextKeyWritables = new ArrayList[maxAlias]; fetchDone = new boolean[maxAlias]; foundNextKeyGroup = new boolean[maxAlias]; + keyComparators = new WritableComparator[maxAlias][]; + + for (Entry> entry : conf.getKeys().entrySet()) { + keyComparators[entry.getKey().intValue()] = new WritableComparator[entry.getValue().size()]; + } int bucketSize; @@ -279,7 +288,7 @@ private void putDummyOrEmpty(Byte i) { result[pos] = -1; continue; } - result[pos] = compareKeys(key, smallestOne); + result[pos] = compareKeys(pos, key, smallestOne); if (result[pos] < 0) { smallestOne = key; } @@ -411,14 +420,21 @@ private void promoteNextGroupToCandidate(Byte t) throws HiveException { this.nextGroupStorage[t] = oldRowContainer; } + @SuppressWarnings("rawtypes") private boolean processKey(byte alias, List key) throws HiveException { List keyWritable = keyWritables[alias]; if (keyWritable == null) { // the first group. keyWritables[alias] = key; + keyComparators[alias] = new WritableComparator[key.size()]; + for (int i = 0; i < key.size(); i++) { + WritableComparable k = (WritableComparable) key.get(i); + keyComparators[alias][i] = WritableComparator.get(k.getClass()); + } return false; } else { - int cmp = compareKeys(key, keyWritable); + // TODO: equality can be checked faster than comparisons by sorting comparators in order of speed (int > string) + int cmp = compareKeys(alias, key, keyWritable); if (cmp != 0) { nextKeyWritables[alias] = key; return true; @@ -428,30 +444,42 @@ private boolean processKey(byte alias, List key) throws HiveException { } @SuppressWarnings("rawtypes") - private int compareKeys(List k1, List k2) { - int ret = 0; + private int compareKeys(byte alias, List k1, List k2) { + final WritableComparator[] comparators = keyComparators[alias]; // join keys have difference sizes? - ret = k1.size() - k2.size(); - if (ret != 0) { - return ret; + if (k1.size() != k2.size()) { + return k1.size() - k2.size(); } - for (int i = 0; i < k1.size(); i++) { + if (comparators.length == 0) { + // cross-product - no keys really + return 0; + } + + if (comparators.length > 1) { + // rare case + return compareKeysMany(comparators, k1, k2); + } else { + return compareKey(comparators[0], + (WritableComparable) k1.get(0), + (WritableComparable) k2.get(0), + nullsafes != null ? nullsafes[0]: false); + } + } + + @SuppressWarnings("rawtypes") + private int compareKeysMany(WritableComparator[] comparators, + final List k1, + final List k2) { + // invariant: k1.size == k2.size + int ret = 0; + final int size = k1.size(); + for (int i = 0; i < size; i++) { WritableComparable key_1 = (WritableComparable) k1.get(i); WritableComparable key_2 = (WritableComparable) k2.get(i); - if (key_1 == null && key_2 == null) { - if (nullsafes != null && nullsafes[i]) { - continue; - } else { - return -1; - } - } else if (key_1 == null) { - return -1; - } else if (key_2 == null) { - return 1; - } - ret = WritableComparator.get(key_1.getClass()).compare(key_1, key_2); + ret = compareKey(comparators[i], key_1, key_2, + nullsafes != null ? nullsafes[i] : false); if (ret != 0) { return ret; } @@ -459,6 +487,27 @@ private int compareKeys(List k1, List k2) { return ret; } + @SuppressWarnings("rawtypes") + private int compareKey(WritableComparator comparator, + final WritableComparable key_1, + final WritableComparable key_2, + final boolean nullsafe) { + + if (key_1 == null && key_2 == null) { + if (nullsafe) { + return 0; + } else { + return -1; + } + } else if (key_1 == null) { + return -1; + } else if (key_2 == null) { + return 1; + } + + return comparator.compare(key_1, key_2); + } + @SuppressWarnings("unchecked") private List mergeJoinComputeKeys(Object row, Byte alias) throws HiveException { if ((joinKeysObjectInspectors != null) && (joinKeysObjectInspectors[alias] != null)) {