commit 37cc1b0fa58e22b544f35f8c89311c952c2b7021 Author: Sahil Takiar Date: Thu Aug 2 21:58:02 2018 +0200 HIVE-20273: Spark jobs aren't cancelled if getSparkJobInfo or getSparkStagesInfo diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTask.java index 92775107bc..604fe8c90c 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTask.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTask.java @@ -160,7 +160,14 @@ public int execute(DriverContext driverContext) { // Get the final state of the Spark job and parses its job info SparkJobStatus sparkJobStatus = jobRef.getSparkJobStatus(); - getSparkJobInfo(sparkJobStatus); + + // Do not try and fetch the Spark Job info if there was an error while monitoring the + // remote Spark job, fetching the Spark job info requires invoking the monitoring + // connection, which already threw an exception, so we shouldn't retry + if (rc != 1) { + getSparkJobInfo(sparkJobStatus); + } + setSparkException(sparkJobStatus, rc); if (rc == 0) { @@ -195,6 +202,10 @@ public int execute(DriverContext driverContext) { LOG.info("The Spark job or one stage of it has too many tasks" + ". Cancelling Spark job " + sparkJobID + " with application ID " + jobID); killJob(); + } else if (rc == 5) { + LOG.info("The Spark job was cancelled. Cancelling Spark job " + sparkJobID + " with " + + "application ID " + jobID); + killJob(); } if (this.jobID == null) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/RemoteSparkJobMonitor.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/RemoteSparkJobMonitor.java index 87b69cbae4..beee71e6c9 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/RemoteSparkJobMonitor.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/RemoteSparkJobMonitor.java @@ -155,7 +155,7 @@ public int startMonitor() { console.printInfo("Spark job[" + sparkJobStatus.getJobId() + " was cancelled"); running = false; done = true; - rc = 3; + rc = 5; break; } @@ -164,8 +164,7 @@ public int startMonitor() { } } catch (Exception e) { Exception finalException = e; - if (e instanceof InterruptedException || - (e instanceof HiveException && e.getCause() instanceof InterruptedException)) { + if (e instanceof InterruptedException) { finalException = new HiveException(e, ErrorMsg.SPARK_JOB_INTERRUPTED); LOG.warn("Interrupted while monitoring the Hive on Spark application, exiting"); } else { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/RemoteSparkJobStatus.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/RemoteSparkJobStatus.java index 3d414430ab..906a3f21e1 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/RemoteSparkJobStatus.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/RemoteSparkJobStatus.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.exec.spark.status.impl; +import com.google.common.annotations.VisibleForTesting; import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities; import org.apache.hadoop.hive.ql.exec.spark.Statistic.SparkStatisticsNames; @@ -62,6 +63,7 @@ private final SparkClient sparkClient; private final JobHandle jobHandle; private Throwable error; + private boolean wasInterrupted; private final transient long sparkClientTimeoutInSeconds; public RemoteSparkJobStatus(SparkClient sparkClient, JobHandle jobHandle, long timeoutInSeconds) { @@ -76,13 +78,14 @@ public String getAppID() { Future getAppID = sparkClient.run(new GetAppIDJob()); try { return getAppID.get(sparkClientTimeoutInSeconds, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.warn("Failed to get APP ID.", e); + error = e; + wasInterrupted = true; } catch (Exception e) { LOG.warn("Failed to get APP ID.", e); - if (Thread.interrupted()) { - error = e; - } - return null; } + return null; } @Override @@ -157,13 +160,14 @@ public String getWebUIURL() { Future getWebUIURL = sparkClient.run(new GetWebUIURLJob()); try { return getWebUIURL.get(sparkClientTimeoutInSeconds, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.warn("Failed to get web UI URL.", e); + error = e; + wasInterrupted = true; } catch (Exception e) { LOG.warn("Failed to get web UI URL.", e); - if (Thread.interrupted()) { - error = e; - } - return "UNKNOWN"; } + return "UNKNOWN"; } @Override @@ -238,13 +242,14 @@ private SparkJobInfo getSparkJobInfo() throws HiveException { } public JobHandle.State getRemoteJobState() { - if (error != null) { - return JobHandle.State.FAILED; + if (wasInterrupted) { + return JobHandle.State.CANCELLED; } return jobHandle.getState(); } - private static class GetSparkStagesInfoJob implements Job> { + @VisibleForTesting + public static class GetSparkStagesInfoJob implements Job> { private final String clientJobId; private final int sparkJobId; @@ -289,7 +294,9 @@ private GetSparkStagesInfoJob() { return sparkStageInfos; } } - private static class GetJobInfoJob implements Job { + + @VisibleForTesting + public static class GetJobInfoJob implements Job { private final String clientJobId; private final int sparkJobId; @@ -351,7 +358,8 @@ public JobExecutionStatus status() { }; } - private static class GetAppIDJob implements Job { + @VisibleForTesting + public static class GetAppIDJob implements Job { public GetAppIDJob() { } @@ -362,7 +370,8 @@ public String call(JobContext jc) throws Exception { } } - private static class GetWebUIURLJob implements Job { + @VisibleForTesting + public static class GetWebUIURLJob implements Job { public GetWebUIURLJob() { } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkTask.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkTask.java index 2017fc15f9..c64372556f 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkTask.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkTask.java @@ -18,7 +18,12 @@ package org.apache.hadoop.hive.ql.exec.spark; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.isA; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -26,9 +31,14 @@ import static org.mockito.Mockito.when; import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeoutException; +import com.google.common.collect.Lists; import org.apache.hadoop.hive.common.metrics.common.Metrics; import org.apache.hadoop.hive.common.metrics.common.MetricsConstant; import org.apache.hadoop.hive.conf.HiveConf; @@ -40,16 +50,22 @@ import org.apache.hadoop.hive.ql.exec.spark.status.RemoteSparkJobMonitor; import org.apache.hadoop.hive.ql.exec.spark.status.SparkJobRef; import org.apache.hadoop.hive.ql.exec.spark.status.SparkJobStatus; +import org.apache.hadoop.hive.ql.exec.spark.status.impl.RemoteSparkJobRef; import org.apache.hadoop.hive.ql.exec.spark.status.impl.RemoteSparkJobStatus; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.BaseWork; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.SparkWork; import org.apache.hadoop.hive.ql.session.SessionState; +import org.apache.hive.spark.client.JobHandle; import org.apache.hive.spark.client.JobHandle.State; +import org.apache.hive.spark.client.SparkClient; +import org.apache.spark.JobExecutionStatus; import org.apache.spark.SparkException; +import org.apache.spark.SparkJobInfo; +import org.apache.spark.SparkStageInfo; import org.junit.Assert; import org.junit.Test; @@ -107,7 +123,7 @@ public void testRemoteSparkCancel() { HiveConf hiveConf = new HiveConf(); SessionState.start(hiveConf); RemoteSparkJobMonitor remoteSparkJobMonitor = new RemoteSparkJobMonitor(hiveConf, jobSts); - Assert.assertEquals(remoteSparkJobMonitor.startMonitor(), 3); + Assert.assertEquals(remoteSparkJobMonitor.startMonitor(), 5); } @Test @@ -191,16 +207,9 @@ public void testSparkExceptionAndMonitorError() { @Test public void testHandleInterruptedException() throws Exception { HiveConf hiveConf = new HiveConf(); - - SparkTask sparkTask = new SparkTask(); - sparkTask.setWork(mock(SparkWork.class)); - DriverContext mockDriverContext = mock(DriverContext.class); - QueryState mockQueryState = mock(QueryState.class); - when(mockQueryState.getConf()).thenReturn(hiveConf); - - sparkTask.initialize(mockQueryState, null, mockDriverContext, null); + SparkTask sparkTask = createMockSparkTask(hiveConf, mockDriverContext); SparkJobStatus mockSparkJobStatus = mock(SparkJobStatus.class); when(mockSparkJobStatus.getMonitorError()).thenReturn(new InterruptedException()); @@ -227,6 +236,180 @@ public void testHandleInterruptedException() throws Exception { verify(mockSparkJobRef, atLeastOnce()).cancelJob(); } + @Test + public void testHandleGetStateInterruptedException() throws Exception { + HiveConf hiveConf = new HiveConf(); + DriverContext mockDriverContext = mock(DriverContext.class); + + SparkTask sparkTask = createMockSparkTask(hiveConf, mockDriverContext); + + SparkClient mockSparkClient = createActiveMockSparkClient(); + + Future mockFuture = createFutureThrowInterrupt(); + + doReturn(mockFuture).when(mockSparkClient).run(isA(RemoteSparkJobStatus.GetJobInfoJob.class)); + + JobHandle mockJobHandle = createMockJobHandle(); + + RemoteSparkJobStatus remoteSparkJobStatus = new RemoteSparkJobStatus(mockSparkClient, + mockJobHandle, 0); + + setupSparkSession(hiveConf, mockJobHandle, remoteSparkJobStatus); + + sparkTask.execute(mockDriverContext); + + verify(mockJobHandle, atLeastOnce()).cancel(anyBoolean()); + } + + @Test + public void testHandleGetSparkStageProgressInterruptedException() throws Exception { + HiveConf hiveConf = new HiveConf(); + DriverContext mockDriverContext = mock(DriverContext.class); + + SparkTask sparkTask = createMockSparkTask(hiveConf, mockDriverContext); + + SparkClient mockSparkClient = createActiveMockSparkClient(); + + mockGetJobInfoJob(mockSparkClient); + + Future getSparkStagesInfoJob = createFutureThrowInterrupt(); + + doReturn(getSparkStagesInfoJob).when(mockSparkClient).run(any(RemoteSparkJobStatus.GetSparkStagesInfoJob.class)); + + JobHandle mockJobHandle = createMockJobHandle(); + + RemoteSparkJobStatus remoteSparkJobStatus = new RemoteSparkJobStatus(mockSparkClient, + mockJobHandle, 0); + + setupSparkSession(hiveConf, mockJobHandle, remoteSparkJobStatus); + + sparkTask.execute(mockDriverContext); + + verify(mockJobHandle, atLeastOnce()).cancel(anyBoolean()); + } + + @Test + public void testHandleGetAppIdInterruptedException() throws Exception { + HiveConf hiveConf = new HiveConf(); + DriverContext mockDriverContext = mock(DriverContext.class); + + SparkTask sparkTask = createMockSparkTask(hiveConf, mockDriverContext); + SparkClient mockSparkClient = createActiveMockSparkClient(); + + mockGetJobInfoJob(mockSparkClient); + mockGetSparkStagesInfoJob(mockSparkClient); + + Future mockFuture = createFutureThrowInterrupt(); + + doReturn(mockFuture).when(mockSparkClient).run(isA(RemoteSparkJobStatus.GetAppIDJob.class)); + + JobHandle mockJobHandle = createMockJobHandle(); + + RemoteSparkJobStatus remoteSparkJobStatus = new RemoteSparkJobStatus(mockSparkClient, + mockJobHandle, 0); + + setupSparkSession(hiveConf, mockJobHandle, remoteSparkJobStatus); + + sparkTask.execute(mockDriverContext); + + verify(mockJobHandle, atLeastOnce()).cancel(anyBoolean()); + } + + @Test + public void testHandleGetWebUIURLInterruptedException() throws Exception { + HiveConf hiveConf = new HiveConf(); + DriverContext mockDriverContext = mock(DriverContext.class); + + SparkTask sparkTask = createMockSparkTask(hiveConf, mockDriverContext); + SparkClient mockSparkClient = createActiveMockSparkClient(); + + mockGetJobInfoJob(mockSparkClient); + mockGetSparkStagesInfoJob(mockSparkClient); + + Future mockFuture = createFutureThrowInterrupt(); + + doReturn(mockFuture).when(mockSparkClient).run(isA(RemoteSparkJobStatus.GetWebUIURLJob.class)); + + JobHandle mockJobHandle = createMockJobHandle(); + + RemoteSparkJobStatus remoteSparkJobStatus = new RemoteSparkJobStatus(mockSparkClient, + mockJobHandle, 0); + + setupSparkSession(hiveConf, mockJobHandle, remoteSparkJobStatus); + + sparkTask.execute(mockDriverContext); + + verify(mockJobHandle, atLeastOnce()).cancel(anyBoolean()); + } + + private void mockGetSparkStagesInfoJob( + SparkClient mockSparkClient) throws InterruptedException, ExecutionException, TimeoutException { + ArrayList sparkStageInfos = new ArrayList<>(); + + Future getSparkStagesInfoJobFuture = mock(Future.class); + doReturn(sparkStageInfos).when(getSparkStagesInfoJobFuture).get(anyLong(), any()); + + doReturn(getSparkStagesInfoJobFuture).when(mockSparkClient).run(isA(RemoteSparkJobStatus + .GetSparkStagesInfoJob.class)); + } + + private void mockGetJobInfoJob( + SparkClient mockSparkClient) throws InterruptedException, ExecutionException, TimeoutException { + SparkJobInfo sparkJobInfo = mock(SparkJobInfo.class); + when(sparkJobInfo.status()).thenReturn(JobExecutionStatus.RUNNING); + + Future getJobInfoJobFuture = mock(Future.class); + doReturn(sparkJobInfo).when(getJobInfoJobFuture).get(anyLong(), any()); + + doReturn(getJobInfoJobFuture).when(mockSparkClient).run(isA(RemoteSparkJobStatus.GetJobInfoJob + .class)); + } + + private JobHandle createMockJobHandle() { + JobHandle mockJobHandle = (JobHandle) mock(JobHandle.class); + when(mockJobHandle.getState()).thenReturn(State.STARTED); + when(mockJobHandle.getSparkJobIds()).thenReturn(Lists.newArrayList(0)); + return mockJobHandle; + } + + private void setupSparkSession(HiveConf hiveConf, JobHandle mockJobHandle, + RemoteSparkJobStatus remoteSparkJobStatus) throws Exception { + SparkSession mockSparkSession = mock(SparkSession.class); + SparkJobRef sparkJobRef = new RemoteSparkJobRef(hiveConf, mockJobHandle, remoteSparkJobStatus); + + when(mockSparkSession.submit(any(), any())).thenReturn(sparkJobRef); + + SessionState.start(hiveConf); + SessionState.get().setSparkSession(mockSparkSession); + } + + private SparkTask createMockSparkTask(HiveConf hiveConf, DriverContext driverContext) { + SparkTask sparkTask = new SparkTask(); + sparkTask.setWork(mock(SparkWork.class)); + + QueryState mockQueryState = mock(QueryState.class); + when(mockQueryState.getConf()).thenReturn(hiveConf); + + sparkTask.initialize(mockQueryState, null, driverContext, null); + + return sparkTask; + } + + private SparkClient createActiveMockSparkClient() { + SparkClient mockSparkClient = mock(SparkClient.class); + when(mockSparkClient.isActive()).thenReturn(true); + + return mockSparkClient; + } + + private Future createFutureThrowInterrupt() throws InterruptedException, ExecutionException, TimeoutException { + Future mockFuture = mock(Future.class); + doThrow(InterruptedException.class).when(mockFuture).get(); + doThrow(InterruptedException.class).when(mockFuture).get(anyLong(), any()); + + return mockFuture; + } + private boolean isEmptySparkWork(SparkWork sparkWork) { List allWorks = sparkWork.getAllWork(); boolean allWorksIsEmtpy = true;