diff --git a/llap-client/src/java/org/apache/hadoop/hive/llap/configuration/LlapConfiguration.java b/llap-client/src/java/org/apache/hadoop/hive/llap/configuration/LlapConfiguration.java index 62695b7e2230fe96fa40587d8674d742efc5ffb1..8c5c3e4540cbe7cb9bfa5f59e5c72cb731c65cac 100644 --- a/llap-client/src/java/org/apache/hadoop/hive/llap/configuration/LlapConfiguration.java +++ b/llap-client/src/java/org/apache/hadoop/hive/llap/configuration/LlapConfiguration.java @@ -81,5 +81,11 @@ public LlapConfiguration() { LLAP_DAEMON_PREFIX + "task.scheduler.node.re-enable.timeout.ms"; public static final long LLAP_DAEMON_TASK_SCHEDULER_NODE_REENABLE_TIMEOUT_MILLIS_DEFAULT = 2000l; + public static final String LLAP_DAEMON_TASK_SCHEDULER_WAIT_QUEUE_SIZE = + LLAP_DAEMON_PREFIX + "task.scheduler.wait.queue.size"; + public static final int LLAP_DAEMON_TASK_SCHEDULER_WAIT_QUEUE_SIZE_DEFAULT = 10; + public static final String LLAP_DAEMON_TASK_SCHEDULER_ENABLE_PREEMPTION = + LLAP_DAEMON_PREFIX + "task.scheduler.enable.preemption"; + public static final boolean LLAP_DAEMON_TASK_SCHEDULER_ENABLE_PREEMPTION_DEFAULT = false; } diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/BoundedPriorityBlockingQueue.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/BoundedPriorityBlockingQueue.java new file mode 100644 index 0000000000000000000000000000000000000000..78d3c6cacfdd0dee96477baf6fd9dcdb9f016545 --- /dev/null +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/BoundedPriorityBlockingQueue.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.llap.daemon.impl; + +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; + +/** + * Priority blocking queue of bounded size. The entries that are added already added will be + * ordered based on the specified comparator. If the queue is full, offer() will return false and + * add() will throw IllegalStateException. + */ +public class BoundedPriorityBlockingQueue extends PriorityBlockingQueue { + private int maxSize; + + public BoundedPriorityBlockingQueue(int maxSize) { + this.maxSize = maxSize; + } + + public BoundedPriorityBlockingQueue(Comparator comparator, int maxSize) { + super(maxSize, comparator); + this.maxSize = maxSize; + } + + @Override + public boolean add(E e) { + if (size() >= maxSize) { + throw new IllegalStateException("BoundedPriorityBlockingQueue is full"); + } else { + return super.add(e); + } + } + + @Override + public boolean offer(E e) { + if (size() >= maxSize) { + return false; + } else { + return super.offer(e); + } + } +} diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/ContainerRunnerImpl.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/ContainerRunnerImpl.java index c142982c030530aa8645a0bb13a3e28c4f0371c9..3a750b3b41abff1c3d2eba2694717ff3e3ed6796 100644 --- a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/ContainerRunnerImpl.java +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/ContainerRunnerImpl.java @@ -18,24 +18,16 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.security.PrivilegedExceptionAction; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.common.CallableWithNdc; import org.apache.hadoop.hive.llap.LlapNodeId; import org.apache.hadoop.hive.llap.daemon.ContainerRunner; import org.apache.hadoop.hive.llap.daemon.HistoryLogger; @@ -46,45 +38,22 @@ import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SourceStateUpdatedRequestProto; import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SubmitWorkRequestProto; import org.apache.hadoop.hive.llap.metrics.LlapDaemonExecutorMetrics; -import org.apache.hadoop.hive.llap.protocol.LlapTaskUmbilicalProtocol; import org.apache.hadoop.hive.llap.shufflehandler.ShuffleHandler; -import org.apache.hadoop.hive.llap.tezplugins.Converters; import org.apache.hadoop.io.DataInputBuffer; -import org.apache.hadoop.ipc.RPC; -import org.apache.hadoop.net.NetUtils; import org.apache.hadoop.security.Credentials; -import org.apache.hadoop.security.SecurityUtil; -import org.apache.hadoop.security.UserGroupInformation; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.service.AbstractService; import org.apache.hadoop.yarn.api.ApplicationConstants; import org.apache.hadoop.yarn.util.AuxiliaryServiceHelper; import org.apache.log4j.Logger; import org.apache.log4j.NDC; -import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.security.JobTokenIdentifier; import org.apache.tez.common.security.TokenCache; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezConstants; -import org.apache.tez.dag.api.TezException; -import org.apache.tez.mapreduce.input.MRInputLegacy; -import org.apache.tez.runtime.api.ExecutionContext; import org.apache.tez.runtime.api.impl.ExecutionContextImpl; -import org.apache.tez.runtime.api.impl.InputSpec; -import org.apache.tez.runtime.api.impl.TaskSpec; -import org.apache.tez.runtime.common.objectregistry.ObjectRegistryImpl; -import org.apache.tez.runtime.internals.api.TaskReporterInterface; -import org.apache.tez.runtime.task.TezChild.ContainerExecutionResult; import com.google.common.base.Preconditions; -import com.google.common.base.Stopwatch; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.tez.runtime.task.TezTaskRunner; public class ContainerRunnerImpl extends AbstractService implements ContainerRunner { @@ -93,7 +62,7 @@ private static final Logger LOG = Logger.getLogger(ContainerRunnerImpl.class); private volatile AMReporter amReporter; - private final ListeningExecutorService executorService; + private final Scheduler executorService; private final AtomicReference localAddress; private final String[] localDirsBase; private final Map localEnv = new HashMap<>(); @@ -101,13 +70,14 @@ private final long memoryPerExecutor; private final LlapDaemonExecutorMetrics metrics; private final Configuration conf; - private final ConfParams confParams; + private final TaskRunnerCallable.ConfParams confParams; // Map of dagId to vertices and associated state. private final ConcurrentMap> sourceCompletionMap = new ConcurrentHashMap<>(); // TODO Support for removing queued containers, interrupting / killing specific containers - public ContainerRunnerImpl(Configuration conf, int numExecutors, String[] localDirsBase, int localShufflePort, + public ContainerRunnerImpl(Configuration conf, int numExecutors, int waitQueueSize, + boolean enablePreemption, String[] localDirsBase, int localShufflePort, AtomicReference localAddress, long totalMemoryAvailableBytes, LlapDaemonExecutorMetrics metrics) { super("ContainerRunnerImpl"); @@ -117,9 +87,7 @@ public ContainerRunnerImpl(Configuration conf, int numExecutors, String[] localD this.localDirsBase = localDirsBase; this.localAddress = localAddress; - ExecutorService raw = Executors.newFixedThreadPool(numExecutors, - new ThreadFactoryBuilder().setNameFormat(THREAD_NAME_FORMAT).build()); - this.executorService = MoreExecutors.listeningDecorator(raw); + this.executorService = new TaskExecutorService(numExecutors, waitQueueSize, enablePreemption); AuxiliaryServiceHelper.setServiceDataIntoEnv( TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, ByteBuffer.allocate(4).putInt(localShufflePort), localEnv); @@ -134,7 +102,7 @@ public ContainerRunnerImpl(Configuration conf, int numExecutors, String[] localD } catch (IOException e) { throw new RuntimeException("Failed to setup local filesystem instance", e); } - confParams = new ConfParams( + confParams = new TaskRunnerCallable.ConfParams( conf.getInt(TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS, TezConfiguration.TEZ_TASK_AM_HEARTBEAT_INTERVAL_MS_DEFAULT), conf.getLong( @@ -230,9 +198,8 @@ public void submitWork(SubmitWorkRequestProto request) throws IOException { ConcurrentMap sourceCompletionMap = getSourceCompletionMap(request.getFragmentSpec().getDagName()); TaskRunnerCallable callable = new TaskRunnerCallable(request, new Configuration(getConfig()), new ExecutionContextImpl(localAddress.get().getHostName()), env, localDirs, - credentials, memoryPerExecutor, amReporter, sourceCompletionMap, confParams); - ListenableFuture future = executorService.submit(callable); - Futures.addCallback(future, new TaskRunnerCallback(request, callable)); + credentials, memoryPerExecutor, amReporter, sourceCompletionMap, confParams, metrics); + executorService.schedule(callable); metrics.incrExecutorTotalRequestsHandled(); metrics.incrExecutorNumQueuedRequests(); } finally { @@ -240,6 +207,10 @@ public void submitWork(SubmitWorkRequestProto request) throws IOException { } } + private void notifyAMOfRejection(TaskRunnerCallable callable) { + LOG.error("Notifying AM of request rejection is not implemented yet!"); + } + @Override public void sourceStateUpdated(SourceStateUpdatedRequestProto request) { LOG.info("Processing state update: " + stringifySourceStateUpdateRequest(request)); @@ -247,273 +218,6 @@ public void sourceStateUpdated(SourceStateUpdatedRequestProto request) { dagMap.put(request.getSrcName(), request.getState()); } - static class TaskRunnerCallable extends CallableWithNdc { - - private final SubmitWorkRequestProto request; - private final Configuration conf; - private final String[] localDirs; - private final Map envMap; - private final String pid = null; - private final ObjectRegistryImpl objectRegistry; - private final ExecutionContext executionContext; - private final Credentials credentials; - private final long memoryAvailable; - private final ConfParams confParams; - private final Token jobToken; - private final AMReporter amReporter; - private final ConcurrentMap sourceCompletionMap; - private final TaskSpec taskSpec; - private volatile TezTaskRunner taskRunner; - private volatile TaskReporterInterface taskReporter; - private volatile ListeningExecutorService executor; - private LlapTaskUmbilicalProtocol umbilical; - private volatile long startTime; - private volatile String threadName; - private volatile boolean cancelled = false; - - - - TaskRunnerCallable(SubmitWorkRequestProto request, Configuration conf, - ExecutionContext executionContext, Map envMap, - String[] localDirs, Credentials credentials, - long memoryAvailable, AMReporter amReporter, - ConcurrentMap sourceCompletionMap, ConfParams confParams) { - this.request = request; - this.conf = conf; - this.executionContext = executionContext; - this.envMap = envMap; - this.localDirs = localDirs; - this.objectRegistry = new ObjectRegistryImpl(); - this.sourceCompletionMap = sourceCompletionMap; - this.credentials = credentials; - this.memoryAvailable = memoryAvailable; - this.confParams = confParams; - this.jobToken = TokenCache.getSessionToken(credentials); - this.taskSpec = Converters.getTaskSpecfromProto(request.getFragmentSpec()); - this.amReporter = amReporter; - // Register with the AMReporter when the callable is setup. Unregister once it starts running. - this.amReporter.registerTask(request.getAmHost(), request.getAmPort(), request.getUser(), jobToken); - } - - @Override - protected ContainerExecutionResult callInternal() throws Exception { - this.startTime = System.currentTimeMillis(); - this.threadName = Thread.currentThread().getName(); - if (LOG.isDebugEnabled()) { - LOG.debug("canFinish: " + taskSpec.getTaskAttemptID() + ": " + canFinish()); - } - - - // Unregister from the AMReporter, since the task is now running. - this.amReporter.unregisterTask(request.getAmHost(), request.getAmPort()); - - // TODO This executor seems unnecessary. Here and TezChild - ExecutorService executorReal = Executors.newFixedThreadPool(1, - new ThreadFactoryBuilder() - .setDaemon(true) - .setNameFormat( - "TezTaskRunner_" + request.getFragmentSpec().getTaskAttemptIdString()) - .build()); - executor = MoreExecutors.listeningDecorator(executorReal); - - // TODO Consolidate this code with TezChild. - Stopwatch sw = new Stopwatch().start(); - UserGroupInformation taskUgi = UserGroupInformation.createRemoteUser(request.getUser()); - taskUgi.addCredentials(credentials); - - Map serviceConsumerMetadata = new HashMap<>(); - serviceConsumerMetadata.put(TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, - TezCommonUtils.convertJobTokenToBytes(jobToken)); - Multimap startedInputsMap = HashMultimap.create(); - - UserGroupInformation taskOwner = - UserGroupInformation.createRemoteUser(request.getTokenIdentifier()); - final InetSocketAddress address = - NetUtils.createSocketAddrForHost(request.getAmHost(), request.getAmPort()); - SecurityUtil.setTokenService(jobToken, address); - taskOwner.addToken(jobToken); - umbilical = taskOwner.doAs(new PrivilegedExceptionAction() { - @Override - public LlapTaskUmbilicalProtocol run() throws Exception { - return RPC.getProxy(LlapTaskUmbilicalProtocol.class, - LlapTaskUmbilicalProtocol.versionID, address, conf); - } - }); - - taskReporter = new LlapTaskReporter( - umbilical, - confParams.amHeartbeatIntervalMsMax, - confParams.amCounterHeartbeatInterval, - confParams.amMaxEventsPerHeartbeat, - new AtomicLong(0), - request.getContainerIdString()); - - taskRunner = new TezTaskRunner(conf, taskUgi, localDirs, - taskSpec, - request.getAppAttemptNumber(), - serviceConsumerMetadata, envMap, startedInputsMap, taskReporter, executor, objectRegistry, - pid, - executionContext, memoryAvailable); - - boolean shouldDie; - try { - shouldDie = !taskRunner.run(); - if (shouldDie) { - LOG.info("Got a shouldDie notification via heartbeats. Shutting down"); - return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.SUCCESS, null, - "Asked to die by the AM"); - } - } catch (IOException e) { - return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, - e, "TaskExecutionFailure: " + e.getMessage()); - } catch (TezException e) { - return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, - e, "TaskExecutionFailure: " + e.getMessage()); - } finally { - // TODO Fix UGI and FS Handling. Closing UGI here causes some errors right now. -// FileSystem.closeAllForUGI(taskUgi); - } - LOG.info("ExecutionTime for Container: " + request.getContainerIdString() + "=" + - sw.stop().elapsedMillis()); - if (LOG.isDebugEnabled()) { - LOG.debug("canFinish post completion: " + taskSpec.getTaskAttemptID() + ": " + canFinish()); - } - - return new ContainerExecutionResult(ContainerExecutionResult.ExitStatus.SUCCESS, null, - null); - } - - /** - * Check whether a task can run to completion or may end up blocking on it's sources. - * This currently happens via looking up source state. - * TODO: Eventually, this should lookup the Hive Processor to figure out whether - * it's reached a state where it can finish - especially in cases of failures - * after data has been fetched. - * @return - */ - public boolean canFinish() { - List inputSpecList = taskSpec.getInputs(); - boolean canFinish = true; - if (inputSpecList != null && !inputSpecList.isEmpty()) { - for (InputSpec inputSpec : inputSpecList) { - if (isSourceOfInterest(inputSpec)) { - // Lookup the state in the map. - SourceStateProto state = sourceCompletionMap.get(inputSpec.getSourceVertexName()); - if (state != null && state == SourceStateProto.S_SUCCEEDED) { - continue; - } else { - if (LOG.isDebugEnabled()) { - LOG.debug("Cannot finish due to source: " + inputSpec.getSourceVertexName()); - } - canFinish = false; - break; - } - } - } - } - return canFinish; - } - - private boolean isSourceOfInterest(InputSpec inputSpec) { - String inputClassName = inputSpec.getInputDescriptor().getClassName(); - // MRInput is not of interest since it'll always be ready. - return !inputClassName.equals(MRInputLegacy.class.getName()); - } - - public void shutdown() { - executor.shutdownNow(); - if (taskReporter != null) { - taskReporter.shutdown(); - } - if (umbilical != null) { - RPC.stopProxy(umbilical); - } - } - } - - final class TaskRunnerCallback implements FutureCallback { - - private final SubmitWorkRequestProto request; - private final TaskRunnerCallable taskRunnerCallable; - - TaskRunnerCallback(SubmitWorkRequestProto request, - TaskRunnerCallable taskRunnerCallable) { - this.request = request; - this.taskRunnerCallable = taskRunnerCallable; - } - - // TODO Slightly more useful error handling - @Override - public void onSuccess(ContainerExecutionResult result) { - switch (result.getExitStatus()) { - case SUCCESS: - LOG.info("Successfully finished: " + getTaskIdentifierString(request)); - metrics.incrExecutorTotalSuccess(); - break; - case EXECUTION_FAILURE: - LOG.info("Failed to run: " + getTaskIdentifierString(request)); - metrics.incrExecutorTotalExecutionFailed(); - break; - case INTERRUPTED: - LOG.info("Interrupted while running: " + getTaskIdentifierString(request)); - metrics.incrExecutorTotalInterrupted(); - break; - case ASKED_TO_DIE: - LOG.info("Asked to die while running: " + getTaskIdentifierString(request)); - metrics.incrExecutorTotalAskedToDie(); - break; - } - taskRunnerCallable.shutdown(); - HistoryLogger - .logFragmentEnd(request.getApplicationIdString(), request.getContainerIdString(), - localAddress.get().getHostName(), request.getFragmentSpec().getDagName(), - request.getFragmentSpec().getVertexName(), - request.getFragmentSpec().getFragmentNumber(), - request.getFragmentSpec().getAttemptNumber(), taskRunnerCallable.threadName, - taskRunnerCallable.startTime, true); - metrics.decrExecutorNumQueuedRequests(); - } - - @Override - public void onFailure(Throwable t) { - LOG.error("TezTaskRunner execution failed for : " + getTaskIdentifierString(request), t); - // TODO HIVE-10236 Report a fatal error over the umbilical - taskRunnerCallable.shutdown(); - HistoryLogger - .logFragmentEnd(request.getApplicationIdString(), request.getContainerIdString(), - localAddress.get().getHostName(), request.getFragmentSpec().getDagName(), - request.getFragmentSpec().getVertexName(), - request.getFragmentSpec().getFragmentNumber(), - request.getFragmentSpec().getAttemptNumber(), taskRunnerCallable.threadName, - taskRunnerCallable.startTime, false); - metrics.decrExecutorNumQueuedRequests(); - } - - private String getTaskIdentifierString(SubmitWorkRequestProto request) { - StringBuilder sb = new StringBuilder(); - sb.append("AppId=").append(request.getApplicationIdString()) - .append(", containerId=").append(request.getContainerIdString()) - .append(", Dag=").append(request.getFragmentSpec().getDagName()) - .append(", Vertex=").append(request.getFragmentSpec().getVertexName()) - .append(", FragmentNum=").append(request.getFragmentSpec().getFragmentNumber()) - .append(", Attempt=").append(request.getFragmentSpec().getAttemptNumber()); - return sb.toString(); - } - } - - private static class ConfParams { - final int amHeartbeatIntervalMsMax; - final long amCounterHeartbeatInterval; - final int amMaxEventsPerHeartbeat; - - public ConfParams(int amHeartbeatIntervalMsMax, long amCounterHeartbeatInterval, - int amMaxEventsPerHeartbeat) { - this.amHeartbeatIntervalMsMax = amHeartbeatIntervalMsMax; - this.amCounterHeartbeatInterval = amCounterHeartbeatInterval; - this.amMaxEventsPerHeartbeat = amMaxEventsPerHeartbeat; - } - } - private String stringifySourceStateUpdateRequest(SourceStateUpdatedRequestProto request) { StringBuilder sb = new StringBuilder(); sb.append("dagName=").append(request.getDagName()) @@ -522,7 +226,7 @@ private String stringifySourceStateUpdateRequest(SourceStateUpdatedRequestProto return sb.toString(); } - private String stringifySubmitRequest(SubmitWorkRequestProto request) { + public static String stringifySubmitRequest(SubmitWorkRequestProto request) { StringBuilder sb = new StringBuilder(); sb.append("am_details=").append(request.getAmHost()).append(":").append(request.getAmPort()); sb.append(", user=").append(request.getUser()); diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java index eb8d64bd087129637b182349a677c778e8f49d24..86b1f5caf87e5cf181fb10244a48e7ba10e95237 100644 --- a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/LlapDaemon.java @@ -68,8 +68,8 @@ private final AtomicReference address = new AtomicReference(); public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemoryBytes, - boolean ioEnabled, long ioMemoryBytes, String[] localDirs, int rpcPort, - int shufflePort) { + boolean ioEnabled, long ioMemoryBytes, String[] localDirs, int rpcPort, + int shufflePort) { super("LlapDaemon"); printAsciiArt(); @@ -89,6 +89,12 @@ public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemor this.numExecutors = numExecutors; this.localDirs = localDirs; + int waitQueueSize = daemonConf.getInt( + LlapConfiguration.LLAP_DAEMON_TASK_SCHEDULER_WAIT_QUEUE_SIZE, + LlapConfiguration.LLAP_DAEMON_TASK_SCHEDULER_WAIT_QUEUE_SIZE_DEFAULT); + boolean enablePreemption = daemonConf.getBoolean( + LlapConfiguration.LLAP_DAEMON_TASK_SCHEDULER_ENABLE_PREEMPTION, + LlapConfiguration.LLAP_DAEMON_TASK_SCHEDULER_ENABLE_PREEMPTION_DEFAULT); LOG.info("Attempting to start LlapDaemonConf with the following configuration: " + "numExecutors=" + numExecutors + ", rpcListenerPort=" + rpcPort + @@ -97,7 +103,9 @@ public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemor ", executorMemory=" + executorMemoryBytes + ", llapIoEnabled=" + ioEnabled + ", llapIoCacheSize=" + ioMemoryBytes + - ", jvmAvailableMemory=" + maxJvmMemory); + ", jvmAvailableMemory=" + maxJvmMemory + + ", waitQueueSize= " + waitQueueSize + + ", enablePreemption= " + enablePreemption); long memRequired = executorMemoryBytes + (ioEnabled ? ioMemoryBytes : 0); Preconditions.checkState(maxJvmMemory >= memRequired, @@ -131,9 +139,16 @@ public LlapDaemon(Configuration daemonConf, int numExecutors, long executorMemor LOG.info("Started LlapMetricsSystem with displayName: " + displayName + " sessionId: " + sessionId); - this.containerRunner = new ContainerRunnerImpl(daemonConf, numExecutors, localDirs, shufflePort, address, - executorMemoryBytes, metrics); - + this.containerRunner = new ContainerRunnerImpl(daemonConf, + numExecutors, + waitQueueSize, + enablePreemption, + localDirs, + shufflePort, + address, + executorMemoryBytes, + metrics); + this.registry = new LlapRegistryService(); } @@ -202,24 +217,25 @@ public static void main(String[] args) throws Exception { // Cache settings will need to be setup in llap-daemon-site.xml - since the daemons don't read hive-site.xml // Ideally, these properties should be part of LlapDameonConf rather than HiveConf LlapConfiguration daemonConf = new LlapConfiguration(); - int numExecutors = daemonConf.getInt(LlapConfiguration.LLAP_DAEMON_NUM_EXECUTORS, - LlapConfiguration.LLAP_DAEMON_NUM_EXECUTORS_DEFAULT); - String[] localDirs = - daemonConf.getTrimmedStrings(LlapConfiguration.LLAP_DAEMON_WORK_DIRS); - int rpcPort = daemonConf.getInt(LlapConfiguration.LLAP_DAEMON_RPC_PORT, - LlapConfiguration.LLAP_DAEMON_RPC_PORT_DEFAULT); - int shufflePort = daemonConf - .getInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, ShuffleHandler.DEFAULT_SHUFFLE_PORT); - long executorMemoryBytes = daemonConf - .getInt(LlapConfiguration.LLAP_DAEMON_MEMORY_PER_INSTANCE_MB, - LlapConfiguration.LLAP_DAEMON_MEMORY_PER_INSTANCE_MB_DEFAULT) * 1024l * 1024l; - long cacheMemoryBytes = - HiveConf.getLongVar(daemonConf, HiveConf.ConfVars.LLAP_ORC_CACHE_MAX_SIZE); - boolean llapIoEnabled = HiveConf.getBoolVar(daemonConf, HiveConf.ConfVars.LLAP_IO_ENABLED); - llapDaemon = - new LlapDaemon(daemonConf, numExecutors, executorMemoryBytes, llapIoEnabled, - cacheMemoryBytes, localDirs, - rpcPort, shufflePort); + int numExecutors = daemonConf.getInt(LlapConfiguration.LLAP_DAEMON_NUM_EXECUTORS, + LlapConfiguration.LLAP_DAEMON_NUM_EXECUTORS_DEFAULT); + + String[] localDirs = + daemonConf.getTrimmedStrings(LlapConfiguration.LLAP_DAEMON_WORK_DIRS); + int rpcPort = daemonConf.getInt(LlapConfiguration.LLAP_DAEMON_RPC_PORT, + LlapConfiguration.LLAP_DAEMON_RPC_PORT_DEFAULT); + int shufflePort = daemonConf + .getInt(ShuffleHandler.SHUFFLE_PORT_CONFIG_KEY, ShuffleHandler.DEFAULT_SHUFFLE_PORT); + long executorMemoryBytes = daemonConf + .getInt(LlapConfiguration.LLAP_DAEMON_MEMORY_PER_INSTANCE_MB, + LlapConfiguration.LLAP_DAEMON_MEMORY_PER_INSTANCE_MB_DEFAULT) * 1024l * 1024l; + long cacheMemoryBytes = + HiveConf.getLongVar(daemonConf, HiveConf.ConfVars.LLAP_ORC_CACHE_MAX_SIZE); + boolean llapIoEnabled = HiveConf.getBoolVar(daemonConf, HiveConf.ConfVars.LLAP_IO_ENABLED); + llapDaemon = + new LlapDaemon(daemonConf, numExecutors, executorMemoryBytes, llapIoEnabled, + cacheMemoryBytes, localDirs, + rpcPort, shufflePort); llapDaemon.init(daemonConf); llapDaemon.start(); diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/Scheduler.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/Scheduler.java new file mode 100644 index 0000000000000000000000000000000000000000..c3102f98b73f9f6d77666d9291f1ff725e36ff04 --- /dev/null +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/Scheduler.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.llap.daemon.impl; + +import java.util.concurrent.RejectedExecutionException; + +/** + * Task scheduler interface + */ +public interface Scheduler { + + /** + * Schedule the task or throw RejectedExecutionException if queues are full + * @param t - task to schedule + * @throws RejectedExecutionException + */ + void schedule(T t) throws RejectedExecutionException; +} diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java new file mode 100644 index 0000000000000000000000000000000000000000..1393028cf193d72897b16d08ba85bcceff9892cb --- /dev/null +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskExecutorService.java @@ -0,0 +1,421 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.llap.daemon.impl; + +import java.util.Comparator; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.log4j.Logger; +import org.apache.tez.runtime.task.TezChild; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +/** + * Task executor service provides method for scheduling tasks. Tasks submitted to executor service + * are submitted to wait queue for scheduling. Wait queue tasks are ordered based on the priority + * of the task. The internal wait queue scheduler moves tasks from wait queue when executor slots + * are available or when a higher priority task arrives and will schedule it for execution. + * When pre-emption is enabled, the tasks from wait queue can replace(pre-empt) a running task. + * The pre-empted task is reported back to the Application Master(AM) for it to be rescheduled. + *

+ * Because of the concurrent nature of task submission, the position of the task in wait queue is + * held as long the scheduling of the task from wait queue (with or without pre-emption) is complete. + * The order of pre-emption is based on the ordering in the pre-emption queue. All tasks that cannot + * run to completion immediately (canFinish = false) are added to pre-emption queue. + *

+ * When all the executor threads are occupied and wait queue is full, the task scheduler will + * throw RejectedExecutionException. + *

+ * Task executor service can be shut down which will terminated all running tasks and reject all + * new tasks. Shutting down of the task executor service can be done gracefully or immediately. + */ +public class TaskExecutorService implements Scheduler { + private static final Logger LOG = Logger.getLogger(TaskExecutorService.class); + private static final boolean isInfoEnabled = LOG.isInfoEnabled(); + private static final boolean isDebugEnabled = LOG.isDebugEnabled(); + private static final boolean isTraceEnabled = LOG.isTraceEnabled(); + private static final String TASK_EXECUTOR_THREAD_NAME_FORMAT = "Task-Executor-%d"; + private static final String WAIT_QUEUE_SCHEDULER_THREAD_NAME_FORMAT = "Wait-Queue-Scheduler-%d"; + + // some object to lock upon. Used by task scheduler to notify wait queue scheduler of new items + // to wait queue + private final Object waitLock; + private final ListeningExecutorService executorService; + private final BlockingQueue waitQueue; + private final ListeningExecutorService waitQueueExecutorService; + private final Map idToTaskMap; + private final Map> preemptionMap; + private final BlockingQueue preemptionQueue; + private final boolean enablePreemption; + private final ThreadPoolExecutor threadPoolExecutor; + private final AtomicInteger numSlotsAvailable; + + public TaskExecutorService(int numExecutors, int waitQueueSize, boolean enablePreemption) { + this.waitLock = new Object(); + this.waitQueue = new BoundedPriorityBlockingQueue<>(new WaitQueueComparator(), waitQueueSize); + this.threadPoolExecutor = new ThreadPoolExecutor(numExecutors, // core pool size + numExecutors, // max pool size + 1, TimeUnit.MINUTES, + new SynchronousQueue(), // direct hand-off + new ThreadFactoryBuilder().setNameFormat(TASK_EXECUTOR_THREAD_NAME_FORMAT).build()); + this.executorService = MoreExecutors.listeningDecorator(threadPoolExecutor); + this.idToTaskMap = new ConcurrentHashMap<>(); + this.preemptionMap = new ConcurrentHashMap<>(); + this.preemptionQueue = new PriorityBlockingQueue<>(numExecutors, + new PreemptionQueueComparator()); + this.enablePreemption = enablePreemption; + this.numSlotsAvailable = new AtomicInteger(numExecutors); + + // single threaded scheduler for tasks from wait queue to executor threads + ExecutorService wes = Executors.newFixedThreadPool(1, new ThreadFactoryBuilder() + .setNameFormat(WAIT_QUEUE_SCHEDULER_THREAD_NAME_FORMAT).build()); + this.waitQueueExecutorService = MoreExecutors.listeningDecorator(wes); + ListenableFuture future = waitQueueExecutorService.submit(new WaitQueueWorker()); + Futures.addCallback(future, new WaitQueueWorkerCallback()); + } + + /** + * Worker that takes tasks from wait queue and schedule it for execution. + */ + private final class WaitQueueWorker implements Runnable { + TaskRunnerCallable task; + + @Override + public void run() { + try { + if (waitQueue.isEmpty()) { + synchronized (waitLock) { + waitLock.wait(); + } + } + + // Since schedule() can be called from multiple threads, we peek the wait queue, + // try scheduling the task and then remove the task if scheduling is successful. + // This will make sure the task's place in the wait queue is held until it gets scheduled. + while ((task = waitQueue.peek()) != null) { + + // if the task cannot finish and if no slots are available then don't schedule it. + // TODO: Event notifications that change canFinish state should notify waitLock + if (!task.canFinish() && numSlotsAvailable.get() == 0) { + synchronized (waitLock) { + waitLock.wait(); + } + } + + boolean scheduled = trySchedule(task); + if (scheduled) { + // wait queue could have been re-ordered in the mean time because of concurrent task + // submission. So remove the specific task instead of the head task. + waitQueue.remove(task); + } + + if (waitQueue.isEmpty()) { + synchronized (waitLock) { + waitLock.wait(); + } + } + } + + } catch (InterruptedException e) { + // Executor service will create new thread if the current thread gets interrupted. We don't + // need to do anything with the exception. + LOG.info(WAIT_QUEUE_SCHEDULER_THREAD_NAME_FORMAT + " thread has been interrupted."); + } + } + } + + private class WaitQueueWorkerCallback implements FutureCallback { + + @Override + public void onSuccess(Object result) { + LOG.error("Wait queue scheduler worker exited with success!"); + } + + @Override + public void onFailure(Throwable t) { + LOG.error("Wait queue scheduler worker exited with failure!"); + } + } + + @Override + public void schedule(TaskRunnerCallable task) throws RejectedExecutionException { + if (waitQueue.offer(task)) { + if (isDebugEnabled) { + LOG.debug(task.getRequestId() + " added to wait queue."); + } + + synchronized (waitLock) { + waitLock.notify(); + } + } else { + throw new RejectedExecutionException("Queues are full. Rejecting request."); + } + } + + private boolean trySchedule(TaskRunnerCallable task) { + + boolean scheduled = false; + try { + ListenableFuture future = executorService.submit(task); + FutureCallback wrappedCallback = + new InternalCompletionListener(task.getCallback()); + Futures.addCallback(future, wrappedCallback); + + if (isInfoEnabled) { + LOG.info(task.getRequestId() + " scheduled for execution."); + } + + // only tasks that cannot finish immediately are pre-emptable. In other words, if all inputs + // to the tasks are not ready yet, the task is eligible for pre-emptable. + if (enablePreemption && !task.canFinish()) { + if (isDebugEnabled) { + LOG.debug(task.getRequestId() + " is not finishable and pre-emption is enabled." + + "Adding it to pre-emption queue."); + } + addTaskToPreemptionList(task, future); + } + + numSlotsAvailable.decrementAndGet(); + scheduled = true; + } catch (RejectedExecutionException e) { + + if (enablePreemption && task.canFinish() && !preemptionQueue.isEmpty()) { + + if (isTraceEnabled) { + LOG.trace("idToTaskMap: " + idToTaskMap.keySet()); + LOG.trace("preemptionMap: " + preemptionMap.keySet()); + LOG.trace("preemptionQueue: " + preemptionQueue); + } + + TaskRunnerCallable pRequest = preemptionQueue.remove(); + + // if some task completes, it will remove itself from pre-emptions lists make this null. + // if it happens bail out and schedule it again as a free slot will be available. + if (pRequest != null) { + + if (isDebugEnabled) { + LOG.debug(pRequest.getRequestId() + " is chosen for pre-emption."); + } + + ListenableFuture pFuture = preemptionMap.get(pRequest); + + // if pFuture is null, then it must have been completed and be removed from preemption map + if (pFuture != null) { + if (isDebugEnabled) { + LOG.debug("Pre-emption invoked for " + pRequest.getRequestId() + + " by interrupting the thread."); + } + pFuture.cancel(true); + removeTaskFromPreemptionList(pRequest, pRequest.getRequestId()); + + // future is cancelled or completed normally, in which case schedule the new request + if (pFuture.isDone() && pFuture.isCancelled()) { + if (isDebugEnabled) { + LOG.debug(pRequest.getRequestId() + " request preempted by " + task.getRequestId()); + } + + notifyAM(pRequest); + } + } + + // try to submit the task from wait queue to executor service. If it gets rejected the + // task from wait queue will hold on to its position for next try. + try { + ListenableFuture future = executorService + .submit(task); + Futures.addCallback(future, task.getCallback()); + numSlotsAvailable.decrementAndGet(); + scheduled = true; + if (isDebugEnabled) { + LOG.debug("Request " + task.getRequestId() + " from wait queue submitted" + + " to executor service."); + } + } catch (RejectedExecutionException e1) { + + // This should not happen as we just freed a slot from executor service by pre-emption, + // which cannot be claimed by other tasks as trySchedule() is serially executed. + scheduled = false; + if (isDebugEnabled) { + LOG.debug("Request " + task.getRequestId() + " from wait queue rejected by" + + " executor service."); + } + } + } + } + } + + return scheduled; + } + + private synchronized void removeTaskFromPreemptionList(TaskRunnerCallable pRequest, + String requestId) { + idToTaskMap.remove(requestId); + preemptionMap.remove(pRequest); + preemptionQueue.remove(pRequest); + } + + private synchronized void addTaskToPreemptionList(TaskRunnerCallable task, + ListenableFuture future) { + idToTaskMap.put(task.getRequestId(), task); + preemptionMap.put(task, future); + preemptionQueue.add(task); + } + + private final class InternalCompletionListener implements + FutureCallback { + private TaskRunnerCallable.TaskRunnerCallback wrappedCallback; + + public InternalCompletionListener(TaskRunnerCallable.TaskRunnerCallback wrappedCallback) { + this.wrappedCallback = wrappedCallback; + } + + @Override + public void onSuccess(TezChild.ContainerExecutionResult result) { + wrappedCallback.onSuccess(result); + updatePreemptionListAndNotify(true); + } + + @Override + public void onFailure(Throwable t) { + wrappedCallback.onFailure(t); + updatePreemptionListAndNotify(false); + } + + private void updatePreemptionListAndNotify(boolean success) { + // if this task was added to pre-emption list, remove it + String taskId = wrappedCallback.getRequestId(); + TaskRunnerCallable task = idToTaskMap.get(taskId); + String state = success ? "succeeded" : "failed"; + if (enablePreemption && task != null) { + removeTaskFromPreemptionList(task, taskId); + if (isDebugEnabled) { + LOG.debug(task.getRequestId() + " request " + state + "! Removed from preemption list."); + } + } + + numSlotsAvailable.incrementAndGet(); + if (!waitQueue.isEmpty()) { + synchronized (waitLock) { + waitLock.notify(); + } + } + } + + } + + private void notifyAM(TaskRunnerCallable request) { + // TODO: Report to AM of pre-emption and rejection + LOG.info("Notifying to AM of preemption is not implemented yet!"); + } + + // TODO: llap daemon should call this to gracefully shutdown the task executor service + public void shutDown(boolean awaitTermination) { + if (awaitTermination) { + if (isDebugEnabled) { + LOG.debug("awaitTermination: " + awaitTermination + " shutting down task executor" + + " service gracefully"); + } + executorService.shutdown(); + try { + if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) { + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + } + + waitQueueExecutorService.shutdown(); + try { + if (!waitQueueExecutorService.awaitTermination(1, TimeUnit.MINUTES)) { + waitQueueExecutorService.shutdownNow(); + } + } catch (InterruptedException e) { + waitQueueExecutorService.shutdownNow(); + } + } else { + if (isDebugEnabled) { + LOG.debug("awaitTermination: " + awaitTermination + " shutting down task executor" + + " service immediately"); + } + executorService.shutdownNow(); + waitQueueExecutorService.shutdownNow(); + } + } + + @VisibleForTesting + public int getPreemptionListSize() { + return preemptionMap.size(); + } + + @VisibleForTesting + public TaskRunnerCallable getPreemptionTask() { + return preemptionQueue.peek(); + } + + @VisibleForTesting + public static class WaitQueueComparator implements Comparator { + + @Override + public int compare(TaskRunnerCallable o1, TaskRunnerCallable o2) { + boolean newCanFinish = o1.canFinish(); + boolean oldCanFinish = o2.canFinish(); + if (newCanFinish == true && oldCanFinish == false) { + return -1; + } else if (newCanFinish == false && oldCanFinish == true) { + return 1; + } + + if (o1.getVertexParallelism() > o2.getVertexParallelism()) { + return 1; + } else if (o1.getVertexParallelism() < o2.getVertexParallelism()) { + return -1; + } + return 0; + } + } + + @VisibleForTesting + public static class PreemptionQueueComparator implements Comparator { + + @Override + public int compare(TaskRunnerCallable o1, TaskRunnerCallable o2) { + if (o1.getVertexParallelism() > o2.getVertexParallelism()) { + return 1; + } else if (o1.getVertexParallelism() < o2.getVertexParallelism()) { + return -1; + } + return 0; + } + } +} diff --git a/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java new file mode 100644 index 0000000000000000000000000000000000000000..2bdfee5842bf81af5bb657227fa567f3c3a4ae50 --- /dev/null +++ b/llap-server/src/java/org/apache/hadoop/hive/llap/daemon/impl/TaskRunnerCallable.java @@ -0,0 +1,392 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.llap.daemon.impl; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.common.CallableWithNdc; +import org.apache.hadoop.hive.llap.daemon.HistoryLogger; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SubmitWorkRequestProto; +import org.apache.hadoop.hive.llap.metrics.LlapDaemonExecutorMetrics; +import org.apache.hadoop.hive.llap.protocol.LlapTaskUmbilicalProtocol; +import org.apache.hadoop.hive.llap.tezplugins.Converters; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.SecurityUtil; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.log4j.Logger; +import org.apache.tez.common.TezCommonUtils; +import org.apache.tez.common.security.JobTokenIdentifier; +import org.apache.tez.common.security.TokenCache; +import org.apache.tez.dag.api.TezConstants; +import org.apache.tez.dag.api.TezException; +import org.apache.tez.mapreduce.input.MRInputLegacy; +import org.apache.tez.runtime.api.ExecutionContext; +import org.apache.tez.runtime.api.impl.InputSpec; +import org.apache.tez.runtime.api.impl.TaskSpec; +import org.apache.tez.runtime.common.objectregistry.ObjectRegistryImpl; +import org.apache.tez.runtime.internals.api.TaskReporterInterface; +import org.apache.tez.runtime.task.TezChild; +import org.apache.tez.runtime.task.TezTaskRunner; + +import com.google.common.base.Stopwatch; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +/** + * + */ +public class TaskRunnerCallable extends CallableWithNdc { + private static final Logger LOG = Logger.getLogger(TaskRunnerCallable.class); + private final LlapDaemonProtocolProtos.SubmitWorkRequestProto request; + private final Configuration conf; + private final String[] localDirs; + private final Map envMap; + private final String pid = null; + private final ObjectRegistryImpl objectRegistry; + private final ExecutionContext executionContext; + private final Credentials credentials; + private final long memoryAvailable; + private final ConfParams confParams; + private final Token jobToken; + private final AMReporter amReporter; + private final ConcurrentMap sourceCompletionMap; + private final TaskSpec taskSpec; + private volatile TezTaskRunner taskRunner; + private volatile TaskReporterInterface taskReporter; + private volatile ListeningExecutorService executor; + private LlapTaskUmbilicalProtocol umbilical; + private volatile long startTime; + private volatile String threadName; + private LlapDaemonExecutorMetrics metrics; + protected String requestId; + + TaskRunnerCallable(LlapDaemonProtocolProtos.SubmitWorkRequestProto request, Configuration conf, + ExecutionContext executionContext, Map envMap, + String[] localDirs, Credentials credentials, + long memoryAvailable, AMReporter amReporter, + ConcurrentMap sourceCompletionMap, + ConfParams confParams, LlapDaemonExecutorMetrics metrics) { + this.request = request; + this.conf = conf; + this.executionContext = executionContext; + this.envMap = envMap; + this.localDirs = localDirs; + this.objectRegistry = new ObjectRegistryImpl(); + this.sourceCompletionMap = sourceCompletionMap; + this.credentials = credentials; + this.memoryAvailable = memoryAvailable; + this.confParams = confParams; + this.jobToken = TokenCache.getSessionToken(credentials); + this.taskSpec = Converters.getTaskSpecfromProto(request.getFragmentSpec()); + this.amReporter = amReporter; + // Register with the AMReporter when the callable is setup. Unregister once it starts running. + if (jobToken != null) { + this.amReporter.registerTask(request.getAmHost(), request.getAmPort(), + request.getUser(), jobToken); + } + this.metrics = metrics; + this.requestId = getTaskAttemptId(request); + } + + @Override + protected TezChild.ContainerExecutionResult callInternal() throws Exception { + this.startTime = System.currentTimeMillis(); + this.threadName = Thread.currentThread().getName(); + if (LOG.isDebugEnabled()) { + LOG.debug("canFinish: " + taskSpec.getTaskAttemptID() + ": " + canFinish()); + } + + // Unregister from the AMReporter, since the task is now running. + this.amReporter.unregisterTask(request.getAmHost(), request.getAmPort()); + + // TODO This executor seems unnecessary. Here and TezChild + ExecutorService executorReal = Executors.newFixedThreadPool(1, + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat( + "TezTaskRunner_" + request.getFragmentSpec().getTaskAttemptIdString()) + .build()); + executor = MoreExecutors.listeningDecorator(executorReal); + + // TODO Consolidate this code with TezChild. + Stopwatch sw = new Stopwatch().start(); + UserGroupInformation taskUgi = UserGroupInformation.createRemoteUser(request.getUser()); + taskUgi.addCredentials(credentials); + + Map serviceConsumerMetadata = new HashMap<>(); + serviceConsumerMetadata.put(TezConstants.TEZ_SHUFFLE_HANDLER_SERVICE_ID, + TezCommonUtils.convertJobTokenToBytes(jobToken)); + Multimap startedInputsMap = HashMultimap.create(); + + UserGroupInformation taskOwner = + UserGroupInformation.createRemoteUser(request.getTokenIdentifier()); + final InetSocketAddress address = + NetUtils.createSocketAddrForHost(request.getAmHost(), request.getAmPort()); + SecurityUtil.setTokenService(jobToken, address); + taskOwner.addToken(jobToken); + umbilical = taskOwner.doAs(new PrivilegedExceptionAction() { + @Override + public LlapTaskUmbilicalProtocol run() throws Exception { + return RPC.getProxy(LlapTaskUmbilicalProtocol.class, + LlapTaskUmbilicalProtocol.versionID, address, conf); + } + }); + + taskReporter = new LlapTaskReporter( + umbilical, + confParams.amHeartbeatIntervalMsMax, + confParams.amCounterHeartbeatInterval, + confParams.amMaxEventsPerHeartbeat, + new AtomicLong(0), + request.getContainerIdString()); + + taskRunner = new TezTaskRunner(conf, taskUgi, localDirs, + taskSpec, + request.getAppAttemptNumber(), + serviceConsumerMetadata, envMap, startedInputsMap, taskReporter, executor, objectRegistry, + pid, + executionContext, memoryAvailable); + + boolean shouldDie; + try { + shouldDie = !taskRunner.run(); + if (shouldDie) { + LOG.info("Got a shouldDie notification via heartbeats. Shutting down"); + return new TezChild.ContainerExecutionResult( + TezChild.ContainerExecutionResult.ExitStatus.SUCCESS, null, + "Asked to die by the AM"); + } + } catch (IOException e) { + return new TezChild.ContainerExecutionResult( + TezChild.ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, + e, "TaskExecutionFailure: " + e.getMessage()); + } catch (TezException e) { + return new TezChild.ContainerExecutionResult( + TezChild.ContainerExecutionResult.ExitStatus.EXECUTION_FAILURE, + e, "TaskExecutionFailure: " + e.getMessage()); + } finally { + // TODO Fix UGI and FS Handling. Closing UGI here causes some errors right now. + // FileSystem.closeAllForUGI(taskUgi); + } + LOG.info("ExecutionTime for Container: " + request.getContainerIdString() + "=" + + sw.stop().elapsedMillis()); + if (LOG.isDebugEnabled()) { + LOG.debug("canFinish post completion: " + taskSpec.getTaskAttemptID() + ": " + canFinish()); + } + + return new TezChild.ContainerExecutionResult( + TezChild.ContainerExecutionResult.ExitStatus.SUCCESS, null, + null); + } + + /** + * Check whether a task can run to completion or may end up blocking on it's sources. + * This currently happens via looking up source state. + * TODO: Eventually, this should lookup the Hive Processor to figure out whether + * it's reached a state where it can finish - especially in cases of failures + * after data has been fetched. + * + * @return + */ + public boolean canFinish() { + List inputSpecList = taskSpec.getInputs(); + boolean canFinish = true; + if (inputSpecList != null && !inputSpecList.isEmpty()) { + for (InputSpec inputSpec : inputSpecList) { + if (isSourceOfInterest(inputSpec)) { + // Lookup the state in the map. + LlapDaemonProtocolProtos.SourceStateProto state = sourceCompletionMap + .get(inputSpec.getSourceVertexName()); + if (state != null && state == LlapDaemonProtocolProtos.SourceStateProto.S_SUCCEEDED) { + continue; + } else { + if (LOG.isDebugEnabled()) { + LOG.debug("Cannot finish due to source: " + inputSpec.getSourceVertexName()); + } + canFinish = false; + break; + } + } + } + } + return canFinish; + } + + private boolean isSourceOfInterest(InputSpec inputSpec) { + String inputClassName = inputSpec.getInputDescriptor().getClassName(); + // MRInput is not of interest since it'll always be ready. + return !inputClassName.equals(MRInputLegacy.class.getName()); + } + + public void shutdown() { + if (executor != null) { + executor.shutdownNow(); + } + if (taskReporter != null) { + taskReporter.shutdown(); + } + if (umbilical != null) { + RPC.stopProxy(umbilical); + } + } + + @Override + public String toString() { + return requestId; + } + + @Override + public int hashCode() { + return requestId.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof TaskRunnerCallable)) { + return false; + } + return requestId.equals(((TaskRunnerCallable) obj).getRequestId()); + } + + public int getVertexParallelism() { + return request.getFragmentSpec().getVertexParallelism(); + } + + public String getRequestId() { + return requestId; + } + + public TaskRunnerCallback getCallback() { + return new TaskRunnerCallback(request, this); + } + + final class TaskRunnerCallback implements FutureCallback { + + private final LlapDaemonProtocolProtos.SubmitWorkRequestProto request; + private final TaskRunnerCallable taskRunnerCallable; + private final String requestId; + + TaskRunnerCallback(LlapDaemonProtocolProtos.SubmitWorkRequestProto request, + TaskRunnerCallable taskRunnerCallable) { + this.request = request; + this.taskRunnerCallable = taskRunnerCallable; + this.requestId = getTaskIdentifierString(request); + } + + public String getRequestId() { + return requestId; + } + + // TODO Slightly more useful error handling + @Override + public void onSuccess(TezChild.ContainerExecutionResult result) { + switch (result.getExitStatus()) { + case SUCCESS: + LOG.info("Successfully finished: " + requestId); + metrics.incrExecutorTotalSuccess(); + break; + case EXECUTION_FAILURE: + LOG.info("Failed to run: " + requestId); + metrics.incrExecutorTotalExecutionFailed(); + break; + case INTERRUPTED: + LOG.info("Interrupted while running: " + requestId); + metrics.incrExecutorTotalInterrupted(); + break; + case ASKED_TO_DIE: + LOG.info("Asked to die while running: " + requestId); + metrics.incrExecutorTotalAskedToDie(); + break; + } + taskRunnerCallable.shutdown(); + HistoryLogger + .logFragmentEnd(request.getApplicationIdString(), request.getContainerIdString(), + executionContext.getHostName(), request.getFragmentSpec().getDagName(), + request.getFragmentSpec().getVertexName(), + request.getFragmentSpec().getFragmentNumber(), + request.getFragmentSpec().getAttemptNumber(), taskRunnerCallable.threadName, + taskRunnerCallable.startTime, true); + metrics.decrExecutorNumQueuedRequests(); + } + + @Override + public void onFailure(Throwable t) { + LOG.error("TezTaskRunner execution failed for : " + getTaskIdentifierString(request), t); + // TODO HIVE-10236 Report a fatal error over the umbilical + taskRunnerCallable.shutdown(); + HistoryLogger + .logFragmentEnd(request.getApplicationIdString(), request.getContainerIdString(), + executionContext.getHostName(), request.getFragmentSpec().getDagName(), + request.getFragmentSpec().getVertexName(), + request.getFragmentSpec().getFragmentNumber(), + request.getFragmentSpec().getAttemptNumber(), taskRunnerCallable.threadName, + taskRunnerCallable.startTime, false); + if (metrics != null) { + metrics.decrExecutorNumQueuedRequests(); + } + } + } + + public static class ConfParams { + final int amHeartbeatIntervalMsMax; + final long amCounterHeartbeatInterval; + final int amMaxEventsPerHeartbeat; + + public ConfParams(int amHeartbeatIntervalMsMax, long amCounterHeartbeatInterval, + int amMaxEventsPerHeartbeat) { + this.amHeartbeatIntervalMsMax = amHeartbeatIntervalMsMax; + this.amCounterHeartbeatInterval = amCounterHeartbeatInterval; + this.amMaxEventsPerHeartbeat = amMaxEventsPerHeartbeat; + } + } + + public static String getTaskIdentifierString( + LlapDaemonProtocolProtos.SubmitWorkRequestProto request) { + StringBuilder sb = new StringBuilder(); + sb.append("AppId=").append(request.getApplicationIdString()) + .append(", containerId=").append(request.getContainerIdString()) + .append(", Dag=").append(request.getFragmentSpec().getDagName()) + .append(", Vertex=").append(request.getFragmentSpec().getVertexName()) + .append(", FragmentNum=").append(request.getFragmentSpec().getFragmentNumber()) + .append(", Attempt=").append(request.getFragmentSpec().getAttemptNumber()); + return sb.toString(); + } + + private String getTaskAttemptId(SubmitWorkRequestProto request) { + return request.getFragmentSpec().getTaskAttemptIdString(); + } +} diff --git a/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java b/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java new file mode 100644 index 0000000000000000000000000000000000000000..44a4633eb89feeca403b8d98d80b78e72cf739eb --- /dev/null +++ b/llap-server/src/test/org/apache/hadoop/hive/llap/daemon/impl/TestTaskExecutorService.java @@ -0,0 +1,306 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.llap.daemon.impl; + +import static org.junit.Assert.assertEquals; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.RejectedExecutionException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.EntityDescriptorProto; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.FragmentSpecProto; +import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.SubmitWorkRequestProto; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.tez.dag.records.TezDAGID; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.apache.tez.dag.records.TezTaskID; +import org.apache.tez.dag.records.TezVertexID; +import org.apache.tez.runtime.api.impl.ExecutionContextImpl; +import org.apache.tez.runtime.task.TezChild; +import org.apache.tez.runtime.task.TezChild.ContainerExecutionResult; +import org.apache.tez.runtime.task.TezChild.ContainerExecutionResult.ExitStatus; +import org.junit.Before; +import org.junit.Test; + +public class TestTaskExecutorService { + private static Configuration conf; + private static Credentials cred = new Credentials(); + + private static class MockRequest extends TaskRunnerCallable { + private int workTime; + private boolean canFinish; + + public MockRequest(LlapDaemonProtocolProtos.SubmitWorkRequestProto requestProto, + boolean canFinish, int workTime) { + super(requestProto, conf, new ExecutionContextImpl("localhost"), null, null, cred, 0, null, + null, null, null); + this.workTime = workTime; + this.canFinish = canFinish; + } + + @Override + protected TezChild.ContainerExecutionResult callInternal() throws Exception { + System.out.println(requestId + " is executing.."); + Thread.sleep(workTime); + return new ContainerExecutionResult(ExitStatus.SUCCESS, null, null); + } + + @Override + public boolean canFinish() { + return canFinish; + } + } + + @Before + public void setup() { + conf = new Configuration(); + } + + private SubmitWorkRequestProto createRequest(int fragmentNumber, int parallelism) { + ApplicationId appId = ApplicationId.newInstance(9999, 72); + TezDAGID dagId = TezDAGID.getInstance(appId, 1); + TezVertexID vId = TezVertexID.getInstance(dagId, 35); + TezTaskID tId = TezTaskID.getInstance(vId, 389); + TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, fragmentNumber); + return SubmitWorkRequestProto + .newBuilder() + .setFragmentSpec( + FragmentSpecProto + .newBuilder() + .setAttemptNumber(0) + .setDagName("MockDag") + .setFragmentNumber(fragmentNumber) + .setVertexName("MockVertex") + .setVertexParallelism(parallelism) + .setProcessorDescriptor( + EntityDescriptorProto.newBuilder().setClassName("MockProcessor").build()) + .setTaskAttemptIdString(taId.toString()).build()).setAmHost("localhost") + .setAmPort(12345).setAppAttemptNumber(0).setApplicationIdString("MockApp_1") + .setContainerIdString("MockContainer_1").setUser("MockUser") + .setTokenIdentifier("MockToken_1").build(); + } + + + @Test(expected = RejectedExecutionException.class) + public void testThreadPoolRejection() throws InterruptedException { + TaskExecutorService scheduler = new TaskExecutorService(2, 2, false); + scheduler.schedule(new MockRequest(createRequest(1, 4), true, 1000)); + Thread.sleep(100); + scheduler.schedule(new MockRequest(createRequest(2, 4), true, 1000)); + Thread.sleep(100); + assertEquals(0, scheduler.getPreemptionListSize()); + scheduler.schedule(new MockRequest(createRequest(3, 4), true, 1000)); + Thread.sleep(100); + scheduler.schedule(new MockRequest(createRequest(4, 4), true, 1000)); + Thread.sleep(100); + assertEquals(0, scheduler.getPreemptionListSize()); + // this request should be rejected + scheduler.schedule(new MockRequest(createRequest(5, 8), true, 1000)); + } + + @Test + public void testPreemption() throws InterruptedException { + TaskExecutorService scheduler = new TaskExecutorService(2, 2, true); + scheduler.schedule(new MockRequest(createRequest(1, 4), false, 100000)); + Thread.sleep(100); + scheduler.schedule(new MockRequest(createRequest(2, 4), false, 100000)); + Thread.sleep(100); + assertEquals(2, scheduler.getPreemptionListSize()); + // these should invoke preemption + scheduler.schedule(new MockRequest(createRequest(3, 8), true, 1000)); + Thread.sleep(100); + scheduler.schedule(new MockRequest(createRequest(4, 8), true, 1000)); + Thread.sleep(100); + assertEquals(0, scheduler.getPreemptionListSize()); + } + + @Test + public void testPreemptionOrder() throws InterruptedException { + TaskExecutorService scheduler = new TaskExecutorService(2, 2, true); + MockRequest r1 = new MockRequest(createRequest(1, 4), false, 100000); + scheduler.schedule(r1); + Thread.sleep(100); + MockRequest r2 = new MockRequest(createRequest(2, 4), false, 100000); + scheduler.schedule(r2); + Thread.sleep(100); + assertEquals(r1, scheduler.getPreemptionTask()); + // these should invoke preemption + scheduler.schedule(new MockRequest(createRequest(3, 8), true, 1000)); + // wait till pre-emption to kick-in and complete + Thread.sleep(100); + assertEquals(r2, scheduler.getPreemptionTask()); + scheduler.schedule(new MockRequest(createRequest(4, 8), true, 1000)); + // wait till pre-emption to kick-in and complete + Thread.sleep(100); + assertEquals(0, scheduler.getPreemptionListSize()); + } + + @Test + public void testWaitQueueComparator() throws InterruptedException { + MockRequest r1 = new MockRequest(createRequest(1, 2), false, 100000); + MockRequest r2 = new MockRequest(createRequest(2, 4), false, 100000); + MockRequest r3 = new MockRequest(createRequest(3, 6), false, 1000000); + MockRequest r4 = new MockRequest(createRequest(4, 8), false, 1000000); + MockRequest r5 = new MockRequest(createRequest(5, 10), false, 1000000); + BlockingQueue queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r1, queue.take()); + assertEquals(r2, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + + r1 = new MockRequest(createRequest(1, 2), true, 100000); + r2 = new MockRequest(createRequest(2, 4), true, 100000); + r3 = new MockRequest(createRequest(3, 6), true, 1000000); + r4 = new MockRequest(createRequest(4, 8), true, 1000000); + r5 = new MockRequest(createRequest(5, 10), true, 1000000); + queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r1, queue.take()); + assertEquals(r2, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + + r1 = new MockRequest(createRequest(1, 1), true, 100000); + r2 = new MockRequest(createRequest(2, 1), false, 100000); + r3 = new MockRequest(createRequest(3, 1), true, 1000000); + r4 = new MockRequest(createRequest(4, 1), false, 1000000); + r5 = new MockRequest(createRequest(5, 10), true, 1000000); + queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r1, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + assertEquals(r2, queue.take()); + + r1 = new MockRequest(createRequest(1, 2), true, 100000); + r2 = new MockRequest(createRequest(2, 4), false, 100000); + r3 = new MockRequest(createRequest(3, 6), true, 1000000); + r4 = new MockRequest(createRequest(4, 8), false, 1000000); + r5 = new MockRequest(createRequest(5, 10), true, 1000000); + queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r1, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r2, queue.take()); + assertEquals(r4, queue.take()); + + r1 = new MockRequest(createRequest(1, 2), true, 100000); + r2 = new MockRequest(createRequest(2, 4), false, 100000); + r3 = new MockRequest(createRequest(3, 6), false, 1000000); + r4 = new MockRequest(createRequest(4, 8), false, 1000000); + r5 = new MockRequest(createRequest(5, 10), true, 1000000); + queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r1, queue.take()); + assertEquals(r2, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + + r1 = new MockRequest(createRequest(1, 2), false, 100000); + r2 = new MockRequest(createRequest(2, 4), true, 100000); + r3 = new MockRequest(createRequest(3, 6), true, 1000000); + r4 = new MockRequest(createRequest(4, 8), true, 1000000); + r5 = new MockRequest(createRequest(5, 10), true, 1000000); + queue = new BoundedPriorityBlockingQueue( + new TaskExecutorService.WaitQueueComparator(), 4); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r2, queue.peek()); + queue.offer(r3); + assertEquals(r2, queue.peek()); + queue.offer(r4); + assertEquals(r2, queue.peek()); + assertEquals(false, queue.offer(r5)); + assertEquals(r2, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + assertEquals(r1, queue.take()); + } + + @Test + public void testPreemptionQueueComparator() throws InterruptedException { + MockRequest r1 = new MockRequest(createRequest(1, 2), false, 100000); + MockRequest r2 = new MockRequest(createRequest(2, 4), false, 100000); + MockRequest r3 = new MockRequest(createRequest(3, 6), false, 1000000); + MockRequest r4 = new MockRequest(createRequest(4, 8), false, 1000000); + BlockingQueue queue = new PriorityBlockingQueue(4, + new TaskExecutorService.PreemptionQueueComparator()); + queue.offer(r1); + assertEquals(r1, queue.peek()); + queue.offer(r2); + assertEquals(r1, queue.peek()); + queue.offer(r3); + assertEquals(r1, queue.peek()); + queue.offer(r4); + assertEquals(r1, queue.take()); + assertEquals(r2, queue.take()); + assertEquals(r3, queue.take()); + assertEquals(r4, queue.take()); + } +}