diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java index 22fb7f1..68013c5 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java @@ -21,6 +21,7 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -35,6 +36,7 @@ import org.apache.hadoop.hive.ql.exec.tez.TezContext; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @@ -78,6 +80,7 @@ transient List otherKey = null; transient List values = null; transient RecordSource[] sources; + transient Map comparators; transient List> originalParents = new ArrayList>(); @@ -106,6 +109,11 @@ public CommonMergeJoinOperator() { fetchDone = new boolean[maxAlias]; foundNextKeyGroup = new boolean[maxAlias]; + comparators = new HashMap(); + for (Entry> entry : conf.getKeys().entrySet()) { + comparators.put(entry.getKey(), new WritableComparator[entry.getValue().size()]); + } + int bucketSize; int oldVar = HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEMAPJOINBUCKETCACHESIZE); @@ -279,7 +287,7 @@ private void putDummyOrEmpty(Byte i) { result[pos] = -1; continue; } - result[pos] = compareKeys(key, smallestOne); + result[pos] = compareKeys(pos, key, smallestOne); if (result[pos] < 0) { smallestOne = key; } @@ -418,7 +426,7 @@ private boolean processKey(byte alias, List key) throws HiveException { keyWritables[alias] = key; return false; } else { - int cmp = compareKeys(key, keyWritable); + int cmp = compareKeys(alias, key, keyWritable); if (cmp != 0) { nextKeyWritables[alias] = key; return true; @@ -428,8 +436,9 @@ private boolean processKey(byte alias, List key) throws HiveException { } @SuppressWarnings("rawtypes") - private int compareKeys(List k1, List k2) { + private int compareKeys(byte alias, List k1, List k2) { int ret = 0; + WritableComparator[] comps = comparators.get(alias); // join keys have difference sizes? ret = k1.size() - k2.size(); @@ -451,7 +460,12 @@ private int compareKeys(List k1, List k2) { } else if (key_2 == null) { return 1; } - ret = WritableComparator.get(key_1.getClass()).compare(key_1, key_2); + + if (comps[i] == null) { + comps[i] = WritableComparator.get(key_1.getClass()); + } + + ret = comps[i].compare(key_1, key_2); if (ret != 0) { return ret; }