diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/IdentityTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/IdentityTran.java index 6c3cf2f..9d2a89b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/IdentityTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/IdentityTran.java @@ -18,15 +18,13 @@ package org.apache.hadoop.hive.ql.exec.spark; -import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; -public class IdentityTran implements SparkTran { +public class IdentityTran implements SparkTran { @Override - public JavaPairRDD transform( - JavaPairRDD input) { + public JavaPairRDD transform(JavaPairRDD input) { return input; } } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java index 0732e06..8e778c6 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapInput.java @@ -24,7 +24,7 @@ import com.google.common.base.Preconditions; -public class MapInput implements SparkTran { +public class MapInput implements SparkTran { private JavaPairRDD hadoopRDD; private boolean toCache; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java index e62527c..91d5bdc 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java @@ -22,7 +22,7 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; -public class MapTran implements SparkTran { +public class MapTran implements SparkTran { private HiveMapFunction mapFunc; @Override diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java index 52ac724..f4fa8ed 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java @@ -22,27 +22,16 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; -public class ReduceTran implements SparkTran { - private SparkShuffler shuffler; +public class ReduceTran implements SparkTran, BytesWritable> { private HiveReduceFunction reduceFunc; - private int numPartitions; @Override public JavaPairRDD transform( - JavaPairRDD input) { - return shuffler.shuffle(input, numPartitions).mapPartitionsToPair(reduceFunc); + JavaPairRDD> input) { + return input.mapPartitionsToPair(reduceFunc); } public void setReduceFunction(HiveReduceFunction redFunc) { this.reduceFunc = redFunc; } - - public void setShuffler(SparkShuffler shuffler) { - this.shuffler = shuffler; - } - - public void setNumPartitions(int numPartitions) { - this.numPartitions = numPartitions; - } - } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java new file mode 100644 index 0000000..0f07205 --- /dev/null +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ShuffleTran.java @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.spark; + +import org.apache.hadoop.hive.ql.io.HiveKey; +import org.apache.hadoop.io.BytesWritable; +import org.apache.spark.api.java.JavaPairRDD; + +public class ShuffleTran implements SparkTran> { + private final SparkShuffler shuffler; + private final int numOfPartitions; + + public ShuffleTran(SparkShuffler sf, int n) { + shuffler = sf; + numOfPartitions = n; + } + @Override + public JavaPairRDD> transform(JavaPairRDD input) { + return shuffler.shuffle(input, numOfPartitions); + } +} diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java index 8e251df..199dd32 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java @@ -82,13 +82,23 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { if (work instanceof MapWork) { MapInput mapInput = generateMapInput((MapWork)work); sparkPlan.addTran(mapInput); - tran = generate(work, null); + tran = generate((MapWork)work); sparkPlan.addTran(tran); sparkPlan.connect(mapInput, tran); - } else { + } else if (work instanceof ReduceWork) { List parentWorks = sparkWork.getParents(work); - tran = generate(work, sparkWork.getEdgeProperty(parentWorks.get(0), work)); + tran = generate((ReduceWork)work); + ShuffleTran shuffleTran = generate(sparkWork.getEdgeProperty(parentWorks.get(0),work)); sparkPlan.addTran(tran); + sparkPlan.addTran(shuffleTran); + sparkPlan.connect(shuffleTran, tran); + for (BaseWork parentWork : parentWorks) { + SparkTran parentTran = workToTranMap.get(parentWork); + sparkPlan.connect(parentTran, shuffleTran); + } + } else { + List parentWorks = sparkWork.getParents(work); + tran = new IdentityTran(); for (BaseWork parentWork : parentWorks) { SparkTran parentTran = workToTranMap.get(parentWork); sparkPlan.connect(parentTran, tran); @@ -129,24 +139,6 @@ private Class getInputFormat(JobConf jobConf, MapWork mWork) throws HiveExceptio return inputFormatClass; } - public SparkTran generate(BaseWork work, SparkEdgeProperty edge) throws Exception { - if (work instanceof MapWork) { - MapWork mw = (MapWork) work; - return generate(mw); - } else if (work instanceof ReduceWork) { - ReduceWork rw = (ReduceWork) work; - ReduceTran tran = generate(rw); - SparkShuffler shuffler = generate(edge); - tran.setShuffler(shuffler); - tran.setNumPartitions(edge.getNumPartitions()); - return tran; - } else if (work instanceof UnionWork) { - return new IdentityTran(); - } else { - throw new HiveException("Unexpected work: " + work.getName()); - } - } - private MapInput generateMapInput(MapWork mapWork) throws Exception { JobConf jobConf = cloneJobConf(mapWork); @@ -157,15 +149,18 @@ private MapInput generateMapInput(MapWork mapWork) return new MapInput(hadoopRDD); } - private SparkShuffler generate(SparkEdgeProperty edge) { + private ShuffleTran generate(SparkEdgeProperty edge) { Preconditions.checkArgument(!edge.isShuffleNone(), "AssertionError: SHUFFLE_NONE should only be used for UnionWork."); + SparkShuffler shuffler; if (edge.isMRShuffle()) { - return new SortByShuffler(false); + shuffler = new SortByShuffler(false); } else if (edge.isShuffleSort()) { - return new SortByShuffler(true); + shuffler = new SortByShuffler(true); + } else { + shuffler = new GroupByShuffler(); } - return new GroupByShuffler(); + return new ShuffleTran(shuffler, edge.getNumPartitions()); } private MapTran generate(MapWork mw) throws Exception { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java index e770158..17ec257 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java @@ -21,7 +21,7 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; -public interface SparkTran { - JavaPairRDD transform( - JavaPairRDD input); +public interface SparkTran { + JavaPairRDD transform( + JavaPairRDD input); }