diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java index abd4718..3d06275 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/GroupByShuffler.java @@ -7,8 +7,11 @@ @Override public JavaPairRDD> shuffle( - JavaPairRDD input) { - return input.groupByKey(/* default to hash partition */); + JavaPairRDD input, int numPartitions) { + if (numPartitions > 0) { + return input.groupByKey(numPartitions); + } + return input.groupByKey(); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java index b2fe482..76b74e7 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java @@ -24,11 +24,12 @@ public class ReduceTran implements SparkTran { private SparkShuffler shuffler; private HiveReduceFunction reduceFunc; + private int numPartitions; @Override public JavaPairRDD transform( JavaPairRDD input) { - return shuffler.shuffle(input).mapPartitionsToPair(reduceFunc); + return shuffler.shuffle(input, numPartitions).mapPartitionsToPair(reduceFunc); } public void setReduceFunction(HiveReduceFunction redFunc) { @@ -39,4 +40,8 @@ public void setShuffler(SparkShuffler shuffler) { this.shuffler = shuffler; } + public void setNumPartitions(int numPartitions) { + this.numPartitions = numPartitions; + } + } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java index f262065..ca85e02 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SortByShuffler.java @@ -18,8 +18,12 @@ package org.apache.hadoop.hive.ql.exec.spark; +import java.util.ArrayList; +import java.util.Comparator; import java.util.Iterator; +import java.util.List; +import com.google.common.collect.Ordering; import org.apache.hadoop.io.BytesWritable; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; @@ -30,10 +34,12 @@ @Override public JavaPairRDD> shuffle( - JavaPairRDD input) { - JavaPairRDD rdd = input.sortByKey(); + JavaPairRDD input, int numPartitions) { + Comparator comp = Ordering.natural(); + // Due to HIVE-7540, numPartitions must be to 1 + JavaPairRDD rdd = input.sortByKey(comp, true, 1); return rdd.mapPartitionsToPair(new ShuffleFunction()); - }; + } private static class ShuffleFunction implements PairFlatMapFunction>, @@ -48,7 +54,7 @@ final Iterator>> resultIt = new Iterator>>() { BytesWritable curKey = null; - BytesWritable curValue = null; + List curValues = new ArrayList(); @Override public boolean hasNext() { @@ -60,13 +66,30 @@ public boolean hasNext() { // TODO: implement this by accumulating rows with the same key into a list. // Note that this list needs to improved to prevent excessive memory usage, but this // can be done in later phase. - return null; + while (it.hasNext()) { + Tuple2 pair = it.next(); + if (curKey != null && !curKey.equals(pair._1())) { + BytesWritable key = curKey; + List values = curValues; + curKey = pair._1(); + curValues = new ArrayList(); + curValues.add(pair._2()); + return new Tuple2>(key, values); + } + curKey = pair._1(); + curValues.add(pair._2()); + } + // if we get here, this should be the last element we have + BytesWritable key = curKey; + curKey = null; + return new Tuple2>(key, curValues); } @Override public void remove() { // Not implemented. // throw Unsupported Method Invocation Exception. + throw new UnsupportedOperationException(); } }; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java index 73553ee..91d7c40 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkPlanGenerator.java @@ -67,6 +67,7 @@ public SparkPlan generate(SparkWork sparkWork) throws Exception { SparkShuffler st = generate(edge); ReduceTran rt = generate(child); rt.setShuffler(st); + rt.setNumPartitions(edge.getNumPartitions()); trans.add(rt); w = child; } @@ -110,7 +111,9 @@ private MapTran generate(MapWork mw) throws IOException { } private SparkShuffler generate(SparkEdgeProperty edge) { - // TODO: create different shuffler based on edge prop. + if (edge.isShuffleSort()){ + return new SortByShuffler(); + } return new GroupByShuffler(); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkShuffler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkShuffler.java index f9f9c10..2475359 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkShuffler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkShuffler.java @@ -24,6 +24,6 @@ public interface SparkShuffler { JavaPairRDD> shuffle( - JavaPairRDD input); + JavaPairRDD input, int numPartitions); } diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java index 25eea14..ca6daae 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkUtils.java @@ -33,26 +33,12 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.FetchTask; -import org.apache.hadoop.hive.ql.exec.Operator; -import org.apache.hadoop.hive.ql.exec.FileSinkOperator; -import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; -import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; -import org.apache.hadoop.hive.ql.exec.TableScanOperator; -import org.apache.hadoop.hive.ql.exec.UnionOperator; -import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.exec.*; import org.apache.hadoop.hive.ql.optimizer.GenMapRedUtils; import org.apache.hadoop.hive.ql.parse.ParseContext; import org.apache.hadoop.hive.ql.parse.PrunedPartitionList; import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.ql.plan.BaseWork; -import org.apache.hadoop.hive.ql.plan.FileSinkDesc; -import org.apache.hadoop.hive.ql.plan.MapWork; -import org.apache.hadoop.hive.ql.plan.OperatorDesc; -import org.apache.hadoop.hive.ql.plan.ReduceWork; -import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty; -import org.apache.hadoop.hive.ql.plan.SparkWork; -import org.apache.hadoop.hive.ql.plan.UnionWork; +import org.apache.hadoop.hive.ql.plan.*; /** * GenSparkUtils is a collection of shared helper methods to produce SparkWork @@ -137,12 +123,15 @@ public ReduceWork createReduceWork(GenSparkProcContext context, Operator root sparkWork.add(reduceWork); - SparkEdgeProperty edgeProp; - if (reduceWork.isAutoReduceParallelism()) { - edgeProp = - new SparkEdgeProperty(0); - } else { - edgeProp = new SparkEdgeProperty(0); + SparkEdgeProperty edgeProp = new SparkEdgeProperty(SparkEdgeProperty.SHUFFLE_NONE, + reduceSink.getConf().getNumReducers()); + + if (root instanceof GroupByOperator) { + edgeProp.setShuffleGroup(); + } + String sortOrder=reduceSink.getConf().getOrder(); + if (sortOrder != null && !sortOrder.trim().isEmpty()) { + edgeProp.setShuffleSort(); } sparkWork.connect( diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java index ceb7b6c..79bda49 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/spark/GenSparkWork.java @@ -28,11 +28,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.ql.exec.HashTableDummyOperator; -import org.apache.hadoop.hive.ql.exec.MapJoinOperator; -import org.apache.hadoop.hive.ql.exec.Operator; -import org.apache.hadoop.hive.ql.exec.OperatorFactory; -import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator; +import org.apache.hadoop.hive.ql.exec.*; import org.apache.hadoop.hive.ql.lib.Node; import org.apache.hadoop.hive.ql.lib.NodeProcessor; import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx; @@ -263,13 +259,14 @@ public Object process(Node nd, Stack stack, if (!context.connectedReduceSinks.contains(rs)) { // add dependency between the two work items - SparkEdgeProperty edgeProp; - if (rWork.isAutoReduceParallelism()) { - edgeProp = - new SparkEdgeProperty(0/*context.conf, EdgeType.SIMPLE_EDGE, true, - rWork.getMinReduceTasks(), rWork.getMaxReduceTasks(), bytesPerReducer*/); - } else { - edgeProp = new SparkEdgeProperty(0/*EdgeType.SIMPLE_EDGE*/); + SparkEdgeProperty edgeProp = new SparkEdgeProperty(SparkEdgeProperty.SHUFFLE_NONE, + rs.getConf().getNumReducers()); + if(rWork.getReducer() instanceof GroupByOperator){ + edgeProp.setShuffleGroup(); + } + String sortOrder=rs.getConf().getOrder(); + if(sortOrder != null && !sortOrder.trim().isEmpty()) { + edgeProp.setShuffleSort(); } sparkWork.connect(work, rWork, edgeProp); context.connectedReduceSinks.add(rs); diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/SparkEdgeProperty.java ql/src/java/org/apache/hadoop/hive/ql/plan/SparkEdgeProperty.java index 9447578..d9cfeb5 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/SparkEdgeProperty.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/SparkEdgeProperty.java @@ -34,7 +34,7 @@ public SparkEdgeProperty(long edgeType, int numPartitions) { this.numPartitions = numPartitions; } - public SparkEdgeProperty(int edgeType) { + public SparkEdgeProperty(long edgeType) { this.edgeType = edgeType; }