diff --git llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/AMReporter.java llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/AMReporter.java index 93237e6..a30f8b9 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/AMReporter.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/AMReporter.java @@ -119,7 +119,8 @@ private final DaemonId daemonId; public AMReporter(int numExecutors, int maxThreads, AtomicReference - localAddress, QueryFailedHandler queryFailedHandler, Configuration conf, DaemonId daemonId) { + localAddress, QueryFailedHandler queryFailedHandler, Configuration conf, DaemonId daemonId, + SocketFactory socketFactory) { super(AMReporter.class.getName()); this.localAddress = localAddress; this.queryFailedHandler = queryFailedHandler; @@ -151,7 +152,7 @@ public AMReporter(int numExecutors, int maxThreads, AtomicReference 0, "Invalid number of executors: " + numExecutors + ". Must be > 0"); @@ -122,6 +126,7 @@ public ContainerRunnerImpl(Configuration conf, int numExecutors, int waitQueueSi this.signer = UserGroupInformation.isSecurityEnabled() ? new LlapSignerImpl(conf, daemonId.getClusterString()) : null; this.fsUgiFactory = fsUgiFactory; + this.socketFactory = socketFactory; this.clusterId = daemonId.getClusterString(); this.daemonId = daemonId; @@ -239,7 +244,8 @@ public SubmitWorkResponseProto submitWork(SubmitWorkRequestProto request) throws queryIdentifier, qIdProto.getApplicationIdString(), dagId, vertex.getDagName(), vertex.getHiveQueryId(), dagIdentifier, vertex.getVertexName(), request.getFragmentNumber(), request.getAttemptNumber(), - vertex.getUser(), vertex, jobToken, fragmentIdString, tokenInfo); + vertex.getUser(), vertex, jobToken, fragmentIdString, tokenInfo, request.getAmHost(), + request.getAmPort()); String[] localDirs = fragmentInfo.getLocalDirs(); Preconditions.checkNotNull(localDirs); @@ -250,12 +256,12 @@ public SubmitWorkResponseProto submitWork(SubmitWorkRequestProto request) throws // Used for re-localization, to add the user specified configuration (conf_pb_binary_stream) Configuration callableConf = new Configuration(getConfig()); - UserGroupInformation taskUgi = fsUgiFactory == null ? null : fsUgiFactory.createUgi(); + UserGroupInformation fsTaskUgi = fsUgiFactory == null ? null : fsUgiFactory.createUgi(); TaskRunnerCallable callable = new TaskRunnerCallable(request, fragmentInfo, callableConf, new ExecutionContextImpl(localAddress.get().getHostName()), env, credentials, memoryPerExecutor, amReporter, confParams, metrics, killedTaskHandler, - this, tezHadoopShim, attemptId, vertex, initialEvent, taskUgi, - completionListener); + this, tezHadoopShim, attemptId, vertex, initialEvent, fsTaskUgi, + completionListener, socketFactory); submissionState = executorService.schedule(callable); if (LOG.isInfoEnabled()) { diff --git llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java index fc9f530..eb05f4c 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.management.ObjectName; +import javax.net.SocketFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.common.JvmPauseMonitor; @@ -64,6 +65,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge; import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge.UdfWhitelistChecker; import org.apache.hadoop.metrics2.util.MBeans; +import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.service.CompositeService; import org.apache.hadoop.util.ExitUtil; @@ -105,6 +107,7 @@ private final long maxJvmMemory; private final String[] localDirs; private final DaemonId daemonId; + private final SocketFactory socketFactory; // TODO Not the best way to share the address private final AtomicReference srvAddress = new AtomicReference<>(), @@ -255,8 +258,9 @@ public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemor " sessionId: " + sessionId); int maxAmReporterThreads = HiveConf.getIntVar(daemonConf, ConfVars.LLAP_DAEMON_AM_REPORTER_MAX_THREADS); + this.socketFactory = NetUtils.getDefaultSocketFactory(daemonConf); this.amReporter = new AMReporter(numExecutors, maxAmReporterThreads, srvAddress, - new QueryFailedHandlerProxy(), daemonConf, daemonId); + new QueryFailedHandlerProxy(), daemonConf, daemonId, socketFactory); SecretManager sm = null; if (UserGroupInformation.isSecurityEnabled()) { @@ -274,7 +278,7 @@ public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemor } this.containerRunner = new ContainerRunnerImpl(daemonConf, numExecutors, waitQueueSize, enablePreemption, localDirs, this.shufflePort, srvAddress, executorMemoryPerInstance, metrics, - amReporter, executorClassLoader, daemonId, fsUgiFactory); + amReporter, executorClassLoader, daemonId, fsUgiFactory, socketFactory); addIfService(containerRunner); // Not adding the registry as a service, since we need to control when it is initialized - conf used to pickup properties. diff --git llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryInfo.java llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryInfo.java index 1080d3e..eaa3e7e 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryInfo.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryInfo.java @@ -16,6 +16,7 @@ import java.io.File; import java.io.IOException; +import java.net.InetSocketAddress; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -25,6 +26,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReentrantLock; import com.google.common.base.Preconditions; @@ -36,6 +38,11 @@ import org.apache.hadoop.hive.llap.daemon.FinishableStateUpdateHandler; import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SignableVertexSpec; import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SourceStateProto; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.SecurityUtil; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.tez.common.security.JobTokenIdentifier; public class QueryInfo { private final QueryIdentifier queryIdentifier; @@ -57,6 +64,7 @@ private final FinishableStateTracker finishableStateTracker = new FinishableStateTracker(); private final String tokenUserName, appId; + private final AtomicReference umbilicalUgi; public QueryInfo(QueryIdentifier queryIdentifier, String appIdString, String dagIdString, String dagName, String hiveQueryIdString, @@ -76,6 +84,7 @@ public QueryInfo(QueryIdentifier queryIdentifier, String appIdString, String dag this.localFs = localFs; this.tokenUserName = tokenUserName; this.appId = tokenAppId; + this.umbilicalUgi = new AtomicReference<>(); } public QueryIdentifier getQueryIdentifier() { @@ -297,4 +306,24 @@ public String getTokenUserName() { public String getTokenAppId() { return appId; } + + public void setupUmbilicalUgi(String umbilicalUser, Token appToken, String amHost, int amPort) { + synchronized (umbilicalUgi) { + if (umbilicalUgi.get() == null) { + UserGroupInformation taskOwner = + UserGroupInformation.createRemoteUser(umbilicalUser); + final InetSocketAddress address = + NetUtils.createSocketAddrForHost(amHost, amPort); + SecurityUtil.setTokenService(appToken, address); + taskOwner.addToken(appToken); + umbilicalUgi.set(taskOwner); + } + } + } + + public UserGroupInformation getUmbilicalUgi() { + synchronized (umbilicalUgi) { + return umbilicalUgi.get(); + } + } } diff --git llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryTracker.java llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryTracker.java index 9eaddd2..5cf3a38 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryTracker.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/QueryTracker.java @@ -139,7 +139,7 @@ public QueryTracker(Configuration conf, String[] localDirsBase, String clusterId QueryFragmentInfo registerFragment(QueryIdentifier queryIdentifier, String appIdString, String dagIdString, String dagName, String hiveQueryIdString, int dagIdentifier, String vertexName, int fragmentNumber, int attemptNumber, String user, SignableVertexSpec vertex, Token appToken, - String fragmentIdString, LlapTokenInfo tokenInfo) throws IOException { + String fragmentIdString, LlapTokenInfo tokenInfo, String amHost, int amPort) throws IOException { ReadWriteLock dagLock = getDagLock(queryIdentifier); // Note: This is a readLock to prevent a race with queryComplete. Operations @@ -174,6 +174,8 @@ QueryFragmentInfo registerFragment(QueryIdentifier queryIdentifier, String appId if (old != null) { queryInfo = old; } else { + // Ensure the UGI is setup once. + queryInfo.setupUmbilicalUgi(vertex.getTokenIdentifier(), appToken, amHost, amPort); isExistingQueryInfo = false; } } diff --git llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java index 4b677aa..8fce546 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java @@ -41,7 +41,6 @@ import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.Credentials; -import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; import org.apache.log4j.MDC; @@ -65,6 +64,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.SocketFactory; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.security.PrivilegedExceptionAction; @@ -116,7 +116,8 @@ private final SignableVertexSpec vertex; private final TezEvent initialEvent; private final SchedulerFragmentCompletingListener completionListener; - private UserGroupInformation taskUgi; + private UserGroupInformation fsTaskUgi; + private final SocketFactory socketFactory; @VisibleForTesting public TaskRunnerCallable(SubmitWorkRequestProto request, QueryFragmentInfo fragmentInfo, @@ -125,7 +126,8 @@ public TaskRunnerCallable(SubmitWorkRequestProto request, QueryFragmentInfo frag LlapDaemonExecutorMetrics metrics, KilledTaskHandler killedTaskHandler, FragmentCompletionHandler fragmentCompleteHandler, HadoopShim tezHadoopShim, TezTaskAttemptID attemptId, SignableVertexSpec vertex, TezEvent initialEvent, - UserGroupInformation taskUgi, SchedulerFragmentCompletingListener completionListener) { + UserGroupInformation fsTaskUgi, SchedulerFragmentCompletingListener completionListener, + SocketFactory socketFactory) { this.request = request; this.fragmentInfo = fragmentInfo; this.conf = conf; @@ -153,8 +155,9 @@ public TaskRunnerCallable(SubmitWorkRequestProto request, QueryFragmentInfo frag this.fragmentCompletionHanler = fragmentCompleteHandler; this.tezHadoopShim = tezHadoopShim; this.initialEvent = initialEvent; - this.taskUgi = taskUgi; + this.fsTaskUgi = fsTaskUgi; this.completionListener = completionListener; + this.socketFactory = socketFactory; } public long getStartTime() { @@ -196,27 +199,27 @@ protected TaskRunner2Result callInternal() throws Exception { // TODO Consolidate this code with TezChild. runtimeWatch.start(); - if (taskUgi == null) { - taskUgi = UserGroupInformation.createRemoteUser(vertex.getUser()); + if (fsTaskUgi == null) { + fsTaskUgi = UserGroupInformation.createRemoteUser(vertex.getUser()); } - taskUgi.addCredentials(credentials); + fsTaskUgi.addCredentials(credentials); Map serviceConsumerMetadata = new HashMap<>(); serviceConsumerMetadata.put(TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, TezCommonUtils.convertJobTokenToBytes(jobToken)); Multimap startedInputsMap = createStartedInputMap(vertex); - UserGroupInformation taskOwner = - UserGroupInformation.createRemoteUser(vertex.getTokenIdentifier()); + final UserGroupInformation taskOwner = fragmentInfo.getQueryInfo().getUmbilicalUgi(); + if (LOG.isDebugEnabled()) { + LOG.debug("taskOwner hashCode:" + taskOwner.hashCode()); + } final InetSocketAddress address = NetUtils.createSocketAddrForHost(request.getAmHost(), request.getAmPort()); - SecurityUtil.setTokenService(jobToken, address); - taskOwner.addToken(jobToken); umbilical = taskOwner.doAs(new PrivilegedExceptionAction() { @Override public LlapTaskUmbilicalProtocol run() throws Exception { return RPC.getProxy(LlapTaskUmbilicalProtocol.class, - LlapTaskUmbilicalProtocol.versionID, address, conf); + LlapTaskUmbilicalProtocol.versionID, address, taskOwner, conf, socketFactory); } }); @@ -238,7 +241,7 @@ public LlapTaskUmbilicalProtocol run() throws Exception { try { synchronized (this) { if (shouldRunTask) { - taskRunner = new TezTaskRunner2(conf, taskUgi, fragmentInfo.getLocalDirs(), + taskRunner = new TezTaskRunner2(conf, fsTaskUgi, fragmentInfo.getLocalDirs(), taskSpec, vertex.getQueryIdentifier().getAppAttemptNumber(), serviceConsumerMetadata, envMap, startedInputsMap, taskReporter, executor, @@ -260,7 +263,7 @@ public LlapTaskUmbilicalProtocol run() throws Exception { isCompleted.set(true); return result; } finally { - FileSystem.closeAllForUGI(taskUgi); + FileSystem.closeAllForUGI(fsTaskUgi); LOG.info("ExecutionTime for Container: " + request.getContainerIdString() + "=" + runtimeWatch.stop().elapsedMillis()); if (LOG.isDebugEnabled()) { diff --git llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorTestHelpers.java llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorTestHelpers.java index 5dc1be5..ae3328a 100644 --- llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorTestHelpers.java +++ llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorTestHelpers.java @@ -44,6 +44,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.net.SocketFactory; + public class TaskExecutorTestHelpers { private static final Logger LOG = LoggerFactory.getLogger(TestTaskExecutorService.class); @@ -184,7 +186,7 @@ public MockRequest(SubmitWorkRequestProto requestProto, QueryFragmentInfo fragme mock(KilledTaskHandler.class), mock( FragmentCompletionHandler.class), new DefaultHadoopShim(), null, requestProto.getWorkSpec().getVertex(), initialEvent, null, mock( - SchedulerFragmentCompletingListener.class)); + SchedulerFragmentCompletingListener.class), mock(SocketFactory.class)); this.workTime = workTime; this.canFinish = canFinish; }