diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java index 83d54bd..d8e98f8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java @@ -20,12 +20,15 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Stack; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; @@ -33,9 +36,15 @@ import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.spark.GenSparkProcContext; import org.apache.hadoop.hive.ql.plan.BaseWork; @@ -54,8 +63,42 @@ public class SparkReduceSinkMapJoinProc implements NodeProcessor { + public static class SparkMapJoinFollowedByGroupByProcessor implements NodeProcessor { + private boolean hasGroupBy = false; + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + GenSparkProcContext context = (GenSparkProcContext) procCtx; + hasGroupBy = true; + GroupByOperator op = (GroupByOperator) nd; + float groupByMemoryUsage = context.conf.getFloatVar( + HiveConf.ConfVars.HIVEMAPJOINFOLLOWEDBYMAPAGGRHASHMEMORY); + op.getConf().setGroupByMemoryUsage(groupByMemoryUsage); + return null; + } + + public boolean getHasGroupBy() { + return hasGroupBy; + } + } + protected transient Log LOG = LogFactory.getLog(this.getClass().getName()); + private boolean hasGroupBy(Operator mapjoinOp, + GenSparkProcContext context) throws SemanticException { + List> childOps = mapjoinOp.getChildOperators(); + Map rules = new LinkedHashMap(); + SparkMapJoinFollowedByGroupByProcessor processor = new SparkMapJoinFollowedByGroupByProcessor(); + rules.put(new RuleRegExp("GBY", GroupByOperator.getOperatorName() + "%"), processor); + Dispatcher disp = new DefaultRuleDispatcher(null, rules, context); + GraphWalker ogw = new DefaultGraphWalker(disp); + ArrayList topNodes = new ArrayList(); + topNodes.addAll(childOps); + ogw.startWalking(topNodes, null); + return processor.getHasGroupBy(); + } + /* (non-Javadoc) * This processor addresses the RS-MJ case that occurs in spark on the small/hash * table side of things. The work that RS will be a part of must be connected @@ -66,7 +109,8 @@ * or reduce work. */ @Override - public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, Object... nodeOutputs) + public Object process(Node nd, Stack stack, + NodeProcessorCtx procContext, Object... nodeOutputs) throws SemanticException { GenSparkProcContext context = (GenSparkProcContext) procContext; MapJoinOperator mapJoinOp = (MapJoinOperator)nd; @@ -89,7 +133,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, context.mapJoinParentMap.put(mapJoinOp, parents); } - List mapJoinWork = null; + List mapJoinWork; /* * if there was a pre-existing work generated for the big-table mapjoin side, @@ -120,8 +164,8 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, LOG.debug("Mapjoin "+mapJoinOp+", pos: "+pos+" --> "+parentWork.getName()); mapJoinOp.getConf().getParentToInput().put(pos, parentWork.getName()); - int numBuckets = -1; -/* EdgeType edgeType = EdgeType.BROADCAST_EDGE; +/* int numBuckets = -1; + EdgeType edgeType = EdgeType.BROADCAST_EDGE; if (mapJoinOp.getConf().isBucketMapJoin()) { // disable auto parallelism for bucket map joins @@ -143,7 +187,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, LOG.debug("connecting "+parentWork.getName()+" with "+myWork.getName()); sparkWork.connect(parentWork, myWork, edgeProp); - ReduceSinkOperator r = null; + ReduceSinkOperator r; if (parentRS.getConf().getOutputName() != null) { LOG.debug("Cloning reduce sink for multi-child broadcast edge"); // we've already set this one up. Need to clone for the next work. @@ -228,21 +272,44 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, } context.linkChildOpWithDummyOp.put(mapJoinOp, dummyOperators); - - //replace ReduceSinkOp with HashTableSinkOp for the RSops which are parents of MJop + // replace ReduceSinkOp with HashTableSinkOp for the RSops which are parents of MJop MapJoinDesc mjDesc = mapJoinOp.getConf(); + HiveConf conf = context.conf; + + float hashtableMemoryUsage; + if (hasGroupBy(mapJoinOp, context)) { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEFOLLOWBYGBYMAXMEMORYUSAGE); + } else { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE); + } + mjDesc.setHashTableMemoryUsage(hashtableMemoryUsage); SparkHashTableSinkDesc hashTableSinkDesc = new SparkHashTableSinkDesc(mjDesc); SparkHashTableSinkOperator hashTableSinkOp = (SparkHashTableSinkOperator) OperatorFactory.get(hashTableSinkDesc); + byte tag = (byte) pos; + int[] valueIndex = mjDesc.getValueIndex(tag); + if (valueIndex != null) { + List newValues = new ArrayList(); + List values = hashTableSinkDesc.getExprs().get(tag); + for (int index = 0; index < values.size(); index++) { + if (valueIndex[index] < 0) { + newValues.add(values.get(index)); + } + } + hashTableSinkDesc.getExprs().put(tag, newValues); + } + //get all parents of reduce sink List> RSparentOps = parentRS.getParentOperators(); for (Operator parent : RSparentOps) { parent.replaceChild(parentRS, hashTableSinkOp); } hashTableSinkOp.setParentOperators(RSparentOps); - hashTableSinkOp.setTag((byte)pos); + hashTableSinkOp.setTag(tag); return true; } }