diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java index 83d54bd..86be4fa 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/spark/SparkReduceSinkMapJoinProc.java @@ -20,12 +20,15 @@ import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Stack; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.GroupByOperator; import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; @@ -33,9 +36,15 @@ import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.RowSchema; import org.apache.hadoop.hive.ql.exec.SparkHashTableSinkOperator; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.parse.spark.GenSparkProcContext; import org.apache.hadoop.hive.ql.plan.BaseWork; @@ -54,8 +63,38 @@ public class SparkReduceSinkMapJoinProc implements NodeProcessor { + + public static class SparkMapJoinFollowedByGroupByProcessor implements NodeProcessor { + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + GenSparkProcContext context = (GenSparkProcContext) procCtx; + if (nd.getName().equals("GBY")) { + context.followedByGroupBy = true; + GroupByOperator op = (GroupByOperator) nd; + float groupByMemoryUsage = context.conf.getFloatVar( + HiveConf.ConfVars.HIVEMAPJOINFOLLOWEDBYMAPAGGRHASHMEMORY); + op.getConf().setGroupByMemoryUsage(groupByMemoryUsage); + } + return null; + } + } + protected transient Log LOG = LogFactory.getLog(this.getClass().getName()); + private void checkGroupBy(Operator mapjoinOp, + GenSparkProcContext context) throws SemanticException { + List> childOps = mapjoinOp.getChildOperators(); + Map rules = new LinkedHashMap(); + rules.put(new RuleRegExp("GBY", GroupByOperator.getOperatorName() + "%"), + new SparkMapJoinFollowedByGroupByProcessor()); + Dispatcher disp = new DefaultRuleDispatcher(null, rules, context); + GraphWalker ogw = new DefaultGraphWalker(disp); + ArrayList topNodes = new ArrayList(); + topNodes.addAll(childOps); + ogw.startWalking(topNodes, null); + } + /* (non-Javadoc) * This processor addresses the RS-MJ case that occurs in spark on the small/hash * table side of things. The work that RS will be a part of must be connected @@ -79,6 +118,7 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, context.preceedingWork = null; context.currentRootOperator = null; + ReduceSinkOperator parentRS = (ReduceSinkOperator)stack.get(stack.size() - 2); // remove the tag for in-memory side of mapjoin parentRS.getConf().setSkipTag(true); @@ -228,21 +268,48 @@ public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, } context.linkChildOpWithDummyOp.put(mapJoinOp, dummyOperators); - - //replace ReduceSinkOp with HashTableSinkOp for the RSops which are parents of MJop + // replace ReduceSinkOp with HashTableSinkOp for the RSops which are parents of MJop MapJoinDesc mjDesc = mapJoinOp.getConf(); + HiveConf conf = context.conf; + + // check whether this MJ operator is followed by any GBY operator + context.followedByGroupBy = false; + checkGroupBy(mapJoinOp, context); + + float hashtableMemoryUsage; + if (context.followedByGroupBy) { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEFOLLOWBYGBYMAXMEMORYUSAGE); + } else { + hashtableMemoryUsage = conf.getFloatVar( + HiveConf.ConfVars.HIVEHASHTABLEMAXMEMORYUSAGE); + } + mjDesc.setHashTableMemoryUsage(hashtableMemoryUsage); SparkHashTableSinkDesc hashTableSinkDesc = new SparkHashTableSinkDesc(mjDesc); SparkHashTableSinkOperator hashTableSinkOp = (SparkHashTableSinkOperator) OperatorFactory.get(hashTableSinkDesc); + byte tag = (byte) pos; + int[] valueIndex = mjDesc.getValueIndex(tag); + if (valueIndex != null) { + List newValues = new ArrayList(); + List values = hashTableSinkDesc.getExprs().get(tag); + for (int index = 0; index < values.size(); index++) { + if (valueIndex[index] < 0) { + newValues.add(values.get(index)); + } + } + hashTableSinkDesc.getExprs().put(tag, newValues); + } + //get all parents of reduce sink List> RSparentOps = parentRS.getParentOperators(); for (Operator parent : RSparentOps) { parent.replaceChild(parentRS, hashTableSinkOp); } hashTableSinkOp.setParentOperators(RSparentOps); - hashTableSinkOp.setTag((byte)pos); + hashTableSinkOp.setTag(tag); return true; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java index 8290568..9b57a1d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkProcContext.java @@ -138,6 +138,10 @@ // This is necessary as sometimes semantic analyzer's mapping is different than operator's own alias. public final Map> topOps; + // Used in MapJoin, to check whether a MapJoinOperator is followed by GroupByOperator + // This is used to determine hash table memory usage. + public boolean followedByGroupBy; + @SuppressWarnings("unchecked") public GenSparkProcContext(HiveConf conf, ParseContext parseContext, List> moveTask, List> rootTasks, @@ -170,5 +174,6 @@ public GenSparkProcContext(HiveConf conf, ParseContext parseContext, this.clonedReduceSinks = new LinkedHashSet(); this.fileSinkSet = new LinkedHashSet(); this.connectedReduceSinks = new LinkedHashSet(); + this.followedByGroupBy = false; } }