diff --git llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java
similarity index 98%
rename from llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
rename to llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java
index f9ca677..2884e40 100644
--- llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
+++ llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapDaemonProtocolClientProxy.java
@@ -16,8 +16,6 @@
import javax.net.SocketFactory;
-import java.io.ByteArrayInputStream;
-import java.io.DataInputStream;
import java.io.IOException;
import java.security.PrivilegedAction;
import java.util.HashSet;
@@ -49,7 +47,6 @@
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
import org.apache.hadoop.hive.llap.LlapNodeId;
-import org.apache.hadoop.hive.llap.configuration.LlapConfiguration;
import org.apache.hadoop.hive.llap.daemon.LlapDaemonProtocolBlockingPB;
import org.apache.hadoop.hive.llap.daemon.impl.LlapDaemonProtocolClientImpl;
import org.apache.hadoop.hive.llap.daemon.rpc.LlapDaemonProtocolProtos.QueryCompleteRequestProto;
@@ -71,9 +68,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class TaskCommunicator extends AbstractService {
+public class LlapDaemonProtocolClientProxy extends AbstractService {
- private static final Logger LOG = LoggerFactory.getLogger(TaskCommunicator.class);
+ private static final Logger LOG = LoggerFactory.getLogger(LlapDaemonProtocolClientProxy.class);
private final ConcurrentMap hostProxies;
@@ -85,9 +82,9 @@
private volatile ListenableFuture requestManagerFuture;
private final Token llapToken;
- public TaskCommunicator(
+ public LlapDaemonProtocolClientProxy(
int numThreads, Configuration conf, Token llapToken) {
- super(TaskCommunicator.class.getSimpleName());
+ super(LlapDaemonProtocolClientProxy.class.getSimpleName());
this.hostProxies = new ConcurrentHashMap<>();
this.socketFactory = NetUtils.getDefaultSocketFactory(conf);
this.llapToken = llapToken;
diff --git llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java
index ce248e9..c0c02c9 100644
--- llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java
+++ llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/LlapTaskCommunicator.java
@@ -23,8 +23,10 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
@@ -81,6 +83,9 @@
private static final Logger LOG = LoggerFactory.getLogger(LlapTaskCommunicator.class);
+ private static final boolean isInfoEnabled = LOG.isInfoEnabled();
+ private static final boolean isDebugEnabed = LOG.isDebugEnabled();
+
private final SubmitWorkRequestProto BASE_SUBMIT_WORK_REQUEST;
private final ConcurrentMap credentialMap;
@@ -90,11 +95,17 @@
private final SourceStateTracker sourceStateTracker;
private final Set nodesForQuery = new HashSet<>();
- private TaskCommunicator communicator;
+ private LlapDaemonProtocolClientProxy communicator;
private long deleteDelayOnDagComplete;
private final LlapTaskUmbilicalProtocol umbilical;
private final Token token;
+ // These two structures track the list of known nodes, and the list of nodes which are sending in keep-alive heartbeats.
+ // Primarily for debugging purposes a.t.m, since there's some unexplained TASK_TIMEOUTS which are currently being observed.
+ private final ConcurrentMap knownNodeMap = new ConcurrentHashMap<>();
+ private final ConcurrentMap pingedNodeMap = new ConcurrentHashMap<>();
+
+
private volatile String currentDagName;
public LlapTaskCommunicator(
@@ -133,7 +144,7 @@ public void initialize() throws Exception {
super.initialize();
Configuration conf = getConf();
int numThreads = HiveConf.getIntVar(conf, ConfVars.LLAP_DAEMON_COMMUNICATOR_NUM_THREADS);
- this.communicator = new TaskCommunicator(numThreads, conf, token);
+ this.communicator = new LlapDaemonProtocolClientProxy(numThreads, conf, token);
this.deleteDelayOnDagComplete = HiveConf.getTimeVar(
conf, ConfVars.LLAP_FILE_CLEANUP_DELAY_SECONDS, TimeUnit.SECONDS);
LOG.info("Running LlapTaskCommunicator with "
@@ -237,6 +248,7 @@ public void registerRunningTaskAttempt(final ContainerId containerId, final Task
}
LlapNodeId nodeId = LlapNodeId.getInstance(host, port);
+ registerKnownNode(nodeId);
entityTracker.registerTaskAttempt(containerId, taskSpec.getTaskAttemptID(), host, port);
nodesForQuery.add(nodeId);
@@ -256,7 +268,7 @@ public void registerRunningTaskAttempt(final ContainerId containerId, final Task
getContext()
.taskStartedRemotely(taskSpec.getTaskAttemptID(), containerId);
communicator.sendSubmitWork(requestProto, host, port,
- new TaskCommunicator.ExecuteRequestCallback() {
+ new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() {
@Override
public void setResponse(SubmitWorkResponseProto response) {
LOG.info("Successfully launched task: " + taskSpec.getTaskAttemptID());
@@ -330,14 +342,14 @@ private void sendTaskTerminated(final TezTaskAttemptID taskAttemptId,
LOG.info(
"DBG: Attempting to send terminateRequest for fragment {} due to internal preemption invoked by {}",
taskAttemptId.toString(), invokedByContainerEnd ? "containerEnd" : "taskEnd");
- LlapNodeId nodeId = entityTracker.getNodeIfForTaskAttempt(taskAttemptId);
+ LlapNodeId nodeId = entityTracker.getNodeIdForTaskAttempt(taskAttemptId);
// NodeId can be null if the task gets unregistered due to failure / being killed by the daemon itself
if (nodeId != null) {
TerminateFragmentRequestProto request =
TerminateFragmentRequestProto.newBuilder().setDagName(currentDagName)
.setFragmentIdentifierString(taskAttemptId.toString()).build();
communicator.sendTerminateFragment(request, nodeId.getHostname(), nodeId.getPort(),
- new TaskCommunicator.ExecuteRequestCallback() {
+ new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() {
@Override
public void setResponse(TerminateFragmentResponseProto response) {
}
@@ -362,7 +374,7 @@ public void dagComplete(final String dagName) {
for (final LlapNodeId llapNodeId : nodesForQuery) {
LOG.info("Sending dagComplete message for {}, to {}", dagName, llapNodeId);
communicator.sendQueryComplete(request, llapNodeId.getHostname(), llapNodeId.getPort(),
- new TaskCommunicator.ExecuteRequestCallback() {
+ new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() {
@Override
public void setResponse(LlapDaemonProtocolProtos.QueryCompleteResponseProto response) {
}
@@ -388,7 +400,7 @@ public void onVertexStateUpdated(VertexStateUpdate vertexStateUpdate) {
public void sendStateUpdate(final String host, final int port,
final SourceStateUpdatedRequestProto request) {
communicator.sendSourceStateUpdate(request, host, port,
- new TaskCommunicator.ExecuteRequestCallback() {
+ new LlapDaemonProtocolClientProxy.ExecuteRequestCallback() {
@Override
public void setResponse(SourceStateUpdatedResponseProto response) {
}
@@ -406,6 +418,79 @@ public void indicateError(Throwable t) {
}
+ private static class PingingNodeInfo {
+ final AtomicLong logTimestamp;
+ final AtomicInteger pingCount;
+
+ PingingNodeInfo(long currentTs) {
+ logTimestamp = new AtomicLong(currentTs);
+ pingCount = new AtomicInteger(1);
+ }
+ }
+
+ public void registerKnownNode(LlapNodeId nodeId) {
+ Long old = knownNodeMap.putIfAbsent(nodeId,
+ TimeUnit.MILLISECONDS.convert(System.nanoTime(), TimeUnit.NANOSECONDS));
+ if (old == null) {
+ if (isInfoEnabled) {
+ LOG.info("Added new known node: {}", nodeId);
+ }
+ }
+ }
+
+ public void registerPingingNode(LlapNodeId nodeId) {
+ long currentTs = TimeUnit.MILLISECONDS.convert(System.nanoTime(), TimeUnit.NANOSECONDS);
+ PingingNodeInfo ni = new PingingNodeInfo(currentTs);
+ PingingNodeInfo old = pingedNodeMap.put(nodeId, ni);
+ if (old == null) {
+ if (isInfoEnabled) {
+ LOG.info("Added new pinging node: [{}]", nodeId);
+ }
+ } else {
+ old.pingCount.incrementAndGet();
+ }
+ // The node should always be known by this point. Log occasionally if it is not known.
+ if (!knownNodeMap.containsKey(nodeId)) {
+ if (old == null) {
+ // First time this is seen. Log it.
+ LOG.warn("Received ping from unknownNode: [{}], count={}", nodeId, ni.pingCount.get());
+ } else {
+ // Pinged before. Log only occasionally.
+ if (currentTs > old.logTimestamp.get() + 5000l) { // 5 seconds elapsed. Log again.
+ LOG.warn("Received ping from unknownNode: [{}], count={}", nodeId, old.pingCount.get());
+ old.logTimestamp.set(currentTs);
+ }
+ }
+
+ }
+ }
+
+
+ private final AtomicLong nodeNotFoundLogTime = new AtomicLong(0);
+
+ void nodePinged(String hostname, int port) {
+ LlapNodeId nodeId = LlapNodeId.getInstance(hostname, port);
+ registerPingingNode(nodeId);
+ BiMap biMap =
+ entityTracker.getContainerAttemptMapForNode(nodeId);
+ if (biMap != null) {
+ synchronized (biMap) {
+ for (Map.Entry entry : biMap.entrySet()) {
+ getContext().taskAlive(entry.getValue());
+ getContext().containerAlive(entry.getKey());
+ }
+ }
+ } else {
+ long currentTs = TimeUnit.MILLISECONDS.convert(System.nanoTime(), TimeUnit.NANOSECONDS);
+ if (currentTs > nodeNotFoundLogTime.get() + 5000l) {
+ LOG.warn("Received ping from node without any registered tasks or containers: " + hostname +
+ ":" + port +
+ ". Could be caused by pre-emption by the AM," +
+ " or a mismatched hostname. Enable debug logging for mismatched host names");
+ nodeNotFoundLogTime.set(currentTs);
+ }
+ }
+ }
private void resetCurrentDag(String newDagName) {
// Working on the assumption that a single DAG runs at a time per AM.
@@ -451,6 +536,8 @@ private ByteBuffer serializeCredentials(Credentials credentials) throws IOExcept
return ByteBuffer.wrap(containerTokens_dob.getData(), 0, containerTokens_dob.getLength());
}
+
+
protected class LlapTaskUmbilicalProtocolImpl implements LlapTaskUmbilicalProtocol {
private final TezTaskUmbilicalProtocol tezUmbilical;
@@ -472,7 +559,7 @@ public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOExce
@Override
public void nodeHeartbeat(Text hostname, int port) throws IOException {
- entityTracker.nodePinged(hostname.toString(), port);
+ nodePinged(hostname.toString(), port);
if (LOG.isDebugEnabled()) {
LOG.debug("Received heartbeat from [" + hostname + ":" + port +"]");
}
@@ -499,10 +586,17 @@ public ProtocolSignature getProtocolSignature(String protocol, long clientVersio
}
}
- private final class EntityTracker {
- private final ConcurrentMap attemptToNodeMap = new ConcurrentHashMap<>();
- private final ConcurrentMap containerToNodeMap = new ConcurrentHashMap<>();
- private final ConcurrentMap> nodeMap = new ConcurrentHashMap<>();
+ /**
+ * Track the association between known containers and taskAttempts, along with the nodes they are assigned to.
+ */
+ @VisibleForTesting
+ static final class EntityTracker {
+ @VisibleForTesting
+ final ConcurrentMap attemptToNodeMap = new ConcurrentHashMap<>();
+ @VisibleForTesting
+ final ConcurrentMap containerToNodeMap = new ConcurrentHashMap<>();
+ @VisibleForTesting
+ final ConcurrentMap> nodeMap = new ConcurrentHashMap<>();
void registerTaskAttempt(ContainerId containerId, TezTaskAttemptID taskAttemptId, String host, int port) {
if (LOG.isDebugEnabled()) {
@@ -510,6 +604,10 @@ void registerTaskAttempt(ContainerId containerId, TezTaskAttemptID taskAttemptId
}
LlapNodeId llapNodeId = LlapNodeId.getInstance(host, port);
attemptToNodeMap.putIfAbsent(taskAttemptId, llapNodeId);
+
+ registerContainer(containerId, host, port);
+
+ // nodeMap registration.
BiMap tmpMap = HashBiMap.create();
BiMap old = nodeMap.putIfAbsent(llapNodeId, tmpMap);
BiMap usedInstance;
@@ -535,10 +633,9 @@ void unregisterTaskAttempt(TezTaskAttemptID attemptId) {
synchronized(bMap) {
matched = bMap.inverse().remove(attemptId);
}
- }
- // Removing here. Registration into the map has to make sure to put
- if (bMap.isEmpty()) {
- nodeMap.remove(llapNodeId);
+ if (bMap.isEmpty()) {
+ nodeMap.remove(llapNodeId);
+ }
}
// Remove the container mapping
@@ -549,23 +646,29 @@ void unregisterTaskAttempt(TezTaskAttemptID attemptId) {
}
void registerContainer(ContainerId containerId, String hostname, int port) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Registering " + containerId + " for node: " + hostname + ":" + port);
+ }
containerToNodeMap.putIfAbsent(containerId, LlapNodeId.getInstance(hostname, port));
+ // nodeMap registration is not required, since there's no taskId association.
}
LlapNodeId getNodeIdForContainer(ContainerId containerId) {
return containerToNodeMap.get(containerId);
}
- LlapNodeId getNodeIfForTaskAttempt(TezTaskAttemptID taskAttemptId) {
+ LlapNodeId getNodeIdForTaskAttempt(TezTaskAttemptID taskAttemptId) {
return attemptToNodeMap.get(taskAttemptId);
}
ContainerId getContainerIdForAttempt(TezTaskAttemptID taskAttemptId) {
- LlapNodeId llapNodeId = getNodeIfForTaskAttempt(taskAttemptId);
+ LlapNodeId llapNodeId = getNodeIdForTaskAttempt(taskAttemptId);
if (llapNodeId != null) {
BiMap bMap = nodeMap.get(llapNodeId).inverse();
if (bMap != null) {
- return bMap.get(taskAttemptId);
+ synchronized (bMap) {
+ return bMap.get(taskAttemptId);
+ }
} else {
return null;
}
@@ -579,7 +682,9 @@ TezTaskAttemptID getTaskAttemptIdForContainer(ContainerId containerId) {
if (llapNodeId != null) {
BiMap bMap = nodeMap.get(llapNodeId);
if (bMap != null) {
- return bMap.get(containerId);
+ synchronized (bMap) {
+ return bMap.get(containerId);
+ }
} else {
return null;
}
@@ -601,10 +706,9 @@ void unregisterContainer(ContainerId containerId) {
synchronized(bMap) {
matched = bMap.remove(containerId);
}
- }
- // Removing here. Registration into the map has to make sure to put
- if (bMap.isEmpty()) {
- nodeMap.remove(llapNodeId);
+ if (bMap.isEmpty()) {
+ nodeMap.remove(llapNodeId);
+ }
}
// Remove the container mapping
@@ -613,25 +717,20 @@ void unregisterContainer(ContainerId containerId) {
}
}
- private final AtomicLong nodeNotFoundLogTime = new AtomicLong(0);
- void nodePinged(String hostname, int port) {
- LlapNodeId nodeId = LlapNodeId.getInstance(hostname, port);
- BiMap biMap = nodeMap.get(nodeId);
- if (biMap != null) {
- synchronized(biMap) {
- for (Map.Entry entry : biMap.entrySet()) {
- getContext().taskAlive(entry.getValue());
- getContext().containerAlive(entry.getKey());
- }
- }
- } else {
- if (System.currentTimeMillis() > nodeNotFoundLogTime.get() + 5000l) {
- LOG.warn("Received ping from unknown node: " + hostname + ":" + port +
- ". Could be caused by pre-emption by the AM," +
- " or a mismatched hostname. Enable debug logging for mismatched host names");
- nodeNotFoundLogTime.set(System.currentTimeMillis());
- }
- }
+ /**
+ * Return a {@link BiMap} containing container->taskAttemptId mapping for the host specified.
+ *
+ *
+ * This method return the internal structure used by the EntityTracker. Users must synchronize
+ * on the structure to ensure correct usage.
+ *
+ * @param llapNodeId
+ * @return
+ */
+ BiMap getContainerAttemptMapForNode(LlapNodeId llapNodeId) {
+ BiMap biMap = nodeMap.get(llapNodeId);
+ return biMap;
}
+
}
-}
\ No newline at end of file
+}
diff --git llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java
similarity index 86%
rename from llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
rename to llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java
index 2aef4ed..a6af8c2 100644
--- llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
+++ llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapDaemonProtocolClientProxy.java
@@ -28,7 +28,7 @@
import org.apache.hadoop.hive.llap.LlapNodeId;
import org.junit.Test;
-public class TestTaskCommunicator {
+public class TestLlapDaemonProtocolClientProxy {
@Test (timeout = 5000)
public void testMultipleNodes() {
@@ -38,8 +38,8 @@ public void testMultipleNodes() {
LlapNodeId nodeId2 = LlapNodeId.getInstance("host2", 1025);
Message mockMessage = mock(Message.class);
- TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock(
- TaskCommunicator.ExecuteRequestCallback.class);
+ LlapDaemonProtocolClientProxy.ExecuteRequestCallback mockExecuteRequestCallback = mock(
+ LlapDaemonProtocolClientProxy.ExecuteRequestCallback.class);
// Request two messages
requestManager.queueRequest(
@@ -66,8 +66,8 @@ public void testSingleInvocationPerNode() {
LlapNodeId nodeId1 = LlapNodeId.getInstance("host1", 1025);
Message mockMessage = mock(Message.class);
- TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock(
- TaskCommunicator.ExecuteRequestCallback.class);
+ LlapDaemonProtocolClientProxy.ExecuteRequestCallback mockExecuteRequestCallback = mock(
+ LlapDaemonProtocolClientProxy.ExecuteRequestCallback.class);
// First request for host.
requestManager.queueRequest(
@@ -101,7 +101,7 @@ public void testSingleInvocationPerNode() {
}
- static class RequestManagerForTest extends TaskCommunicator.RequestManager {
+ static class RequestManagerForTest extends LlapDaemonProtocolClientProxy.RequestManager {
int numSubmissionsCounters = 0;
private Map numInvocationsPerNode = new HashMap<>();
@@ -110,7 +110,7 @@ public RequestManagerForTest(int numThreads) {
super(numThreads);
}
- protected void submitToExecutor(TaskCommunicator.CallableRequest request, LlapNodeId nodeId) {
+ protected void submitToExecutor(LlapDaemonProtocolClientProxy.CallableRequest request, LlapNodeId nodeId) {
numSubmissionsCounters++;
MutableInt nodeCount = numInvocationsPerNode.get(nodeId);
if (nodeCount == null) {
@@ -127,10 +127,10 @@ void reset() {
}
- static class CallableRequestForTest extends TaskCommunicator.CallableRequest {
+ static class CallableRequestForTest extends LlapDaemonProtocolClientProxy.CallableRequest {
protected CallableRequestForTest(LlapNodeId nodeId, Message message,
- TaskCommunicator.ExecuteRequestCallback callback) {
+ LlapDaemonProtocolClientProxy.ExecuteRequestCallback callback) {
super(nodeId, message, callback);
}
diff --git llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java
new file mode 100644
index 0000000..f02a3d7
--- /dev/null
+++ llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestLlapTaskCommunicator.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.llap.tezplugins;
+
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+
+import org.apache.hadoop.hive.llap.LlapNodeId;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.junit.Test;
+
+public class TestLlapTaskCommunicator {
+
+ @Test (timeout = 5000)
+ public void testEntityTracker1() {
+ LlapTaskCommunicator.EntityTracker entityTracker = new LlapTaskCommunicator.EntityTracker();
+
+ String host1 = "host1";
+ String host2 = "host2";
+ String host3 = "host3";
+ int port = 1451;
+
+
+ // Simple container registration and un-registration without any task attempt being involved.
+ ContainerId containerId101 = constructContainerId(101);
+ entityTracker.registerContainer(containerId101, host1, port);
+ assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId101));
+
+ entityTracker.unregisterContainer(containerId101);
+ assertNull(entityTracker.getContainerAttemptMapForNode(host1, port));
+ assertNull(entityTracker.getNodeIdForContainer(containerId101));
+ assertEquals(0, entityTracker.nodeMap.size());
+ assertEquals(0, entityTracker.attemptToNodeMap.size());
+ assertEquals(0, entityTracker.containerToNodeMap.size());
+
+
+ // Simple task registration and un-registration.
+ ContainerId containerId1 = constructContainerId(1);
+ TezTaskAttemptID taskAttemptId1 = constructTaskAttemptId(1);
+ entityTracker.registerTaskAttempt(containerId1, taskAttemptId1, host1, port);
+ assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId1));
+ assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForTaskAttempt(taskAttemptId1));
+
+ entityTracker.unregisterTaskAttempt(taskAttemptId1);
+ assertNull(entityTracker.getContainerAttemptMapForNode(host1, port));
+ assertNull(entityTracker.getNodeIdForContainer(containerId1));
+ assertNull(entityTracker.getNodeIdForTaskAttempt(taskAttemptId1));
+ assertEquals(0, entityTracker.nodeMap.size());
+ assertEquals(0, entityTracker.attemptToNodeMap.size());
+ assertEquals(0, entityTracker.containerToNodeMap.size());
+
+ // Register taskAttempt, unregister container. TaskAttempt should also be unregistered
+ ContainerId containerId201 = constructContainerId(201);
+ TezTaskAttemptID taskAttemptId201 = constructTaskAttemptId(201);
+ entityTracker.registerTaskAttempt(containerId201, taskAttemptId201, host1, port);
+ assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForContainer(containerId201));
+ assertEquals(LlapNodeId.getInstance(host1, port), entityTracker.getNodeIdForTaskAttempt(taskAttemptId201));
+
+ entityTracker.unregisterContainer(containerId201);
+ assertNull(entityTracker.getContainerAttemptMapForNode(host1, port));
+ assertNull(entityTracker.getNodeIdForContainer(containerId201));
+ assertNull(entityTracker.getNodeIdForTaskAttempt(taskAttemptId201));
+ assertEquals(0, entityTracker.nodeMap.size());
+ assertEquals(0, entityTracker.attemptToNodeMap.size());
+ assertEquals(0, entityTracker.containerToNodeMap.size());
+
+ entityTracker.unregisterTaskAttempt(taskAttemptId201); // No errors
+ }
+
+
+ private ContainerId constructContainerId(int id) {
+ ContainerId containerId = mock(ContainerId.class);
+ doReturn(id).when(containerId).getId();
+ doReturn((long)id).when(containerId).getContainerId();
+ return containerId;
+ }
+
+ private TezTaskAttemptID constructTaskAttemptId(int id) {
+ TezTaskAttemptID taskAttemptId = mock(TezTaskAttemptID.class);
+ doReturn(id).when(taskAttemptId).getId();
+ return taskAttemptId;
+ }
+
+}