commit 69d398723e01f6bcb5e423f352dc062e1fde0752 Author: Sahil Takiar Date: Tue Jan 2 13:35:13 2018 -0800 HIVE-16484: Investigate SparkLauncher for HoS as alternative to bin/spark-submit diff --git a/common/src/java/org/apache/hadoop/hive/common/ProcessRunner.java b/common/src/java/org/apache/hadoop/hive/common/ProcessRunner.java new file mode 100644 index 0000000000..ff7658fd3c --- /dev/null +++ b/common/src/java/org/apache/hadoop/hive/common/ProcessRunner.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.common; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import com.google.common.base.Joiner; + +import org.slf4j.Logger; + + +public class ProcessRunner { + + private static final long MAX_ERR_LOG_LINES = 1000; + private final Logger log; + + public ProcessRunner(Logger log) { + this.log = log; + } + + public void run(List args) throws IOException { + String cmd = Joiner.on(" ").join(args); + log.info("Running command with argv: {}", cmd); + ProcessBuilder pb = new ProcessBuilder("sh", "-c", cmd); + + final Process child = pb.start(); + final List childErrorLog = Collections.synchronizedList(new ArrayList()); + + String threadName = Thread.currentThread().getName(); + redirect("stdout-redir-" + threadName, new Redirector(child.getInputStream())); + redirect("stderr-redir-" + threadName, new Redirector(child.getErrorStream(), childErrorLog)); + + try { + int exitCode = child.waitFor(); + if (exitCode != 0) { + StringBuilder errStr = new StringBuilder(); + synchronized(childErrorLog) { + for (String aChildErrorLog : childErrorLog) { + errStr.append(aChildErrorLog); + errStr.append('\n'); + } + } + log.warn("Child process exited with error log " + errStr.toString()); + log.warn("Child process exited with code {}", exitCode); + } + } catch (InterruptedException ie) { + log.warn("Waiting thread interrupted, killing child process."); + Thread.interrupted(); + child.destroy(); + } catch (Exception e) { + log.warn("Exception while waiting for child process.", e); + } + } + + private void redirect(String name, Redirector redirector) { + Thread thread = new Thread(redirector); + thread.setName(name); + thread.setDaemon(true); + thread.start(); + } + + private class Redirector implements Runnable { + + private final BufferedReader in; + private List errLogs; + private int numErrLogLines = 0; + + Redirector(InputStream in) { + this.in = new BufferedReader(new InputStreamReader(in)); + } + + Redirector(InputStream in, List errLogs) { + this.in = new BufferedReader(new InputStreamReader(in)); + this.errLogs = errLogs; + } + + @Override + public void run() { + try { + String line; + while ((line = in.readLine()) != null) { + log.info(line); + if (errLogs != null) { + if (numErrLogLines++ < MAX_ERR_LOG_LINES) { + errLogs.add(line); + } + } + } + } catch (Exception e) { + log.warn("Error in redirector thread.", e); + } + } + } +} diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java index ede8ce9e40..4abb8da800 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java @@ -68,6 +68,13 @@ @InterfaceAudience.Private public class RemoteDriver { + static final String REMOTE_HOST_KEY = "--remote-host"; + static final String REMOTE_PORT_KEY = "--remote-port"; + static final String CONF_KEY = "--conf"; + + private static final String CLIENT_ID_KEY = "--client-id"; + private static final String SECRET_KEY = "--secret"; + private static final Logger LOG = LoggerFactory.getLogger(RemoteDriver.class); private final Map> activeJobs; @@ -99,15 +106,15 @@ private RemoteDriver(String[] args) throws Exception { int serverPort = -1; for (int idx = 0; idx < args.length; idx += 2) { String key = args[idx]; - if (key.equals("--remote-host")) { + if (key.equals(REMOTE_HOST_KEY)) { serverAddress = getArg(args, idx); - } else if (key.equals("--remote-port")) { + } else if (key.equals(REMOTE_PORT_KEY)) { serverPort = Integer.parseInt(getArg(args, idx)); - } else if (key.equals("--client-id")) { + } else if (key.equals(CLIENT_ID_KEY)) { conf.set(SparkClientFactory.CONF_CLIENT_ID, getArg(args, idx)); - } else if (key.equals("--secret")) { + } else if (key.equals(SECRET_KEY)) { conf.set(SparkClientFactory.CONF_KEY_SECRET, getArg(args, idx)); - } else if (key.equals("--conf")) { + } else if (key.equals(CONF_KEY)) { String[] val = getArg(args, idx).split("[=]", 2); conf.set(val[0], val[1]); } else { diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java index 50c7bb20c4..c6b6494c45 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientFactory.java @@ -34,9 +34,6 @@ @InterfaceAudience.Private public final class SparkClientFactory { - /** Used to run the driver in-process, mostly for testing. */ - static final String CONF_KEY_IN_PROCESS = "spark.client.do_not_use.run_driver_in_process"; - /** Used by client and driver to share a client ID for establishing an RPC session. */ static final String CONF_CLIENT_ID = "spark.client.authentication.client_id"; @@ -84,9 +81,8 @@ public static void stop() { * @param hiveConf Configuration for Hive, contains hive.* properties. */ public static SparkClient createClient(Map sparkConf, HiveConf hiveConf, String sessionId) - throws IOException, SparkException { + throws IOException, InterruptedException { Preconditions.checkState(server != null, "initialize() not called."); return new SparkClientImpl(server, sparkConf, hiveConf, sessionId); } - } diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java index 49b7deb5ee..93e1cd9578 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientImpl.java @@ -25,8 +25,10 @@ import com.google.common.base.Strings; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import com.google.common.io.Resources; import io.netty.channel.ChannelHandlerContext; @@ -42,32 +44,39 @@ import java.io.Writer; import java.net.URI; import java.net.URL; -import java.util.ArrayList; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import org.apache.commons.lang3.StringUtils; -import org.apache.hadoop.hive.common.log.LogRedirector; +import org.apache.hadoop.hive.common.ProcessRunner; import org.apache.hadoop.hive.conf.Constants; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.shims.Utils; import org.apache.hadoop.security.SecurityUtil; + import org.apache.hive.spark.client.rpc.Rpc; import org.apache.hive.spark.client.rpc.RpcConfiguration; import org.apache.hive.spark.client.rpc.RpcServer; + import org.apache.spark.SparkContext; import org.apache.spark.SparkException; +import org.apache.spark.launcher.SparkAppHandle; +import org.apache.spark.launcher.SparkLauncher; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + class SparkClientImpl implements SparkClient { + private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(SparkClientImpl.class); @@ -82,27 +91,30 @@ private static final String DRIVER_EXTRA_CLASSPATH = "spark.driver.extraClassPath"; private static final String EXECUTOR_EXTRA_CLASSPATH = "spark.executor.extraClassPath"; - private final Map conf; + protected final Map conf; private final HiveConf hiveConf; - private final Thread driverThread; + private final SparkAppHandle sparkAppHandle; private final Map> jobs; private final Rpc driverRpc; private final ClientProtocol protocol; private volatile boolean isAlive; SparkClientImpl(RpcServer rpcServer, Map conf, HiveConf hiveConf, - String sessionid) throws IOException, SparkException { + String sessionId) throws IOException, InterruptedException { this.conf = conf; this.hiveConf = hiveConf; this.jobs = Maps.newConcurrentMap(); String secret = rpcServer.createSecret(); - this.driverThread = startDriver(rpcServer, sessionid, secret); + String serverAddress = rpcServer.getAddress(); + String serverPort = String.valueOf(rpcServer.getPort()); + + this.sparkAppHandle = startDriver(sessionId, secret, serverAddress, serverPort); this.protocol = new ClientProtocol(); try { // The RPC server will take care of timeouts here. - this.driverRpc = rpcServer.registerClient(sessionid, secret, protocol).get(); + this.driverRpc = rpcServer.registerClient(sessionId, secret, protocol).get(); } catch (Throwable e) { String errorMsg = null; if (e.getCause() instanceof TimeoutException) { @@ -117,12 +129,12 @@ errorMsg = "Error while waiting for client to connect."; } LOG.error(errorMsg, e); - driverThread.interrupt(); + try { - driverThread.join(); - } catch (InterruptedException ie) { - // Give up. - LOG.warn("Interrupted before driver thread was finished.", ie); + shutdownSparkHandle(); + } catch (Throwable shutdownException) { + // LOG any exception but don't propagate it, instead propagate the exception from RPC client + LOG.error("Failed to properly shutdown the Spark job", e); } throw Throwables.propagate(e); } @@ -167,16 +179,7 @@ public void stop() { } } - long endTime = System.currentTimeMillis() + DEFAULT_SHUTDOWN_TIMEOUT; - try { - driverThread.join(DEFAULT_SHUTDOWN_TIMEOUT); - } catch (InterruptedException ie) { - LOG.debug("Interrupted before driver thread was finished."); - } - if (endTime - System.currentTimeMillis() <= 0) { - LOG.warn("Timed out shutting down remote driver, interrupting..."); - driverThread.interrupt(); - } + shutdownSparkHandle(); } @Override @@ -208,336 +211,203 @@ void cancel(String jobId) { protocol.cancel(jobId); } - private Thread startDriver(final RpcServer rpcServer, final String clientId, final String secret) - throws IOException { - Runnable runnable; - final String serverAddress = rpcServer.getAddress(); - final String serverPort = String.valueOf(rpcServer.getPort()); - - if (conf.containsKey(SparkClientFactory.CONF_KEY_IN_PROCESS)) { - // Mostly for testing things quickly. Do not do this in production. - // when invoked in-process it inherits the environment variables of the parent - LOG.warn("!!!! Running remote driver in-process. !!!!"); - runnable = new Runnable() { - @Override - public void run() { - List args = Lists.newArrayList(); - args.add("--remote-host"); - args.add(serverAddress); - args.add("--remote-port"); - args.add(serverPort); - args.add("--client-id"); - args.add(clientId); - args.add("--secret"); - args.add(secret); - - for (Map.Entry e : conf.entrySet()) { - args.add("--conf"); - args.add(String.format("%s=%s", e.getKey(), conf.get(e.getKey()))); - } - try { - RemoteDriver.main(args.toArray(new String[args.size()])); - } catch (Exception e) { - LOG.error("Error running driver.", e); - } - } - }; - } else { - // If a Spark installation is provided, use the spark-submit script. Otherwise, call the - // SparkSubmit class directly, which has some caveats (like having to provide a proper - // version of Guava on the classpath depending on the deploy mode). - String sparkHome = Strings.emptyToNull(conf.get(SPARK_HOME_KEY)); - if (sparkHome == null) { - sparkHome = Strings.emptyToNull(System.getenv(SPARK_HOME_ENV)); - } + protected SparkAppHandle startDriver(final String clientId, final String secret, final String serverAddress, + final String serverPort) throws IOException, InterruptedException { + + String sparkHome = Strings.emptyToNull(conf.get(SPARK_HOME_KEY)); + if (sparkHome == null) { + sparkHome = Strings.emptyToNull(System.getenv(SPARK_HOME_ENV)); + } + if (sparkHome == null) { + sparkHome = Strings.emptyToNull(System.getProperty(SPARK_HOME_KEY)); + } + String sparkLogDir = conf.get("hive.spark.log.dir"); + if (sparkLogDir == null) { if (sparkHome == null) { - sparkHome = Strings.emptyToNull(System.getProperty(SPARK_HOME_KEY)); - } - String sparkLogDir = conf.get("hive.spark.log.dir"); - if (sparkLogDir == null) { - if (sparkHome == null) { - sparkLogDir = "./target/"; - } else { - sparkLogDir = sparkHome + "/logs/"; - } + sparkLogDir = "./target/"; + } else { + sparkLogDir = sparkHome + "/logs/"; } + } - String osxTestOpts = ""; - if (Strings.nullToEmpty(System.getProperty("os.name")).toLowerCase().contains("mac")) { - osxTestOpts = Strings.nullToEmpty(System.getenv(OSX_TEST_OPTS)); - } + String osxTestOpts = ""; + if (Strings.nullToEmpty(System.getProperty("os.name")).toLowerCase().contains("mac")) { + osxTestOpts = Strings.nullToEmpty(System.getenv(OSX_TEST_OPTS)); + } - String driverJavaOpts = Joiner.on(" ").skipNulls().join( - "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(DRIVER_OPTS_KEY)); - String executorJavaOpts = Joiner.on(" ").skipNulls().join( - "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(EXECUTOR_OPTS_KEY)); - - // Create a file with all the job properties to be read by spark-submit. Change the - // file's permissions so that only the owner can read it. This avoid having the - // connection secret show up in the child process's command line. - File properties = File.createTempFile("spark-submit.", ".properties"); - if (!properties.setReadable(false) || !properties.setReadable(true, true)) { - throw new IOException("Cannot change permissions of job properties file."); - } - properties.deleteOnExit(); + String driverJavaOpts = Joiner.on(" ").skipNulls().join( + "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(DRIVER_OPTS_KEY)); + String executorJavaOpts = Joiner.on(" ").skipNulls().join( + "-Dhive.spark.log.dir=" + sparkLogDir, osxTestOpts, conf.get(EXECUTOR_OPTS_KEY)); + + // Create a file with all the job properties to be read by spark-submit. Change the + // file's permissions so that only the owner can read it. This avoid having the + // connection secret show up in the child process's command line. + File properties = File.createTempFile("spark-submit.", ".properties"); + if (!properties.setReadable(false) || !properties.setReadable(true, true)) { + throw new IOException("Cannot change permissions of job properties file."); + } + properties.deleteOnExit(); - Properties allProps = new Properties(); - // first load the defaults from spark-defaults.conf if available - try { - URL sparkDefaultsUrl = Thread.currentThread().getContextClassLoader().getResource("spark-defaults.conf"); - if (sparkDefaultsUrl != null) { - LOG.info("Loading spark defaults: " + sparkDefaultsUrl); - allProps.load(new ByteArrayInputStream(Resources.toByteArray(sparkDefaultsUrl))); - } - } catch (Exception e) { - String msg = "Exception trying to load spark-defaults.conf: " + e; - throw new IOException(msg, e); - } - // then load the SparkClientImpl config - for (Map.Entry e : conf.entrySet()) { - allProps.put(e.getKey(), conf.get(e.getKey())); + Properties allProps = new Properties(); + // first load the defaults from spark-defaults.conf if available + try { + URL sparkDefaultsUrl = Thread.currentThread().getContextClassLoader().getResource("spark-defaults.conf"); + if (sparkDefaultsUrl != null) { + LOG.info("Loading spark defaults: " + sparkDefaultsUrl); + allProps.load(new ByteArrayInputStream(Resources.toByteArray(sparkDefaultsUrl))); } - allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId); - allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret); - allProps.put(DRIVER_OPTS_KEY, driverJavaOpts); - allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts); - - String isTesting = conf.get("spark.testing"); - if (isTesting != null && isTesting.equalsIgnoreCase("true")) { - String hiveHadoopTestClasspath = Strings.nullToEmpty(System.getenv("HIVE_HADOOP_TEST_CLASSPATH")); - if (!hiveHadoopTestClasspath.isEmpty()) { - String extraDriverClasspath = Strings.nullToEmpty((String)allProps.get(DRIVER_EXTRA_CLASSPATH)); - if (extraDriverClasspath.isEmpty()) { - allProps.put(DRIVER_EXTRA_CLASSPATH, hiveHadoopTestClasspath); - } else { - extraDriverClasspath = extraDriverClasspath.endsWith(File.pathSeparator) ? extraDriverClasspath : extraDriverClasspath + File.pathSeparator; - allProps.put(DRIVER_EXTRA_CLASSPATH, extraDriverClasspath + hiveHadoopTestClasspath); - } - - String extraExecutorClasspath = Strings.nullToEmpty((String)allProps.get(EXECUTOR_EXTRA_CLASSPATH)); - if (extraExecutorClasspath.isEmpty()) { - allProps.put(EXECUTOR_EXTRA_CLASSPATH, hiveHadoopTestClasspath); - } else { - extraExecutorClasspath = extraExecutorClasspath.endsWith(File.pathSeparator) ? extraExecutorClasspath : extraExecutorClasspath + File.pathSeparator; - allProps.put(EXECUTOR_EXTRA_CLASSPATH, extraExecutorClasspath + hiveHadoopTestClasspath); - } + } catch (Exception e) { + String msg = "Exception trying to load spark-defaults.conf: " + e; + throw new IOException(msg, e); + } + // then load the SparkClientImpl config + for (Map.Entry e : conf.entrySet()) { + allProps.put(e.getKey(), conf.get(e.getKey())); + } + allProps.put(SparkClientFactory.CONF_CLIENT_ID, clientId); + allProps.put(SparkClientFactory.CONF_KEY_SECRET, secret); + allProps.put(DRIVER_OPTS_KEY, driverJavaOpts); + allProps.put(EXECUTOR_OPTS_KEY, executorJavaOpts); + + String isTesting = conf.get("spark.testing"); + if (isTesting != null && isTesting.equalsIgnoreCase("true")) { + String hiveHadoopTestClasspath = Strings.nullToEmpty(System.getenv("HIVE_HADOOP_TEST_CLASSPATH")); + if (!hiveHadoopTestClasspath.isEmpty()) { + String extraDriverClasspath = Strings.nullToEmpty((String) allProps.get(DRIVER_EXTRA_CLASSPATH)); + if (extraDriverClasspath.isEmpty()) { + allProps.put(DRIVER_EXTRA_CLASSPATH, hiveHadoopTestClasspath); + } else { + extraDriverClasspath = extraDriverClasspath.endsWith( + File.pathSeparator) ? extraDriverClasspath : extraDriverClasspath + File.pathSeparator; + allProps.put(DRIVER_EXTRA_CLASSPATH, extraDriverClasspath + hiveHadoopTestClasspath); } - } - Writer writer = new OutputStreamWriter(new FileOutputStream(properties), Charsets.UTF_8); - try { - allProps.store(writer, "Spark Context configuration"); - } finally { - writer.close(); + String extraExecutorClasspath = Strings.nullToEmpty((String) allProps.get(EXECUTOR_EXTRA_CLASSPATH)); + if (extraExecutorClasspath.isEmpty()) { + allProps.put(EXECUTOR_EXTRA_CLASSPATH, hiveHadoopTestClasspath); + } else { + extraExecutorClasspath = extraExecutorClasspath.endsWith( + File.pathSeparator) ? extraExecutorClasspath : extraExecutorClasspath + File.pathSeparator; + allProps.put(EXECUTOR_EXTRA_CLASSPATH, extraExecutorClasspath + hiveHadoopTestClasspath); + } } + } - // Define how to pass options to the child process. If launching in client (or local) - // mode, the driver options need to be passed directly on the command line. Otherwise, - // SparkSubmit will take care of that for us. - String master = conf.get("spark.master"); - Preconditions.checkArgument(master != null, "spark.master is not defined."); - String deployMode = conf.get("spark.submit.deployMode"); - - List argv = Lists.newLinkedList(); - - if (sparkHome != null) { - argv.add(new File(sparkHome, "bin/spark-submit").getAbsolutePath()); - } else { - LOG.info("No spark.home provided, calling SparkSubmit directly."); - argv.add(new File(System.getProperty("java.home"), "bin/java").getAbsolutePath()); - - if (master.startsWith("local") || master.startsWith("mesos") || - SparkClientUtilities.isYarnClientMode(master, deployMode) || - master.startsWith("spark")) { - String mem = conf.get("spark.driver.memory"); - if (mem != null) { - argv.add("-Xms" + mem); - argv.add("-Xmx" + mem); - } - - String cp = conf.get("spark.driver.extraClassPath"); - if (cp != null) { - argv.add("-classpath"); - argv.add(cp); - } + try (Writer writer = new OutputStreamWriter(new FileOutputStream(properties), Charsets.UTF_8)) { + allProps.store(writer, "Spark Context configuration"); + } - String libPath = conf.get("spark.driver.extraLibPath"); - if (libPath != null) { - argv.add("-Djava.library.path=" + libPath); - } + // Define how to pass options to the child process. If launching in client (or local) + // mode, the driver options need to be passed directly on the command line. Otherwise, + // SparkSubmit will take care of that for us. + String master = conf.get("spark.master"); + Preconditions.checkArgument(master != null, "spark.master is not defined."); + String deployMode = conf.get("spark.submit.deployMode"); + + // Add credential provider password to the child process's environment + // In case of Spark the credential provider location is provided in the jobConf when the job is submitted + String password = getSparkJobCredentialProviderPassword(); + ImmutableMap.Builder env = new ImmutableMap.Builder<>(); + if (password != null) { + env.put(Constants.HADOOP_CREDENTIAL_PASSWORD_ENVVAR, password); + } + if (isTesting != null) { + env.put("SPARK_TESTING", isTesting); + } - String extra = conf.get(DRIVER_OPTS_KEY); - if (extra != null) { - for (String opt : extra.split("[ ]")) { - if (!opt.trim().isEmpty()) { - argv.add(opt.trim()); - } - } - } - } + SparkLauncher sparkLauncher = new SparkLauncher(env.build()); + if (sparkHome != null) { + sparkLauncher.setSparkHome(sparkHome); + } - argv.add("org.apache.spark.deploy.SparkSubmit"); + if (SparkClientUtilities.isYarnClusterMode(master, deployMode)) { + String executorCores = conf.get("spark.executor.cores"); + if (executorCores != null) { + sparkLauncher.addSparkArg("--executor-cores", executorCores); } - if (SparkClientUtilities.isYarnClusterMode(master, deployMode)) { - String executorCores = conf.get("spark.executor.cores"); - if (executorCores != null) { - argv.add("--executor-cores"); - argv.add(executorCores); - } - - String executorMemory = conf.get("spark.executor.memory"); - if (executorMemory != null) { - argv.add("--executor-memory"); - argv.add(executorMemory); - } - - String numOfExecutors = conf.get("spark.executor.instances"); - if (numOfExecutors != null) { - argv.add("--num-executors"); - argv.add(numOfExecutors); - } + String executorMemory = conf.get("spark.executor.memory"); + if (executorMemory != null) { + sparkLauncher.addSparkArg("--executor-memory", executorMemory); } - // The options --principal/--keypad do not work with --proxy-user in spark-submit.sh - // (see HIVE-15485, SPARK-5493, SPARK-19143), so Hive could only support doAs or - // delegation token renewal, but not both. Since doAs is a more common case, if both - // are needed, we choose to favor doAs. So when doAs is enabled, we use kinit command, - // otherwise, we pass the principal/keypad to spark to support the token renewal for - // long-running application. - if ("kerberos".equals(hiveConf.get(HADOOP_SECURITY_AUTHENTICATION))) { - String principal = SecurityUtil.getServerPrincipal(hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL), - "0.0.0.0"); - String keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); - if (StringUtils.isNotBlank(principal) && StringUtils.isNotBlank(keyTabFile)) { - if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) { - List kinitArgv = Lists.newLinkedList(); - kinitArgv.add("kinit"); - kinitArgv.add(principal); - kinitArgv.add("-k"); - kinitArgv.add("-t"); - kinitArgv.add(keyTabFile + ";"); - kinitArgv.addAll(argv); - argv = kinitArgv; - } else { - // if doAs is not enabled, we pass the principal/keypad to spark-submit in order to - // support the possible delegation token renewal in Spark - argv.add("--principal"); - argv.add(principal); - argv.add("--keytab"); - argv.add(keyTabFile); - } - } + + String numOfExecutors = conf.get("spark.executor.instances"); + if (numOfExecutors != null) { + sparkLauncher.addSparkArg("--num-executors", numOfExecutors); } + } + // The options --principal/--keypad do not work with --proxy-user in spark-submit.sh + // (see HIVE-15485, SPARK-5493, SPARK-19143), so Hive could only support doAs or + // delegation token renewal, but not both. Since doAs is a more common case, if both + // are needed, we choose to favor doAs. So when doAs is enabled, we use kinit command, + // otherwise, we pass the principal/keypad to spark to support the token renewal for + // long-running application. + if ("kerberos".equals(hiveConf.get(HADOOP_SECURITY_AUTHENTICATION))) { + String principal = SecurityUtil.getServerPrincipal(hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL), + "0.0.0.0"); + String keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB); if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) { - try { - String currentUser = Utils.getUGI().getShortUserName(); - // do not do impersonation in CLI mode - if (!currentUser.equals(System.getProperty("user.name"))) { - LOG.info("Attempting impersonation of " + currentUser); - argv.add("--proxy-user"); - argv.add(currentUser); - } - } catch (Exception e) { - String msg = "Cannot obtain username: " + e; - throw new IllegalStateException(msg, e); - } + runKinit(principal, keyTabFile); + } else { + // if doAs is not enabled, we pass the principal/keypad to spark-submit in order to + // support the possible delegation token renewal in Spark + sparkLauncher.addSparkArg("--principal", principal); + sparkLauncher.addSparkArg("--keytab", keyTabFile); } - - String regStr = conf.get("spark.kryo.registrator"); - if (HIVE_KRYO_REG_NAME.equals(regStr)) { - argv.add("--jars"); - argv.add(SparkClientUtilities.findKryoRegistratorJar(hiveConf)); + } + if (hiveConf.getBoolVar(HiveConf.ConfVars.HIVE_SERVER2_ENABLE_DOAS)) { + try { + String currentUser = Utils.getUGI().getShortUserName(); + // do not do impersonation in CLI mode + if (!currentUser.equals(System.getProperty("user.name"))) { + LOG.info("Attempting impersonation of " + currentUser); + sparkLauncher.addSparkArg("--proxy-user", currentUser); + } + } catch (Exception e) { + String msg = "Cannot obtain username: " + e; + throw new IllegalStateException(msg, e); } + } - argv.add("--properties-file"); - argv.add(properties.getAbsolutePath()); - argv.add("--class"); - argv.add(RemoteDriver.class.getName()); - - String jar = "spark-internal"; - if (SparkContext.jarOfClass(this.getClass()).isDefined()) { - jar = SparkContext.jarOfClass(this.getClass()).get(); - } - argv.add(jar); - - argv.add("--remote-host"); - argv.add(serverAddress); - argv.add("--remote-port"); - argv.add(serverPort); - - //hive.spark.* keys are passed down to the RemoteDriver via --conf, - //as --properties-file contains the spark.* keys that are meant for SparkConf object. - for (String hiveSparkConfKey : RpcConfiguration.HIVE_SPARK_RSC_CONFIGS) { - String value = RpcConfiguration.getValue(hiveConf, hiveSparkConfKey); - argv.add("--conf"); - argv.add(String.format("%s=%s", hiveSparkConfKey, value)); - } + String regStr = conf.get("spark.kryo.registrator"); + if (HIVE_KRYO_REG_NAME.equals(regStr)) { + sparkLauncher.addJar(SparkClientUtilities.findKryoRegistratorJar(hiveConf)); + } - String cmd = Joiner.on(" ").join(argv); - LOG.info("Running client driver with argv: {}", cmd); - ProcessBuilder pb = new ProcessBuilder("sh", "-c", cmd); - - // Prevent hive configurations from being visible in Spark. - pb.environment().remove("HIVE_HOME"); - pb.environment().remove("HIVE_CONF_DIR"); - // Add credential provider password to the child process's environment - // In case of Spark the credential provider location is provided in the jobConf when the job is submitted - String password = getSparkJobCredentialProviderPassword(); - if(password != null) { - pb.environment().put(Constants.HADOOP_CREDENTIAL_PASSWORD_ENVVAR, password); - } - if (isTesting != null) { - pb.environment().put("SPARK_TESTING", isTesting); - } + sparkLauncher.setPropertiesFile(properties.getAbsolutePath()); + sparkLauncher.setMainClass(RemoteDriver.class.getName()); - final Process child = pb.start(); - String threadName = Thread.currentThread().getName(); - final List childErrorLog = Collections.synchronizedList(new ArrayList()); - final LogRedirector.LogSourceCallback callback = () -> {return isAlive;}; + String jar = "spark-internal"; + if (SparkContext.jarOfClass(this.getClass()).isDefined()) { + jar = SparkContext.jarOfClass(this.getClass()).get(); + } + sparkLauncher.setAppResource(jar); - LogRedirector.redirect("RemoteDriver-stdout-redir-" + threadName, - new LogRedirector(child.getInputStream(), LOG, callback)); - LogRedirector.redirect("RemoteDriver-stderr-redir-" + threadName, - new LogRedirector(child.getErrorStream(), LOG, childErrorLog, callback)); + sparkLauncher.addAppArgs(RemoteDriver.REMOTE_HOST_KEY, serverAddress); + sparkLauncher.addAppArgs(RemoteDriver.REMOTE_PORT_KEY, serverPort); - runnable = new Runnable() { - @Override - public void run() { - try { - int exitCode = child.waitFor(); - if (exitCode != 0) { - StringBuilder errStr = new StringBuilder(); - synchronized(childErrorLog) { - Iterator iter = childErrorLog.iterator(); - while(iter.hasNext()){ - errStr.append(iter.next()); - errStr.append('\n'); - } - } - - LOG.warn("Child process exited with code {}", exitCode); - rpcServer.cancelClient(clientId, - "Child process (spark-submit) exited before connecting back with error log " + errStr.toString()); - } - } catch (InterruptedException ie) { - LOG.warn("Thread waiting on the child process (spark-submit) is interrupted, killing the child process."); - rpcServer.cancelClient(clientId, "Thread waiting on the child porcess (spark-submit) is interrupted"); - Thread.interrupted(); - child.destroy(); - } catch (Exception e) { - String errMsg = "Exception while waiting for child process (spark-submit)"; - LOG.warn(errMsg, e); - rpcServer.cancelClient(clientId, errMsg); - } - } - }; + //hive.spark.* keys are passed down to the RemoteDriver via --conf, + //as --properties-file contains the spark.* keys that are meant for SparkConf object. + for (String hiveSparkConfKey : RpcConfiguration.HIVE_SPARK_RSC_CONFIGS) { + String value = RpcConfiguration.getValue(hiveConf, hiveSparkConfKey); + if (value != null) { + sparkLauncher.addAppArgs(RemoteDriver.CONF_KEY, String.format("%s=%s", hiveSparkConfKey, value)); + } } - Thread thread = new Thread(runnable); - thread.setDaemon(true); - thread.setName("Driver"); - thread.start(); - return thread; + LOG.info("Running client driver"); + return sparkLauncher.startApplication(); + } + + private void runKinit(String principal, String keyTabFile) throws IOException { + List kinitArgv = Lists.newLinkedList(); + kinitArgv.add("kinit"); + kinitArgv.add(principal); + kinitArgv.add("-k"); + kinitArgv.add("-t"); + kinitArgv.add(keyTabFile + ";"); + new ProcessRunner(LOG).run(kinitArgv); } private String getSparkJobCredentialProviderPassword() { @@ -549,6 +419,49 @@ private String getSparkJobCredentialProviderPassword() { return null; } + /** + * Performs a safe shutdown of a Spark job by using the {@link SparkAppHandle}. It first adds a custom + * {@link SparkAppHandle.Listener} that sends a signal to the current thread if the Spark {@link SparkAppHandle.State} + * changes from running to stopped. Then it invokes {@link SparkAppHandle#stop()} so that the Spark job can perform + * a safe shutdown. Then it waits to see if the listener has sent the signal for the state change from running to + * killed. It waits for {@link #DEFAULT_SHUTDOWN_TIMEOUT} milliseconds. If the Spark job hasn't shutdown then + * {@link SparkAppHandle#kill()} is invoked to perform a forced shutdown of the Spark application. + */ + private void shutdownSparkHandle() { + Set notRunningSparkStates = Sets.newHashSet(SparkAppHandle.State.FAILED, + SparkAppHandle.State.FINISHED, SparkAppHandle.State.KILLED, SparkAppHandle.State.LOST); + + if (!notRunningSparkStates.contains(this.sparkAppHandle.getState())) { + final CountDownLatch shutdownLatch = new CountDownLatch(1); + + this.sparkAppHandle.addListener(new SparkAppHandle.Listener() { + @Override + public void stateChanged(SparkAppHandle sparkAppHandle) { + if (notRunningSparkStates.contains(sparkAppHandle.getState())) { + shutdownLatch.countDown(); + } + } + + @Override + public void infoChanged(SparkAppHandle sparkAppHandle) { + // Do nothing + } + }); + + if (!notRunningSparkStates.contains(this.sparkAppHandle.getState())) { + this.sparkAppHandle.stop(); + try { + if (!shutdownLatch.await(DEFAULT_SHUTDOWN_TIMEOUT, TimeUnit.MILLISECONDS)) { + this.sparkAppHandle.kill(); + } + } catch (InterruptedException e) { + LOG.warn("Interrupted while waiting for Spark job to shutdown", e); + Thread.currentThread().interrupt(); + } + } + } + } + private class ClientProtocol extends BaseProtocol { JobHandleImpl submit(Job job, List> listeners) { diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/RemoteDriverLocalRunner.java b/spark-client/src/test/java/org/apache/hive/spark/client/RemoteDriverLocalRunner.java new file mode 100644 index 0000000000..e12720a18a --- /dev/null +++ b/spark-client/src/test/java/org/apache/hive/spark/client/RemoteDriverLocalRunner.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hive.spark.client; + +import java.util.List; +import java.util.Map; + +import com.google.common.collect.Lists; + +import org.apache.spark.launcher.SparkAppHandle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Runs the {@link RemoteDriver} locally inside a dedicated thread. Implements the {@link SparkAppHandle} interface + * so that the driver can be stopped and its status can be queried. This class is mainly used for testing, it should + * not be used in a production environment. + */ +class RemoteDriverLocalRunner implements SparkAppHandle { + + private static final Logger LOG = LoggerFactory.getLogger(RemoteDriverLocalRunner.class); + + private Thread thread; + + RemoteDriverLocalRunner(Map conf, String serverAddress, String serverPort, + String clientId, String secret) { + thread = new Thread(new Runnable() { + @Override + public void run() { + List args = Lists.newArrayList(); + args.add("--remote-host"); + args.add(serverAddress); + args.add("--remote-port"); + args.add(serverPort); + args.add("--client-id"); + args.add(clientId); + args.add("--secret"); + args.add(secret); + + for (Map.Entry e : conf.entrySet()) { + args.add("--conf"); + args.add(String.format("%s=%s", e.getKey(), conf.get(e.getKey()))); + } + try { + RemoteDriver.main(args.toArray(new String[args.size()])); + } catch (Exception e) { + LOG.error("Error running driver.", e); + } + } + + }); + thread.setDaemon(true); + thread.setName("Driver"); + } + + @Override + public void addListener(Listener listener) { + // Do nothing + } + + @Override + public State getState() { + return null; + } + + @Override + public String getAppId() { + return null; + } + + @Override + public void stop() { + try { + thread.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void kill() { + thread.interrupt(); + } + + @Override + public void disconnect() { + try { + thread.join(); + } catch (InterruptedException e) { + LOG.error("RemoteDriver thread failed to stop", e); + Thread.currentThread().interrupt(); + } + } + + void run() { + thread.start(); + } +} diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/SparkClientTest.java b/spark-client/src/test/java/org/apache/hive/spark/client/SparkClientTest.java new file mode 100644 index 0000000000..057c96e710 --- /dev/null +++ b/spark-client/src/test/java/org/apache/hive/spark/client/SparkClientTest.java @@ -0,0 +1,33 @@ +package org.apache.hive.spark.client; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.spark.client.rpc.RpcServer; +import org.apache.spark.SparkException; +import org.apache.spark.launcher.SparkAppHandle; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; + +class SparkClientTest extends SparkClientImpl { + + private static final Logger LOG = LoggerFactory.getLogger(SparkClientImpl.class); + + SparkClientTest(RpcServer rpcServer, Map conf, HiveConf hiveConf, + String sessionId) throws IOException, InterruptedException { + super(rpcServer, conf, hiveConf, sessionId); + } + + @Override + protected SparkAppHandle startDriver(final String clientId, final String secret, final String serverAddress, + final String serverPort) { + // Mostly for testing things quickly. Do not do this in production. + // when invoked in-process it inherits the environment variables of the parent + LOG.warn("!!!! Running remote driver in-process. !!!!"); + RemoteDriverLocalRunner remoteDriverLocalRunner = new RemoteDriverLocalRunner(conf, serverAddress, serverPort, + clientId, secret); + remoteDriverLocalRunner.run(); + return remoteDriverLocalRunner; + } +} diff --git a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java index 697d8d144d..a16f7cdfff 100644 --- a/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java +++ b/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java @@ -48,17 +48,21 @@ import java.util.zip.ZipEntry; import com.google.common.base.Objects; -import com.google.common.base.Strings; import com.google.common.io.ByteStreams; + import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hive.spark.client.rpc.RpcServer; import org.apache.hive.spark.counter.SparkCounters; + import org.apache.spark.SparkException; import org.apache.spark.SparkFiles; import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.VoidFunction; + import org.junit.Test; + import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -68,30 +72,16 @@ private static final long TIMEOUT = 20; private static final HiveConf HIVECONF = new HiveConf(); - private Map createConf(boolean local) { + private Map createConf() { Map conf = new HashMap(); - if (local) { - conf.put(SparkClientFactory.CONF_KEY_IN_PROCESS, "true"); - conf.put("spark.master", "local"); - conf.put("spark.app.name", "SparkClientSuite Local App"); - } else { - String classpath = System.getProperty("java.class.path"); - conf.put("spark.master", "local"); - conf.put("spark.app.name", "SparkClientSuite Remote App"); - conf.put("spark.driver.extraClassPath", classpath); - conf.put("spark.executor.extraClassPath", classpath); - } - - if (!Strings.isNullOrEmpty(System.getProperty("spark.home"))) { - conf.put("spark.home", System.getProperty("spark.home")); - } - + conf.put("spark.master", "local"); + conf.put("spark.app.name", "SparkClientSuite Local App"); return conf; } @Test public void testJobSubmission() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); @@ -112,7 +102,7 @@ public void call(SparkClient client) throws Exception { @Test public void testSimpleSparkJob() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle handle = client.submit(new SparkJob()); @@ -123,7 +113,7 @@ public void call(SparkClient client) throws Exception { @Test public void testErrorJob() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); @@ -151,7 +141,7 @@ public void call(SparkClient client) throws Exception { @Test public void testSyncRpc() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { Future result = client.run(new SyncRpc()); @@ -162,7 +152,7 @@ public void call(SparkClient client) throws Exception { @Test public void testRemoteClient() throws Exception { - runTest(false, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle handle = client.submit(new SparkJob()); @@ -173,7 +163,7 @@ public void call(SparkClient client) throws Exception { @Test public void testMetricsCollection() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle.Listener listener = newListener(); @@ -202,7 +192,7 @@ public void call(SparkClient client) throws Exception { @Test public void testAddJarsAndFiles() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { File jar = null; @@ -256,7 +246,7 @@ public void call(SparkClient client) throws Exception { @Test public void testCounters() throws Exception { - runTest(true, new TestFunction() { + runTest(new TestFunction() { @Override public void call(SparkClient client) throws Exception { JobHandle job = client.submit(new CounterIncrementJob()); @@ -308,19 +298,18 @@ public Void answer(InvocationOnMock invocation) throws Throwable { }).when(listener); } - private void runTest(boolean local, TestFunction test) throws Exception { - Map conf = createConf(local); - SparkClientFactory.initialize(conf); + private void runTest(TestFunction test) throws Exception { + Map conf = createConf(); SparkClient client = null; - try { + try (RpcServer rpcServer = new RpcServer(conf)) { test.config(conf); - client = SparkClientFactory.createClient(conf, HIVECONF, UUID.randomUUID().toString()); + client = new SparkClientTest(rpcServer, conf, HIVECONF, UUID.randomUUID().toString()); test.call(client); } finally { if (client != null) { client.stop(); } - SparkClientFactory.stop(); + } }