diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 8724930..797568b 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -3342,6 +3342,9 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal "If this is set to true, mapjoin optimization in Hive/Spark will use statistics from\n" + "TableScan operators at the root of operator tree, instead of parent ReduceSink\n" + "operators of the Join operator."), + SPARK_OPTIMIZE_SHUFFLE_SERDE("hive.spark.optimize.shuffle.serde", true, + "If this is set to true, Hive on Spark will register custom serializers for data types\n" + + "in shuffle. This should result in less shuffled data."), SPARK_CLIENT_FUTURE_TIMEOUT("hive.spark.client.future.timeout", "60s", new TimeValidator(TimeUnit.SECONDS), "Timeout for requests from Hive client to remote Spark driver."), diff --git a/itests/src/test/resources/testconfiguration.properties b/itests/src/test/resources/testconfiguration.properties index 772113a..5401d72 100644 --- a/itests/src/test/resources/testconfiguration.properties +++ b/itests/src/test/resources/testconfiguration.properties @@ -1406,7 +1406,8 @@ miniSparkOnYarn.only.query.files=spark_combine_equivalent_work.q,\ spark_vectorized_dynamic_partition_pruning.q,\ spark_use_ts_stats_for_mapjoin.q,\ spark_use_op_stats.q,\ - spark_explain_groupbyshuffle.q + spark_explain_groupbyshuffle.q,\ + spark_opt_shuffle_serde.q miniSparkOnYarn.query.files=auto_sortmerge_join_16.q,\ bucket4.q,\ diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveSparkClientFactory.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveSparkClientFactory.java index 6e9ba7c..f683303 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveSparkClientFactory.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveSparkClientFactory.java @@ -29,7 +29,6 @@ import org.apache.commons.compress.utils.CharsetNames; import org.apache.hadoop.fs.CommonConfigurationKeysPublic; import org.apache.hadoop.hive.common.LogUtils; -import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hive.spark.client.SparkClientUtilities; import org.slf4j.Logger; @@ -198,13 +197,20 @@ public static HiveSparkClient createHiveSparkClient(HiveConf hiveconf) throws Ex } } + final boolean optShuffleSerDe = hiveConf.getBoolVar( + HiveConf.ConfVars.SPARK_OPTIMIZE_SHUFFLE_SERDE); + Set classes = Sets.newHashSet( - Splitter.on(",").trimResults().omitEmptyStrings().split( - Strings.nullToEmpty(sparkConf.get("spark.kryo.classesToRegister")))); + Splitter.on(",").trimResults().omitEmptyStrings().split( + Strings.nullToEmpty(sparkConf.get("spark.kryo.classesToRegister")))); classes.add(Writable.class.getName()); classes.add(VectorizedRowBatch.class.getName()); - classes.add(BytesWritable.class.getName()); - classes.add(HiveKey.class.getName()); + if (!optShuffleSerDe) { + classes.add(HiveKey.class.getName()); + classes.add(BytesWritable.class.getName()); + } else { + sparkConf.put("spark.kryo.registrator", SparkClientUtilities.HIVE_KRYO_REG_FULLNAME); + } sparkConf.put("spark.kryo.classesToRegister", Joiner.on(",").join(classes)); // set yarn queue name diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java index beeafd0..ec90843 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/LocalHiveSparkClient.java @@ -18,10 +18,17 @@ package org.apache.hadoop.hive.ql.exec.spark; +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import com.google.common.io.Files; +import org.apache.commons.io.FileUtils; +import org.apache.hive.spark.client.SparkClientUtilities; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.fs.FileSystem; @@ -66,7 +73,8 @@ private static LocalHiveSparkClient client; - public static synchronized LocalHiveSparkClient getInstance(SparkConf sparkConf) { + public static synchronized LocalHiveSparkClient getInstance(SparkConf sparkConf) + throws IOException, ExecutionException, InterruptedException, URISyntaxException { if (client == null) { client = new LocalHiveSparkClient(sparkConf); } @@ -81,8 +89,20 @@ public static synchronized LocalHiveSparkClient getInstance(SparkConf sparkConf) private final JobMetricsListener jobMetricsListener; - private LocalHiveSparkClient(SparkConf sparkConf) { + private File localTmpDir = null; + + private LocalHiveSparkClient(SparkConf sparkConf) + throws IOException, ExecutionException, InterruptedException, URISyntaxException { + File regJar = null; + if (SparkClientUtilities.needKryoRegJar(sparkConf)) { + localTmpDir = Files.createTempDir(); + regJar = SparkClientUtilities.getKryoRegistratorJar(localTmpDir); + SparkClientUtilities.addJarToContextLoader(regJar); + } sc = new JavaSparkContext(sparkConf); + if (regJar != null) { + sc.addJar(regJar.getPath()); + } jobMetricsListener = new JobMetricsListener(); sc.sc().listenerBus().addListener(jobMetricsListener); } @@ -217,5 +237,13 @@ public void close() { if (sc != null) { sc.stop(); } + if (localTmpDir != null) { + try { + FileUtils.deleteDirectory(localTmpDir); + } catch (IOException e) { + LOG.warn("Failed to delete local tmp dir " + localTmpDir, e); + } + } + localTmpDir = null; } } diff --git a/ql/src/main/resources/spark/HiveKryoRegistrator.java b/ql/src/main/resources/spark/HiveKryoRegistrator.java new file mode 100644 index 0000000..2f5e525 --- /dev/null +++ b/ql/src/main/resources/spark/HiveKryoRegistrator.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.spark; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.serializer.KryoRegistrator; +import org.apache.hadoop.hive.ql.io.HiveKey; +import org.apache.hadoop.io.BytesWritable; + +public class HiveKryoRegistrator implements KryoRegistrator { + @Override + public void registerClasses(Kryo kryo) { + kryo.register(HiveKey.class, new HiveKeySerializer()); + kryo.register(BytesWritable.class, new BytesWritableSerializer()); + } + + private static class HiveKeySerializer extends Serializer { + + public void write(Kryo kryo, Output output, HiveKey object) { + output.writeVarInt(object.getLength(), true); + output.write(object.getBytes(), 0, object.getLength()); + output.writeVarInt(object.hashCode(), false); + } + + public HiveKey read(Kryo kryo, Input input, Class type) { + int len = input.readVarInt(true); + byte[] bytes = new byte[len]; + input.readBytes(bytes); + return new HiveKey(bytes, input.readVarInt(false)); + } + } + + private static class BytesWritableSerializer extends Serializer { + + public void write(Kryo kryo, Output output, BytesWritable object) { + output.writeVarInt(object.getLength(), true); + output.write(object.getBytes(), 0, object.getLength()); + } + + public BytesWritable read(Kryo kryo, Input input, Class type) { + int len = input.readVarInt(true); + byte[] bytes = new byte[len]; + input.readBytes(bytes); + return new BytesWritable(bytes); + } + + } +} diff --git a/ql/src/test/queries/clientpositive/spark_opt_shuffle_serde.q b/ql/src/test/queries/clientpositive/spark_opt_shuffle_serde.q new file mode 100644 index 0000000..2c4691a --- /dev/null +++ b/ql/src/test/queries/clientpositive/spark_opt_shuffle_serde.q @@ -0,0 +1,7 @@ +set hive.spark.optimize.shuffle.serde=true; + +set hive.spark.use.groupby.shuffle=true; +select key, count(*) from src group by key order by key limit 100; + +set hive.spark.use.groupby.shuffle=false; +select key, count(*) from src group by key order by key limit 100; diff --git a/ql/src/test/results/clientpositive/spark/spark_opt_shuffle_serde.q.out b/ql/src/test/results/clientpositive/spark/spark_opt_shuffle_serde.q.out new file mode 100644 index 0000000..cd9c7bc --- /dev/null +++ b/ql/src/test/results/clientpositive/spark/spark_opt_shuffle_serde.q.out @@ -0,0 +1,216 @@ +PREHOOK: query: select key, count(*) from src group by key order by key limit 100 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select key, count(*) from src group by key order by key limit 100 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 3 +10 1 +100 2 +103 2 +104 2 +105 1 +11 1 +111 1 +113 2 +114 1 +116 1 +118 2 +119 3 +12 2 +120 2 +125 2 +126 1 +128 3 +129 2 +131 1 +133 1 +134 2 +136 1 +137 2 +138 4 +143 1 +145 1 +146 2 +149 2 +15 2 +150 1 +152 2 +153 1 +155 1 +156 1 +157 1 +158 1 +160 1 +162 1 +163 1 +164 2 +165 2 +166 1 +167 3 +168 1 +169 4 +17 1 +170 1 +172 2 +174 2 +175 2 +176 2 +177 1 +178 1 +179 2 +18 2 +180 1 +181 1 +183 1 +186 1 +187 3 +189 1 +19 1 +190 1 +191 2 +192 1 +193 3 +194 1 +195 2 +196 1 +197 2 +199 3 +2 1 +20 1 +200 2 +201 1 +202 1 +203 2 +205 2 +207 2 +208 3 +209 2 +213 2 +214 1 +216 2 +217 2 +218 1 +219 2 +221 2 +222 1 +223 2 +224 2 +226 1 +228 1 +229 2 +230 5 +233 2 +235 1 +237 2 +238 2 +PREHOOK: query: select key, count(*) from src group by key order by key limit 100 +PREHOOK: type: QUERY +PREHOOK: Input: default@src +#### A masked pattern was here #### +POSTHOOK: query: select key, count(*) from src group by key order by key limit 100 +POSTHOOK: type: QUERY +POSTHOOK: Input: default@src +#### A masked pattern was here #### +0 3 +10 1 +100 2 +103 2 +104 2 +105 1 +11 1 +111 1 +113 2 +114 1 +116 1 +118 2 +119 3 +12 2 +120 2 +125 2 +126 1 +128 3 +129 2 +131 1 +133 1 +134 2 +136 1 +137 2 +138 4 +143 1 +145 1 +146 2 +149 2 +15 2 +150 1 +152 2 +153 1 +155 1 +156 1 +157 1 +158 1 +160 1 +162 1 +163 1 +164 2 +165 2 +166 1 +167 3 +168 1 +169 4 +17 1 +170 1 +172 2 +174 2 +175 2 +176 2 +177 1 +178 1 +179 2 +18 2 +180 1 +181 1 +183 1 +186 1 +187 3 +189 1 +19 1 +190 1 +191 2 +192 1 +193 3 +194 1 +195 2 +196 1 +197 2 +199 3 +2 1 +20 1 +200 2 +201 1 +202 1 +203 2 +205 2 +207 2 +208 3 +209 2 +213 2 +214 1 +216 2 +217 2 +218 1 +219 2 +221 2 +222 1 +223 2 +224 2 +226 1 +228 1 +229 2 +230 5 +233 2 +235 1 +237 2 +238 2 diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java index ede8ce9..65597d5 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/RemoteDriver.java @@ -116,6 +116,15 @@ private RemoteDriver(String[] args) throws Exception { } } + // Since kryo is relocated in Hive, we have to compile the registrator class at runtime + // Do this before setting up Rpc and SparkContext, because this may change thread + // context class loader. see HIVE-15104 + File regJar = null; + if (SparkClientUtilities.needKryoRegJar(conf)) { + regJar = SparkClientUtilities.getKryoRegistratorJar(localTmpDir); + SparkClientUtilities.addJarToContextLoader(regJar); + } + executor = Executors.newCachedThreadPool(); LOG.info("Connecting to: {}:{}", serverAddress, serverPort); @@ -155,6 +164,19 @@ public void rpcClosed(Rpc rpc) { try { JavaSparkContext sc = new JavaSparkContext(conf); + if (regJar != null) { + File target; + String master = conf.get("spark.master", ""); + String deploy = conf.get("spark.submit.deployMode", ""); + // in yarn-cluster mode, need to mv the jar to driver's working dir + if (SparkClientUtilities.isYarnClusterMode(master, deploy)) { + target = new File(regJar.getName()); + Files.move(regJar, target); + } else { + target = regJar; + } + sc.addJar(target.getPath()); + } sc.sc().addSparkListener(new ClientListener()); synchronized (jcLock) { jc = new JobContextImpl(sc, localTmpDir); diff --git a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientUtilities.java b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientUtilities.java index 210da2a..91784ac 100644 --- a/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientUtilities.java +++ b/spark-client/src/main/java/org/apache/hive/spark/client/SparkClientUtilities.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

* Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,28 +18,58 @@ package org.apache.hive.spark.client; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URISyntaxException; import java.net.URL; import java.net.URLClassLoader; +import java.net.URLDecoder; +import java.nio.charset.Charset; +import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; +import com.google.common.io.ByteStreams; +import com.google.common.io.Resources; import org.apache.commons.lang.StringUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.util.MutableURLClassLoader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import javax.tools.JavaCompiler; +import javax.tools.JavaFileObject; +import javax.tools.SimpleJavaFileObject; +import javax.tools.ToolProvider; + public class SparkClientUtilities { protected static final transient Logger LOG = LoggerFactory.getLogger(SparkClientUtilities.class); private static final Map downloadedFiles = new ConcurrentHashMap<>(); + private static final String REG_PACK_NAME = "org.apache.hadoop.hive.ql.exec.spark"; + private static final String HIVE_KRYO_REG = "HiveKryoRegistrator"; + public static final String HIVE_KRYO_REG_FULLNAME = REG_PACK_NAME + "." + HIVE_KRYO_REG; + /** * Add new elements to the classpath. * @@ -74,7 +104,8 @@ /** * Create a URL from a string representing a path to a local file. * The path string can be just a path, or can start with file:/, file:/// - * @param path path string + * + * @param path path string * @return */ private static URL urlFromPathString(String path, Long timeStamp, @@ -136,4 +167,114 @@ public static String getDeployModeFromMaster(String master) { } return null; } + + public static boolean needKryoRegJar(SparkConf sparkConf) { + String registrators = sparkConf.get("spark.kryo.registrator", ""); + return registrators != null && (registrators.contains(HIVE_KRYO_REG_FULLNAME)); + } + + // copied from Utilities + private static String jarFinderGetJar(Class klass) { + Preconditions.checkNotNull(klass, "klass"); + ClassLoader loader = klass.getClassLoader(); + if (loader != null) { + String class_file = klass.getName().replaceAll("\\.", "/") + ".class"; + try { + for (Enumeration itr = loader.getResources(class_file); itr.hasMoreElements(); ) { + URL url = (URL) itr.nextElement(); + String path = url.getPath(); + if (path.startsWith("file:")) { + path = path.substring("file:".length()); + } + path = URLDecoder.decode(path, "UTF-8"); + if ("jar".equals(url.getProtocol())) { + path = URLDecoder.decode(path, "UTF-8"); + return path.replaceAll("!.*$", ""); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return null; + } + + public static File getKryoRegistratorJar(File tmpDir) + throws IOException, ExecutionException, InterruptedException, URISyntaxException { + final File srcDir = new File(tmpDir, "src_" + UUID.randomUUID()); + Preconditions.checkState(srcDir.mkdir(), "Failed to create tmp dir for source files."); + List options = new ArrayList<>(); + options.add("-d"); + options.add(srcDir.getPath()); + + String cp = System.getProperty("java.class.path") + System.getProperty("path.separator") + + jarFinderGetJar(SparkClientUtilities.class); + options.add("-classpath"); + options.add(cp); + + List sourceFiles = Collections.singletonList("/spark/" + HIVE_KRYO_REG + ".java"); + + Preconditions.checkState(compileSource(sourceFiles, options).call(), + "Failed to compile sources."); + + java.nio.file.Path prefix = Paths.get("", REG_PACK_NAME.split("\\.")); + File output = new File(srcDir, prefix.toString()); + Preconditions.checkArgument(output.isDirectory()); + + File jarFile = new File(tmpDir, "kryo-registrator.jar"); + buildJar(Arrays.asList(output.listFiles()), jarFile, prefix.toString()); + return jarFile; + } + + private static JavaCompiler.CompilationTask compileSource(List fileNames, + Iterable options) throws IOException, URISyntaxException { + List sources = new ArrayList<>(fileNames.size()); + for (String fileName : fileNames) { + String code = Resources.toString(Resources.getResource(SparkClientUtilities.class, fileName), + Charset.defaultCharset()); + sources.add(new JavaSourceFromFile(new File(fileName).toURI(), code)); + } + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + return compiler.getTask(null, null, null, options, null, sources); + } + + private static void buildJar(List classFiles, File jarFile, String prefix) + throws IOException { + try (FileOutputStream jarFileStream = new FileOutputStream(jarFile); + JarOutputStream jarStream = new JarOutputStream( + jarFileStream, new java.util.jar.Manifest())) { + for (File file : classFiles) { + JarEntry jarEntry = new JarEntry(Paths.get(prefix, file.getName()).toString()); + jarStream.putNextEntry(jarEntry); + try (FileInputStream in = new FileInputStream(file)) { + ByteStreams.copy(in, jarStream); + } + } + } + } + + private static class JavaSourceFromFile extends SimpleJavaFileObject { + private final String code; + + JavaSourceFromFile(URI uri, String code) throws IOException { + super(uri, Kind.SOURCE); + this.code = code; + } + + @Override + public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException { + return code; + } + } + + public static void addJarToContextLoader(File jar) throws MalformedURLException { + ClassLoader loader = Thread.currentThread().getContextClassLoader(); + if (loader instanceof MutableURLClassLoader) { + ((MutableURLClassLoader) loader).addURL(jar.toURI().toURL()); + } else { + URLClassLoader newLoader = + new URLClassLoader(new URL[]{jar.toURI().toURL()}, loader); + Thread.currentThread().setContextClassLoader(newLoader); + } + } }