diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HashTableLoader.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HashTableLoader.java index 536b92c..562867e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HashTableLoader.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HashTableLoader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import org.apache.commons.logging.Log; @@ -87,29 +88,33 @@ public void load(MapJoinTableContainer[] mapJoinTables, HybridHashTableConf nwayConf = null; long totalSize = 0; int biggest = 0; // position of the biggest small table + Map tableMemorySizes = null; if (useHybridGraceHashJoin && mapJoinTables.length > 2) { // Create a Conf for n-way HybridHashTableContainers nwayConf = new HybridHashTableConf(); // Find the biggest small table; also calculate total data size of all small tables - long maxSize = 0; // the size of the biggest small table + long maxSize = Long.MIN_VALUE; // the size of the biggest small table for (int pos = 0; pos < mapJoinTables.length; pos++) { if (pos == desc.getPosBigTable()) { continue; } - totalSize += desc.getParentDataSizes().get(pos); - biggest = desc.getParentDataSizes().get(pos) > maxSize ? pos : biggest; - maxSize = desc.getParentDataSizes().get(pos) > maxSize ? desc.getParentDataSizes().get(pos) - : maxSize; + long smallTableSize = desc.getParentDataSizes().get(pos); + totalSize += smallTableSize; + if (maxSize < smallTableSize) { + maxSize = smallTableSize; + biggest = pos; + } } + tableMemorySizes = divideHybridHashTableMemory(mapJoinTables, desc, + totalSize, noConditionalTaskThreshold); // Using biggest small table, calculate number of partitions to create for each small table - float percentage = (float) maxSize / totalSize; - long memory = (long) (noConditionalTaskThreshold * percentage); + long memory = tableMemorySizes.get(biggest); int numPartitions = 0; try { numPartitions = HybridHashTableContainer.calcNumPartitions(memory, - desc.getParentDataSizes().get(biggest), + maxSize, HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINNUMPARTITIONS), HiveConf.getIntVar(hconf, HiveConf.ConfVars.HIVEHYBRIDGRACEHASHJOINMINWBSIZE), nwayConf); @@ -159,9 +164,7 @@ public void load(MapJoinTableContainer[] mapJoinTables, long memory = 0; if (useHybridGraceHashJoin) { if (mapJoinTables.length > 2) { - // Allocate n-way join memory proportionally - float percentage = (float) desc.getParentDataSizes().get(pos) / totalSize; - memory = (long) (noConditionalTaskThreshold * percentage); + memory = tableMemorySizes.get(pos); } else { // binary join memory = noConditionalTaskThreshold; } @@ -186,6 +189,45 @@ public void load(MapJoinTableContainer[] mapJoinTables, } } + private static Map divideHybridHashTableMemory( + MapJoinTableContainer[] mapJoinTables, MapJoinDesc desc, + long totalSize, long totalHashTableMemory) { + int smallTableCount = Math.max(mapJoinTables.length - 1, 1); + Map tableMemorySizes = new HashMap(); + // If any table has bad size estimate, we need to fall back to sizing each table equally + boolean fallbackToEqualProportions = totalSize <= 0; + + if (!fallbackToEqualProportions) { + for (Map.Entry tableSizeEntry : desc.getParentDataSizes().entrySet()) { + if (tableSizeEntry.getKey() == desc.getPosBigTable()) { + continue; + } + + long tableSize = tableSizeEntry.getValue(); + if (tableSize <= 0) { + fallbackToEqualProportions = true; + break; + } + float percentage = (float) tableSize / totalSize; + long tableMemory = (long) (totalHashTableMemory * percentage); + tableMemorySizes.put(tableSizeEntry.getKey(), tableMemory); + } + } + + if (fallbackToEqualProportions) { + // Just give each table the same amount of memory. + long equalPortion = totalHashTableMemory / smallTableCount; + for (Integer pos : desc.getParentDataSizes().keySet()) { + if (pos == desc.getPosBigTable()) { + break; + } + tableMemorySizes.put(pos, equalPortion); + } + } + + return tableMemorySizes; + } + private String describeOi(String desc, ObjectInspector keyOi) { for (StructField field : ((StructObjectInspector)keyOi).getAllStructFieldRefs()) { ObjectInspector oi = field.getFieldObjectInspector();