diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SetSparkReducerParallelism.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SetSparkReducerParallelism.java index e808a4f..ea444c4 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SetSparkReducerParallelism.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SetSparkReducerParallelism.java @@ -18,10 +18,13 @@ package org.apache.hadoop.hive.ql.optimizer.spark; +import java.util.Collection; +import java.util.EnumSet; import java.util.List; import java.util.Set; import java.util.Stack; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.common.ObjectPair; @@ -50,6 +53,8 @@ import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc; import org.apache.hadoop.hive.ql.stats.StatsUtils; +import static org.apache.hadoop.hive.ql.plan.ReduceSinkDesc.ReducerTraits.UNIFORM; + /** * SetSparkReducerParallelism determines how many reducers should * be run for a given reduce sink, clone from SetReducerParallelism. @@ -120,41 +125,64 @@ public Object process(Node nd, Stack stack, } } - long numberOfBytes = 0; - - if (useOpStats) { - // we need to add up all the estimates from the siblings of this reduce sink - for (Operator sibling - : sink.getChildOperators().get(0).getParentOperators()) { - if (sibling.getStatistics() != null) { - numberOfBytes = StatsUtils.safeAdd(numberOfBytes, sibling.getStatistics().getDataSize()); - if (LOG.isDebugEnabled()) { - LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics()); - } - } else { - LOG.warn("No stats available from: " + sibling); - } - } - } else if (parentSinks.isEmpty()) { - // Not using OP stats and this is the first sink in the path, meaning that - // we should use TS stats to infer parallelism - for (Operator sibling - : sink.getChildOperators().get(0).getParentOperators()) { - Set sources = - OperatorUtils.findOperatorsUpstream(sibling, TableScanOperator.class); - for (TableScanOperator source : sources) { - if (source.getStatistics() != null) { - numberOfBytes = StatsUtils.safeAdd(numberOfBytes, source.getStatistics().getDataSize()); + if (useOpStats || parentSinks.isEmpty()) { + long numberOfBytes = 0; + if (useOpStats) { + // we need to add up all the estimates from the siblings of this reduce sink + for (Operator sibling + : sink.getChildOperators().get(0).getParentOperators()) { + if (sibling.getStatistics() != null) { + numberOfBytes = StatsUtils.safeAdd(numberOfBytes, sibling.getStatistics().getDataSize()); if (LOG.isDebugEnabled()) { - LOG.debug("Table source " + source + " has stats: " + source.getStatistics()); + LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics()); } } else { - LOG.warn("No stats available from table source: " + source); + LOG.warn("No stats available from: " + sibling); } } + } else { + // Not using OP stats and this is the first sink in the path, meaning that + // we should use TS stats to infer parallelism + for (Operator sibling + : sink.getChildOperators().get(0).getParentOperators()) { + Set sources = + OperatorUtils.findOperatorsUpstream(sibling, TableScanOperator.class); + for (TableScanOperator source : sources) { + if (source.getStatistics() != null) { + numberOfBytes = StatsUtils.safeAdd(numberOfBytes, source.getStatistics().getDataSize()); + if (LOG.isDebugEnabled()) { + LOG.debug("Table source " + source + " has stats: " + source.getStatistics()); + } + } else { + LOG.warn("No stats available from table source: " + source); + } + } + } + LOG.debug("Gathered stats for sink " + sink + ". Total size is " + + numberOfBytes + " bytes."); + } + + // Divide it by 2 so that we can have more reducers + long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2; + int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer, + maxReducers, false); + + getSparkMemoryAndCores(context); + if (sparkMemoryAndCores != null && + sparkMemoryAndCores.getFirst() > 0 && sparkMemoryAndCores.getSecond() > 0) { + // warn the user if bytes per reducer is much larger than memory per task + if ((double) sparkMemoryAndCores.getFirst() / bytesPerReducer < 0.5) { + LOG.warn("Average load of a reducer is much larger than its available memory. " + + "Consider decreasing hive.exec.reducers.bytes.per.reducer"); + } + + // If there are more cores, use the number of cores + numReducers = Math.max(numReducers, sparkMemoryAndCores.getSecond()); } - LOG.debug("Gathered stats for sink " + sink + ". Total size is " - + numberOfBytes + " bytes."); + numReducers = Math.min(numReducers, maxReducers); + LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers + + " (calculated)"); + desc.setNumReducers(numReducers); } else { // Use the maximum parallelism from all parent reduce sinks int numberOfReducers = 0; @@ -164,30 +192,14 @@ public Object process(Node nd, Stack stack, desc.setNumReducers(numberOfReducers); LOG.debug("Set parallelism for sink " + sink + " to " + numberOfReducers + " based on its parents"); - return false; - } - - // Divide it by 2 so that we can have more reducers - long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2; - int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer, - maxReducers, false); - - getSparkMemoryAndCores(context); - if (sparkMemoryAndCores != null && - sparkMemoryAndCores.getFirst() > 0 && sparkMemoryAndCores.getSecond() > 0) { - // warn the user if bytes per reducer is much larger than memory per task - if ((double) sparkMemoryAndCores.getFirst() / bytesPerReducer < 0.5) { - LOG.warn("Average load of a reducer is much larger than its available memory. " + - "Consider decreasing hive.exec.reducers.bytes.per.reducer"); - } - - // If there are more cores, use the number of cores - numReducers = Math.max(numReducers, sparkMemoryAndCores.getSecond()); } - numReducers = Math.min(numReducers, maxReducers); - LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers + - " (calculated)"); - desc.setNumReducers(numReducers); + } + final Collection keyCols = + ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getKeyCols()); + final Collection partCols = + ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getPartitionCols()); + if (keyCols != null && keyCols.equals(partCols)) { + desc.setReducerTraits(EnumSet.of(UNIFORM)); } } else { LOG.info("Number of reducers for sink " + sink + " was already determined to be: " + desc.getNumReducers());