diff --git llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java similarity index 98% rename from llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java rename to llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java index f9ca677..2884e40 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java @@ -16,8 +16,6 @@ import javax.net.SocketFactory; -import java.io.ByteArrayInputStream; -import java.io.DataInputStream; import java.io.IOException; import java.security.PrivilegedAction; import java.util.HashSet; @@ -49,7 +47,6 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.llap.LlapNodeId; -import org.apache.hadoop.hive.llap.configuration.LlapConfiguration; import org.apache.hadoop.hive.llap.daemon.LlapDaemonProtocolBlockingPB; import org.apache.hadoop.hive.llap.daemon.impl.LlapDaemonProtocolClientImpl; import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.QueryCompleteRequestProto; @@ -71,9 +68,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class TaskCommunicator extends AbstractService { +public class LlapDaemonProtocolClientProxy extends AbstractService { - private static final Logger LOG = LoggerFactory.getLogger(TaskCommunicator.class); + private static final Logger LOG = LoggerFactory.getLogger(LlapDaemonProtocolClientProxy.class); private final ConcurrentMap hostProxies; @@ -85,9 +82,9 @@ private volatile ListenableFuture requestManagerFuture; private final Token llapToken; - public TaskCommunicator( + public LlapDaemonProtocolClientProxy( int numThreads, Configuration conf, Token llapToken) { - super(TaskCommunicator.class.getSimpleName()); + super(LlapDaemonProtocolClientProxy.class.getSimpleName()); this.hostProxies = new ConcurrentHashMap<>(); this.socketFactory = NetUtils.getDefaultSocketFactory(conf); this.llapToken = llapToken; diff --git llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java index ce248e9..69aa1f5 100644 --- llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java +++ llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; @@ -90,7 +91,7 @@ private final SourceStateTracker sourceStateTracker; private final Set nodesForQuery = new HashSet<>(); - private TaskCommunicator communicator; + private LlapDaemonProtocolClientProxy communicator; private long deleteDelayOnDagComplete; private final LlapTaskUmbilicalProtocol umbilical; private final Token token; @@ -133,7 +134,7 @@ public void initialize() throws Exception { super.initialize(); Configuration conf = getConf(); int numThreads = HiveConf.getIntVar(conf, ConfVars.LLAP_DAEMON_COMMUNICATOR_NUM_THREADS); - this.communicator = new TaskCommunicator(numThreads, conf, token); + this.communicator = new LlapDaemonProtocolClientProxy(numThreads, conf, token); this.deleteDelayOnDagComplete = HiveConf.getTimeVar( conf, ConfVars.LLAP_FILE_CLEANUP_DELAY_SECONDS, TimeUnit.SECONDS); LOG.info("Running LlapTaskCommunicator with " @@ -256,7 +257,7 @@ public void registerRunningTaskAttempt(final ContainerId containerId, final Task getContext() .taskStartedRemotely(taskSpec.getTaskAttemptID(), containerId); communicator.sendSubmitWork(requestProto, host, port, - new TaskCommunicator.ExecuteRequestCallback() { + new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() { @Override public void setResponse(SubmitWorkResponseProto response) { LOG.info("Successfully launched task: " + taskSpec.getTaskAttemptID()); @@ -330,14 +331,14 @@ private void sendTaskTerminated(final TezTaskAttemptID taskAttemptId, LOG.info( "DBG: Attempting to send terminateRequest for fragment {} due to internal preemption invoked by {}", taskAttemptId.toString(), invokedByContainerEnd ? "containerEnd" : "taskEnd"); - LlapNodeId nodeId = entityTracker.getNodeIfForTaskAttempt(taskAttemptId); + LlapNodeId nodeId = entityTracker.getNodeIdForTaskAttempt(taskAttemptId); // NodeId can be null if the task gets unregistered due to failure / being killed by the daemon itself if (nodeId != null) { TerminateFragmentRequestProto request = TerminateFragmentRequestProto.newBuilder().setDagName(currentDagName) .setFragmentIdentifierString(taskAttemptId.toString()).build(); communicator.sendTerminateFragment(request, nodeId.getHostname(), nodeId.getPort(), - new TaskCommunicator.ExecuteRequestCallback() { + new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() { @Override public void setResponse(TerminateFragmentResponseProto response) { } @@ -362,7 +363,7 @@ public void dagComplete(final String dagName) { for (final LlapNodeId llapNodeId : nodesForQuery) { LOG.info("Sending dagComplete message for {}, to {}", dagName, llapNodeId); communicator.sendQueryComplete(request, llapNodeId.getHostname(), llapNodeId.getPort(), - new TaskCommunicator.ExecuteRequestCallback() { + new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() { @Override public void setResponse(LlapDaemonProtocolProtos.QueryCompleteResponseProto response) { } @@ -388,7 +389,7 @@ public void onVertexStateUpdated(VertexStateUpdate vertexStateUpdate) { public void sendStateUpdate(final String host, final int port, final SourceStateUpdatedRequestProto request) { communicator.sendSourceStateUpdate(request, host, port, - new TaskCommunicator.ExecuteRequestCallback() { + new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() { @Override public void setResponse(SourceStateUpdatedResponseProto response) { } @@ -405,7 +406,26 @@ public void indicateError(Throwable t) { }); } - + private final AtomicLong nodeNotFoundLogTime = new AtomicLong(0); + void nodePinged(String hostname, int port) { + LlapNodeId nodeId = LlapNodeId.getInstance(hostname, port); + BiMap biMap = entityTracker.getContainerAttemptMapForNode(hostname, port); + if (biMap != null) { + synchronized(biMap) { + for (Map.Entry entry : biMap.entrySet()) { + getContext().taskAlive(entry.getValue()); + getContext().containerAlive(entry.getKey()); + } + } + } else { + if (System.currentTimeMillis() > nodeNotFoundLogTime.get() + 5000l) { + LOG.warn("Received ping from unknown node: " + hostname + ":" + port + + ". Could be caused by pre-emption by the AM," + + " or a mismatched hostname. Enable debug logging for mismatched host names"); + nodeNotFoundLogTime.set(System.currentTimeMillis()); + } + } + } private void resetCurrentDag(String newDagName) { // Working on the assumption that a single DAG runs at a time per AM. @@ -451,6 +471,8 @@ private ByteBuffer serializeCredentials(Credentials credentials) throws IOExcept return ByteBuffer.wrap(containerTokens_dob.getData(), 0, containerTokens_dob.getLength()); } + + protected class LlapTaskUmbilicalProtocolImpl implements LlapTaskUmbilicalProtocol { private final TezTaskUmbilicalProtocol tezUmbilical; @@ -472,7 +494,7 @@ public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOExce @Override public void nodeHeartbeat(Text hostname, int port) throws IOException { - entityTracker.nodePinged(hostname.toString(), port); + nodePinged(hostname.toString(), port); if (LOG.isDebugEnabled()) { LOG.debug("Received heartbeat from [" + hostname + ":" + port +"]"); } @@ -499,10 +521,17 @@ public ProtocolSignature getProtocolSignature(String protocol, long clientVersio } } - private final class EntityTracker { - private final ConcurrentMap attemptToNodeMap = new ConcurrentHashMap<>(); - private final ConcurrentMap containerToNodeMap = new ConcurrentHashMap<>(); - private final ConcurrentMap> nodeMap = new ConcurrentHashMap<>(); + /** + * Track the association between known containers and taskAttempts, along with the nodes they are assigned to. + */ + @VisibleForTesting + static final class EntityTracker { + @VisibleForTesting + final ConcurrentMap attemptToNodeMap = new ConcurrentHashMap<>(); + @VisibleForTesting + final ConcurrentMap containerToNodeMap = new ConcurrentHashMap<>(); + @VisibleForTesting + final ConcurrentMap> nodeMap = new ConcurrentHashMap<>(); void registerTaskAttempt(ContainerId containerId, TezTaskAttemptID taskAttemptId, String host, int port) { if (LOG.isDebugEnabled()) { @@ -510,6 +539,10 @@ void registerTaskAttempt(ContainerId containerId, TezTaskAttemptID taskAttemptId } LlapNodeId llapNodeId = LlapNodeId.getInstance(host, port); attemptToNodeMap.putIfAbsent(taskAttemptId, llapNodeId); + + registerContainer(containerId, host, port); + + // nodeMap registration. BiMap tmpMap = HashBiMap.create(); BiMap old = nodeMap.putIfAbsent(llapNodeId, tmpMap); BiMap usedInstance; @@ -535,10 +568,9 @@ void unregisterTaskAttempt(TezTaskAttemptID attemptId) { synchronized(bMap) { matched = bMap.inverse().remove(attemptId); } - } - // Removing here. Registration into the map has to make sure to put - if (bMap.isEmpty()) { - nodeMap.remove(llapNodeId); + if (bMap.isEmpty()) { + nodeMap.remove(llapNodeId); + } } // Remove the container mapping @@ -549,23 +581,29 @@ void unregisterTaskAttempt(TezTaskAttemptID attemptId) { } void registerContainer(ContainerId containerId, String hostname, int port) { + if (LOG.isDebugEnabled()) { + LOG.debug("Registering " + containerId + " for node: " + hostname + ":" + port); + } containerToNodeMap.putIfAbsent(containerId, LlapNodeId.getInstance(hostname, port)); + // nodeMap registration is not required, since there's no taskId association. } LlapNodeId getNodeIdForContainer(ContainerId containerId) { return containerToNodeMap.get(containerId); } - LlapNodeId getNodeIfForTaskAttempt(TezTaskAttemptID taskAttemptId) { + LlapNodeId getNodeIdForTaskAttempt(TezTaskAttemptID taskAttemptId) { return attemptToNodeMap.get(taskAttemptId); } ContainerId getContainerIdForAttempt(TezTaskAttemptID taskAttemptId) { - LlapNodeId llapNodeId = getNodeIfForTaskAttempt(taskAttemptId); + LlapNodeId llapNodeId = getNodeIdForTaskAttempt(taskAttemptId); if (llapNodeId != null) { BiMap bMap = nodeMap.get(llapNodeId).inverse(); if (bMap != null) { - return bMap.get(taskAttemptId); + synchronized (bMap) { + return bMap.get(taskAttemptId); + } } else { return null; } @@ -579,7 +617,9 @@ TezTaskAttemptID getTaskAttemptIdForContainer(ContainerId containerId) { if (llapNodeId != null) { BiMap bMap = nodeMap.get(llapNodeId); if (bMap != null) { - return bMap.get(containerId); + synchronized (bMap) { + return bMap.get(containerId); + } } else { return null; } @@ -601,10 +641,9 @@ void unregisterContainer(ContainerId containerId) { synchronized(bMap) { matched = bMap.remove(containerId); } - } - // Removing here. Registration into the map has to make sure to put - if (bMap.isEmpty()) { - nodeMap.remove(llapNodeId); + if (bMap.isEmpty()) { + nodeMap.remove(llapNodeId); + } } // Remove the container mapping @@ -613,25 +652,20 @@ void unregisterContainer(ContainerId containerId) { } } - private final AtomicLong nodeNotFoundLogTime = new AtomicLong(0); - void nodePinged(String hostname, int port) { - LlapNodeId nodeId = LlapNodeId.getInstance(hostname, port); - BiMap biMap = nodeMap.get(nodeId); - if (biMap != null) { - synchronized(biMap) { - for (Map.Entry entry : biMap.entrySet()) { - getContext().taskAlive(entry.getValue()); - getContext().containerAlive(entry.getKey()); - } - } - } else { - if (System.currentTimeMillis() > nodeNotFoundLogTime.get() + 5000l) { - LOG.warn("Received ping from unknown node: " + hostname + ":" + port + - ". Could be caused by pre-emption by the AM," + - " or a mismatched hostname. Enable debug logging for mismatched host names"); - nodeNotFoundLogTime.set(System.currentTimeMillis()); - } - } + /** + * Return a {@link BiMap} containing container->taskAttemptId mapping for the host specified.

+ * + * This method return the internal structure used by the EntityTracker. Users must synchronize + * on the structure to ensure correct usage. + * + * @param hostname + * @param port + * @return + */ + BiMap getContainerAttemptMapForNode(String hostname, int port) { + LlapNodeId llapNodeId = LlapNodeId.getInstance(hostname, port); + BiMap biMap = nodeMap.get(llapNodeId); + return biMap; } } -} \ No newline at end of file +} diff --git llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java similarity index 86% rename from llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java rename to llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java index 2aef4ed..a6af8c2 100644 --- llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java +++ llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.llap.LlapNodeId; import org.junit.Test; -public class TestTaskCommunicator { +public class TestLlapDaemonProtocolClientProxy { @Test (timeout = 5000) public void testMultipleNodes() { @@ -38,8 +38,8 @@ public void testMultipleNodes() { LlapNodeId nodeId2 = LlapNodeId.getInstance("host2", 1025); Message mockMessage = mock(Message.class); - TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock( - TaskCommunicator.ExecuteRequestCallback.class); + LlapDaemonProtocolClientProxy.ExecuteRequestCallback mockExecuteRequestCallback = mock( + LlapDaemonProtocolClientProxy.ExecuteRequestCallback.class); // Request two messages requestManager.queueRequest( @@ -66,8 +66,8 @@ public void testSingleInvocationPerNode() { LlapNodeId nodeId1 = LlapNodeId.getInstance("host1", 1025); Message mockMessage = mock(Message.class); - TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock( - TaskCommunicator.ExecuteRequestCallback.class); + LlapDaemonProtocolClientProxy.ExecuteRequestCallback mockExecuteRequestCallback = mock( + LlapDaemonProtocolClientProxy.ExecuteRequestCallback.class); // First request for host. requestManager.queueRequest( @@ -101,7 +101,7 @@ public void testSingleInvocationPerNode() { } - static class RequestManagerForTest extends TaskCommunicator.RequestManager { + static class RequestManagerForTest extends LlapDaemonProtocolClientProxy.RequestManager { int numSubmissionsCounters = 0; private Map numInvocationsPerNode = new HashMap<>(); @@ -110,7 +110,7 @@ public RequestManagerForTest(int numThreads) { super(numThreads); } - protected void submitToExecutor(TaskCommunicator.CallableRequest request, LlapNodeId nodeId) { + protected void submitToExecutor(LlapDaemonProtocolClientProxy.CallableRequest request, LlapNodeId nodeId) { numSubmissionsCounters++; MutableInt nodeCount = numInvocationsPerNode.get(nodeId); if (nodeCount == null) { @@ -127,10 +127,10 @@ void reset() { } - static class CallableRequestForTest extends TaskCommunicator.CallableRequest { + static class CallableRequestForTest extends LlapDaemonProtocolClientProxy.CallableRequest { protected CallableRequestForTest(LlapNodeId nodeId, Message message, - TaskCommunicator.ExecuteRequestCallback callback) { + LlapDaemonProtocolClientProxy.ExecuteRequestCallback callback) { super(nodeId, message, callback); } diff --git llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java new file mode 100644 index 0000000..f02a3d7 --- /dev/null +++ llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.llap.tezplugins; + + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +import org.apache.hadoop.hive.llap.LlapNodeId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.tez.dag.records.TezTaskAttemptID; +import org.junit.Test; + +public class TestLlapTaskCommunicator { + + @Test (timeout = 5000) + public void testEntityTracker1() { + LlapTaskCommunicator.EntityTracker entityTracker = new LlapTaskCommunicator.EntityTracker(); + + String host1 = "host1"; + String host2 = "host2"; + String host3 = "host3"; + int port = 1451; + + + // Simple container registration and un-registration without any task attempt being involved. + ContainerId containerId101 = constructContainerId(101); + entityTracker.registerContainer(containerId101, host1, port); + assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId101)); + + entityTracker.unregisterContainer(containerId101); + assertNull(entityTracker.getContainerAttemptMapForNode(host1, port)); + assertNull(entityTracker.getNodeIdForContainer(containerId101)); + assertEquals(0, entityTracker.nodeMap.size()); + assertEquals(0, entityTracker.attemptToNodeMap.size()); + assertEquals(0, entityTracker.containerToNodeMap.size()); + + + // Simple task registration and un-registration. + ContainerId containerId1 = constructContainerId(1); + TezTaskAttemptID taskAttemptId1 = constructTaskAttemptId(1); + entityTracker.registerTaskAttempt(containerId1, taskAttemptId1, host1, port); + assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId1)); + assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForTaskAttempt(taskAttemptId1)); + + entityTracker.unregisterTaskAttempt(taskAttemptId1); + assertNull(entityTracker.getContainerAttemptMapForNode(host1, port)); + assertNull(entityTracker.getNodeIdForContainer(containerId1)); + assertNull(entityTracker.getNodeIdForTaskAttempt(taskAttemptId1)); + assertEquals(0, entityTracker.nodeMap.size()); + assertEquals(0, entityTracker.attemptToNodeMap.size()); + assertEquals(0, entityTracker.containerToNodeMap.size()); + + // Register taskAttempt, unregister container. TaskAttempt should also be unregistered + ContainerId containerId201 = constructContainerId(201); + TezTaskAttemptID taskAttemptId201 = constructTaskAttemptId(201); + entityTracker.registerTaskAttempt(containerId201, taskAttemptId201, host1, port); + assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId201)); + assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForTaskAttempt(taskAttemptId201)); + + entityTracker.unregisterContainer(containerId201); + assertNull(entityTracker.getContainerAttemptMapForNode(host1, port)); + assertNull(entityTracker.getNodeIdForContainer(containerId201)); + assertNull(entityTracker.getNodeIdForTaskAttempt(taskAttemptId201)); + assertEquals(0, entityTracker.nodeMap.size()); + assertEquals(0, entityTracker.attemptToNodeMap.size()); + assertEquals(0, entityTracker.containerToNodeMap.size()); + + entityTracker.unregisterTaskAttempt(taskAttemptId201); // No errors + } + + + private ContainerId constructContainerId(int id) { + ContainerId containerId = mock(ContainerId.class); + doReturn(id).when(containerId).getId(); + doReturn((long)id).when(containerId).getContainerId(); + return containerId; + } + + private TezTaskAttemptID constructTaskAttemptId(int id) { + TezTaskAttemptID taskAttemptId = mock(TezTaskAttemptID.class); + doReturn(id).when(taskAttemptId).getId(); + return taskAttemptId; + } + +}