diff --git itests/src/test/resources/testconfiguration.properties itests/src/test/resources/testconfiguration.properties index 156b19b..6f3317f 100644 --- itests/src/test/resources/testconfiguration.properties +++ itests/src/test/resources/testconfiguration.properties @@ -241,7 +241,6 @@ minitez.query.files=bucket_map_join_tez1.q,\ tez_union_decimal.q,\ tez_union_group_by.q,\ tez_smb_main.q,\ - tez_smb_1.q,\ vectorized_dynamic_partition_pruning.q beeline.positive.exclude=add_part_exist.q,\ diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java index 1d1405e..55c0d8b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -36,7 +35,6 @@ import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; -import org.apache.hadoop.hive.serde2.objectinspector.InspectableObject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.StructField; @@ -64,7 +62,6 @@ private static final long serialVersionUID = 1L; private boolean isBigTableWork; private static final Log LOG = LogFactory.getLog(CommonMergeJoinOperator.class.getName()); - private Map aliasToInputNameMap; transient List[] keyWritables; transient List[] nextKeyWritables; transient RowContainer>[] nextGroupStorage; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java index f624bf4..776e522 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java @@ -639,4 +639,12 @@ public OperatorType getType() { public Map getTagToOperatorTree() { return MapRecordProcessor.getConnectOps(); } + + public ObjectInspector getCurrentObjectInspector() { + return current.tblRawRowObjectInspector; + } + + public Deserializer getCurrentDeserializer() { + return current.deserializer; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index 155002a..998ae40 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -409,7 +409,7 @@ private static BaseWork getBaseWork(Configuration conf, String name) { } gWorkMap.put(path, gWork); } else { - LOG.debug("Found plan in cache."); + LOG.debug("Found plan in cache for name: " + name); gWork = gWorkMap.get(path); } return gWork; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java index 848be26..11fb2b5 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -89,6 +90,13 @@ private final Map> inputToGroupedSplitMap = new HashMap>(); + private int numInputsAffectingRootInputSpecUpdate = 1; + private int numInputsSeenSoFar = 0; + private final Map emMap = Maps.newHashMap(); + private final List finalSplits = Lists.newLinkedList(); + private final Map inputNameInputSpecMap = + new HashMap(); + public CustomPartitionVertex(VertexManagerPluginContext context) { super(context); } @@ -108,12 +116,13 @@ public void initialize() { this.numBuckets = vertexConf.getNumBuckets(); this.mainWorkName = vertexConf.getInputName(); this.vertexType = vertexConf.getVertexType(); + this.numInputsAffectingRootInputSpecUpdate = vertexConf.getNumInputs(); } @Override public void onVertexStarted(Map> completions) { int numTasks = context.getVertexNumTasks(context.getVertexName()); - List scheduledTasks = + List scheduledTasks = new ArrayList(numTasks); for (int i = 0; i < numTasks; ++i) { scheduledTasks.add(new VertexManagerPluginContext.TaskWithLocationHint(new Integer(i), null)); @@ -133,8 +142,8 @@ public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) { @Override public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List events) { + numInputsSeenSoFar++; LOG.info("On root vertex initialized " + inputName); - try { // This is using the payload from the RootVertexInitializer corresponding // to InputName. Ideally it should be using it's own configuration class - @@ -174,12 +183,15 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr if (event instanceof InputConfigureVertexTasksEvent) { // No tasks should have been started yet. Checked by initial state // check. + LOG.info("Got a input configure vertex event for input: " + inputName); Preconditions.checkState(dataInformationEventSeen == false); InputConfigureVertexTasksEvent cEvent = (InputConfigureVertexTasksEvent) event; // The vertex cannot be configured until all DataEvents are seen - to // build the routing table. configureVertexTaskEvent = cEvent; + LOG.info("Configure task for input name: " + inputName + " num tasks: " + + configureVertexTaskEvent.getNumTasks()); dataInformationEvents = Lists.newArrayListWithCapacity(configureVertexTaskEvent.getNumTasks()); } @@ -205,6 +217,8 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr } } + LOG.info("Path file splits map for input name: " + inputName + " is " + pathFileSplitsMap); + Multimap bucketToInitialSplitMap = getBucketSplitMapForPath(pathFileSplitsMap); @@ -217,30 +231,47 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr int availableSlots = totalResource / taskResource; - LOG.info("Grouping splits. " + availableSlots + " available slots, " + waves + " waves."); + LOG.info("Grouping splits. " + availableSlots + " available slots, " + waves + + " waves. Bucket initial splits map: " + bucketToInitialSplitMap); JobConf jobConf = new JobConf(conf); ShimLoader.getHadoopShims().getMergedCredentials(jobConf); Multimap bucketToGroupedSplitMap = HashMultimap. create(); - for (Integer key : bucketToInitialSplitMap.keySet()) { - InputSplit[] inputSplitArray = - (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); - Multimap groupedSplit = - HiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, - availableSlots, inputName); - bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); + if ((mainWorkName.isEmpty()) || (inputName.compareTo(mainWorkName) == 0)) { + for (Integer key : bucketToInitialSplitMap.keySet()) { + InputSplit[] inputSplitArray = + (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); + Multimap groupedSplit = + HiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + availableSlots, inputName); + bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); + } + } else { + // do not group in case of side work because there is only 1 KV reader per grouped split. + // This would affect SMB joins where we want to find the smallest key in all the bucket + // files. + for (Integer key : bucketToInitialSplitMap.keySet()) { + for (InputSplit inputSplit : bucketToInitialSplitMap.get(key)) { + InputSplit[] inputSplitArray = new InputSplit[1]; + inputSplitArray[0] = inputSplit; + Multimap groupedSplit = + HiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + availableSlots, inputName); + bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); + } + } } - LOG.info("We have grouped the splits into " + bucketToGroupedSplitMap); if ((mainWorkName.isEmpty() == false) && (mainWorkName.compareTo(inputName) != 0)) { /* - * this is the small table side. In case of SMB join, we may need to send each split to the + * this is the small table side. In case of SMB join, we need to send each split to the * corresponding bucket-based task on the other side. In case a split needs to go to * multiple downstream tasks, we need to clone the event and send it to the right * destination. */ - processAllSideEvents(inputName, bucketToGroupedSplitMap); + LOG.info("This is the side work/ multi-mr work"); + processAllSideEventsSetParallelism(inputName, bucketToGroupedSplitMap); } else { processAllEvents(inputName, bucketToGroupedSplitMap); } @@ -249,18 +280,33 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr } } - private void processAllSideEvents(String inputName, + private void processAllSideEventsSetParallelism(String inputName, Multimap bucketToGroupedSplitMap) throws IOException { // the bucket to task map should have been setup by the big table. + LOG.info("Processing events for input " + inputName); if (bucketToTaskMap.isEmpty()) { + LOG.info("We don't have a routing table yet. Will need to wait for the main input" + + " initialization"); inputToGroupedSplitMap.put(inputName, bucketToGroupedSplitMap); return; } + processAllSideEvents(inputName, bucketToGroupedSplitMap); + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); + } + + private void processAllSideEvents(String inputName, + Multimap bucketToGroupedSplitMap) throws IOException { List taskEvents = new ArrayList(); + LOG.info("We have a routing table and we are going to set the destination tasks for the" + + " multi mr inputs. " + bucketToTaskMap); + + Integer[] numSplitsForTask = new Integer[taskCount]; for (Entry> entry : bucketToGroupedSplitMap.asMap().entrySet()) { Collection destTasks = bucketToTaskMap.get(entry.getKey()); for (Integer task : destTasks) { + int count = 0; for (InputSplit split : entry.getValue()) { + count++; MRSplitProto serializedSplit = MRInputHelpers.createSplitProto(split); InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload(task, serializedSplit @@ -268,16 +314,21 @@ private void processAllSideEvents(String inputName, diEvent.setTargetIndex(task); taskEvents.add(diEvent); } + numSplitsForTask[task] = count; } } + inputNameInputSpecMap.put(inputName, + InputSpecUpdate.createPerTaskInputSpecUpdate(Arrays.asList(numSplitsForTask))); + + LOG.info("For input name: " + inputName + " task events size is " + taskEvents.size()); + context.addRootInputEvents(inputName, taskEvents); } private void processAllEvents(String inputName, Multimap bucketToGroupedSplitMap) throws IOException { - List finalSplits = Lists.newLinkedList(); for (Entry> entry : bucketToGroupedSplitMap.asMap().entrySet()) { int bucketNum = entry.getKey(); Collection initialSplits = entry.getValue(); @@ -288,6 +339,9 @@ private void processAllEvents(String inputName, } } + inputNameInputSpecMap.put(inputName, + InputSpecUpdate.getDefaultSinglePhysicalInputSpecUpdate()); + // Construct the EdgeManager descriptor to be used by all edges which need // the routing table. EdgeManagerPluginDescriptor hiveEdgeManagerDesc = null; @@ -297,7 +351,6 @@ private void processAllEvents(String inputName, UserPayload payload = getBytePayload(bucketToTaskMap); hiveEdgeManagerDesc.setUserPayload(payload); } - Map emMap = Maps.newHashMap(); // Replace the edge manager for all vertices which have routing type custom. for (Entry edgeEntry : context.getInputVertexEdgeProperties().entrySet()) { @@ -308,7 +361,7 @@ private void processAllEvents(String inputName, } } - LOG.info("Task count is " + taskCount); + LOG.info("Task count is " + taskCount + " for input name: " + inputName); List taskEvents = Lists.newArrayListWithCapacity(finalSplits.size()); @@ -323,27 +376,35 @@ private void processAllEvents(String inputName, taskEvents.add(diEvent); } - // Replace the Edge Managers - Map rootInputSpecUpdate = - new HashMap(); - rootInputSpecUpdate.put( - inputName, - InputSpecUpdate.getDefaultSinglePhysicalInputSpecUpdate()); - if ((mainWorkName.compareTo(inputName) == 0) || (mainWorkName.isEmpty())) { - context.setVertexParallelism( - taskCount, - VertexLocationHint.create(grouper.createTaskLocationHints(finalSplits - .toArray(new InputSplit[finalSplits.size()]))), emMap, rootInputSpecUpdate); - } - // Set the actual events for the tasks. + LOG.info("For input name: " + inputName + " task events size is " + taskEvents.size()); context.addRootInputEvents(inputName, taskEvents); if (inputToGroupedSplitMap.isEmpty() == false) { for (Entry> entry : inputToGroupedSplitMap.entrySet()) { processAllSideEvents(entry.getKey(), entry.getValue()); } + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); inputToGroupedSplitMap.clear(); } + + // Only done when it is a bucket map join only no SMB. + if (numInputsAffectingRootInputSpecUpdate == 1) { + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); + } + } + + private void + setVertexParallelismAndRootInputSpec(Map rootInputSpecUpdate) + throws IOException { + if (numInputsAffectingRootInputSpecUpdate != numInputsSeenSoFar) { + return; + } + + LOG.info("Setting vertex parallelism since we have seen all inputs."); + + context.setVertexParallelism(taskCount, VertexLocationHint.create(grouper + .createTaskLocationHints(finalSplits.toArray(new InputSplit[finalSplits.size()]))), emMap, + rootInputSpecUpdate); } UserPayload getBytePayload(Multimap routingTable) throws IOException { @@ -392,6 +453,11 @@ private FileSplit getFileSplitFromEvent(InputDataInformationEvent event) throws bucketNum++; } + // this is just for SMB join use-case. The numBuckets would be equal to that of the big table + // and the small table could have lesser number of buckets. In this case, we want to send the + // data from the right buckets to the big table side. For e.g. Big table has 8 buckets and small + // table has 4 buckets, bucket 0 of small table needs to be sent to bucket 4 of the big table as + // well. if (bucketNum < numBuckets) { int loopedBucketId = 0; for (; bucketNum < numBuckets; bucketNum++) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java index 4829f92..993377a 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java @@ -35,14 +35,17 @@ private int numBuckets; private VertexType vertexType = VertexType.AUTO_INITIALIZED_EDGES; + private int numInputs; private String inputName; public CustomVertexConfiguration() { } - public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String inputName) { + public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String inputName, + int numInputs) { this.numBuckets = numBuckets; this.vertexType = vertexType; + this.numInputs = numInputs; this.inputName = inputName; } @@ -50,6 +53,7 @@ public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String i public void write(DataOutput out) throws IOException { out.writeInt(this.vertexType.ordinal()); out.writeInt(this.numBuckets); + out.writeInt(numInputs); out.writeUTF(inputName); } @@ -57,6 +61,7 @@ public void write(DataOutput out) throws IOException { public void readFields(DataInput in) throws IOException { this.vertexType = VertexType.values()[in.readInt()]; this.numBuckets = in.readInt(); + this.numInputs = in.readInt(); this.inputName = in.readUTF(); } @@ -71,4 +76,8 @@ public VertexType getVertexType() { public String getInputName() { return inputName; } + + public int getNumInputs() { + return numInputs; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java index 7703d69..5efdb8b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java @@ -268,8 +268,9 @@ public GroupInputEdge createEdge(VertexGroup group, JobConf vConf, Vertex w, case CUSTOM_EDGE: { mergeInputClass = ConcatenatedMergedKeyValueInput.class; int numBuckets = edgeProp.getNumBuckets(); + // by default only 1 input determines the number of inputs per task. CustomVertexConfiguration vertexConf = - new CustomVertexConfiguration(numBuckets, vertexType, ""); + new CustomVertexConfiguration(numBuckets, vertexType, "", 1); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); VertexManagerPluginDescriptor desc = @@ -314,8 +315,9 @@ public Edge createEdge(JobConf vConf, Vertex v, Vertex w, TezEdgeProperty edgePr switch(edgeProp.getEdgeType()) { case CUSTOM_EDGE: { int numBuckets = edgeProp.getNumBuckets(); + // by default only 1 input determines the number of inputs per task. CustomVertexConfiguration vertexConf = - new CustomVertexConfiguration(numBuckets, vertexType, ""); + new CustomVertexConfiguration(numBuckets, vertexType, "", 1); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); VertexManagerPluginDescriptor desc = VertexManagerPluginDescriptor.create( @@ -496,9 +498,11 @@ private Vertex createVertex(JobConf conf, MergeJoinWork mergeJoinWork, LocalReso VertexManagerPluginDescriptor desc = VertexManagerPluginDescriptor.create(CustomPartitionVertex.class.getName()); + // the +1 to the size is because of the main work. CustomVertexConfiguration vertexConf = new CustomVertexConfiguration(mergeJoinWork.getMergeJoinOperator().getConf() - .getNumBuckets(), vertexType, mergeJoinWork.getBigTableAlias()); + .getNumBuckets(), vertexType, mergeJoinWork.getBigTableAlias(), + mapWorkList.size() + 1); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); byte[] userPayload = dob.getData(); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java index c77e081..3db871c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java @@ -44,10 +44,10 @@ import org.apache.hadoop.hive.ql.exec.tez.TezProcessor.TezKVOutputCollector; import org.apache.hadoop.hive.ql.exec.tez.tools.KeyValueInputMerger; import org.apache.hadoop.hive.ql.exec.vector.VectorMapOperator; -import org.apache.hadoop.hive.ql.io.IOContext; import org.apache.hadoop.hive.ql.log.PerfLogger; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.mapred.JobConf; import org.apache.tez.mapreduce.input.MRInputLegacy; import org.apache.tez.mapreduce.input.MultiMRInput; @@ -158,8 +158,6 @@ void init(JobConf jconf, ProcessorContext processorContext, MRTaskReporter mrRep if (mergeWorkList != null) { MapOperator mergeMapOp = null; for (MapWork mergeMapWork : mergeWorkList) { - processorContext.waitForAnyInputReady(Collections.singletonList((Input) (inputs - .get(mergeMapWork.getName())))); if (mergeMapWork.getVectorMode()) { mergeMapOp = new VectorMapOperator(); } else { @@ -249,7 +247,10 @@ private void initializeMapRecordSources() throws Exception { Collection kvReaders = multiMRInput.getKeyValueReaders(); l4j.debug("There are " + kvReaders.size() + " key-value readers for input " + inputName); List kvReaderList = new ArrayList(kvReaders); - reader = new KeyValueInputMerger(kvReaderList); + reader = + new KeyValueInputMerger(kvReaderList, mapOp.getCurrentDeserializer(), + new ObjectInspector[] { mapOp.getCurrentObjectInspector() }, mapOp.getConf() + .getSortCols()); sources[tag].init(jconf, mapOp, reader); } ((TezContext) MapredContext.get()).setRecordSources(sources); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java index 0419568..7c880eb 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java @@ -19,6 +19,7 @@ package org.apache.hadoop.hive.ql.exec.tez; import java.io.IOException; +import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -44,8 +45,9 @@ private KeyValueReader reader = null; private final boolean grouped = false; - void init(JobConf jconf, MapOperator mapOp, KeyValueReader reader) throws IOException { - execContext = new ExecMapperContext(jconf); + void init(JobConf jconf, MapOperator mapOp, KeyValueReader reader) + throws IOException { + execContext = mapOp.getExecContext(); this.mapOp = mapOp; this.reader = reader; } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java index 42c7d37..0588146 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.text.NumberFormat; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -38,6 +39,7 @@ import org.apache.tez.mapreduce.processor.MRTaskReporter; import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; import org.apache.tez.runtime.api.Event; +import org.apache.tez.runtime.api.Input; import org.apache.tez.runtime.api.LogicalInput; import org.apache.tez.runtime.api.LogicalOutput; import org.apache.tez.runtime.api.ProcessorContext; @@ -154,8 +156,11 @@ protected void initializeAndRunProcessor(Map inputs, if (!cacheAccess.isInputCached(inputEntry.getKey())) { LOG.info("Input: " + inputEntry.getKey() + " is not cached"); inputEntry.getValue().start(); + processorContext.waitForAnyInputReady(Collections.singletonList((Input) (inputEntry + .getValue()))); } else { - LOG.info("Input: " + inputEntry.getKey() + " is already cached. Skipping start"); + LOG.info("Input: " + inputEntry.getKey() + + " is already cached. Skipping start and wait for ready"); } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java index 516722d..a1bf44d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java @@ -24,7 +24,14 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.hive.serde2.Deserializer; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.io.Writable; import org.apache.tez.runtime.library.api.KeyValueReader; /** @@ -34,16 +41,25 @@ * Uses a priority queue to pick the KeyValuesReader of the input that is next in * sort order. */ +@SuppressWarnings("deprecation") public class KeyValueInputMerger extends KeyValueReader { public static final Log l4j = LogFactory.getLog(KeyValueInputMerger.class); private PriorityQueue pQueue = null; private KeyValueReader nextKVReader = null; + private ObjectInspector[] inputObjInspectors = null; + private Deserializer deserializer = null; + private List sortCols = null; - public KeyValueInputMerger(List multiMRInputs) throws Exception { + public KeyValueInputMerger(List multiMRInputs, + Deserializer deserializer, + ObjectInspector[] inputObjInspectors, List sortCols) throws Exception { //get KeyValuesReaders from the LogicalInput and add them to priority queue int initialCapacity = multiMRInputs.size(); pQueue = new PriorityQueue(initialCapacity, new KVReaderComparator()); + this.inputObjInspectors = inputObjInspectors; + this.deserializer = deserializer; + this.sortCols = sortCols; l4j.info("Initialized the priority queue with multi mr inputs: " + multiMRInputs.size()); for (KeyValueReader input : multiMRInputs) { addToQueue(input); @@ -93,12 +109,42 @@ public Object getCurrentValue() throws IOException { */ class KVReaderComparator implements Comparator { + @SuppressWarnings({ "unchecked" }) @Override public int compare(KeyValueReader kvReadr1, KeyValueReader kvReadr2) { try { - BinaryComparable key1 = (BinaryComparable) kvReadr1.getCurrentValue(); - BinaryComparable key2 = (BinaryComparable) kvReadr2.getCurrentValue(); - return key1.compareTo(key2); + ObjectInspector oi = inputObjInspectors[0]; + List row1, row2; + try { + // we need to copy to standard object otherwise deserializer overwrites the values + row1 = + (List) ObjectInspectorUtils.copyToStandardObject( + deserializer.deserialize((Writable) kvReadr1.getCurrentValue()), oi, + ObjectInspectorCopyOption.WRITABLE); + row2 = + (List) ObjectInspectorUtils.copyToStandardObject( + deserializer.deserialize((Writable) kvReadr2.getCurrentValue()), oi, + ObjectInspectorCopyOption.WRITABLE); + } catch (SerDeException e) { + throw new IOException(e); + } + + StructObjectInspector structOI = (StructObjectInspector) oi; + int compare = 0; + for (String field : sortCols) { + StructField sf = structOI.getStructFieldRef(field); + int pos = structOI.getAllStructFieldRefs().indexOf(sf); + Object key1 = row1.get(pos); + Object key2 = row2.get(pos); + ObjectInspector stdOI = + ObjectInspectorUtils.getStandardObjectInspector(sf.getFieldObjectInspector()); + compare = ObjectInspectorUtils.compare(key1, stdOI, key2, stdOI); + if (compare != 0) { + return compare; + } + } + + return compare; } catch (IOException e) { l4j.error("Caught exception while reading shuffle input", e); //die! diff --git ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java index cad567a..3fce23d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java @@ -61,6 +61,7 @@ public static IOContext get(String inputName) { public static void clear() { IOContext.threadLocal.remove(); + inputNameIOContextMap.clear(); ioContext = new IOContext(); } diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java index 7a3280c..41a8974 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -85,13 +85,18 @@ JoinOperator joinOp = (JoinOperator) nd; - if (!context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN) - && !(context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN))) { + TezBucketJoinProcCtx tezBucketJoinProcCtx = new TezBucketJoinProcCtx(context.conf); + if (!context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN)) { // we are just converting to a common merge join operator. The shuffle // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); - return null; + Object retval = checkAndConvertSMBJoin(context, joinOp, tezBucketJoinProcCtx); + if (retval == null) { + return retval; + } else { + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + } } // if we have traits, and table info is present in the traits, we know the @@ -99,7 +104,6 @@ // reducers from the parent operators. int numBuckets = -1; int estimatedBuckets = -1; - TezBucketJoinProcCtx tezBucketJoinProcCtx = new TezBucketJoinProcCtx(context.conf); if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_CONVERT_JOIN_BUCKET_MAPJOIN_TEZ)) { for (OperatorparentOp : joinOp.getParentOperators()) { if (parentOp.getOpTraits().getNumBuckets() > 0) { @@ -126,53 +130,15 @@ LOG.info("Estimated number of buckets " + numBuckets); int mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, numBuckets); if (mapJoinConversionPos < 0) { - // we cannot convert to bucket map join, we cannot convert to - // map join either based on the size. Check if we can convert to SMB join. - if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN) == false) { - convertJoinSMBJoin(joinOp, context, 0, 0, false, false); - return null; - } - Class bigTableMatcherClass = null; - try { - bigTableMatcherClass = - (Class) (Class.forName(HiveConf.getVar( - context.parseContext.getConf(), - HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN_BIGTABLE_SELECTOR))); - } catch (ClassNotFoundException e) { - throw new SemanticException(e.getMessage()); - } - - BigTableSelectorForAutoSMJ bigTableMatcher = - ReflectionUtils.newInstance(bigTableMatcherClass, null); - JoinDesc joinDesc = joinOp.getConf(); - JoinCondDesc[] joinCondns = joinDesc.getConds(); - Set joinCandidates = MapJoinProcessor.getBigTableCandidates(joinCondns); - if (joinCandidates.isEmpty()) { - // This is a full outer join. This can never be a map-join - // of any type. So return false. - return false; - } - mapJoinConversionPos = - bigTableMatcher.getBigTablePosition(context.parseContext, joinOp, joinCandidates); - if (mapJoinConversionPos < 0) { - // contains aliases from sub-query - // we are just converting to a common merge join operator. The shuffle - // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); - return null; - } - - if (checkConvertJoinSMBJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { - convertJoinSMBJoin(joinOp, context, mapJoinConversionPos, - tezBucketJoinProcCtx.getNumBuckets(), tezBucketJoinProcCtx.isSubQuery(), true); + Object retval = checkAndConvertSMBJoin(context, joinOp, tezBucketJoinProcCtx); + if (retval == null) { + return retval; } else { - // we are just converting to a common merge join operator. The shuffle - // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + // only case is full outer join with SMB enabled which is not possible. Convert to regular + // join. + convertJoinSMBJoin(joinOp, context, 0, 0, false, false); + return null; } - return null; } if (numBuckets > 1) { @@ -206,6 +172,57 @@ return null; } + private Object checkAndConvertSMBJoin(OptimizeTezProcContext context, JoinOperator joinOp, + TezBucketJoinProcCtx tezBucketJoinProcCtx) throws SemanticException { + // we cannot convert to bucket map join, we cannot convert to + // map join either based on the size. Check if we can convert to SMB join. + if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN) == false) { + convertJoinSMBJoin(joinOp, context, 0, 0, false, false); + return null; + } + Class bigTableMatcherClass = null; + try { + bigTableMatcherClass = + (Class) (Class.forName(HiveConf.getVar( + context.parseContext.getConf(), + HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN_BIGTABLE_SELECTOR))); + } catch (ClassNotFoundException e) { + throw new SemanticException(e.getMessage()); + } + + BigTableSelectorForAutoSMJ bigTableMatcher = + ReflectionUtils.newInstance(bigTableMatcherClass, null); + JoinDesc joinDesc = joinOp.getConf(); + JoinCondDesc[] joinCondns = joinDesc.getConds(); + Set joinCandidates = MapJoinProcessor.getBigTableCandidates(joinCondns); + if (joinCandidates.isEmpty()) { + // This is a full outer join. This can never be a map-join + // of any type. So return false. + return false; + } + int mapJoinConversionPos = + bigTableMatcher.getBigTablePosition(context.parseContext, joinOp, joinCandidates); + if (mapJoinConversionPos < 0) { + // contains aliases from sub-query + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + } + + if (checkConvertJoinSMBJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { + convertJoinSMBJoin(joinOp, context, mapJoinConversionPos, + tezBucketJoinProcCtx.getNumBuckets(), tezBucketJoinProcCtx.isSubQuery(), true); + } else { + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + } + return null; +} + // replaces the join operator with a new CommonJoinOperator, removes the // parent reduce sinks private void convertJoinSMBJoin(JoinOperator joinOp, OptimizeTezProcContext context, @@ -630,7 +647,7 @@ private boolean hasDynamicPartitionBroadcast(Operator parent) { hasDynamicPartitionPruning = true; break; } - + if (op instanceof ReduceSinkOperator || op instanceof FileSinkOperator) { // crossing reduce sink or file sink means the pruning isn't for this parent. break; diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java index 8516643..2c83c4b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java @@ -1,7 +1,5 @@ package org.apache.hadoop.hive.ql.optimizer; -import java.util.HashMap; -import java.util.Map; import java.util.Stack; import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator; diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java index 516e576..d56dee0 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java @@ -123,6 +123,11 @@ public Object process(Node nd, Stack stack, context.rootToWorkMap.put(root, work); } + // this is where we set the sort columns that we will be using for KeyValueInputMerge + if (operator instanceof DummyStoreOperator) { + work.addSortCols(root.getOpTraits().getSortCols().get(0)); + } + if (!context.childToWorkMap.containsKey(operator)) { List workItems = new LinkedList(); workItems.add(work); diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java index 05be1f1..9d8d52d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.plan; +import java.util.ArrayList; import java.util.LinkedList; import java.util.LinkedHashSet; import java.util.List; @@ -42,6 +43,7 @@ // schema info. List dummyOps; int tag; + private final List sortColNames = new ArrayList(); public BaseWork() {} @@ -148,4 +150,12 @@ public void setTag(int tag) { public int getTag() { return tag; } + + public void addSortCols(List sortCols) { + this.sortColNames.addAll(sortCols); + } + + public List getSortCols() { + return sortColNames; + } }