diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java new file mode 100644 index 0000000..3e867ad --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/plan/TezWork.java @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.plan; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * TezWork. This class encapsulates all the work objects that can be executed + * in a single tez job. Currently it's basically a tree with MapWork at the + * leaves and and ReduceWork in all other nodes. + * + */ +@SuppressWarnings("serial") +@Explain(displayName = "Tez") +public class TezWork extends AbstractOperatorDesc { + + private static transient final Log LOG = LogFactory.getLog(TezWork.class); + + private final Set roots = new HashSet(); + private final Set leaves = new HashSet(); + private final Map> workGraph = new HashMap>(); + private final Map> invertedWorkGraph = new HashMap>(); + + /** + * getAllWork returns a topologically sorted list of BaseWork + */ + @Explain(skipHeader = true, displayName = "Tez Work") + public List getAllWork() { + + List result = new LinkedList(); + Set seen = new HashSet(); + + for (BaseWork leaf: leaves) { + // make sure all leaves are visited at least once + visit(leaf, seen, result); + } + + return result; + } + + private void visit(BaseWork child, Set seen, List result) { + + if (seen.contains(child)) { + // don't visit multiple times + return; + } + + seen.add(child); + + for (BaseWork parent: getParents(child)) { + if (!seen.contains(parent)) { + visit(parent, seen, result); + } + } + + result.add(child); + } + + /** + * add creates a new node in the graph without any connections + */ + public void add(BaseWork w) { + if (workGraph.containsKey(w)) { + return; + } + workGraph.put(w, new LinkedList()); + invertedWorkGraph.put(w, new LinkedList()); + roots.add(w); + leaves.add(w); + } + + /** + * connect adds an edge between a and b. Both nodes have + * to be added prior to calling connect. + */ + public void connect(BaseWork a, BaseWork b) { + workGraph.get(a).add(b); + invertedWorkGraph.get(b).add(a); + roots.remove(b); + leaves.remove(a); + } + + /** + * disconnect removes an edge between a and b. Both a and + * b have to be in the graph. If there is no matching edge + * no change happens. + */ + public void disconnect(BaseWork a, BaseWork b) { + workGraph.get(a).remove(b); + invertedWorkGraph.get(b).remove(a); + if (getParents(b).isEmpty()) { + roots.add(b); + } + if (getChildren(a).isEmpty()) { + leaves.add(a); + } + } + + /** + * getRoots returns all nodes that do not have a parent. + */ + public Set getRoots() { + return new HashSet(roots); + } + + /** + * getLeaves returns all nodes that do not have a child + */ + public Set getLeaves() { + return new HashSet(leaves); + } + + /** + * getParents returns all the nodes with edges leading into work + */ + public List getParents(BaseWork work) { + assert invertedWorkGraph.containsKey(work) + && invertedWorkGraph.get(work) != null; + return new LinkedList(invertedWorkGraph.get(work)); + } + + /** + * getChildren returns all the nodes with edges leading out of work + */ + public List getChildren(BaseWork work) { + assert workGraph.containsKey(work) + && workGraph.get(work) != null; + return new LinkedList(workGraph.get(work)); + } + + /** + * remove removes a node from the graph and removes all edges with + * work as start or end point. No change to the graph if the node + * doesn't exist. + */ + public void remove(BaseWork work) { + + if (!workGraph.containsKey(work)) { + return; + } + + List children = getChildren(work); + List parents = getParents(work); + + for (BaseWork w: children) { + invertedWorkGraph.get(w).remove(work); + if (invertedWorkGraph.get(w).size() == 0) { + roots.add(w); + } + } + + for (BaseWork w: parents) { + workGraph.get(w).remove(work); + if (workGraph.get(w).size() == 0) { + leaves.add(w); + } + } + + roots.remove(work); + leaves.remove(work); + + workGraph.remove(work); + invertedWorkGraph.remove(work); + } +} diff --git ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java new file mode 100644 index 0000000..918b0ff --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWork.java @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.plan; + +import junit.framework.Assert; +import org.junit.Before; +import org.junit.Test; +import java.util.List; +import java.util.LinkedList; + +public class TestTezWork { + + private List nodes; + private TezWork work; + + @Before + public void setup() throws Exception { + nodes = new LinkedList(); + work = new TezWork(); + addWork(5); + } + + private void addWork(int n) { + for (int i = 0; i < n; ++i) { + BaseWork w = new MapWork(); + nodes.add(w); + work.add(w); + } + } + + @Test + public void testAdd() throws Exception { + Assert.assertEquals(work.getAllWork().size(), nodes.size()); + Assert.assertEquals(work.getRoots().size(), nodes.size()); + Assert.assertEquals(work.getLeaves().size(), nodes.size()); + for (BaseWork w: nodes) { + Assert.assertEquals(work.getParents(w).size(), 0); + Assert.assertEquals(work.getChildren(w).size(), 0); + } + } + + @Test + public void testConnect() throws Exception { + BaseWork parent = nodes.get(0); + BaseWork child = nodes.get(1); + + work.connect(parent, child); + + Assert.assertEquals(work.getParents(child).size(), 1); + Assert.assertEquals(work.getChildren(parent).size(), 1); + Assert.assertEquals(work.getChildren(parent).get(0), child); + Assert.assertEquals(work.getParents(child).get(0), parent); + Assert.assertTrue(work.getRoots().contains(parent) && !work.getRoots().contains(child)); + Assert.assertTrue(!work.getLeaves().contains(parent) && work.getLeaves().contains(child)); + for (BaseWork w: nodes) { + if (w == parent || w == child) { + continue; + } + Assert.assertEquals(work.getParents(w).size(), 0); + Assert.assertEquals(work.getChildren(w).size(), 0); + } + } + + @Test + public void testDisconnect() throws Exception { + BaseWork parent = nodes.get(0); + BaseWork children[] = {nodes.get(1), nodes.get(2)}; + + work.connect(parent, children[0]); + work.connect(parent, children[1]); + + work.disconnect(parent, children[0]); + + Assert.assertTrue(work.getChildren(parent).contains(children[1])); + Assert.assertTrue(!work.getChildren(parent).contains(children[0])); + Assert.assertTrue(work.getRoots().contains(parent) && work.getRoots().contains(children[0]) + && !work.getRoots().contains(children[1])); + Assert.assertTrue(!work.getLeaves().contains(parent) && work.getLeaves().contains(children[0]) + && work.getLeaves().contains(children[1])); + } + + @Test + public void testRemove() throws Exception { + BaseWork parent = nodes.get(0); + BaseWork children[] = {nodes.get(1), nodes.get(2)}; + + work.connect(parent, children[0]); + work.connect(parent, children[1]); + + work.remove(parent); + + Assert.assertEquals(work.getParents(children[0]).size(), 0); + Assert.assertEquals(work.getParents(children[1]).size(), 0); + Assert.assertEquals(work.getAllWork().size(), nodes.size()-1); + Assert.assertEquals(work.getRoots().size(), nodes.size()-1); + Assert.assertEquals(work.getLeaves().size(), nodes.size()-1); + } + + @Test + public void testGetAllWork() throws Exception { + for (int i = 4; i > 0; --i) { + work.connect(nodes.get(i), nodes.get(i-1)); + } + + List sorted = work.getAllWork(); + for (int i = 0; i < 5; ++i) { + Assert.assertEquals(sorted.get(i), nodes.get(4-i)); + } + } +}