commit ba04a2a5a17b9367180991aae83c72ecd83b69e8 Author: Sahil Takiar Date: Thu Jan 25 17:36:30 2018 -0800 HIVE-18533: Add option to use InProcessLauncher to submit spark jobs diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index d0eb2a4801..1c85887a80 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -4198,6 +4198,11 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "If a Spark job contains more tasks than the maximum, it will be cancelled. A value of -1 means no limit."), SPARK_STAGE_MAX_TASKS("hive.spark.stage.max.tasks", -1, "The maximum number of tasks a stage in a Spark job may have.\n" + "If a Spark job stage contains more tasks than the maximum, the job will be cancelled. A value of -1 means no limit."), + SPARK_CLIENT_TYPE("hive.spark.client.type", SparkClientType.SPARK_SUBMIT_CLIENT.toString(), + "Controls how the Spark application is launched. If " + SparkClientType.SPARK_SUBMIT_CLIENT + + " is specified (default) then the spark-submit shell script is used to launch the Spark " + + "app. If " + SparkClientType.SPARK_LAUNCHER_CLIENT + " is specified then Spark's " + + "SparkLauncher is used to programmatically launch the app."), NWAYJOINREORDER("hive.reorder.nway.joins", true, "Runs reordering of tables within single n-way join (i.e.: picks streamtable)"), HIVE_MERGE_NWAY_JOINS("hive.merge.nway.joins", true, @@ -5754,4 +5759,12 @@ public void verifyAndSetAll(Map overlay) { return ret; } + /** + * The type of launcher to use when submitting the HoS application. Used in + * {@link ConfVars#SPARK_CLIENT_TYPE}. + */ + public enum SparkClientType { + SPARK_SUBMIT_CLIENT, + SPARK_LAUNCHER_CLIENT + } } diff --git a/itests/qtest-spark/pom.xml b/itests/qtest-spark/pom.xml index c55044622c..88ada891d9 100644 --- a/itests/qtest-spark/pom.xml +++ b/itests/qtest-spark/pom.xml @@ -63,6 +63,12 @@ + + org.apache.spark + spark-yarn_${scala.binary.version} + ${spark.version} + test + org.eclipse.jetty jetty-util diff --git a/itests/src/test/resources/testconfiguration.properties b/itests/src/test/resources/testconfiguration.properties index 13c08de3c5..d0423801b5 100644 --- a/itests/src/test/resources/testconfiguration.properties +++ b/itests/src/test/resources/testconfiguration.properties @@ -1593,7 +1593,8 @@ miniSparkOnYarn.only.query.files=spark_combine_equivalent_work.q,\ spark_use_ts_stats_for_mapjoin.q,\ spark_use_op_stats.q,\ spark_explain_groupbyshuffle.q,\ - spark_opt_shuffle_serde.q + spark_opt_shuffle_serde.q,\ + spark_in_process_launcher.q miniSparkOnYarn.query.files=auto_sortmerge_join_16.q,\ bucket4.q,\ diff --git a/ql/src/test/queries/clientpositive/spark_in_process_launcher.q b/ql/src/test/queries/clientpositive/spark_in_process_launcher.q new file mode 100644 index 0000000000..4455b45ddc --- /dev/null +++ b/ql/src/test/queries/clientpositive/spark_in_process_launcher.q @@ -0,0 +1,10 @@ +--! qt:dataset:src + +set hive.spark.client.type=SPARK_LAUNCHER_CLIENT; + +-- Hack to restart the HoS session +set hive.execution.engine=mr; +set hive.execution.engine=spark; + +explain select key, count(*) from src group by key order by key limit 100; +select key, count(*) from src group by key order by key limit 100; diff --git a/ql/src/test/results/clientpositive/spark/spark_in_process_launcher.q.out b/ql/src/test/results/clientpositive/spark/spark_in_process_launcher.q.out new file mode 100644 index 0000000000..84afffb1f8 --- /dev/null +++ b/ql/src/test/results/clientpositive/spark/spark_in_process_launcher.q.out @@ -0,0 +1,186 @@ +PREHOOK: query: explain select key, count(*) from src group by key order by key limit 100 +PREHOOK: type: QUERY +POSTHOOK: query: explain select key, count(*) from src group by key order by key limit 100 +POSTHOOK: type: QUERY +STAGE DEPENDENCIES: + Stage-1 is a root stage + Stage-0 depends on stages: Stage-1 + +STAGE PLANS: + Stage: Stage-1 + Spark + Edges: + Reducer 2 <- Map 1 (GROUP, 2) + Reducer 3 <- Reducer 2 (SORT, 1) +#### A masked pattern was here #### + Vertices: + Map 1 + Map Operator Tree: + TableScan + alias: src + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Select Operator + expressions: key (type: string) + outputColumnNames: key + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Group By Operator + aggregations: count() + keys: key (type: string) + mode: hash + outputColumnNames: _col0, _col1 + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: _col0 (type: string) + sort order: + + Map-reduce partition columns: _col0 (type: string) + Statistics: Num rows: 500 Data size: 5312 Basic stats: COMPLETE Column stats: NONE + TopN Hash Memory Usage: 0.1 + value expressions: _col1 (type: bigint) + Execution mode: vectorized + Reducer 2 + Execution mode: vectorized + Reduce Operator Tree: + Group By Operator + aggregations: count(VALUE._col0) + keys: KEY._col0 (type: string) + mode: mergepartial + outputColumnNames: _col0, _col1 + Statistics: Num rows: 250 Data size: 2656 Basic stats: COMPLETE Column stats: NONE + Reduce Output Operator + key expressions: _col0 (type: string) + sort order: + + Statistics: Num rows: 250 Data size: 2656 Basic stats: COMPLETE Column stats: NONE + TopN Hash Memory Usage: 0.1 + value expressions: _col1 (type: bigint) + Reducer 3 + Execution mode: vectorized + Reduce Operator Tree: + Select Operator + expressions: KEY.reducesinkkey0 (type: string), VALUE._col0 (type: bigint) + outputColumnNames: _col0, _col1 + Statistics: Num rows: 250 Data size: 2656 Basic stats: COMPLETE Column stats: NONE + Limit + Number of rows: 100 + Statistics: Num rows: 100 Data size: 1000 Basic stats: COMPLETE Column stats: NONE + File Output Operator + compressed: false + Statistics: Num rows: 100 Data size: 1000 Basic stats: COMPLETE Column stats: NONE + table: + input format: org.apache.hadoop.mapred.SequenceFileInputFormat + output format: org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat + serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe + + Stage: Stage-0 + Fetch Operator + limit: 100 + Processor Tree: + ListSink + +PREHOOK: query: select key, count(*) from src group by key order by key limit 100 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +PREHOOK: Output: hdfs://### HDFS PATH ### +POSTHOOK: query: select key, count(*) from src group by key order by key limit 100 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +POSTHOOK: Output: hdfs://### HDFS PATH ### +0 3 +10 1 +100 2 +103 2 +104 2 +105 1 +11 1 +111 1 +113 2 +114 1 +116 1 +118 2 +119 3 +12 2 +120 2 +125 2 +126 1 +128 3 +129 2 +131 1 +133 1 +134 2 +136 1 +137 2 +138 4 +143 1 +145 1 +146 2 +149 2 +15 2 +150 1 +152 2 +153 1 +155 1 +156 1 +157 1 +158 1 +160 1 +162 1 +163 1 +164 2 +165 2 +166 1 +167 3 +168 1 +169 4 +17 1 +170 1 +172 2 +174 2 +175 2 +176 2 +177 1 +178 1 +179 2 +18 2 +180 1 +181 1 +183 1 +186 1 +187 3 +189 1 +19 1 +190 1 +191 2 +192 1 +193 3 +194 1 +195 2 +196 1 +197 2 +199 3 +2 1 +20 1 +200 2 +201 1 +202 1 +203 2 +205 2 +207 2 +208 3 +209 2 +213 2 +214 1 +216 2 +217 2 +218 1 +219 2 +221 2 +222 1 +223 2 +224 2 +226 1 +228 1 +229 2 +230 5 +233 2 +235 1 +237 2 +238 2 diff --git a/spark-client/pom.xml b/spark-client/pom.xml index 00305d125f..49f7ed94ac 100644 --- a/spark-client/pom.xml +++ b/spark-client/pom.xml @@ -99,6 +99,12 @@ + + org.apache.spark + spark-yarn_${scala.binary.version} + ${spark.version} + runtime + junit junit diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java b/spark-client/src/main/java/org/apache/hive/spark/client/AbstractSparkClient.java similarity index 69% rename from spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java rename to spark-client/src/main/java/org/apache/hive/spark/client/AbstractSparkClient.java index d450515359..a67d42b8ee 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/AbstractSparkClient.java @@ -25,7 +25,6 @@ import com.google.common.base.Strings; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.io.Resources; @@ -42,19 +41,17 @@ import java.io.Writer; import java.net.URI; import java.net.URL; -import java.util.ArrayList; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.UUID; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.apache.commons.lang3.StringUtils; -import org.apache.hadoop.hive.common.log.LogRedirector; -import org.apache.hadoop.hive.conf.Constants; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.shims.Utils; @@ -66,37 +63,49 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class SparkClientImpl implements SparkClient { +/** + * An abstract implementation of {@link SparkClient} that allows sub-classes to override how the + * spark application is launched. It provides the following functionality: (1) creating the client + * connection to the {@link RemoteDriver} and managing its lifecycle, (2) monitoring the thread + * used to submit the Spark application, (3) safe shutdown of the {@link RemoteDriver}, and (4) + * configuration handling for submitting the Spark application. + * + *

+ * This class contains the client protocol used to communicate with the {@link RemoteDriver}. + * It uses this protocol to submit {@link Job}s to the {@link RemoteDriver}. + *

+ */ +abstract class AbstractSparkClient implements SparkClient { + private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(SparkClientImpl.class); + private static final Logger LOG = LoggerFactory.getLogger(AbstractSparkClient.class); private static final long DEFAULT_SHUTDOWN_TIMEOUT = 10000; // In milliseconds private static final String OSX_TEST_OPTS = "SPARK_OSX_TEST_OPTS"; - private static final String SPARK_HOME_ENV = "SPARK_HOME"; - private static final String SPARK_HOME_KEY = "spark.home"; private static final String DRIVER_OPTS_KEY = "spark.driver.extraJavaOptions"; private static final String EXECUTOR_OPTS_KEY = "spark.executor.extraJavaOptions"; private static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath"; private static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath"; + private static final String SPARK_DEPLOY_MODE = "spark.submit.deployMode"; - private final Map conf; + protected final Map conf; private final HiveConf hiveConf; - private final Thread driverThread; + private final Future driverFuture; private final Map> jobs; private final Rpc driverRpc; private final ClientProtocol protocol; - private volatile boolean isAlive; + protected volatile boolean isAlive; - SparkClientImpl(RpcServer rpcServer, Map conf, HiveConf hiveConf, + protected AbstractSparkClient(RpcServer rpcServer, Map conf, HiveConf hiveConf, String sessionid) throws IOException { this.conf = conf; this.hiveConf = hiveConf; this.jobs = Maps.newConcurrentMap(); String secret = rpcServer.createSecret(); - this.driverThread = startDriver(rpcServer, sessionid, secret); + this.driverFuture = startDriver(rpcServer, sessionid, secret); this.protocol = new ClientProtocol(); try { @@ -116,12 +125,14 @@ errorMsg = "Error while waiting for client to connect."; } LOG.error(errorMsg, e); - driverThread.interrupt(); + driverFuture.cancel(true); try { - driverThread.join(); + driverFuture.get(); } catch (InterruptedException ie) { // Give up. LOG.warn("Interrupted before driver thread was finished.", ie); + } catch (ExecutionException ee) { + LOG.error("Driver thread failed", ee); } throw Throwables.propagate(e); } @@ -166,15 +177,16 @@ public void stop() { } } - long endTime = System.currentTimeMillis() + DEFAULT_SHUTDOWN_TIMEOUT; try { - driverThread.join(DEFAULT_SHUTDOWN_TIMEOUT); + driverFuture.get(DEFAULT_SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + LOG.error("Exception while waiting for driver future to complete", e); + } catch (TimeoutException e) { + LOG.warn("Timed out shutting down remote driver, cancelling..."); + driverFuture.cancel(true); } catch (InterruptedException ie) { LOG.debug("Interrupted before driver thread was finished."); - } - if (endTime - System.currentTimeMillis() <= 0) { - LOG.warn("Timed out shutting down remote driver, interrupting..."); - driverThread.interrupt(); + driverFuture.cancel(true); } } @@ -203,26 +215,18 @@ public boolean isActive() { return isAlive && driverRpc.isActive(); } - void cancel(String jobId) { + @Override + public void cancel(String jobId) { protocol.cancel(jobId); } - private Thread startDriver(final RpcServer rpcServer, final String clientId, final String secret) - throws IOException { - Runnable runnable; + private Future startDriver(final RpcServer rpcServer, final String clientId, + final String secret) throws IOException { final String serverAddress = rpcServer.getAddress(); final String serverPort = String.valueOf(rpcServer.getPort()); - // If a Spark installation is provided, use the spark-submit script. Otherwise, call the - // SparkSubmit class directly, which has some caveats (like having to provide a proper - // version of Guava on the classpath depending on the deploy mode). - String sparkHome = Strings.emptyToNull(conf.get(SPARK_HOME_KEY)); - if (sparkHome == null) { - sparkHome = Strings.emptyToNull(System.getenv(SPARK_HOME_ENV)); - } - if (sparkHome == null) { - sparkHome = Strings.emptyToNull(System.getProperty(SPARK_HOME_KEY)); - } + String sparkHome = getSparkHome(); + String sparkLogDir = conf.get("hive.spark.log.dir"); if (sparkLogDir == null) { if (sparkHome == null) { @@ -306,66 +310,22 @@ private Thread startDriver(final RpcServer rpcServer, final String clientId, fin // SparkSubmit will take care of that for us. String master = conf.get("spark.master"); Preconditions.checkArgument(master != null, "spark.master is not defined."); - String deployMode = conf.get("spark.submit.deployMode"); - - List argv = Lists.newLinkedList(); - - if (sparkHome != null) { - argv.add(new File(sparkHome, "bin/spark-submit").getAbsolutePath()); - } else { - LOG.info("No spark.home provided, calling SparkSubmit directly."); - argv.add(new File(System.getProperty("java.home"), "bin/java").getAbsolutePath()); - - if (master.startsWith("local") || master.startsWith("mesos") || - SparkClientUtilities.isYarnClientMode(master, deployMode) || - master.startsWith("spark")) { - String mem = conf.get("spark.driver.memory"); - if (mem != null) { - argv.add("-Xms" + mem); - argv.add("-Xmx" + mem); - } - - String cp = conf.get("spark.driver.extraClassPath"); - if (cp != null) { - argv.add("-classpath"); - argv.add(cp); - } - - String libPath = conf.get("spark.driver.extraLibPath"); - if (libPath != null) { - argv.add("-Djava.library.path=" + libPath); - } - - String extra = conf.get(DRIVER_OPTS_KEY); - if (extra != null) { - for (String opt : extra.split("[ ]")) { - if (!opt.trim().isEmpty()) { - argv.add(opt.trim()); - } - } - } - } - - argv.add("org.apache.spark.deploy.SparkSubmit"); - } + String deployMode = conf.get(SPARK_DEPLOY_MODE); if (SparkClientUtilities.isYarnClusterMode(master, deployMode)) { String executorCores = conf.get("spark.executor.cores"); if (executorCores != null) { - argv.add("--executor-cores"); - argv.add(executorCores); + addExecutorCores(executorCores); } String executorMemory = conf.get("spark.executor.memory"); if (executorMemory != null) { - argv.add("--executor-memory"); - argv.add(executorMemory); + addExecutorMemory(executorMemory); } String numOfExecutors = conf.get("spark.executor.instances"); if (numOfExecutors != null) { - argv.add("--num-executors"); - argv.add(numOfExecutors); + addNumExecutors(numOfExecutors); } } // The options --principal/--keypad do not work with --proxy-user in spark-submit.sh @@ -378,24 +338,9 @@ private Thread startDriver(final RpcServer rpcServer, final String clientId, fin String principal = SecurityUtil.getServerPrincipal(hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL), "0.0.0.0"); String keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); + boolean isDoAsEnabled = hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS); if (StringUtils.isNotBlank(principal) && StringUtils.isNotBlank(keyTabFile)) { - if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) { - List kinitArgv = Lists.newLinkedList(); - kinitArgv.add("kinit"); - kinitArgv.add(principal); - kinitArgv.add("-k"); - kinitArgv.add("-t"); - kinitArgv.add(keyTabFile + ";"); - kinitArgv.addAll(argv); - argv = kinitArgv; - } else { - // if doAs is not enabled, we pass the principal/keypad to spark-submit in order to - // support the possible delegation token renewal in Spark - argv.add("--principal"); - argv.add(principal); - argv.add("--keytab"); - argv.add(keyTabFile); - } + addKeytabAndPrincipal(isDoAsEnabled, keyTabFile, principal); } } if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) { @@ -404,8 +349,7 @@ private Thread startDriver(final RpcServer rpcServer, final String clientId, fin // do not do impersonation in CLI mode if (!currentUser.equals(System.getProperty("user.name"))) { LOG.info("Attempting impersonation of " + currentUser); - argv.add("--proxy-user"); - argv.add(currentUser); + addProxyUser(currentUser); } } catch (Exception e) { String msg = "Cannot obtain username: " + e; @@ -415,109 +359,61 @@ private Thread startDriver(final RpcServer rpcServer, final String clientId, fin String regStr = conf.get("spark.kryo.registrator"); if (HIVE_KRYO_REG_NAME.equals(regStr)) { - argv.add("--jars"); - argv.add(SparkClientUtilities.findKryoRegistratorJar(hiveConf)); + addJars(SparkClientUtilities.findKryoRegistratorJar(hiveConf)); } - argv.add("--properties-file"); - argv.add(properties.getAbsolutePath()); - argv.add("--class"); - argv.add(RemoteDriver.class.getName()); + addPropertiesFile(properties.getAbsolutePath()); + addClass(RemoteDriver.class.getName()); String jar = "spark-internal"; if (SparkContext.jarOfClass(this.getClass()).isDefined()) { jar = SparkContext.jarOfClass(this.getClass()).get(); } - argv.add(jar); + addExecutableJar(jar); + - argv.add(RemoteDriver.REMOTE_DRIVER_HOST_CONF); - argv.add(serverAddress); - argv.add(RemoteDriver.REMOTE_DRIVER_PORT_CONF); - argv.add(serverPort); + addAppArg(RemoteDriver.REMOTE_DRIVER_HOST_CONF); + addAppArg(serverAddress); + addAppArg(RemoteDriver.REMOTE_DRIVER_PORT_CONF); + addAppArg(serverPort); //hive.spark.* keys are passed down to the RemoteDriver via REMOTE_DRIVER_CONF // so that they are not used in sparkContext but only in remote driver, //as --properties-file contains the spark.* keys that are meant for SparkConf object. for (String hiveSparkConfKey : RpcConfiguration.HIVE_SPARK_RSC_CONFIGS) { String value = RpcConfiguration.getValue(hiveConf, hiveSparkConfKey); - argv.add(RemoteDriver.REMOTE_DRIVER_CONF); - argv.add(String.format("%s=%s", hiveSparkConfKey, value)); + addAppArg(RemoteDriver.REMOTE_DRIVER_CONF); + addAppArg(String.format("%s=%s", hiveSparkConfKey, value)); } - String cmd = Joiner.on(" ").join(argv); - LOG.info("Running client driver with argv: {}", cmd); - ProcessBuilder pb = new ProcessBuilder("sh", "-c", cmd); + return launchDriver(isTesting, rpcServer, clientId); + } - // Prevent hive configurations from being visible in Spark. - pb.environment().remove("HIVE_HOME"); - pb.environment().remove("HIVE_CONF_DIR"); - // Add credential provider password to the child process's environment - // In case of Spark the credential provider location is provided in the jobConf when the job is submitted - String password = getSparkJobCredentialProviderPassword(); - if(password != null) { - pb.environment().put(Constants.HADOOP_CREDENTIAL_PASSWORD_ENVVAR, password); - } - if (isTesting != null) { - pb.environment().put("SPARK_TESTING", isTesting); - } + protected abstract Future launchDriver(String isTesting, RpcServer rpcServer, String + clientId) throws IOException; - final Process child = pb.start(); - String threadName = Thread.currentThread().getName(); - final List childErrorLog = Collections.synchronizedList(new ArrayList()); - final LogRedirector.LogSourceCallback callback = () -> {return isAlive;}; + protected abstract String getSparkHome(); - LogRedirector.redirect("RemoteDriver-stdout-redir-" + threadName, - new LogRedirector(child.getInputStream(), LOG, callback)); - LogRedirector.redirect("RemoteDriver-stderr-redir-" + threadName, - new LogRedirector(child.getErrorStream(), LOG, childErrorLog, callback)); + protected abstract void addAppArg(String arg); - runnable = new Runnable() { - @Override - public void run() { - try { - int exitCode = child.waitFor(); - if (exitCode != 0) { - StringBuilder errStr = new StringBuilder(); - synchronized(childErrorLog) { - Iterator iter = childErrorLog.iterator(); - while(iter.hasNext()){ - errStr.append(iter.next()); - errStr.append('\n'); - } - } - - LOG.warn("Child process exited with code {}", exitCode); - rpcServer.cancelClient(clientId, - "Child process (spark-submit) exited before connecting back with error log " + errStr.toString()); - } - } catch (InterruptedException ie) { - LOG.warn("Thread waiting on the child process (spark-submit) is interrupted, killing the child process."); - rpcServer.cancelClient(clientId, "Thread waiting on the child porcess (spark-submit) is interrupted"); - Thread.interrupted(); - child.destroy(); - } catch (Exception e) { - String errMsg = "Exception while waiting for child process (spark-submit)"; - LOG.warn(errMsg, e); - rpcServer.cancelClient(clientId, errMsg); - } - } - }; + protected abstract void addExecutableJar(String jar); - Thread thread = new Thread(runnable); - thread.setDaemon(true); - thread.setName("Driver"); - thread.start(); - return thread; - } + protected abstract void addPropertiesFile(String absolutePath); - private String getSparkJobCredentialProviderPassword() { - if (conf.containsKey("spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD")) { - return conf.get("spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD"); - } else if (conf.containsKey("spark.executorEnv.HADOOP_CREDSTORE_PASSWORD")) { - return conf.get("spark.executorEnv.HADOOP_CREDSTORE_PASSWORD"); - } - return null; - } + protected abstract void addClass(String name); + + protected abstract void addJars(String jars); + + protected abstract void addProxyUser(String proxyUser); + + protected abstract void addKeytabAndPrincipal(boolean isDoAsEnabled, String keyTabFile, + String principal); + + protected abstract void addNumExecutors(String numOfExecutors); + + protected abstract void addExecutorMemory(String executorMemory); + + protected abstract void addExecutorCores(String executorCores); private class ClientProtocol extends BaseProtocol { @@ -525,7 +421,7 @@ private String getSparkJobCredentialProviderPassword() { final String jobId = UUID.randomUUID().toString(); final Promise promise = driverRpc.createPromise(); final JobHandleImpl handle = - new JobHandleImpl(SparkClientImpl.this, promise, jobId, listeners); + new JobHandleImpl(AbstractSparkClient.this, promise, jobId, listeners); jobs.put(jobId, handle); final io.netty.util.concurrent.Future rpc = driverRpc.call(new JobRequest(jobId, job)); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java b/spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java index 2881252b0e..61489a3542 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/JobHandleImpl.java @@ -34,7 +34,7 @@ */ class JobHandleImpl implements JobHandle { - private final SparkClientImpl client; + private final SparkClient client; private final String jobId; private final MetricsCollection metrics; private final Promise promise; @@ -43,8 +43,8 @@ private volatile State state; private volatile SparkCounters sparkCounters; - JobHandleImpl(SparkClientImpl client, Promise promise, String jobId, - List> listeners) { + JobHandleImpl(SparkClient client, Promise promise, String jobId, + List> listeners) { this.client = client; this.jobId = jobId; this.promise = promise; @@ -233,7 +233,7 @@ private void fireStateChange(State newState, Listener listener) { } } - /** Last attempt at preventing stray jobs from accumulating in SparkClientImpl. */ + /** Last attempt at preventing stray jobs from accumulating in SparkClient. */ @Override protected void finalize() { if (!isDone()) { diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java index 1922e412a1..913889951e 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClient.java @@ -110,4 +110,11 @@ * Check if remote context is still active. */ boolean isActive(); + + /** + * Cancel the specified jobId + * + * @param jobId the jobId to cancel + */ + void cancel(String jobId); } diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java index fd9b72583a..b96eafd9f4 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java @@ -18,13 +18,11 @@ package org.apache.hive.spark.client; import java.io.IOException; -import java.io.PrintStream; import java.util.Map; import org.apache.hadoop.hive.common.classification.InterfaceAudience; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hive.spark.client.rpc.RpcServer; -import org.apache.spark.SparkException; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -82,10 +80,19 @@ public static void stop() { * @param hiveConf Configuration for Hive, contains hive.* properties. */ public static SparkClient createClient(Map sparkConf, HiveConf hiveConf, - String sessionId) - throws IOException, SparkException { + String sessionId) throws IOException { Preconditions.checkState(server != null, "initialize() not called."); - return new SparkClientImpl(server, sparkConf, hiveConf, sessionId); + switch (HiveConf.SparkClientType.valueOf( + hiveConf.getVar(HiveConf.ConfVars.SPARK_CLIENT_TYPE))) { + case SPARK_SUBMIT_CLIENT: + return new SparkSubmitSparkClient(server, sparkConf, hiveConf, sessionId); + case SPARK_LAUNCHER_CLIENT: + return new SparkLauncherSparkClient(server, sparkConf, hiveConf, sessionId); + default: + throw new IllegalArgumentException("Unknown Hive on Spark launcher type " + hiveConf.getVar( + HiveConf.ConfVars.SPARK_CLIENT_TYPE) + " valid options are " + + HiveConf.SparkClientType.SPARK_SUBMIT_CLIENT + " or " + + HiveConf.SparkClientType.SPARK_LAUNCHER_CLIENT); + } } - } diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkLauncherSparkClient.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkLauncherSparkClient.java new file mode 100644 index 0000000000..132c9b7976 --- /dev/null +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkLauncherSparkClient.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.spark.client; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.spark.client.rpc.RpcServer; + +import org.apache.spark.launcher.AbstractLauncher; +import org.apache.spark.launcher.InProcessLauncher; +import org.apache.spark.launcher.SparkAppHandle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + + +/** + * Extends the {@link AbstractSparkClient} and uses Spark's + * {@link org.apache.spark.launcher.SparkLauncher} to submit the HoS application. Specifically, + * it uses the {@link InProcessLauncher} to avoid spawning a sub-process to submit the Spark app. + * It uses an implementation of {@link Future} called {@link SparkLauncherFuture} to monitor the + * lifecycle of the Spark application. This allows the {@link AbstractSparkClient} to monitor the + * status of the spark submit code. + */ +public class SparkLauncherSparkClient extends AbstractSparkClient { + + private static final Logger LOG = LoggerFactory.getLogger( + SparkLauncherSparkClient.class.getName()); + + private static final long serialVersionUID = 2153000661341457380L; + + private static final Set FAILED_SPARK_STATES = Sets.newHashSet( + SparkAppHandle.State.FAILED, + SparkAppHandle.State.KILLED, + SparkAppHandle.State.LOST); + + private AbstractLauncher sparkLauncher; + + SparkLauncherSparkClient(RpcServer rpcServer, + Map conf, + HiveConf hiveConf, + String sessionid) throws IOException { + super(rpcServer, conf, hiveConf, sessionid); + } + + @Override + protected Future launchDriver(String isTesting, RpcServer rpcServer, + String clientId) throws IOException { + if (isTesting != null) { + System.setProperty("spark.testing", "true"); + } + + // Only allow the spark.master to be local in unit tests + if (isTesting == null) { + Preconditions.checkArgument(SparkClientUtilities.isYarnClusterMode( + this.conf.get("spark.master"), this.conf.get("spark.submit.deployMode")), + getClass().getName() + " is only supported in yarn-cluster mode"); + } + + CountDownLatch shutdownLatch = new CountDownLatch(1); + getSparkLauncher(); + + return new SparkLauncherFuture(shutdownLatch, getSparkLauncher().startApplication( + new SparkAppListener(shutdownLatch, rpcServer, clientId))); + } + + @Override + protected String getSparkHome() { + return null; + } + + @Override + protected void addAppArg(String arg) { + getSparkLauncher().addAppArgs(arg); + } + + @Override + protected void addExecutableJar(String jar) { + getSparkLauncher().setAppResource(jar); + } + + @Override + protected void addPropertiesFile(String absolutePath) { + getSparkLauncher().setPropertiesFile(absolutePath); + } + + @Override + protected void addClass(String name) { + getSparkLauncher().setMainClass(name); + } + + @Override + protected void addJars(String jars) { + getSparkLauncher().addJar(jars); + } + + @Override + protected void addProxyUser(String proxyUser) { + throw new UnsupportedOperationException(); +// getSparkLauncher().addSparkArg("--proxy-user", proxyUser); + } + + @Override + protected void addKeytabAndPrincipal(boolean isDoAsEnabled, String keyTabFile, String principal) { + throw new UnsupportedOperationException(); +// getSparkLauncher().addSparkArg("--principal", principal); +// getSparkLauncher().addSparkArg("--keytab", keyTabFile); + } + + @Override + protected void addNumExecutors(String numOfExecutors) { + getSparkLauncher().addSparkArg("--num-executors", numOfExecutors); + } + + @Override + protected void addExecutorMemory(String executorMemory) { + getSparkLauncher().addSparkArg("--executor-memory", executorMemory); + } + + @Override + protected void addExecutorCores(String executorCores) { + getSparkLauncher().addSparkArg("--executor-cores", executorCores); + } + + private AbstractLauncher getSparkLauncher() { + if (this.sparkLauncher == null) { + this.sparkLauncher = new InProcessLauncher(); + } + return this.sparkLauncher; + } + + @VisibleForTesting + static final class SparkLauncherFuture implements Future { + + private final CountDownLatch shutdownLatch; + private final SparkAppHandle sparkAppHandle; + + @VisibleForTesting + SparkLauncherFuture(CountDownLatch shutdownLatch, SparkAppHandle sparkAppHandle) { + this.shutdownLatch = shutdownLatch; + this.sparkAppHandle = sparkAppHandle; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + this.sparkAppHandle.stop(); + return true; + } + + @Override + public boolean isCancelled() { + return FAILED_SPARK_STATES.contains(this.sparkAppHandle.getState()); + } + + @Override + public boolean isDone() { + return this.sparkAppHandle.getState().isFinal(); + } + + @Override + public Void get() throws InterruptedException { + try { + return get(-1, null); + } catch (TimeoutException e) { + throw new IllegalStateException("Timeouts should never occur", e); + } + } + + @Override + public Void get(long timeout, TimeUnit unit) throws InterruptedException, TimeoutException { + if (timeout > 0) { + if (this.shutdownLatch.await(timeout, unit)) { + return null; + } else { + throw new TimeoutException("Spark Application did not reach running state within " + + "allotted timeout of " + timeout + " " + unit.toString()); + } + } else { + this.shutdownLatch.await(); + } + if (isCancelled()) { + throw new CancellationException("Spark app has been cancelled, current state " + + this.sparkAppHandle.getState()); + } + return null; + } + } + + @VisibleForTesting + static final class SparkAppListener implements SparkAppHandle.Listener { + + private final CountDownLatch shutdownLatch; + private final RpcServer rpcServer; + private final String clientId; + + SparkAppListener(CountDownLatch shutdownLatch, RpcServer rpcServer, String clientId) { + this.shutdownLatch = shutdownLatch; + this.rpcServer = rpcServer; + this.clientId = clientId; + } + + @Override + public void stateChanged(SparkAppHandle sparkAppHandle) { + LOG.info("Spark app transitioned to state = " + sparkAppHandle.getState()); + if (sparkAppHandle.getState().isFinal() || sparkAppHandle.getState().equals( + SparkAppHandle.State.RUNNING)) { + this.shutdownLatch.countDown(); + // sparkAppHandle.disconnect(); + LOG.info("Disconnected from Spark app handle"); + } + if (FAILED_SPARK_STATES.contains(sparkAppHandle.getState())) { + this.rpcServer.cancelClient(this.clientId, "Spark app launcher failed, transitioned " + + "to state " + sparkAppHandle.getState()); + } + } + + @Override + public void infoChanged(SparkAppHandle sparkAppHandle) { + // Do nothing + } + } +} diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java new file mode 100644 index 0000000000..04f3fef22d --- /dev/null +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkSubmitSparkClient.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.spark.client; + +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import com.google.common.base.Strings; +import com.google.common.collect.Lists; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; + +import org.apache.hadoop.hive.common.log.LogRedirector; +import org.apache.hadoop.hive.conf.Constants; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.spark.client.rpc.RpcServer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Extends the {@link AbstractSparkClient} and launches a child process to run Spark's {@code + * bin/spark-submit} script. Logs are re-directed from the child process logs. + */ +class SparkSubmitSparkClient extends AbstractSparkClient { + + private static final Logger LOG = LoggerFactory.getLogger(SparkSubmitSparkClient.class); + + private static final String SPARK_HOME_ENV = "SPARK_HOME"; + private static final String SPARK_HOME_KEY = "spark.home"; + + private static final long serialVersionUID = -4272763023516238171L; + + private List argv; + + SparkSubmitSparkClient(RpcServer rpcServer, Map conf, HiveConf hiveConf, + String sessionid) throws IOException { + super(rpcServer, conf, hiveConf, sessionid); + } + + @Override + protected String getSparkHome() { + String sparkHome = Strings.emptyToNull(conf.get(SPARK_HOME_KEY)); + if (sparkHome == null) { + sparkHome = Strings.emptyToNull(System.getenv(SPARK_HOME_ENV)); + } + if (sparkHome == null) { + sparkHome = Strings.emptyToNull(System.getProperty(SPARK_HOME_KEY)); + } + + Preconditions.checkNotNull(sparkHome, "Cannot use " + HiveConf.SparkClientType + .SPARK_SUBMIT_CLIENT + " without setting Spark Home"); + String master = conf.get("spark.master"); + Preconditions.checkArgument(master != null, "spark.master is not defined."); + + argv = Lists.newLinkedList(); + argv.add(new File(sparkHome, "bin/spark-submit").getAbsolutePath()); + + return sparkHome; + } + + @Override + protected void addAppArg(String arg) { + argv.add(arg); + } + + @Override + protected void addExecutableJar(String jar) { + argv.add(jar); + } + + @Override + protected void addPropertiesFile(String absolutePath) { + argv.add("--properties-file"); + argv.add(absolutePath); + } + + @Override + protected void addClass(String name) { + argv.add("--class"); + argv.add(RemoteDriver.class.getName()); + } + + @Override + protected void addJars(String jars) { + argv.add("--jars"); + argv.add(jars); + } + + @Override + protected void addProxyUser(String proxyUser) { + argv.add("--proxy-user"); + argv.add(proxyUser); + } + + @Override + protected void addKeytabAndPrincipal(boolean isDoAsEnabled, String keyTabFile, String principal) { + if (isDoAsEnabled) { + List kinitArgv = Lists.newLinkedList(); + kinitArgv.add("kinit"); + kinitArgv.add(principal); + kinitArgv.add("-k"); + kinitArgv.add("-t"); + kinitArgv.add(keyTabFile + ";"); + kinitArgv.addAll(argv); + argv = kinitArgv; + } else { + // if doAs is not enabled, we pass the principal/keypad to spark-submit in order to + // support the possible delegation token renewal in Spark + argv.add("--principal"); + argv.add(principal); + argv.add("--keytab"); + argv.add(keyTabFile); + } + } + + @Override + protected void addNumExecutors(String numOfExecutors) { + argv.add("--num-executors"); + argv.add(numOfExecutors); + } + + @Override + protected void addExecutorMemory(String executorMemory) { + argv.add("--executor-memory"); + argv.add(executorMemory); + } + + @Override + protected void addExecutorCores(String executorCores) { + argv.add("--executor-cores"); + argv.add(executorCores); + } + + private String getSparkJobCredentialProviderPassword() { + if (conf.containsKey("spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD")) { + return conf.get("spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD"); + } else if (conf.containsKey("spark.executorEnv.HADOOP_CREDSTORE_PASSWORD")) { + return conf.get("spark.executorEnv.HADOOP_CREDSTORE_PASSWORD"); + } + return null; + } + + @Override + protected Future launchDriver(String isTesting, RpcServer rpcServer, String clientId) throws + IOException { + Callable runnable; + + String cmd = Joiner.on(" ").join(argv); + LOG.info("Running client driver with argv: {}", cmd); + ProcessBuilder pb = new ProcessBuilder("sh", "-c", cmd); + + // Prevent hive configurations from being visible in Spark. + pb.environment().remove("HIVE_HOME"); + pb.environment().remove("HIVE_CONF_DIR"); + // Add credential provider password to the child process's environment + // In case of Spark the credential provider location is provided in the jobConf when the job is submitted + String password = getSparkJobCredentialProviderPassword(); + if(password != null) { + pb.environment().put(Constants.HADOOP_CREDENTIAL_PASSWORD_ENVVAR, password); + } + if (isTesting != null) { + pb.environment().put("SPARK_TESTING", isTesting); + } + + final Process child = pb.start(); + String threadName = Thread.currentThread().getName(); + final List childErrorLog = Collections.synchronizedList(new ArrayList()); + final LogRedirector.LogSourceCallback callback = () -> isAlive; + + LogRedirector.redirect("RemoteDriver-stdout-redir-" + threadName, + new LogRedirector(child.getInputStream(), LOG, callback)); + LogRedirector.redirect("RemoteDriver-stderr-redir-" + threadName, + new LogRedirector(child.getErrorStream(), LOG, childErrorLog, callback)); + + runnable = () -> { + try { + int exitCode = child.waitFor(); + if (exitCode != 0) { + StringBuilder errStr = new StringBuilder(); + synchronized(childErrorLog) { + for (Object aChildErrorLog : childErrorLog) { + errStr.append(aChildErrorLog); + errStr.append('\n'); + } + } + + LOG.warn("Child process exited with code {}", exitCode); + rpcServer.cancelClient(clientId, + "Child process (spark-submit) exited before connecting back with error log " + errStr.toString()); + } + } catch (InterruptedException ie) { + LOG.warn("Thread waiting on the child process (spark-submit) is interrupted, killing the child process."); + rpcServer.cancelClient(clientId, "Thread waiting on the child process (spark-submit) is interrupted"); + Thread.interrupted(); + child.destroy(); + } catch (Exception e) { + String errMsg = "Exception while waiting for child process (spark-submit)"; + LOG.warn(errMsg, e); + rpcServer.cancelClient(clientId, errMsg); + } + return null; + }; + + FutureTask futureTask = new FutureTask<>(runnable); + + Thread driverThread = new Thread(futureTask); + driverThread.setDaemon(true); + driverThread.setName("Driver"); + driverThread.start(); + + return futureTask; + } +} diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java index d6b627b630..b81a34ba71 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestJobHandle.java @@ -32,7 +32,7 @@ @RunWith(MockitoJUnitRunner.class) public class TestJobHandle { - @Mock private SparkClientImpl client; + @Mock private SparkClient client; @Mock private Promise promise; @Mock private JobHandle.Listener listener; @Mock private JobHandle.Listener listener2; diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index fdf882b7ae..9a2e044970 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -53,6 +53,7 @@ import com.google.common.io.ByteStreams; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hive.spark.counter.SparkCounters; +import org.apache.spark.SparkContext$; import org.apache.spark.SparkException; import org.apache.spark.SparkFiles; import org.apache.spark.api.java.JavaFutureAction; @@ -71,6 +72,8 @@ static { HIVECONF.set("hive.spark.client.connect.timeout", "30000ms"); + HIVECONF.setVar(HiveConf.ConfVars.SPARK_CLIENT_TYPE, HiveConf.SparkClientType + .SPARK_LAUNCHER_CLIENT.toString()); } private Map createConf() { @@ -81,6 +84,7 @@ conf.put("spark.app.name", "SparkClientSuite Remote App"); conf.put("spark.driver.extraClassPath", classpath); conf.put("spark.executor.extraClassPath", classpath); + conf.put("spark.testing", "true"); if (!Strings.isNullOrEmpty(System.getProperty("spark.home"))) { conf.put("spark.home", System.getProperty("spark.home")); @@ -338,6 +342,26 @@ private void runTest(TestFunction test) throws Exception { client.stop(); } SparkClientFactory.stop(); + waitForSparkContextShutdown(); + } + } + + /** + * This was added to avoid a race condition where we try to create multiple SparkContexts in + * the same process. Since spark.master = local everything is run in the same JVM. Since we + * don't wait for the RemoteDriver to shutdown it's SparkContext, its possible that we finish a + * test before the SparkContext has been shutdown. In order to avoid the multiple SparkContexts + * in a single JVM exception, we wait for the SparkContext to shutdown after each test. + */ + private void waitForSparkContextShutdown() throws InterruptedException { + for (int i = 0; i < 100; i++) { + if (SparkContext$.MODULE$.getActive().isEmpty()) { + break; + } + Thread.sleep(100); + } + if (!SparkContext$.MODULE$.getActive().isEmpty()) { + throw new IllegalStateException("SparkContext did not shutdown in time"); } } diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkLauncherSparkClient.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkLauncherSparkClient.java new file mode 100644 index 0000000000..247b5e28c1 --- /dev/null +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkLauncherSparkClient.java @@ -0,0 +1,99 @@ +package org.apache.hive.spark.client; + +import com.google.common.collect.ImmutableMap; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.spark.client.rpc.RpcServer; +import org.apache.spark.launcher.SparkAppHandle; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TestSparkLauncherSparkClient { + + @Test + public void testSparkLauncherLocal() throws IOException { + HiveConf hiveConf = new HiveConf(); + hiveConf.setVar(HiveConf.ConfVars.SPARK_CLIENT_TYPE, + HiveConf.SparkClientType.SPARK_LAUNCHER_CLIENT.toString()); + + ImmutableMap.Builder sparkConfBuilder = new ImmutableMap.Builder<>(); + + sparkConfBuilder.put("spark.master", "local"); + sparkConfBuilder.put("spark.app.name", "TestSparkLauncherSparkClient App"); + sparkConfBuilder.put("spark.testing", "true"); + + Map sparkConf = sparkConfBuilder.build(); + + SparkClientFactory.initialize(sparkConf); + SparkClient client = null; + try { + client = SparkClientFactory.createClient(sparkConf, hiveConf, UUID.randomUUID().toString()); + client.submit(jc -> jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5)).count()); + } finally { + if (client != null) { + client.stop(); + } + SparkClientFactory.stop(); + } + } + + @Test + public void testSparkLauncherFutureGet() { + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.SUBMITTED, + SparkAppHandle.State.RUNNING); + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.SUBMITTED, + SparkAppHandle.State.FINISHED); + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.SUBMITTED, + SparkAppHandle.State.FAILED); + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.SUBMITTED, + SparkAppHandle.State.KILLED); + + testChainOfStates(SparkAppHandle.State.LOST); + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.LOST); + testChainOfStates(SparkAppHandle.State.CONNECTED, SparkAppHandle.State.SUBMITTED, + SparkAppHandle.State.LOST); + } + + private void testChainOfStates(SparkAppHandle.State... states) { + SparkAppHandle sparkAppHandle = mock(SparkAppHandle.class); + + CountDownLatch shutdownLatch = new CountDownLatch(1); + + SparkLauncherSparkClient.SparkAppListener sparkAppListener = new SparkLauncherSparkClient.SparkAppListener( + shutdownLatch, mock(RpcServer.class), ""); + SparkLauncherSparkClient.SparkLauncherFuture sparkLauncherFuture = new SparkLauncherSparkClient.SparkLauncherFuture( + shutdownLatch, sparkAppHandle); + + CompletableFuture future = CompletableFuture.runAsync(() -> { + try { + sparkLauncherFuture.get(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + for (int i = 0; i < states.length - 1; i++) { + when(sparkAppHandle.getState()).thenReturn(states[i]); + sparkAppListener.stateChanged(sparkAppHandle); + Assert.assertTrue(!future.isDone()); + } + + when(sparkAppHandle.getState()).thenReturn(states[states.length - 1]); + sparkAppListener.stateChanged(sparkAppHandle); + Assert.assertTrue(future.isDone()); + Assert.assertEquals(shutdownLatch.getCount(), 0); + verify(sparkAppHandle).disconnect(); + } +}