diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveMapFunction.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveMapFunction.java index 4d6e197..6629922 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveMapFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveMapFunction.java @@ -23,12 +23,14 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.Reporter; -import org.apache.spark.api.java.function.PairFlatMapFunction; + +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.function.PairFlatMapFunction2; import scala.Tuple2; -public class HiveMapFunction implements PairFlatMapFunction>, -BytesWritable, BytesWritable> { +public class HiveMapFunction implements PairFlatMapFunction2>, BytesWritable, BytesWritable> { private static final long serialVersionUID = 1L; private transient JobConf jobConf; @@ -40,11 +42,12 @@ public HiveMapFunction(byte[] buffer) { } @Override - public Iterable> - call(Iterator> it) throws Exception { + public Iterable> call(TaskContext context, + Iterator> it) throws Exception { if (jobConf == null) { jobConf = KryoSerializer.deserializeJobConf(this.buffer); } + jobConf.setInt("mapred.task.partition", context.partitionId()); SparkMapRecordHandler mapRecordHandler = new SparkMapRecordHandler(); HiveMapFunctionResultList result = new HiveMapFunctionResultList(jobConf, it, mapRecordHandler); @@ -53,5 +56,4 @@ public HiveMapFunction(byte[] buffer) { return result; } - } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveReduceFunction.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveReduceFunction.java index 1dd5a93..f77a2b6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveReduceFunction.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HiveReduceFunction.java @@ -20,16 +20,17 @@ import java.util.Iterator; -import org.apache.hadoop.hive.ql.exec.mr.ExecReducer; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.Reporter; -import org.apache.spark.api.java.function.PairFlatMapFunction; + +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.function.PairFlatMapFunction2; import scala.Tuple2; -public class HiveReduceFunction implements PairFlatMapFunction>>, -BytesWritable, BytesWritable> { +public class HiveReduceFunction implements PairFlatMapFunction2>>, BytesWritable, BytesWritable> { private static final long serialVersionUID = 1L; private transient JobConf jobConf; @@ -41,14 +42,17 @@ public HiveReduceFunction(byte[] buffer) { } @Override - public Iterable> - call(Iterator>> it) throws Exception { + public Iterable> call(TaskContext context, + Iterator>> it) throws Exception { if (jobConf == null) { jobConf = KryoSerializer.deserializeJobConf(this.buffer); } + jobConf.setInt("mapred.task.partition", context.partitionId()); + SparkReduceRecordHandler reducerRecordhandler = new SparkReduceRecordHandler(); - HiveReduceFunctionResultList result = new HiveReduceFunctionResultList(jobConf, it, reducerRecordhandler); + HiveReduceFunctionResultList result = + new HiveReduceFunctionResultList(jobConf, it, reducerRecordhandler); reducerRecordhandler.init(jobConf, result, Reporter.NULL); return result; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java index b03a51c..542fd27 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/MapTran.java @@ -27,7 +27,7 @@ @Override public JavaPairRDD transform( JavaPairRDD input) { - return input.mapPartitionsToPair(mapFunc); + return input.mapPartitionsToPairWithContext(mapFunc, false); } public void setMapFunction(HiveMapFunction mapFunc) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java index 76b74e7..a047592 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ReduceTran.java @@ -29,7 +29,11 @@ @Override public JavaPairRDD transform( JavaPairRDD input) { - return shuffler.shuffle(input, numPartitions).mapPartitionsToPair(reduceFunc); + // TODO The second parameter of mapPartitionsToPiarWithContext is used to identify + // whether reduceFunc preserves record partition, this maybe a optimization point + // we could research later. + return shuffler.shuffle(input, numPartitions) + .mapPartitionsToPairWithContext(reduceFunc, false); } public void setReduceFunction(HiveReduceFunction redFunc) {