diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java index 2895d80..83057eb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/HashTableLoader.java @@ -17,29 +17,19 @@ */ package org.apache.hadoop.hive.ql.exec.spark; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.ql.exec.HashTableSinkOperator; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; -import org.apache.hadoop.hive.ql.exec.Operator; -import org.apache.hadoop.hive.ql.exec.TemporaryHashSinkOperator; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.mr.ExecMapperContext; -import org.apache.hadoop.hive.ql.exec.mr.MapredLocalTask; import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainer; import org.apache.hadoop.hive.ql.exec.persistence.MapJoinTableContainerSerDe; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.plan.MapJoinDesc; import org.apache.hadoop.hive.ql.plan.MapredLocalWork; -import org.apache.hadoop.hive.ql.plan.OperatorDesc; -import org.apache.hadoop.mapred.JobConf; /** * HashTableLoader for Spark to load the hashtable for MapJoins. @@ -66,14 +56,9 @@ public void init(ExecMapperContext context, Configuration hconf, MapJoinOperator public void load( MapJoinTableContainer[] mapJoinTables, MapJoinTableContainerSerDe[] mapJoinTableSerdes, long memUsage) throws HiveException { - - String currentInputPath = context.getCurrentInputPath().toString(); - LOG.info("******* Load from HashTable for input file: " + currentInputPath); + Path currentInputPath = context.getCurrentInputPath(); MapredLocalWork localWork = context.getLocalWork(); try { - if (localWork.getDirectFetchOp() != null) { - loadDirectly(mapJoinTables, currentInputPath); - } // All HashTables share the same base dir, // which is passed in as the tmp path Path baseDir = localWork.getTmpPath(); @@ -81,7 +66,14 @@ public void load( return; } FileSystem fs = FileSystem.get(baseDir.toUri(), hconf); - String fileName = localWork.getBucketFileName(currentInputPath); + + // Note: it's possible that a MJ operator is in a ReduceWork, in which case the + // currentInputPath will be null. But, since currentInputPath is only interesting + // for bucket join case, and for bucket join the MJ operator will always be in + // a MapWork, this should be OK. + String fileName = localWork.getBucketFileName( + currentInputPath == null ? null : currentInputPath.toString()); + for (int pos = 0; pos < mapJoinTables.length; pos++) { if (pos == desc.getPosBigTable() || mapJoinTables[pos] != null) { continue; @@ -94,36 +86,4 @@ public void load( throw new HiveException(e); } } - - @SuppressWarnings("unchecked") - private void loadDirectly(MapJoinTableContainer[] mapJoinTables, String inputFileName) - throws Exception { - MapredLocalWork localWork = context.getLocalWork(); - List> directWorks = localWork.getDirectFetchOp().get(joinOp); - if (directWorks == null || directWorks.isEmpty()) { - return; - } - JobConf job = new JobConf(hconf); - MapredLocalTask localTask = new MapredLocalTask(localWork, job, false); - - HashTableSinkOperator sink = new TemporaryHashSinkOperator(desc); - sink.setParentOperators(new ArrayList>(directWorks)); - - for (Operator operator : directWorks) { - if (operator != null) { - operator.setChildOperators(Arrays.>asList(sink)); - } - } - localTask.setExecContext(context); - localTask.startForward(inputFileName); - - MapJoinTableContainer[] tables = sink.getMapJoinTables(); - for (int i = 0; i < sink.getNumParent(); i++) { - if (sink.getParentOperators().get(i) != null) { - mapJoinTables[i] = tables[i]; - } - } - - Arrays.fill(tables, null); - } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkReduceRecordHandler.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkReduceRecordHandler.java index 141ae6f..a9fbf6c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkReduceRecordHandler.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkReduceRecordHandler.java @@ -34,11 +34,13 @@ import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.mr.ExecMapper.ReportStats; import org.apache.hadoop.hive.ql.exec.mr.ExecMapperContext; +import org.apache.hadoop.hive.ql.exec.mr.MapredLocalTask; import org.apache.hadoop.hive.ql.exec.vector.VectorizedBatchUtil; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriter; import org.apache.hadoop.hive.ql.exec.vector.expressions.VectorExpressionWriterFactory; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.MapredLocalWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.TableDesc; @@ -101,6 +103,7 @@ private StructObjectInspector[] valueStructInspectors; /* this is only used in the error code path */ private List[] valueStringWriters; + private MapredLocalWork localWork = null; public void init(JobConf job, OutputCollector output, Reporter reporter) { super.init(job, output, reporter); @@ -197,8 +200,9 @@ public void init(JobConf job, OutputCollector output, Reporter reporter) { } ExecMapperContext execContext = new ExecMapperContext(job); + localWork = gWork.getMapRedLocalWork(); execContext.setJc(jc); - execContext.setLocalWork(gWork.getMapRedLocalWork()); + execContext.setLocalWork(localWork); reducer.setExecContext(execContext); reducer.setReporter(rp); @@ -209,6 +213,14 @@ public void init(JobConf job, OutputCollector output, Reporter reporter) { try { LOG.info(reducer.dump(0)); reducer.initialize(jc, rowObjectInspector); + + if (localWork != null) { + for (Operator dummyOp : localWork.getDummyParentOp()) { + dummyOp.setExecContext(execContext); + dummyOp.initialize(jc, null); + } + } + } catch (Throwable e) { abort = true; if (e instanceof OutOfMemoryError) { @@ -218,6 +230,7 @@ public void init(JobConf job, OutputCollector output, Reporter reporter) { throw new RuntimeException("Reduce operator initialization failed", e); } } + } @Override @@ -416,6 +429,13 @@ public void close() { } reducer.close(abort); + + if (localWork != null) { + for (Operator dummyOp : localWork.getDummyParentOp()) { + dummyOp.close(abort); + } + } + ReportStats rps = new ReportStats(rp, jc); reducer.preorderMap(rps);