diff --git a/ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java b/ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java index 095aaee..af9b7d0 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java @@ -111,7 +111,7 @@ public Object process(Node nd, Stack stack, // will result into a vertex with multiple FS or RS operators. if (context.childToWorkMap.containsKey(operator)) { // if we've seen both root and child, we can bail. - + // clear out the mapjoin set. we don't need it anymore. context.currentMapJoinOperators.clear(); @@ -349,17 +349,20 @@ public Object process(Node nd, Stack stack, } else if (followingWork instanceof UnionWork) { // this can only be possible if there is merge work followed by the union UnionWork unionWork = (UnionWork) followingWork; - int index = getMergeIndex(tezWork, unionWork, rs); - // guaranteed to be instance of MergeJoinWork if index is valid - BaseWork baseWork = tezWork.getChildren(unionWork).get(index); - if (baseWork instanceof MergeJoinWork) { - MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork; - // disconnect the connection to union work and connect to merge work - followingWork = mergeJoinWork; - rWork = (ReduceWork) mergeJoinWork.getMainWork(); + int index = getFollowingWorkIndex(tezWork, unionWork, rs); + if (index != -1) { + BaseWork baseWork = tezWork.getChildren(unionWork).get(index); + if (baseWork instanceof MergeJoinWork) { + MergeJoinWork mergeJoinWork = (MergeJoinWork) baseWork; + // disconnect the connection to union work and connect to merge work + followingWork = mergeJoinWork; + rWork = (ReduceWork) mergeJoinWork.getMainWork(); + } else { + rWork = (ReduceWork) tezWork.getChildren(unionWork).get(index); + } } else { - throw new SemanticException("Unknown work type found: " - + baseWork.getClass().getCanonicalName()); + throw new SemanticException("Following work not found for the reduce sink: " + + rs.getName()); } } else { rWork = (ReduceWork) followingWork; @@ -403,7 +406,7 @@ public Object process(Node nd, Stack stack, return null; } - private int getMergeIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) { + private int getFollowingWorkIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperator rs) { int index = 0; for (BaseWork baseWork : tezWork.getChildren(unionWork)) { if (baseWork instanceof MergeJoinWork) { @@ -414,6 +417,8 @@ private int getMergeIndex(TezWork tezWork, UnionWork unionWork, ReduceSinkOperat } else { index++; } + } else if (baseWork instanceof ReduceWork) { + return index; } else { index++; }