diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java index 0e75f6e..de0abd1 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java @@ -18,12 +18,13 @@ package org.apache.hadoop.hive.ql.exec.tez; import java.util.Collection; - import java.util.concurrent.ConcurrentHashMap; + import com.google.common.base.Function; import com.google.common.base.Preconditions; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; + import javax.security.auth.login.LoginException; import java.io.File; @@ -37,9 +38,11 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.Stack; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -65,6 +68,7 @@ import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.ErrorMsg; import org.apache.hadoop.hive.ql.QueryPlan; +import org.apache.hadoop.hive.ql.exec.FileSinkOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.mr.ExecMapper; @@ -79,10 +83,22 @@ import org.apache.hadoop.hive.ql.io.merge.MergeFileMapper; import org.apache.hadoop.hive.ql.io.merge.MergeFileOutputFormat; import org.apache.hadoop.hive.ql.io.merge.MergeFileWork; +import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker; +import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher; +import org.apache.hadoop.hive.ql.lib.Dispatcher; +import org.apache.hadoop.hive.ql.lib.GraphWalker; +import org.apache.hadoop.hive.ql.lib.Node; +import org.apache.hadoop.hive.ql.lib.NodeProcessor; +import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; +import org.apache.hadoop.hive.ql.lib.Rule; +import org.apache.hadoop.hive.ql.lib.RuleRegExp; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.BaseWork; +import org.apache.hadoop.hive.ql.plan.FileSinkDesc; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.MergeJoinWork; +import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.TezEdgeProperty; import org.apache.hadoop.hive.ql.plan.TezEdgeProperty.EdgeType; @@ -169,6 +185,57 @@ */ private final ConcurrentHashMap copyNotifiers = new ConcurrentHashMap<>(); + class CollectFileSinkUrisNodeProcessor implements NodeProcessor { + + private final Set uris; + + public CollectFileSinkUrisNodeProcessor(Set uris) { + this.uris = uris; + } + + @Override + public Object process(Node nd, Stack stack, NodeProcessorCtx procCtx, + Object... nodeOutputs) throws SemanticException { + for (Node n : stack) { + Operator op = (Operator) n; + OperatorDesc desc = op.getConf(); + if (desc instanceof FileSinkDesc) { + FileSinkDesc fileSinkDesc = (FileSinkDesc) desc; + Path dirName = fileSinkDesc.getDirName(); + if (dirName != null) { + uris.add(dirName.toUri()); + } + Path destPath = fileSinkDesc.getDestPath(); + if (destPath != null) { + uris.add(destPath.toUri()); + } + } + } + return null; + } + } + + private void addCollectFileSinkUrisRules(Map opRules, NodeProcessor np) { + opRules.put(new RuleRegExp("R1", FileSinkOperator.getOperatorName() + ".*"), np); + } + + private void collectFileSinkUris(List topNodes, Set uris) { + + CollectFileSinkUrisNodeProcessor np = new CollectFileSinkUrisNodeProcessor(uris); + + Map opRules = new LinkedHashMap(); + addCollectFileSinkUrisRules(opRules, np); + + Dispatcher disp = new DefaultRuleDispatcher(np, opRules, null); + GraphWalker ogw = new DefaultGraphWalker(disp); + + try { + ogw.startWalking(topNodes, null); + } catch (SemanticException e) { + throw new RuntimeException(e); + } + } + private void addCredentials(MapWork mapWork, DAG dag) { Set paths = mapWork.getPathToAliases().keySet(); if (!paths.isEmpty()) { @@ -184,15 +251,43 @@ public URI apply(Path path) { if (LOG.isDebugEnabled()) { for (URI uri: uris) { - LOG.debug("Marking URI as needing credentials: "+uri); + LOG.debug("Marking MapWork input URI as needing credentials: " + uri); } } dag.addURIsForCredentials(uris); } + + Set fileSinkUris = new HashSet(); + + List topNodes = new ArrayList(); + LinkedHashMap> aliasToWork = mapWork.getAliasToWork(); + for (Operator operator : aliasToWork.values()) { + topNodes.add(operator); + } + collectFileSinkUris(topNodes, fileSinkUris); + + if (LOG.isDebugEnabled()) { + for (URI fileSinkUri: fileSinkUris) { + LOG.debug("Marking MapWork output URI as needing credentials: " + fileSinkUri); + } + } + dag.addURIsForCredentials(fileSinkUris); } private void addCredentials(ReduceWork reduceWork, DAG dag) { - // nothing at the moment + + Set fileSinkUris = new HashSet(); + + List topNodes = new ArrayList(); + topNodes.add(reduceWork.getReducer()); + collectFileSinkUris(topNodes, fileSinkUris); + + if (LOG.isDebugEnabled()) { + for (URI fileSinkUri: fileSinkUris) { + LOG.debug("Marking ReduceWork output URI as needing credentials: " + fileSinkUri); + } + } + dag.addURIsForCredentials(fileSinkUris); } /*