diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GraphTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GraphTran.java index 5d4414a..ef66104 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GraphTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GraphTran.java @@ -18,48 +18,43 @@ package org.apache.hadoop.hive.ql.exec.spark; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; +import com.google.common.base.Preconditions; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; -public class GraphTran { +public class GraphTran implements SparkTran { - private final Set rootTrans = new HashSet(); + private final List rootTrans = new LinkedList(); private final Set leafTrans = new HashSet(); private final Map> transGraph = new HashMap>(); private final Map> invertedTransGraph = new HashMap>(); private final Map>> unionInputs = new HashMap>>(); - private final Map> mapInputs = new HashMap>(); - public void addRootTranWithInput(SparkTran tran, JavaPairRDD input) { + public void addRootTranWithInput(SparkTran tran) { if (!rootTrans.contains(tran)) { rootTrans.add(tran); leafTrans.add(tran); transGraph.put(tran, new LinkedList()); invertedTransGraph.put(tran, new LinkedList()); } - if (input != null) { - mapInputs.put(tran, input); - } } - public void execute() throws IllegalStateException { + public JavaPairRDD transform(JavaPairRDD... inputs) { Map> resultRDDs = new HashMap>(); - for (SparkTran tran : rootTrans) { + for (int i = 0; i < rootTrans.size(); i++) { // make sure all the root trans are MapTran + SparkTran tran = rootTrans.get(i); + JavaPairRDD input = inputs[i]; + if (!(tran instanceof MapTran)) { throw new IllegalStateException("root transformations must be MapTran!"); } - JavaPairRDD input = mapInputs.get(tran); if (input == null) { throw new IllegalStateException("input is missing for transformation!"); } + JavaPairRDD rdd = tran.transform(input); while (getChildren(tran).size() > 0) { @@ -79,10 +74,9 @@ public void execute() throws IllegalStateException { break; } else if (unionInputList.size() == this.getParents(childTran).size() - 1) { // process // process the last input RDD - for (JavaPairRDD inputRDD : unionInputList) { - ((UnionTran) childTran).setOtherInput(inputRDD); - rdd = childTran.transform(rdd); - } + UnionTran unionTran = (UnionTran) childTran; + unionInputList.add(rdd); + rdd = unionTran.transform(unionInputList.toArray(new JavaPairRDD[unionInputList.size()])); } } else { rdd = childTran.transform(rdd); @@ -94,9 +88,10 @@ public void execute() throws IllegalStateException { resultRDDs.put(tran, rdd); } } - for (JavaPairRDD resultRDD : resultRDDs.values()) { - resultRDD.foreach(HiveVoidFunction.getInstance()); - } + + Preconditions.checkArgument(resultRDDs.size() == 1, + "AssertionError: resultRDD should only contain 1 entry at end."); + return resultRDDs.values().iterator().next(); } /** diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java index b03a51c..b13956b 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.exec.spark; +import com.google.common.base.Preconditions; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; @@ -26,8 +27,10 @@ @Override public JavaPairRDD transform( - JavaPairRDD input) { - return input.mapPartitionsToPair(mapFunc); + JavaPairRDD... input) { + Preconditions.checkArgument(input.length == 1, + "AssertionError: MapTran should only take 1 input"); + return input[0].mapPartitionsToPair(mapFunc); } public void setMapFunction(HiveMapFunction mapFunc) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java index 76b74e7..d9d1653 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.exec.spark; +import com.clearspring.analytics.util.Preconditions; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; @@ -28,8 +29,10 @@ @Override public JavaPairRDD transform( - JavaPairRDD input) { - return shuffler.shuffle(input, numPartitions).mapPartitionsToPair(reduceFunc); + JavaPairRDD... input) { + Preconditions.checkArgument(input.length == 1, + "AssertionError: ReduceTran should only take 1 input"); + return shuffler.shuffle(input[0], numPartitions).mapPartitionsToPair(reduceFunc); } public void setReduceFunction(HiveReduceFunction redFunc) { diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java index 46e4b6d..2089267 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlan.java @@ -18,12 +18,19 @@ package org.apache.hadoop.hive.ql.exec.spark; +import org.apache.hadoop.io.BytesWritable; +import org.apache.spark.api.java.JavaPairRDD; + +import java.util.List; + public class SparkPlan { private GraphTran tran; + private List> inputs; public void execute() throws Exception { - tran.execute(); + JavaPairRDD finalRDD = tran.transform(inputs.toArray(new JavaPairRDD[inputs.size()])); + finalRDD.foreach(HiveVoidFunction.getInstance()); } public void setTran(GraphTran tran) { @@ -33,4 +40,13 @@ public void setTran(GraphTran tran) { public GraphTran getTran() { return tran; } + + public void setInputs(List> inputs) { + this.inputs = inputs; + } + + public List> getInputs() { + return inputs; + } + } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java index 9b11fe4..cd54e85 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java @@ -18,11 +18,7 @@ package org.apache.hadoop.hive.ql.exec.spark; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import com.google.common.base.Preconditions; import org.apache.commons.lang.StringUtils; @@ -77,6 +73,8 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { SparkPlan plan = new SparkPlan(); GraphTran trans = new GraphTran(); Set roots = sparkWork.getRoots(); + List> inputs = new ArrayList>(); + for (BaseWork w : roots) { if (!(w instanceof MapWork)) { throw new Exception( @@ -86,13 +84,15 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { JobConf newJobConf = cloneJobConf(mapWork); SparkTran tran = generate(newJobConf, mapWork); JavaPairRDD input = generateRDD(newJobConf, mapWork); - trans.addRootTranWithInput(tran, input); + trans.addRootTranWithInput(tran); + inputs.add(input); while (sparkWork.getChildren(w).size() > 0) { BaseWork child = sparkWork.getChildren(w).get(0); + SparkTran childTran = childWorkTrans.get(child); if (child instanceof ReduceWork) { - ReduceTran rt = null; + ReduceTran rt; if (((ReduceWork) child).getReducer() instanceof JoinOperator) { // Reduce-side join operator: The strategy to insert a UnionTran (UT) to union the output // of the two separate input map-trans (MT), which are then shuffled to the appropriate partition @@ -141,6 +141,7 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { } childWorkTrans.clear(); plan.setTran(trans); + plan.setInputs(inputs); return plan; } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java index 19894b0..b74aaae 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkTran.java @@ -23,5 +23,5 @@ public interface SparkTran { JavaPairRDD transform( - JavaPairRDD input); + JavaPairRDD... input); } diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/UnionTran.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/UnionTran.java index 5ec7d0f..569ee3a 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/UnionTran.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/spark/UnionTran.java @@ -18,23 +18,21 @@ package org.apache.hadoop.hive.ql.exec.spark; +import com.clearspring.analytics.util.Preconditions; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; public class UnionTran implements SparkTran { - JavaPairRDD otherInput; @Override public JavaPairRDD transform( - JavaPairRDD input) { - return input.union(otherInput); - } - - public void setOtherInput(JavaPairRDD otherInput) { - this.otherInput = otherInput; - } - - public JavaPairRDD getOtherInput() { - return this.otherInput; + JavaPairRDD... input) { + Preconditions.checkArgument(input.length > 0, + "AssertionError: transform should take at least 1 input"); + JavaPairRDD rdd = input[0]; + for (int i = 1; i < input.length; i++) { + rdd = rdd.union(input[i]); + } + return rdd; } } diff --git a/ql/src/test/results/clientpositive/spark/union17.q.out.sorted b/ql/src/test/results/clientpositive/spark/union17.q.out.sorted new file mode 100644 index 0000000..e69de29 diff --git a/ql/src/test/results/clientpositive/spark/union20.q.out.sorted b/ql/src/test/results/clientpositive/spark/union20.q.out.sorted new file mode 100644 index 0000000..e69de29 diff --git a/ql/src/test/results/clientpositive/spark/union21.q.out.sorted b/ql/src/test/results/clientpositive/spark/union21.q.out.sorted new file mode 100644 index 0000000..e69de29 diff --git a/ql/src/test/results/clientpositive/spark/union27.q.out.sorted b/ql/src/test/results/clientpositive/spark/union27.q.out.sorted new file mode 100644 index 0000000..e69de29