diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index c49a0f2..7d8e5bc 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1823,7 +1823,11 @@ TEZ_DYNAMIC_PARTITION_PRUNING_MAX_EVENT_SIZE("hive.tez.dynamic.partition.pruning.max.event.size", 1*1024*1024L, "Maximum size of events sent by processors in dynamic pruning. If this size is crossed no pruning will take place."), TEZ_DYNAMIC_PARTITION_PRUNING_MAX_DATA_SIZE("hive.tez.dynamic.partition.pruning.max.data.size", 100*1024*1024L, - "Maximum total data size of events in dynamic pruning.") + "Maximum total data size of events in dynamic pruning."), + TEZ_SMB_NUMBER_WAVES( + "hive.tez.smb.number.waves", + (float) 0.5, + "The number of waves in which to run the SMB join. Account for cluster being occupied. Ideally should be 1 wave.") ; public final String varname; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java index 1d1405e..55c0d8b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/CommonMergeJoinOperator.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -36,7 +35,6 @@ import org.apache.hadoop.hive.ql.plan.CommonMergeJoinDesc; import org.apache.hadoop.hive.ql.plan.OperatorDesc; import org.apache.hadoop.hive.ql.plan.api.OperatorType; -import org.apache.hadoop.hive.serde2.objectinspector.InspectableObject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.StructField; @@ -64,7 +62,6 @@ private static final long serialVersionUID = 1L; private boolean isBigTableWork; private static final Log LOG = LogFactory.getLog(CommonMergeJoinOperator.class.getName()); - private Map aliasToInputNameMap; transient List[] keyWritables; transient List[] nextKeyWritables; transient RowContainer>[] nextGroupStorage; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java index 8e8e9f6..00ee527 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/MapOperator.java @@ -67,6 +67,7 @@ * different from regular operators in that it starts off by processing a * Writable data structure from a Table (instead of a Hive Object). **/ +@SuppressWarnings("deprecation") public class MapOperator extends Operator implements Serializable, Cloneable { private static final long serialVersionUID = 1L; @@ -177,7 +178,6 @@ void initializeAsRoot(JobConf hconf, MapWork mapWork) throws Exception { private MapOpCtx initObjectInspector(Configuration hconf, MapOpCtx opCtx, StructObjectInspector tableRowOI) throws Exception { - PartitionDesc pd = opCtx.partDesc; TableDesc td = pd.getTableDesc(); @@ -616,4 +616,16 @@ public OperatorType getType() { public Map getTagToOperatorTree() { return MapRecordProcessor.getConnectOps(); } + + public void initializeContexts() { + Path fpath = getExecContext().getCurrentInputPath(); + String nominalPath = getNominalPath(fpath); + Map, MapOpCtx> contexts = opCtxMap.get(nominalPath); + currentCtxs = contexts.values().toArray(new MapOpCtx[contexts.size()]); + } + + public Deserializer getCurrentDeserializer() { + + return currentCtxs[0].deserializer; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java index 7443f8a..5bdeb92 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java @@ -409,7 +409,7 @@ private static BaseWork getBaseWork(Configuration conf, String name) { } gWorkMap.put(path, gWork); } else { - LOG.debug("Found plan in cache."); + LOG.debug("Found plan in cache for name: " + name); gWork = gWorkMap.get(path); } return gWork; @@ -1610,12 +1610,13 @@ public static void renameOrMoveFiles(FileSystem fs, Path src, Path dst) throws I * Group 6: copy [copy keyword] * Group 8: 2 [copy file index] */ + private static final String COPY_KEYWORD = "_copy_"; // copy keyword private static final Pattern COPY_FILE_NAME_TO_TASK_ID_REGEX = Pattern.compile("^.*?"+ // any prefix "([0-9]+)"+ // taskId "(_)"+ // separator "([0-9]{1,6})?"+ // attemptId (limited to 6 digits) - "((_)(\\Bcopy\\B)(_)"+ // copy keyword + "((_)(\\Bcopy\\B)(_)" + "([0-9]{1,6})$)?"+ // copy file index "(\\..*)?$"); // any suffix/file extension @@ -2010,6 +2011,15 @@ public static boolean isCopyFile(String filename) { return false; } + public static String getBucketFileNameFromPathSubString(String bucketName) { + try { + return bucketName.split(COPY_KEYWORD)[0]; + } catch (Exception e) { + e.printStackTrace(); + return bucketName; + } + } + public static String getNameMessage(Exception e) { return e.getClass().getName() + "(" + e.getMessage() + ")"; } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java index 848be26..3ec6a80 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java @@ -21,16 +21,22 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; +import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import java.util.TreeMap; +import java.util.TreeSet; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.plan.TezWork.VertexType; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.io.DataOutputBuffer; @@ -38,6 +44,7 @@ import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.JobConf; +import org.apache.hadoop.mapred.split.TezGroupedSplit; import org.apache.hadoop.mapreduce.split.TezMapReduceSplitsGrouper; import org.apache.tez.common.TezUtils; import org.apache.tez.dag.api.EdgeProperty; @@ -67,11 +74,36 @@ import com.google.protobuf.ByteString; /* - * Only works with old mapred API - * Will only work with a single MRInput for now. + * This is the central piece for Bucket Map Join and SMB join. It has the following + * responsibilities: + * 1. Group incoming splits based on bucketing. + * 2. Generate new serialized events for the grouped splits. + * 3. Create a routing table for the bucket map join and send a serialized version as payload + * for the EdgeManager. + * 4. For SMB join, generate a grouping according to bucketing for the "small" table side. */ public class CustomPartitionVertex extends VertexManagerPlugin { + public class PathComparatorForSplit implements Comparator { + + @Override + public int compare(InputSplit inp1, InputSplit inp2) { + FileSplit fs1 = (FileSplit) inp1; + FileSplit fs2 = (FileSplit) inp2; + + int retval = fs1.getPath().compareTo(fs2.getPath()); + if (retval != 0) { + return retval; + } + + if (fs1.getStart() != fs2.getStart()) { + return (int) (fs1.getStart() - fs2.getStart()); + } + + return 0; + } + } + private static final Log LOG = LogFactory.getLog(CustomPartitionVertex.class.getName()); VertexManagerPluginContext context; @@ -89,6 +121,13 @@ private final Map> inputToGroupedSplitMap = new HashMap>(); + private int numInputsAffectingRootInputSpecUpdate = 1; + private int numInputsSeenSoFar = 0; + private final Map emMap = Maps.newHashMap(); + private final List finalSplits = Lists.newLinkedList(); + private final Map inputNameInputSpecMap = + new HashMap(); + public CustomPartitionVertex(VertexManagerPluginContext context) { super(context); } @@ -108,12 +147,13 @@ public void initialize() { this.numBuckets = vertexConf.getNumBuckets(); this.mainWorkName = vertexConf.getInputName(); this.vertexType = vertexConf.getVertexType(); + this.numInputsAffectingRootInputSpecUpdate = vertexConf.getNumInputs(); } @Override public void onVertexStarted(Map> completions) { int numTasks = context.getVertexNumTasks(context.getVertexName()); - List scheduledTasks = + List scheduledTasks = new ArrayList(numTasks); for (int i = 0; i < numTasks; ++i) { scheduledTasks.add(new VertexManagerPluginContext.TaskWithLocationHint(new Integer(i), null)); @@ -133,8 +173,8 @@ public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) { @Override public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List events) { + numInputsSeenSoFar++; LOG.info("On root vertex initialized " + inputName); - try { // This is using the payload from the RootVertexInitializer corresponding // to InputName. Ideally it should be using it's own configuration class - @@ -168,18 +208,21 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr } boolean dataInformationEventSeen = false; - Map> pathFileSplitsMap = new TreeMap>(); + Map> pathFileSplitsMap = new TreeMap>(); for (Event event : events) { if (event instanceof InputConfigureVertexTasksEvent) { // No tasks should have been started yet. Checked by initial state // check. + LOG.info("Got a input configure vertex event for input: " + inputName); Preconditions.checkState(dataInformationEventSeen == false); InputConfigureVertexTasksEvent cEvent = (InputConfigureVertexTasksEvent) event; // The vertex cannot be configured until all DataEvents are seen - to // build the routing table. configureVertexTaskEvent = cEvent; + LOG.info("Configure task for input name: " + inputName + " num tasks: " + + configureVertexTaskEvent.getNumTasks()); dataInformationEvents = Lists.newArrayListWithCapacity(configureVertexTaskEvent.getNumTasks()); } @@ -196,15 +239,20 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr } catch (IOException e) { throw new RuntimeException("Failed to get file split for event: " + diEvent); } - List fsList = pathFileSplitsMap.get(fileSplit.getPath().getName()); + Set fsList = + pathFileSplitsMap.get(Utilities.getBucketFileNameFromPathSubString(fileSplit.getPath() + .getName())); if (fsList == null) { - fsList = new ArrayList(); - pathFileSplitsMap.put(fileSplit.getPath().getName(), fsList); + fsList = new TreeSet(new PathComparatorForSplit()); + pathFileSplitsMap.put( + Utilities.getBucketFileNameFromPathSubString(fileSplit.getPath().getName()), fsList); } fsList.add(fileSplit); } } + LOG.info("Path file splits map for input name: " + inputName + " is " + pathFileSplitsMap); + Multimap bucketToInitialSplitMap = getBucketSplitMapForPath(pathFileSplitsMap); @@ -217,50 +265,88 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr int availableSlots = totalResource / taskResource; - LOG.info("Grouping splits. " + availableSlots + " available slots, " + waves + " waves."); + LOG.info("Grouping splits. " + availableSlots + " available slots, " + waves + + " waves. Bucket initial splits map: " + bucketToInitialSplitMap); JobConf jobConf = new JobConf(conf); ShimLoader.getHadoopShims().getMergedCredentials(jobConf); Multimap bucketToGroupedSplitMap = HashMultimap. create(); - for (Integer key : bucketToInitialSplitMap.keySet()) { - InputSplit[] inputSplitArray = - (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); - Multimap groupedSplit = - HiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, - availableSlots, inputName); - bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); - } - - LOG.info("We have grouped the splits into " + bucketToGroupedSplitMap); - if ((mainWorkName.isEmpty() == false) && (mainWorkName.compareTo(inputName) != 0)) { + boolean secondLevelGroupingDone = false; + if ((mainWorkName.isEmpty()) || (inputName.compareTo(mainWorkName) == 0)) { + for (Integer key : bucketToInitialSplitMap.keySet()) { + InputSplit[] inputSplitArray = + (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); + HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator(); + Multimap groupedSplit = + hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + availableSlots, inputName, mainWorkName.isEmpty()); + if (mainWorkName.isEmpty() == false) { + Multimap singleBucketToGroupedSplit = + HashMultimap. create(); + singleBucketToGroupedSplit.putAll(key, groupedSplit.values()); + groupedSplit = + grouper.group(jobConf, singleBucketToGroupedSplit, availableSlots, + HiveConf.getFloatVar(conf, HiveConf.ConfVars.TEZ_SMB_NUMBER_WAVES)); + secondLevelGroupingDone = true; + } + bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); + } + processAllEvents(inputName, bucketToGroupedSplitMap, secondLevelGroupingDone); + } else { + // do not group across files in case of side work because there is only 1 KV reader per + // grouped split. This would affect SMB joins where we want to find the smallest key in + // all the bucket files. + for (Integer key : bucketToInitialSplitMap.keySet()) { + HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator(); + InputSplit[] inputSplitArray = + (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); + Multimap groupedSplit = + hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + availableSlots, inputName, false); + bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); + } /* - * this is the small table side. In case of SMB join, we may need to send each split to the + * this is the small table side. In case of SMB join, we need to send each split to the * corresponding bucket-based task on the other side. In case a split needs to go to * multiple downstream tasks, we need to clone the event and send it to the right * destination. */ - processAllSideEvents(inputName, bucketToGroupedSplitMap); - } else { - processAllEvents(inputName, bucketToGroupedSplitMap); + LOG.info("This is the side work - multi-mr work."); + processAllSideEventsSetParallelism(inputName, bucketToGroupedSplitMap); } } catch (Exception e) { throw new RuntimeException(e); } } - private void processAllSideEvents(String inputName, + private void processAllSideEventsSetParallelism(String inputName, Multimap bucketToGroupedSplitMap) throws IOException { // the bucket to task map should have been setup by the big table. + LOG.info("Processing events for input " + inputName); if (bucketToTaskMap.isEmpty()) { + LOG.info("We don't have a routing table yet. Will need to wait for the main input" + + " initialization"); inputToGroupedSplitMap.put(inputName, bucketToGroupedSplitMap); return; } + processAllSideEvents(inputName, bucketToGroupedSplitMap); + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); + } + + private void processAllSideEvents(String inputName, + Multimap bucketToGroupedSplitMap) throws IOException { List taskEvents = new ArrayList(); + LOG.info("We have a routing table and we are going to set the destination tasks for the" + + " multi mr inputs. " + bucketToTaskMap); + + Integer[] numSplitsForTask = new Integer[taskCount]; for (Entry> entry : bucketToGroupedSplitMap.asMap().entrySet()) { Collection destTasks = bucketToTaskMap.get(entry.getKey()); for (Integer task : destTasks) { + int count = 0; for (InputSplit split : entry.getValue()) { + count++; MRSplitProto serializedSplit = MRInputHelpers.createSplitProto(split); InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload(task, serializedSplit @@ -268,26 +354,45 @@ private void processAllSideEvents(String inputName, diEvent.setTargetIndex(task); taskEvents.add(diEvent); } + numSplitsForTask[task] = count; } } + inputNameInputSpecMap.put(inputName, + InputSpecUpdate.createPerTaskInputSpecUpdate(Arrays.asList(numSplitsForTask))); + + LOG.info("For input name: " + inputName + " task events size is " + taskEvents.size()); + context.addRootInputEvents(inputName, taskEvents); } private void processAllEvents(String inputName, - Multimap bucketToGroupedSplitMap) throws IOException { + Multimap bucketToGroupedSplitMap, boolean secondLevelGroupingDone) + throws IOException { - List finalSplits = Lists.newLinkedList(); + int totalInputsCount = 0; + List numSplitsForTask = new ArrayList(); for (Entry> entry : bucketToGroupedSplitMap.asMap().entrySet()) { int bucketNum = entry.getKey(); Collection initialSplits = entry.getValue(); finalSplits.addAll(initialSplits); - for (int i = 0; i < initialSplits.size(); i++) { + for (InputSplit inputSplit : initialSplits) { bucketToTaskMap.put(bucketNum, taskCount); + if (secondLevelGroupingDone) { + TezGroupedSplit groupedSplit = (TezGroupedSplit) inputSplit; + numSplitsForTask.add(groupedSplit.getGroupedSplits().size()); + totalInputsCount += groupedSplit.getGroupedSplits().size(); + } else { + numSplitsForTask.add(1); + totalInputsCount += 1; + } taskCount++; } } + inputNameInputSpecMap.put(inputName, + InputSpecUpdate.createPerTaskInputSpecUpdate(numSplitsForTask)); + // Construct the EdgeManager descriptor to be used by all edges which need // the routing table. EdgeManagerPluginDescriptor hiveEdgeManagerDesc = null; @@ -297,7 +402,6 @@ private void processAllEvents(String inputName, UserPayload payload = getBytePayload(bucketToTaskMap); hiveEdgeManagerDesc.setUserPayload(payload); } - Map emMap = Maps.newHashMap(); // Replace the edge manager for all vertices which have routing type custom. for (Entry edgeEntry : context.getInputVertexEdgeProperties().entrySet()) { @@ -308,42 +412,66 @@ private void processAllEvents(String inputName, } } - LOG.info("Task count is " + taskCount); + LOG.info("Task count is " + taskCount + " for input name: " + inputName); - List taskEvents = - Lists.newArrayListWithCapacity(finalSplits.size()); + List taskEvents = Lists.newArrayListWithCapacity(totalInputsCount); // Re-serialize the splits after grouping. int count = 0; for (InputSplit inputSplit : finalSplits) { - MRSplitProto serializedSplit = MRInputHelpers.createSplitProto(inputSplit); - InputDataInformationEvent diEvent = InputDataInformationEvent.createWithSerializedPayload( - count, serializedSplit.toByteString().asReadOnlyByteBuffer()); - diEvent.setTargetIndex(count); + if (secondLevelGroupingDone) { + TezGroupedSplit tezGroupedSplit = (TezGroupedSplit)inputSplit; + for (InputSplit subSplit : tezGroupedSplit.getGroupedSplits()) { + if ((subSplit instanceof TezGroupedSplit) == false) { + throw new IOException("Unexpected split type found: " + + subSplit.getClass().getCanonicalName()); + } + MRSplitProto serializedSplit = MRInputHelpers.createSplitProto(subSplit); + InputDataInformationEvent diEvent = + InputDataInformationEvent.createWithSerializedPayload(count, serializedSplit + .toByteString().asReadOnlyByteBuffer()); + diEvent.setTargetIndex(count); + taskEvents.add(diEvent); + } + } else { + MRSplitProto serializedSplit = MRInputHelpers.createSplitProto(inputSplit); + InputDataInformationEvent diEvent = + InputDataInformationEvent.createWithSerializedPayload(count, serializedSplit + .toByteString().asReadOnlyByteBuffer()); + diEvent.setTargetIndex(count); + taskEvents.add(diEvent); + } count++; - taskEvents.add(diEvent); - } - - // Replace the Edge Managers - Map rootInputSpecUpdate = - new HashMap(); - rootInputSpecUpdate.put( - inputName, - InputSpecUpdate.getDefaultSinglePhysicalInputSpecUpdate()); - if ((mainWorkName.compareTo(inputName) == 0) || (mainWorkName.isEmpty())) { - context.setVertexParallelism( - taskCount, - VertexLocationHint.create(grouper.createTaskLocationHints(finalSplits - .toArray(new InputSplit[finalSplits.size()]))), emMap, rootInputSpecUpdate); } // Set the actual events for the tasks. + LOG.info("For input name: " + inputName + " task events size is " + taskEvents.size()); context.addRootInputEvents(inputName, taskEvents); if (inputToGroupedSplitMap.isEmpty() == false) { for (Entry> entry : inputToGroupedSplitMap.entrySet()) { processAllSideEvents(entry.getKey(), entry.getValue()); } + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); inputToGroupedSplitMap.clear(); } + + // Only done when it is a bucket map join only no SMB. + if (numInputsAffectingRootInputSpecUpdate == 1) { + setVertexParallelismAndRootInputSpec(inputNameInputSpecMap); + } + } + + private void + setVertexParallelismAndRootInputSpec(Map rootInputSpecUpdate) + throws IOException { + if (numInputsAffectingRootInputSpecUpdate != numInputsSeenSoFar) { + return; + } + + LOG.info("Setting vertex parallelism since we have seen all inputs."); + + context.setVertexParallelism(taskCount, VertexLocationHint.create(grouper + .createTaskLocationHints(finalSplits.toArray(new InputSplit[finalSplits.size()]))), emMap, + rootInputSpecUpdate); } UserPayload getBytePayload(Multimap routingTable) throws IOException { @@ -377,14 +505,14 @@ private FileSplit getFileSplitFromEvent(InputDataInformationEvent event) throws * This method generates the map of bucket to file splits. */ private Multimap getBucketSplitMapForPath( - Map> pathFileSplitsMap) { + Map> pathFileSplitsMap) { int bucketNum = 0; Multimap bucketToInitialSplitMap = ArrayListMultimap. create(); - for (Map.Entry> entry : pathFileSplitsMap.entrySet()) { + for (Map.Entry> entry : pathFileSplitsMap.entrySet()) { int bucketId = bucketNum % numBuckets; for (FileSplit fsplit : entry.getValue()) { bucketToInitialSplitMap.put(bucketId, fsplit); @@ -392,6 +520,11 @@ private FileSplit getFileSplitFromEvent(InputDataInformationEvent event) throws bucketNum++; } + // this is just for SMB join use-case. The numBuckets would be equal to that of the big table + // and the small table could have lesser number of buckets. In this case, we want to send the + // data from the right buckets to the big table side. For e.g. Big table has 8 buckets and small + // table has 4 buckets, bucket 0 of small table needs to be sent to bucket 4 of the big table as + // well. if (bucketNum < numBuckets) { int loopedBucketId = 0; for (; bucketNum < numBuckets; bucketNum++) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java index 4829f92..5dd7bf3 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomVertexConfiguration.java @@ -29,20 +29,31 @@ * This class is the payload for custom vertex. It serializes and de-serializes * @numBuckets: the number of buckets of the "big table" * @vertexType: this is the type of vertex and differentiates between bucket map join and SMB joins - * @inputName: This is the name of the input. Used in case of SMB joins + * @numInputs: The number of inputs that are directly connected to the vertex (MRInput/MultiMRInput). + * In case of bucket map join, it is always 1. + * @inputName: This is the name of the input. Used in case of SMB joins. Empty in case of BucketMapJoin */ public class CustomVertexConfiguration implements Writable { private int numBuckets; private VertexType vertexType = VertexType.AUTO_INITIALIZED_EDGES; + private int numInputs; private String inputName; public CustomVertexConfiguration() { } - public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String inputName) { + // this is the constructor to use for the Bucket map join case. + public CustomVertexConfiguration(int numBuckets, VertexType vertexType) { + this(numBuckets, vertexType, "", 1); + } + + // this is the constructor to use for SMB. + public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String inputName, + int numInputs) { this.numBuckets = numBuckets; this.vertexType = vertexType; + this.numInputs = numInputs; this.inputName = inputName; } @@ -50,6 +61,7 @@ public CustomVertexConfiguration(int numBuckets, VertexType vertexType, String i public void write(DataOutput out) throws IOException { out.writeInt(this.vertexType.ordinal()); out.writeInt(this.numBuckets); + out.writeInt(numInputs); out.writeUTF(inputName); } @@ -57,6 +69,7 @@ public void write(DataOutput out) throws IOException { public void readFields(DataInput in) throws IOException { this.vertexType = VertexType.values()[in.readInt()]; this.numBuckets = in.readInt(); + this.numInputs = in.readInt(); this.inputName = in.readUTF(); } @@ -71,4 +84,8 @@ public VertexType getVertexType() { public String getInputName() { return inputName; } + + public int getNumInputs() { + return numInputs; + } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java index acf2d31..0d3c29d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DagUtils.java @@ -20,8 +20,6 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; -import com.google.protobuf.ByteString; - import javax.security.auth.login.LoginException; import java.io.FileNotFoundException; import java.io.IOException; @@ -49,7 +47,6 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.ql.Context; import org.apache.hadoop.hive.ql.ErrorMsg; -import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator; import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.exec.mr.ExecMapper; @@ -111,16 +108,13 @@ import org.apache.tez.dag.api.VertexGroup; import org.apache.tez.dag.api.VertexManagerPluginDescriptor; import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager; -import org.apache.tez.mapreduce.common.MRInputAMSplitGenerator; import org.apache.tez.mapreduce.hadoop.MRHelpers; import org.apache.tez.mapreduce.hadoop.MRInputHelpers; import org.apache.tez.mapreduce.hadoop.MRJobConfig; -import org.apache.tez.mapreduce.input.MRInput; import org.apache.tez.mapreduce.input.MRInputLegacy; import org.apache.tez.mapreduce.input.MultiMRInput; import org.apache.tez.mapreduce.output.MROutput; import org.apache.tez.mapreduce.partition.MRPartitioner; -import org.apache.tez.mapreduce.protos.MRRuntimeProtos; import org.apache.tez.runtime.library.api.TezRuntimeConfiguration; import org.apache.tez.runtime.library.common.comparator.TezBytesComparator; import org.apache.tez.runtime.library.common.serializer.TezBytesWritableSerialization; @@ -271,8 +265,7 @@ public GroupInputEdge createEdge(VertexGroup group, JobConf vConf, Vertex w, case CUSTOM_EDGE: { mergeInputClass = ConcatenatedMergedKeyValueInput.class; int numBuckets = edgeProp.getNumBuckets(); - CustomVertexConfiguration vertexConf = - new CustomVertexConfiguration(numBuckets, vertexType, ""); + CustomVertexConfiguration vertexConf = new CustomVertexConfiguration(numBuckets, vertexType); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); VertexManagerPluginDescriptor desc = @@ -317,8 +310,7 @@ public Edge createEdge(JobConf vConf, Vertex v, Vertex w, TezEdgeProperty edgePr switch(edgeProp.getEdgeType()) { case CUSTOM_EDGE: { int numBuckets = edgeProp.getNumBuckets(); - CustomVertexConfiguration vertexConf = - new CustomVertexConfiguration(numBuckets, vertexType, ""); + CustomVertexConfiguration vertexConf = new CustomVertexConfiguration(numBuckets, vertexType); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); VertexManagerPluginDescriptor desc = VertexManagerPluginDescriptor.create( @@ -343,7 +335,6 @@ public Edge createEdge(JobConf vConf, Vertex v, Vertex w, TezEdgeProperty edgePr /* * Helper function to create an edge property from an edge type. */ - @SuppressWarnings("rawtypes") private EdgeProperty createEdgeProperty(TezEdgeProperty edgeProp, Configuration conf) throws IOException { MRHelpers.translateMRConfToTez(conf); @@ -435,7 +426,7 @@ public static Resource getContainerResource(Configuration conf) { HiveConf.getIntVar(conf, HiveConf.ConfVars.HIVETEZCONTAINERSIZE) : conf.getInt(MRJobConfig.MAP_MEMORY_MB, MRJobConfig.DEFAULT_MAP_MEMORY_MB); int cpus = HiveConf.getIntVar(conf, HiveConf.ConfVars.HIVETEZCPUVCORES) > 0 ? - HiveConf.getIntVar(conf, HiveConf.ConfVars.HIVETEZCPUVCORES) : + HiveConf.getIntVar(conf, HiveConf.ConfVars.HIVETEZCPUVCORES) : conf.getInt(MRJobConfig.MAP_CPU_VCORES, MRJobConfig.DEFAULT_MAP_CPU_VCORES); return Resource.newInstance(memory, cpus); } @@ -489,13 +480,9 @@ private Vertex createVertex(JobConf conf, MergeJoinWork mergeJoinWork, LocalReso if (mergeJoinWork.getMainWork() instanceof MapWork) { List mapWorkList = mergeJoinWork.getBaseWorkList(); MapWork mapWork = (MapWork) (mergeJoinWork.getMainWork()); - CommonMergeJoinOperator mergeJoinOp = mergeJoinWork.getMergeJoinOperator(); Vertex mergeVx = createVertex(conf, mapWork, appJarLr, additionalLr, fs, mrScratchDir, ctx, vertexType); - // grouping happens in execution phase. Setting the class to TezGroupedSplitsInputFormat - // here would cause pre-mature grouping which would be incorrect. - Class inputFormatClass = HiveInputFormat.class; conf.setClass("mapred.input.format.class", HiveInputFormat.class, InputFormat.class); // mapreduce.tez.input.initializer.serialize.event.payload should be set // to false when using this plug-in to avoid getting a serialized event at run-time. @@ -512,9 +499,11 @@ private Vertex createVertex(JobConf conf, MergeJoinWork mergeJoinWork, LocalReso VertexManagerPluginDescriptor desc = VertexManagerPluginDescriptor.create(CustomPartitionVertex.class.getName()); + // the +1 to the size is because of the main work. CustomVertexConfiguration vertexConf = new CustomVertexConfiguration(mergeJoinWork.getMergeJoinOperator().getConf() - .getNumBuckets(), vertexType, mergeJoinWork.getBigTableAlias()); + .getNumBuckets(), vertexType, mergeJoinWork.getBigTableAlias(), + mapWorkList.size() + 1); DataOutputBuffer dob = new DataOutputBuffer(); vertexConf.write(dob); byte[] userPayload = dob.getData(); @@ -554,6 +543,7 @@ private Vertex createVertex(JobConf conf, MapWork mapWork, DataSourceDescriptor dataSource; int numTasks = -1; + @SuppressWarnings("rawtypes") Class inputFormatClass = conf.getClass("mapred.input.format.class", InputFormat.class); @@ -611,7 +601,13 @@ private Vertex createVertex(JobConf conf, MapWork mapWork, .setCustomInitializerDescriptor(descriptor).build(); } else { // Not HiveInputFormat, or a custom VertexManager will take care of grouping splits - dataSource = MRInputLegacy.createConfigBuilder(conf, inputFormatClass).groupSplits(false).build(); + if (vertexHasCustomInput) { + dataSource = + MultiMRInput.createConfigBuilder(conf, inputFormatClass).groupSplits(false).build(); + } else { + dataSource = + MRInputLegacy.createConfigBuilder(conf, inputFormatClass).groupSplits(false).build(); + } } } else { // Setup client side split generation. @@ -763,6 +759,7 @@ public PreWarmVertex createPreWarmVertex(TezConfiguration conf, * @throws LoginException if we are unable to figure user information * @throws IOException when any dfs operation fails. */ + @SuppressWarnings("deprecation") public Path getDefaultDestDir(Configuration conf) throws LoginException, IOException { UserGroupInformation ugi = ShimLoader.getHadoopShims().getUGIForConf(conf); String userName = ShimLoader.getHadoopShims().getShortUserName(ugi); @@ -875,6 +872,7 @@ public FileStatus getHiveJarDirectory(Configuration conf) throws IOException, Lo return fstatus; } + @SuppressWarnings("deprecation") public static FileStatus validateTargetDir(Path path, Configuration conf) throws IOException { FileSystem fs = path.getFileSystem(conf); FileStatus fstatus = null; @@ -1051,6 +1049,7 @@ private JobConf initializeVertexConf(JobConf conf, Context context, MergeJoinWor * @param ctx This query's context * @return Vertex */ + @SuppressWarnings("deprecation") public Vertex createVertex(JobConf conf, BaseWork work, Path scratchDir, LocalResource appJarLr, List additionalLr, FileSystem fileSystem, Context ctx, boolean hasChildren, diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java index c45479f..afe83d9 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java @@ -18,6 +18,8 @@ package org.apache.hadoop.hive.ql.exec.tez; +import java.io.IOException; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -64,7 +66,6 @@ * making sure that splits from different partitions are only grouped if they * are of the same schema, format and serde */ -@SuppressWarnings("deprecation") public class HiveSplitGenerator extends InputInitializer { private static final Log LOG = LogFactory.getLog(HiveSplitGenerator.class); @@ -72,11 +73,17 @@ private static final SplitGrouper grouper = new SplitGrouper(); private final DynamicPartitionPruner pruner = new DynamicPartitionPruner(); private InputInitializerContext context; + private static Map, Map> cache = + new HashMap, Map>(); public HiveSplitGenerator(InputInitializerContext initializerContext) { super(initializerContext); } + public HiveSplitGenerator() { + this(null); + } + @Override public List initialize() throws Exception { InputInitializerContext rootInputContext = getContext(); @@ -150,58 +157,28 @@ public HiveSplitGenerator(InputInitializerContext initializerContext) { } - public static Multimap generateGroupedSplits(JobConf jobConf, + public Multimap generateGroupedSplits(JobConf jobConf, Configuration conf, InputSplit[] splits, float waves, int availableSlots) throws Exception { - return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null); + return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null, true); } - public static Multimap generateGroupedSplits(JobConf jobConf, - Configuration conf, InputSplit[] splits, float waves, int availableSlots, - String inputName) throws Exception { - - MapWork work = null; - if (inputName != null) { - work = (MapWork) Utilities.getMergeWork(jobConf, inputName); - // work can still be null if there is no merge work for this input - } - if (work == null) { - work = Utilities.getMapWork(jobConf); - } + public Multimap generateGroupedSplits(JobConf jobConf, + Configuration conf, InputSplit[] splits, float waves, int availableSlots, String inputName, + boolean groupAcrossFiles) throws Exception { + MapWork work = populateMapWork(jobConf, inputName); Multimap bucketSplitMultiMap = ArrayListMultimap. create(); - Class previousInputFormatClass = null; - String previousDeserializerClass = null; - Map, Map> cache = - new HashMap, Map>(); - int i = 0; - + InputSplit prevSplit = null; for (InputSplit s : splits) { // this is the bit where we make sure we don't group across partition // schema boundaries - - Path path = ((FileSplit) s).getPath(); - - PartitionDesc pd = - HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), - path, cache); - - String currentDeserializerClass = pd.getDeserializerClassName(); - Class currentInputFormatClass = pd.getInputFileFormatClass(); - - if ((currentInputFormatClass != previousInputFormatClass) - || (!currentDeserializerClass.equals(previousDeserializerClass))) { + if (schemaEvolved(s, prevSplit, groupAcrossFiles, work)) { ++i; - } - - previousInputFormatClass = currentInputFormatClass; - previousDeserializerClass = currentDeserializerClass; - - if (LOG.isDebugEnabled()) { - LOG.debug("Adding split " + path + " to src group " + i); + prevSplit = s; } bucketSplitMultiMap.put(i, s); } @@ -214,6 +191,54 @@ public HiveSplitGenerator(InputInitializerContext initializerContext) { return groupedSplits; } + private MapWork populateMapWork(JobConf jobConf, String inputName) { + MapWork work = null; + if (inputName != null) { + work = (MapWork) Utilities.getMergeWork(jobConf, inputName); + // work can still be null if there is no merge work for this input + } + if (work == null) { + work = Utilities.getMapWork(jobConf); + } + + return work; + } + + public boolean schemaEvolved(InputSplit s, InputSplit prevSplit, boolean groupAcrossFiles, + MapWork work) throws IOException { + boolean retval = false; + Path path = ((FileSplit) s).getPath(); + PartitionDesc pd = + HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), + path, cache); + String currentDeserializerClass = pd.getDeserializerClassName(); + Class currentInputFormatClass = pd.getInputFileFormatClass(); + + Class previousInputFormatClass = null; + String previousDeserializerClass = null; + if (prevSplit != null) { + Path prevPath = ((FileSplit) prevSplit).getPath(); + if (!groupAcrossFiles) { + return !path.equals(prevPath); + } + PartitionDesc prevPD = + HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), + prevPath, cache); + previousDeserializerClass = prevPD.getDeserializerClassName(); + previousInputFormatClass = prevPD.getInputFileFormatClass(); + } + + if ((currentInputFormatClass != previousInputFormatClass) + || (!currentDeserializerClass.equals(previousDeserializerClass))) { + retval = true; + } + + if (LOG.isDebugEnabled()) { + LOG.debug("Adding split " + path + " to src new group? " + retval); + } + return retval; + } + private List createEventList(boolean sendSerializedEvents, InputSplitInfoMem inputSplitInfo) { List events = Lists.newArrayListWithCapacity(inputSplitInfo.getNumTasks() + 1); diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java index 579d65d..bc7603e 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordProcessor.java @@ -17,10 +17,10 @@ */ package org.apache.hadoop.hive.ql.exec.tez; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,15 +44,15 @@ import org.apache.hadoop.hive.ql.exec.tez.TezProcessor.TezKVOutputCollector; import org.apache.hadoop.hive.ql.exec.tez.tools.KeyValueInputMerger; import org.apache.hadoop.hive.ql.exec.vector.VectorMapOperator; -import org.apache.hadoop.hive.ql.io.IOContext; import org.apache.hadoop.hive.ql.log.PerfLogger; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.OperatorDesc; +import org.apache.hadoop.hive.serde2.Deserializer; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.mapred.JobConf; import org.apache.tez.mapreduce.input.MRInputLegacy; import org.apache.tez.mapreduce.input.MultiMRInput; import org.apache.tez.mapreduce.processor.MRTaskReporter; -import org.apache.tez.runtime.api.Input; import org.apache.tez.runtime.api.LogicalInput; import org.apache.tez.runtime.api.LogicalOutput; import org.apache.tez.runtime.api.ProcessorContext; @@ -73,6 +73,7 @@ private int position = 0; private boolean foundCachedMergeWork = false; MRInputLegacy legacyMRInput = null; + MultiMRInput mainWorkMultiMRInput = null; private ExecMapperContext execContext = null; private boolean abort = false; protected static final String MAP_PLAN_KEY = "__MAP_PLAN__"; @@ -129,12 +130,14 @@ void init(JobConf jconf, ProcessorContext processorContext, MRTaskReporter mrRep perfLogger.PerfLogBegin(CLASS_NAME, PerfLogger.TEZ_INIT_OPERATORS); super.init(jconf, processorContext, mrReporter, inputs, outputs); - //Update JobConf using MRInput, info like filename comes via this + // Update JobConf using MRInput, info like filename comes via this legacyMRInput = getMRInput(inputs); - Configuration updatedConf = legacyMRInput.getConfigUpdates(); - if (updatedConf != null) { - for (Entry entry : updatedConf) { - jconf.set(entry.getKey(), entry.getValue()); + if (legacyMRInput != null) { + Configuration updatedConf = legacyMRInput.getConfigUpdates(); + if (updatedConf != null) { + for (Entry entry : updatedConf) { + jconf.set(entry.getKey(), entry.getValue()); + } } } @@ -158,8 +161,6 @@ void init(JobConf jconf, ProcessorContext processorContext, MRTaskReporter mrRep if (mergeWorkList != null) { MapOperator mergeMapOp = null; for (MapWork mergeMapWork : mergeWorkList) { - processorContext.waitForAnyInputReady(Collections.singletonList((Input) (inputs - .get(mergeMapWork.getName())))); if (mergeMapWork.getVectorMode()) { mergeMapOp = new VectorMapOperator(); } else { @@ -235,11 +236,17 @@ void init(JobConf jconf, ProcessorContext processorContext, MRTaskReporter mrRep } private void initializeMapRecordSources() throws Exception { + int size = mergeMapOpList.size() + 1; // the +1 is for the main map operator itself sources = new MapRecordSource[size]; - KeyValueReader reader = legacyMRInput.getReader(); position = mapOp.getConf().getTag(); sources[position] = new MapRecordSource(); + KeyValueReader reader = null; + if (mainWorkMultiMRInput != null) { + reader = getKeyValueReader(mainWorkMultiMRInput.getKeyValueReaders(), mapOp); + } else { + reader = legacyMRInput.getReader(); + } sources[position].init(jconf, mapOp, reader); for (MapOperator mapOp : mergeMapOpList) { int tag = mapOp.getConf().getTag(); @@ -248,13 +255,28 @@ private void initializeMapRecordSources() throws Exception { MultiMRInput multiMRInput = multiMRInputMap.get(inputName); Collection kvReaders = multiMRInput.getKeyValueReaders(); l4j.debug("There are " + kvReaders.size() + " key-value readers for input " + inputName); - List kvReaderList = new ArrayList(kvReaders); - reader = new KeyValueInputMerger(kvReaderList); + reader = getKeyValueReader(kvReaders, mapOp); sources[tag].init(jconf, mapOp, reader); } ((TezContext) MapredContext.get()).setRecordSources(sources); } + @SuppressWarnings("deprecation") + private KeyValueReader getKeyValueReader(Collection keyValueReaders, + MapOperator mapOp) + throws Exception { + List kvReaderList = new ArrayList(keyValueReaders); + // this sets up the map operator contexts correctly + mapOp.initializeContexts(); + Deserializer deserializer = mapOp.getCurrentDeserializer(); + KeyValueReader reader = + new KeyValueInputMerger(kvReaderList, deserializer, + new ObjectInspector[] { deserializer.getObjectInspector() }, mapOp + .getConf() + .getSortCols()); + return reader; + } + private DummyStoreOperator getJoinParentOp(Operator mergeMapOp) { for (Operator childOp : mergeMapOp.getChildOperators()) { if ((childOp.getChildOperators() == null) || (childOp.getChildOperators().isEmpty())) { @@ -335,7 +357,17 @@ private MRInputLegacy getMRInput(Map inputs) throws Except multiMRInputMap.put(inp.getKey(), (MultiMRInput) inp.getValue()); } } - theMRInput.init(); + if (theMRInput != null) { + theMRInput.init(); + } else { + String alias = mapWork.getAliasToWork().keySet().iterator().next(); + if (inputs.get(alias) instanceof MultiMRInput) { + mainWorkMultiMRInput = (MultiMRInput) inputs.get(alias); + } else { + throw new IOException("Unexpected input type found: " + + inputs.get(alias).getClass().getCanonicalName()); + } + } return theMRInput; } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java index 6d075e8..f7d2661 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MapRecordSource.java @@ -19,7 +19,6 @@ package org.apache.hadoop.hive.ql.exec.tez; import java.io.IOException; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.MapOperator; @@ -28,7 +27,6 @@ import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.util.StringUtils; -import org.apache.tez.mapreduce.input.MRInput; import org.apache.tez.runtime.library.api.KeyValueReader; /** diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MergeFileRecordProcessor.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MergeFileRecordProcessor.java index 2b38d79..2998ae7 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MergeFileRecordProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/MergeFileRecordProcessor.java @@ -204,6 +204,7 @@ private boolean processRow(Object key, Object value) { private MRInputLegacy getMRInput(Map inputs) throws Exception { // there should be only one MRInput MRInputLegacy theMRInput = null; + LOG.info("VDK: the inputs are: " + inputs); for (Entry inp : inputs.entrySet()) { if (inp.getValue() instanceof MRInputLegacy) { if (theMRInput != null) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java index 38d74d5..f9c80c2 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java @@ -102,9 +102,19 @@ // compute the total size per bucket long totalSize = 0; + boolean earlyExit = false; for (int bucketId : bucketSplitMap.keySet()) { long size = 0; for (InputSplit s : bucketSplitMap.get(bucketId)) { + // the incoming split may not be a file split when we are re-grouping TezGroupedSplits in + // the case of SMB join. So in this case, we can do an early exit by not doing the + // calculation for bucketSizeMap. Each bucket will assume it can fill availableSlots * waves + // (preset to 0.5) for SMB join. + if (!(s instanceof FileSplit)) { + bucketTaskMap.put(bucketId, (int) (availableSlots * waves)); + earlyExit = true; + continue; + } FileSplit fsplit = (FileSplit) s; size += fsplit.getLength(); totalSize += fsplit.getLength(); @@ -112,6 +122,10 @@ bucketSizeMap.put(bucketId, size); } + if (earlyExit) { + return bucketTaskMap; + } + // compute the number of tasks for (int bucketId : bucketSizeMap.keySet()) { int numEstimatedTasks = 0; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java index 42c7d37..1e528a9 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezProcessor.java @@ -19,12 +19,9 @@ import java.io.IOException; import java.text.NumberFormat; -import java.util.Arrays; -import java.util.HashMap; +import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Map.Entry; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -33,11 +30,10 @@ import org.apache.hadoop.mapred.OutputCollector; import org.apache.hadoop.util.StringUtils; import org.apache.tez.common.TezUtils; -import org.apache.tez.mapreduce.input.MRInputLegacy; -import org.apache.tez.mapreduce.input.MultiMRInput; import org.apache.tez.mapreduce.processor.MRTaskReporter; import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; import org.apache.tez.runtime.api.Event; +import org.apache.tez.runtime.api.Input; import org.apache.tez.runtime.api.LogicalInput; import org.apache.tez.runtime.api.LogicalOutput; import org.apache.tez.runtime.api.ProcessorContext; @@ -152,10 +148,13 @@ protected void initializeAndRunProcessor(Map inputs, // Start the actual Inputs. After MRInput initialization. for (Map.Entry inputEntry : inputs.entrySet()) { if (!cacheAccess.isInputCached(inputEntry.getKey())) { - LOG.info("Input: " + inputEntry.getKey() + " is not cached"); + LOG.info("Starting input " + inputEntry.getKey()); inputEntry.getValue().start(); + processorContext.waitForAnyInputReady(Collections.singletonList((Input) (inputEntry + .getValue()))); } else { - LOG.info("Input: " + inputEntry.getKey() + " is already cached. Skipping start"); + LOG.info("Input: " + inputEntry.getKey() + + " is already cached. Skipping start and wait for ready"); } } @@ -194,6 +193,7 @@ protected void initializeAndRunProcessor(Map inputs, * Must be initialized before it is used. * */ + @SuppressWarnings("rawtypes") static class TezKVOutputCollector implements OutputCollector { private KeyValueWriter writer; private final LogicalOutput output; diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java index 516722d..c8e9606 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/tools/KeyValueInputMerger.java @@ -18,13 +18,23 @@ package org.apache.hadoop.hive.ql.exec.tez.tools; import java.io.IOException; +import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.PriorityQueue; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.hive.serde2.Deserializer; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; +import org.apache.hadoop.io.Writable; import org.apache.tez.runtime.library.api.KeyValueReader; /** @@ -34,16 +44,36 @@ * Uses a priority queue to pick the KeyValuesReader of the input that is next in * sort order. */ +@SuppressWarnings("deprecation") public class KeyValueInputMerger extends KeyValueReader { public static final Log l4j = LogFactory.getLog(KeyValueInputMerger.class); private PriorityQueue pQueue = null; private KeyValueReader nextKVReader = null; + private ObjectInspector[] inputObjInspectors = null; + private Deserializer deserializer = null; + private List structFields = null; + private List fieldOIs = null; + private final Map> kvReaderStandardObjMap = + new HashMap>(); - public KeyValueInputMerger(List multiMRInputs) throws Exception { + public KeyValueInputMerger(List multiMRInputs, Deserializer deserializer, + ObjectInspector[] inputObjInspectors, List sortCols) throws Exception { //get KeyValuesReaders from the LogicalInput and add them to priority queue int initialCapacity = multiMRInputs.size(); pQueue = new PriorityQueue(initialCapacity, new KVReaderComparator()); + this.inputObjInspectors = inputObjInspectors; + this.deserializer = deserializer; + fieldOIs = new ArrayList(); + structFields = new ArrayList(); + StructObjectInspector structOI = (StructObjectInspector) inputObjInspectors[0]; + for (String field : sortCols) { + StructField sf = structOI.getStructFieldRef(field); + structFields.add(sf); + ObjectInspector stdOI = + ObjectInspectorUtils.getStandardObjectInspector(sf.getFieldObjectInspector()); + fieldOIs.add(stdOI); + } l4j.info("Initialized the priority queue with multi mr inputs: " + multiMRInputs.size()); for (KeyValueReader input : multiMRInputs) { addToQueue(input); @@ -58,6 +88,7 @@ public KeyValueInputMerger(List multiMRInputs) throws Exception */ private void addToQueue(KeyValueReader kvReader) throws IOException { if (kvReader.next()) { + kvReaderStandardObjMap.remove(kvReader); pQueue.add(kvReader); } } @@ -93,12 +124,53 @@ public Object getCurrentValue() throws IOException { */ class KVReaderComparator implements Comparator { + @SuppressWarnings({ "unchecked" }) @Override public int compare(KeyValueReader kvReadr1, KeyValueReader kvReadr2) { try { - BinaryComparable key1 = (BinaryComparable) kvReadr1.getCurrentValue(); - BinaryComparable key2 = (BinaryComparable) kvReadr2.getCurrentValue(); - return key1.compareTo(key2); + ObjectInspector oi = inputObjInspectors[0]; + List row1, row2; + try { + if (kvReaderStandardObjMap.containsKey(kvReadr1)) { + row1 = kvReaderStandardObjMap.get(kvReadr1); + } else { + // we need to copy to standard object otherwise deserializer overwrites the values + row1 = + (List) ObjectInspectorUtils.copyToStandardObject( + deserializer.deserialize((Writable) kvReadr1.getCurrentValue()), oi, + ObjectInspectorCopyOption.WRITABLE); + kvReaderStandardObjMap.put(kvReadr1, row1); + } + + if (kvReaderStandardObjMap.containsKey(kvReadr2)) { + row2 = kvReaderStandardObjMap.get(kvReadr2); + } else { + row2 = + (List) ObjectInspectorUtils.copyToStandardObject( + deserializer.deserialize((Writable) kvReadr2.getCurrentValue()), oi, + ObjectInspectorCopyOption.WRITABLE); + kvReaderStandardObjMap.put(kvReadr2, row2); + } + } catch (SerDeException e) { + throw new IOException(e); + } + + StructObjectInspector structOI = (StructObjectInspector) oi; + int compare = 0; + int index = 0; + for (StructField sf : structFields) { + int pos = structOI.getAllStructFieldRefs().indexOf(sf); + Object key1 = row1.get(pos); + Object key2 = row2.get(pos); + ObjectInspector stdOI = fieldOIs.get(index); + compare = ObjectInspectorUtils.compare(key1, stdOI, key2, stdOI); + index++; + if (compare != 0) { + return compare; + } + } + + return compare; } catch (IOException e) { l4j.error("Caught exception while reading shuffle input", e); //die! diff --git ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java index cad567a..d42f568 100644 --- ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java +++ ql/src/java/org/apache/hadoop/hive/ql/io/IOContext.java @@ -21,11 +21,7 @@ import java.util.HashMap; import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.ql.optimizer.ConvertJoinMapJoin; -import org.apache.hadoop.hive.ql.session.SessionState; /** @@ -44,8 +40,6 @@ }; private static Map inputNameIOContextMap = new HashMap(); - private static IOContext ioContext = new IOContext(); - public static Map getMap() { return inputNameIOContextMap; } @@ -61,7 +55,7 @@ public static IOContext get(String inputName) { public static void clear() { IOContext.threadLocal.remove(); - ioContext = new IOContext(); + inputNameIOContextMap.clear(); } long currentBlockStart; diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java index 7a3280c..41a8974 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/ConvertJoinMapJoin.java @@ -85,13 +85,18 @@ JoinOperator joinOp = (JoinOperator) nd; - if (!context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN) - && !(context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN))) { + TezBucketJoinProcCtx tezBucketJoinProcCtx = new TezBucketJoinProcCtx(context.conf); + if (!context.conf.getBoolVar(HiveConf.ConfVars.HIVECONVERTJOIN)) { // we are just converting to a common merge join operator. The shuffle // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); - return null; + Object retval = checkAndConvertSMBJoin(context, joinOp, tezBucketJoinProcCtx); + if (retval == null) { + return retval; + } else { + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + } } // if we have traits, and table info is present in the traits, we know the @@ -99,7 +104,6 @@ // reducers from the parent operators. int numBuckets = -1; int estimatedBuckets = -1; - TezBucketJoinProcCtx tezBucketJoinProcCtx = new TezBucketJoinProcCtx(context.conf); if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_CONVERT_JOIN_BUCKET_MAPJOIN_TEZ)) { for (OperatorparentOp : joinOp.getParentOperators()) { if (parentOp.getOpTraits().getNumBuckets() > 0) { @@ -126,53 +130,15 @@ LOG.info("Estimated number of buckets " + numBuckets); int mapJoinConversionPos = getMapJoinConversionPos(joinOp, context, numBuckets); if (mapJoinConversionPos < 0) { - // we cannot convert to bucket map join, we cannot convert to - // map join either based on the size. Check if we can convert to SMB join. - if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN) == false) { - convertJoinSMBJoin(joinOp, context, 0, 0, false, false); - return null; - } - Class bigTableMatcherClass = null; - try { - bigTableMatcherClass = - (Class) (Class.forName(HiveConf.getVar( - context.parseContext.getConf(), - HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN_BIGTABLE_SELECTOR))); - } catch (ClassNotFoundException e) { - throw new SemanticException(e.getMessage()); - } - - BigTableSelectorForAutoSMJ bigTableMatcher = - ReflectionUtils.newInstance(bigTableMatcherClass, null); - JoinDesc joinDesc = joinOp.getConf(); - JoinCondDesc[] joinCondns = joinDesc.getConds(); - Set joinCandidates = MapJoinProcessor.getBigTableCandidates(joinCondns); - if (joinCandidates.isEmpty()) { - // This is a full outer join. This can never be a map-join - // of any type. So return false. - return false; - } - mapJoinConversionPos = - bigTableMatcher.getBigTablePosition(context.parseContext, joinOp, joinCandidates); - if (mapJoinConversionPos < 0) { - // contains aliases from sub-query - // we are just converting to a common merge join operator. The shuffle - // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); - return null; - } - - if (checkConvertJoinSMBJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { - convertJoinSMBJoin(joinOp, context, mapJoinConversionPos, - tezBucketJoinProcCtx.getNumBuckets(), tezBucketJoinProcCtx.isSubQuery(), true); + Object retval = checkAndConvertSMBJoin(context, joinOp, tezBucketJoinProcCtx); + if (retval == null) { + return retval; } else { - // we are just converting to a common merge join operator. The shuffle - // join in map-reduce case. - int pos = 0; // it doesn't matter which position we use in this case. - convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + // only case is full outer join with SMB enabled which is not possible. Convert to regular + // join. + convertJoinSMBJoin(joinOp, context, 0, 0, false, false); + return null; } - return null; } if (numBuckets > 1) { @@ -206,6 +172,57 @@ return null; } + private Object checkAndConvertSMBJoin(OptimizeTezProcContext context, JoinOperator joinOp, + TezBucketJoinProcCtx tezBucketJoinProcCtx) throws SemanticException { + // we cannot convert to bucket map join, we cannot convert to + // map join either based on the size. Check if we can convert to SMB join. + if (context.conf.getBoolVar(HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN) == false) { + convertJoinSMBJoin(joinOp, context, 0, 0, false, false); + return null; + } + Class bigTableMatcherClass = null; + try { + bigTableMatcherClass = + (Class) (Class.forName(HiveConf.getVar( + context.parseContext.getConf(), + HiveConf.ConfVars.HIVE_AUTO_SORTMERGE_JOIN_BIGTABLE_SELECTOR))); + } catch (ClassNotFoundException e) { + throw new SemanticException(e.getMessage()); + } + + BigTableSelectorForAutoSMJ bigTableMatcher = + ReflectionUtils.newInstance(bigTableMatcherClass, null); + JoinDesc joinDesc = joinOp.getConf(); + JoinCondDesc[] joinCondns = joinDesc.getConds(); + Set joinCandidates = MapJoinProcessor.getBigTableCandidates(joinCondns); + if (joinCandidates.isEmpty()) { + // This is a full outer join. This can never be a map-join + // of any type. So return false. + return false; + } + int mapJoinConversionPos = + bigTableMatcher.getBigTablePosition(context.parseContext, joinOp, joinCandidates); + if (mapJoinConversionPos < 0) { + // contains aliases from sub-query + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + return null; + } + + if (checkConvertJoinSMBJoin(joinOp, context, mapJoinConversionPos, tezBucketJoinProcCtx)) { + convertJoinSMBJoin(joinOp, context, mapJoinConversionPos, + tezBucketJoinProcCtx.getNumBuckets(), tezBucketJoinProcCtx.isSubQuery(), true); + } else { + // we are just converting to a common merge join operator. The shuffle + // join in map-reduce case. + int pos = 0; // it doesn't matter which position we use in this case. + convertJoinSMBJoin(joinOp, context, pos, 0, false, false); + } + return null; +} + // replaces the join operator with a new CommonJoinOperator, removes the // parent reduce sinks private void convertJoinSMBJoin(JoinOperator joinOp, OptimizeTezProcContext context, @@ -630,7 +647,7 @@ private boolean hasDynamicPartitionBroadcast(Operator parent) { hasDynamicPartitionPruning = true; break; } - + if (op instanceof ReduceSinkOperator || op instanceof FileSinkOperator) { // crossing reduce sink or file sink means the pruning isn't for this parent. break; diff --git ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java index 8516643..2c83c4b 100644 --- ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java +++ ql/src/java/org/apache/hadoop/hive/ql/optimizer/MergeJoinProc.java @@ -1,7 +1,5 @@ package org.apache.hadoop.hive.ql.optimizer; -import java.util.HashMap; -import java.util.Map; import java.util.Stack; import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator; diff --git ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java index 516e576..d83518d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/parse/GenTezWork.java @@ -123,6 +123,11 @@ public Object process(Node nd, Stack stack, context.rootToWorkMap.put(root, work); } + // this is where we set the sort columns that we will be using for KeyValueInputMerge + if (operator instanceof DummyStoreOperator) { + work.addSortCols(root.getOpTraits().getSortCols().get(0)); + } + if (!context.childToWorkMap.containsKey(operator)) { List workItems = new LinkedList(); workItems.add(work); @@ -148,6 +153,7 @@ public Object process(Node nd, Stack stack, context.opMergeJoinWorkMap.put(operator, mergeJoinWork); } // connect the work correctly. + work.addSortCols(root.getOpTraits().getSortCols().get(0)); mergeJoinWork.addMergedWork(work, null); Operator parentOp = getParentFromStack(context.currentMergeJoinOperator, stack); diff --git ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java index 05be1f1..9d8d52d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java +++ ql/src/java/org/apache/hadoop/hive/ql/plan/BaseWork.java @@ -18,6 +18,7 @@ package org.apache.hadoop.hive.ql.plan; +import java.util.ArrayList; import java.util.LinkedList; import java.util.LinkedHashSet; import java.util.List; @@ -42,6 +43,7 @@ // schema info. List dummyOps; int tag; + private final List sortColNames = new ArrayList(); public BaseWork() {} @@ -148,4 +150,12 @@ public void setTag(int tag) { public int getTag() { return tag; } + + public void addSortCols(List sortCols) { + this.sortColNames.addAll(sortCols); + } + + public List getSortCols() { + return sortColNames; + } }