diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java index dd3533c..0d13cc8 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java @@ -46,6 +46,7 @@ import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.ReduceWork; +import org.apache.hadoop.hive.ql.plan.TezWork.EdgeType; import org.apache.hadoop.hive.shims.Hadoop20Shims.NullOutputCommitter; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.io.BytesWritable; @@ -147,7 +148,8 @@ private static JobConf initializeVertexConf(JobConf baseConf, MapWork mapWork) { * @param w The second vertex (sink) * @return */ - public static Edge createEdge(JobConf vConf, Vertex v, JobConf wConf, Vertex w) + public static Edge createEdge(JobConf vConf, Vertex v, JobConf wConf, Vertex w, + EdgeType edgeType) throws IOException { // Tez needs to setup output subsequent input pairs correctly @@ -157,9 +159,23 @@ public static Edge createEdge(JobConf vConf, Vertex v, JobConf wConf, Vertex w) v.getProcessorDescriptor().setUserPayload(MRHelpers.createUserPayloadFromConf(vConf)); w.getProcessorDescriptor().setUserPayload(MRHelpers.createUserPayloadFromConf(wConf)); - // all edges are of the same type right now + DataMovementType dataMovementType; + switch (edgeType) { + case BROADCAST_EDGE: + dataMovementType = DataMovementType.BROADCAST; + break; + + case SIMPLE_EDGE: + dataMovementType = DataMovementType.SCATTER_GATHER; + break; + + default: + dataMovementType = DataMovementType.SCATTER_GATHER; + break; + } + EdgeProperty edgeProperty = - new EdgeProperty(DataMovementType.SCATTER_GATHER, + new EdgeProperty(dataMovementType, DataSourceType.PERSISTED, SchedulingType.SEQUENTIAL, new OutputDescriptor(OnFileSortedOutput.class.getName()), diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezTask.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezTask.java index 5cfe755..ed78589 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezTask.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezTask.java @@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.TezWork; +import org.apache.hadoop.hive.ql.plan.TezWork.EdgeType; import org.apache.hadoop.hive.ql.plan.api.StageType; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hadoop.mapred.JobConf; @@ -172,7 +173,13 @@ private DAG build(JobConf conf, TezWork work, Path scratchDir, // add all dependencies (i.e.: edges) to the graph for (BaseWork v: work.getChildren(w)) { assert workToVertex.containsKey(v); - Edge e = DagUtils.createEdge(wxConf, wx, workToConf.get(v), workToVertex.get(v)); + Edge e = null; + EdgeType edgeType = EdgeType.SIMPLE_EDGE; + if (work.isBroadCastEdge(w, v)) { + edgeType = EdgeType.BROADCAST_EDGE; + } + + e = DagUtils.createEdge(wxConf, wx, workToConf.get(v), workToVertex.get(v), edgeType); dag.addEdge(e); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java index dc2188f..5f4a024 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -32,7 +32,6 @@ import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor; import org.apache.hadoop.hive.ql.parse.OptimizeTezProcContext; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.SemanticException; diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java new file mode 100644 index 0000000..6a83d37 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ReduceSinkMapJoinProc.java @@ -0,0 +1,85 @@ +package org.apache.hadoop.hive.ql.optimizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Stack; + +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; +import org.apache.hadoop.hive.ql.exec.Operator; +import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.parse.GenTezProcContext; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.ql.plan.TezWork; +import org.apache.hadoop.hive.ql.plan.TezWork.EdgeType; + +public class ReduceSinkMapJoinProc implements NodeProcessor { + + /* (non-Javadoc) + * This processor addresses the RS-MJ case that occurs in tez on the small/hash + * table side of things. The connection between the work that RS will be a part of + * must be connected to the MJ work via be a broadcast edge. + * We should not walk down the tree when we encounter this pattern because: + * the type of work (map work or reduce work) needs to be determined + * on the basis of the big table side because it may be a mapwork (no need for shuffle) + * or reduce work. + */ + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procContext, Object... nodeOutputs) + throws SemanticException { + GenTezProcContext context = (GenTezProcContext) procContext; + context.preceedingWork = null; + context.currentRootOperator = null; + + MapJoinOperator mapJoinOp = (MapJoinOperator)nd; + + OperatorchildOp = mapJoinOp.getChildOperators().get(0); + ReduceSinkOperator parentRS = (ReduceSinkOperator)stack.get(stack.size() - 2); + while (childOp != null) { + if ((childOp instanceof ReduceSinkOperator) || (childOp instanceof FileSinkOperator)) { + /* + * if there was a pre-existing work generated for the big-table mapjoin side, + * we need to hook the work generated for the RS (associated with the RS-MJ pattern) + * with the pre-existing work. + * + * Otherwise, we need to associate that the reduce sink/file sink down the MJ path + * to be linked to the RS work (associated with the RS-MJ pattern). + * + */ + + BaseWork myWork = context.operatorWorkMap.get(childOp); + BaseWork parentWork = context.operatorWorkMap.get(parentRS); + if (myWork != null) { + // link the work with the work associated with the reduce sink that triggered this rule + TezWork tezWork = context.currentTask.getWork(); + tezWork.connect(parentWork, myWork, EdgeType.BROADCAST_EDGE); + } else { + List linkWorkList = context.linkOpWithWorkMap.get(childOp); + if (linkWorkList == null) { + linkWorkList = new ArrayList(); + } + linkWorkList.add(parentWork); + context.linkOpWithWorkMap.put(childOp, linkWorkList); + } + + break; + } + + if ((childOp.getChildOperators() != null) && (childOp.getChildOperators().size() >= 1)) { + childOp = childOp.getChildOperators().get(0); + } else { + break; + } + } + + // cut the operator tree so as to not retain connections from the parent RS downstream + parentRS.removeChild(mapJoinOp); + return true; + } + +} diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezProcContext.java ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezProcContext.java index 827637a..a53bd5a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezProcContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezProcContext.java @@ -82,6 +82,13 @@ // that follows it. This is used for connecting them later. public final Map, BaseWork> leafOperatorToFollowingWork; + // a map that keeps track of work that need to be linked while + // traversing an operator tree + public final Map, List> linkOpWithWorkMap; + + // a map that maintains operator (file-sink or reduce-sink) to work mapping + public final Map, BaseWork> operatorWorkMap; + @SuppressWarnings("unchecked") public GenTezProcContext(HiveConf conf, ParseContext parseContext, @@ -97,5 +104,7 @@ public GenTezProcContext(HiveConf conf, ParseContext parseContext, this.currentTask = (TezTask) TaskFactory.get(new TezWork(), conf); this.leafOperatorToFollowingWork = new HashMap, BaseWork>(); this.rootOperators = rootOperators; + this.linkOpWithWorkMap = new HashMap, List>(); + this.operatorWorkMap = new HashMap, BaseWork>(); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java index ff8b17b..59ae774 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.parse; import java.util.ArrayList; +import java.util.List; import java.util.Stack; import org.apache.commons.logging.Log; @@ -34,6 +35,7 @@ import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.TezWork; +import org.apache.hadoop.hive.ql.plan.TezWork.EdgeType; /** * GenTezWork separates the operator tree into tez tasks. @@ -101,9 +103,9 @@ public Object process(Node nd, Stack stack, reduceWork.setReducer(root); reduceWork.setNeedsTagging(GenMapRedUtils.needsTagging(reduceWork)); - // All parents should be reduce sinks. We pick the one we just walked - // to choose the number of reducers. In the join/union case they will - // all be -1. In sort/order case where it matters there will be only + // All parents should be reduce sinks. We pick the one we just walked + // to choose the number of reducers. In the join/union case they will + // all be -1. In sort/order case where it matters there will be only // one parent. assert context.parentOfRoot instanceof ReduceSinkOperator; ReduceSinkOperator reduceSink = (ReduceSinkOperator) context.parentOfRoot; @@ -121,7 +123,8 @@ public Object process(Node nd, Stack stack, tezWork.add(reduceWork); tezWork.connect( context.preceedingWork, - reduceWork); + reduceWork, EdgeType.SIMPLE_EDGE); + work = reduceWork; } @@ -142,21 +145,20 @@ public Object process(Node nd, Stack stack, BaseWork followingWork = context.leafOperatorToFollowingWork.get(operator); // need to add this branch to the key + value info - assert operator instanceof ReduceSinkOperator + assert operator instanceof ReduceSinkOperator && followingWork instanceof ReduceWork; ReduceSinkOperator rs = (ReduceSinkOperator) operator; ReduceWork rWork = (ReduceWork) followingWork; GenMapRedUtils.setKeyAndValueDesc(rWork, rs); // add dependency between the two work items - tezWork.connect(work, context.leafOperatorToFollowingWork.get(operator)); + tezWork.connect(work, context.leafOperatorToFollowingWork.get(operator), + EdgeType.SIMPLE_EDGE); } // This is where we cut the tree as described above. We also remember that // we might have to connect parent work with this work later. for (Operator parent: new ArrayList>(root.getParentOperators())) { - assert !context.leafOperatorToFollowingWork.containsKey(parent); - assert !(work instanceof MapWork); context.leafOperatorToFollowingWork.put(parent, work); LOG.debug("Removing " + parent + " as parent from " + root); root.removeParent(parent); @@ -175,6 +177,30 @@ public Object process(Node nd, Stack stack, context.preceedingWork = null; } + /* + * this happens in case of map join operations. + * The tree looks like this: + * + * RS <--- we are here perhaps + * | + * MapJoin + * / \ + * RS TS + * / + * TS + * + * If we are at the RS pointed above, and we may have already visited the + * RS following the TS, we have already generated work for the TS-RS. + * We need to hook the current work to this generated work. + */ + context.operatorWorkMap.put(operator, work); + List linkWorkList = context.linkOpWithWorkMap.get(operator); + if (linkWorkList != null) { + for (BaseWork parentWork : linkWorkList) { + tezWork.connect(parentWork, work, EdgeType.BROADCAST_EDGE); + } + } + return null; } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java index 42cc024..d230ff6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/TezCompiler.java @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.ql.exec.ConditionalTask; import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.JoinOperator; +import org.apache.hadoop.hive.ql.exec.MapJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; import org.apache.hadoop.hive.ql.exec.Task; @@ -47,6 +48,7 @@ import org.apache.hadoop.hive.ql.lib.Rule; import org.apache.hadoop.hive.ql.lib.RuleRegExp; import org.apache.hadoop.hive.ql.optimizer.ConvertJoinMapJoin; +import org.apache.hadoop.hive.ql.optimizer.ReduceSinkMapJoinProc; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.MoveWork; @@ -122,6 +124,9 @@ protected void generateTaskTree(List> rootTasks, Pa opRules.put(new RuleRegExp(new String("Split Work - ReduceSink"), ReduceSinkOperator.getOperatorName() + "%"), new GenTezWork()); + opRules.put(new RuleRegExp(new String("No more walking on ReduceSink-MapJoin"), + ReduceSinkOperator.getOperatorName() + "%" + + MapJoinOperator.getOperatorName() + "%"), new ReduceSinkMapJoinProc()); opRules.put(new RuleRegExp(new String("Split Work - FileSink"), FileSinkOperator.getOperatorName() + "%"), new GenTezWork()); diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java index 3e867ad..6db0099 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java @@ -29,8 +29,8 @@ import org.apache.commons.logging.LogFactory; /** - * TezWork. This class encapsulates all the work objects that can be executed - * in a single tez job. Currently it's basically a tree with MapWork at the + * TezWork. This class encapsulates all the work objects that can be executed + * in a single tez job. Currently it's basically a tree with MapWork at the * leaves and and ReduceWork in all other nodes. * */ @@ -38,12 +38,18 @@ @Explain(displayName = "Tez") public class TezWork extends AbstractOperatorDesc { + public enum EdgeType { + SIMPLE_EDGE, + BROADCAST_EDGE + } + private static transient final Log LOG = LogFactory.getLog(TezWork.class); private final Set roots = new HashSet(); private final Set leaves = new HashSet(); private final Map> workGraph = new HashMap>(); private final Map> invertedWorkGraph = new HashMap>(); + private final Map> broadcastEdge = new HashMap>(); /** * getAllWork returns a topologically sorted list of BaseWork @@ -53,12 +59,12 @@ List result = new LinkedList(); Set seen = new HashSet(); - + for (BaseWork leaf: leaves) { // make sure all leaves are visited at least once visit(leaf, seen, result); } - + return result; } @@ -76,7 +82,7 @@ private void visit(BaseWork child, Set seen, List result) { visit(parent, seen, result); } } - + result.add(child); } @@ -89,23 +95,31 @@ public void add(BaseWork w) { } workGraph.put(w, new LinkedList()); invertedWorkGraph.put(w, new LinkedList()); + broadcastEdge.put(w, new LinkedList()); roots.add(w); leaves.add(w); } /** - * connect adds an edge between a and b. Both nodes have + * connect adds an edge between a and b. Both nodes have * to be added prior to calling connect. */ - public void connect(BaseWork a, BaseWork b) { + public void connect(BaseWork a, BaseWork b, EdgeType edgeType) { workGraph.get(a).add(b); invertedWorkGraph.get(b).add(a); roots.remove(b); leaves.remove(a); + switch (edgeType) { + case BROADCAST_EDGE: + broadcastEdge.get(a).add(b); + break; + default: + break; + } } /** - * disconnect removes an edge between a and b. Both a and + * disconnect removes an edge between a and b. Both a and * b have to be in the graph. If there is no matching edge * no change happens. */ @@ -138,7 +152,7 @@ public void disconnect(BaseWork a, BaseWork b) { * getParents returns all the nodes with edges leading into work */ public List getParents(BaseWork work) { - assert invertedWorkGraph.containsKey(work) + assert invertedWorkGraph.containsKey(work) && invertedWorkGraph.get(work) != null; return new LinkedList(invertedWorkGraph.get(work)); } @@ -147,7 +161,7 @@ public void disconnect(BaseWork a, BaseWork b) { * getChildren returns all the nodes with edges leading out of work */ public List getChildren(BaseWork work) { - assert workGraph.containsKey(work) + assert workGraph.containsKey(work) && workGraph.get(work) != null; return new LinkedList(workGraph.get(work)); } @@ -162,7 +176,7 @@ public void remove(BaseWork work) { if (!workGraph.containsKey(work)) { return; } - + List children = getChildren(work); List parents = getParents(work); @@ -186,4 +200,12 @@ public void remove(BaseWork work) { workGraph.remove(work); invertedWorkGraph.remove(work); } + + // checks if a and b need a broadcast edge between them + public boolean isBroadCastEdge(BaseWork a, BaseWork b) { + if ((broadcastEdge.get(a).contains(b)) || (broadcastEdge.get(b).contains(a))) { + return true; + } + return false; + } } diff --git ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java index 918b0ff..c4c3ddf 100644 --- ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java +++ ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java @@ -17,11 +17,14 @@ */ package org.apache.hadoop.hive.ql.plan; +import java.util.LinkedList; +import java.util.List; + import junit.framework.Assert; + +import org.apache.hadoop.hive.ql.plan.TezWork.EdgeType; import org.junit.Before; import org.junit.Test; -import java.util.List; -import java.util.LinkedList; public class TestTezWork { @@ -58,9 +61,9 @@ public void testAdd() throws Exception { public void testConnect() throws Exception { BaseWork parent = nodes.get(0); BaseWork child = nodes.get(1); - - work.connect(parent, child); - + + work.connect(parent, child, EdgeType.SIMPLE_EDGE); + Assert.assertEquals(work.getParents(child).size(), 1); Assert.assertEquals(work.getChildren(parent).size(), 1); Assert.assertEquals(work.getChildren(parent).get(0), child); @@ -76,13 +79,35 @@ public void testConnect() throws Exception { } } - @Test + @Test + public void testBroadcastConnect() throws Exception { + BaseWork parent = nodes.get(0); + BaseWork child = nodes.get(1); + + work.connect(parent, child, EdgeType.BROADCAST_EDGE); + + Assert.assertEquals(work.getParents(child).size(), 1); + Assert.assertEquals(work.getChildren(parent).size(), 1); + Assert.assertEquals(work.getChildren(parent).get(0), child); + Assert.assertEquals(work.getParents(child).get(0), parent); + Assert.assertTrue(work.getRoots().contains(parent) && !work.getRoots().contains(child)); + Assert.assertTrue(!work.getLeaves().contains(parent) && work.getLeaves().contains(child)); + for (BaseWork w: nodes) { + if (w == parent || w == child) { + continue; + } + Assert.assertEquals(work.getParents(w).size(), 0); + Assert.assertEquals(work.getChildren(w).size(), 0); + } + } + + @Test public void testDisconnect() throws Exception { BaseWork parent = nodes.get(0); BaseWork children[] = {nodes.get(1), nodes.get(2)}; - - work.connect(parent, children[0]); - work.connect(parent, children[1]); + + work.connect(parent, children[0], EdgeType.SIMPLE_EDGE); + work.connect(parent, children[1], EdgeType.SIMPLE_EDGE); work.disconnect(parent, children[0]); @@ -94,14 +119,14 @@ public void testDisconnect() throws Exception { && work.getLeaves().contains(children[1])); } - @Test + @Test public void testRemove() throws Exception { BaseWork parent = nodes.get(0); BaseWork children[] = {nodes.get(1), nodes.get(2)}; - - work.connect(parent, children[0]); - work.connect(parent, children[1]); - + + work.connect(parent, children[0], EdgeType.SIMPLE_EDGE); + work.connect(parent, children[1], EdgeType.SIMPLE_EDGE); + work.remove(parent); Assert.assertEquals(work.getParents(children[0]).size(), 0); @@ -114,7 +139,7 @@ public void testRemove() throws Exception { @Test public void testGetAllWork() throws Exception { for (int i = 4; i > 0; --i) { - work.connect(nodes.get(i), nodes.get(i-1)); + work.connect(nodes.get(i), nodes.get(i-1), EdgeType.SIMPLE_EDGE); } List sorted = work.getAllWork();