diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java index f7612d6..446916c 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/CustomPartitionVertex.java @@ -274,9 +274,8 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr for (Integer key : bucketToInitialSplitMap.keySet()) { InputSplit[] inputSplitArray = (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); - HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator(); Multimap groupedSplit = - hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + grouper.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, availableSlots, inputName, mainWorkName.isEmpty()); if (mainWorkName.isEmpty() == false) { Multimap singleBucketToGroupedSplit = @@ -295,11 +294,10 @@ public void onRootVertexInitialized(String inputName, InputDescriptor inputDescr // grouped split. This would affect SMB joins where we want to find the smallest key in // all the bucket files. for (Integer key : bucketToInitialSplitMap.keySet()) { - HiveSplitGenerator hiveSplitGenerator = new HiveSplitGenerator(); InputSplit[] inputSplitArray = (bucketToInitialSplitMap.get(key).toArray(new InputSplit[0])); Multimap groupedSplit = - hiveSplitGenerator.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, + grouper.generateGroupedSplits(jobConf, conf, inputSplitArray, waves, availableSlots, inputName, false); bucketToGroupedSplitMap.putAll(key, groupedSplit.values()); } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java index 696874e..7abd94d 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java @@ -31,12 +31,13 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; -import javolution.testing.AssertionException; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import org.apache.commons.lang3.mutable.MutableInt; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; @@ -74,30 +75,47 @@ private static final Log LOG = LogFactory.getLog(DynamicPartitionPruner.class); + private final InputInitializerContext context; + private final MapWork work; + private final JobConf jobConf; + + private final Map> sourceInfoMap = new HashMap>(); private final BytesWritable writable = new BytesWritable(); + /* Keeps track of all events that need to be processed - irrespective of the source */ private final BlockingQueue queue = new LinkedBlockingQueue(); + /* Keeps track of vertices from which events are expected */ private final Set sourcesWaitingForEvents = new HashSet(); + // Stores negative values to count columns. Eventually set to #tasks X #columns after the source vertex completes. + private final Map numExpectedEventsPerSource = new HashMap<>(); + private final Map numEventsSeenPerSource = new HashMap<>(); + private int sourceInfoCount = 0; private final Object endOfEvents = new Object(); private int totalEventCount = 0; - public DynamicPartitionPruner() { + public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf) throws + SerDeException { + this.context = context; + this.work = work; + this.jobConf = jobConf; + synchronized (this) { + initialize(); + } } - public void prune(MapWork work, JobConf jobConf, InputInitializerContext context) + public void prune() throws SerDeException, IOException, InterruptedException, HiveException { synchronized(sourcesWaitingForEvents) { - initialize(work, jobConf); if (sourcesWaitingForEvents.isEmpty()) { return; @@ -112,11 +130,11 @@ public void prune(MapWork work, JobConf jobConf, InputInitializerContext context } } - LOG.info("Waiting for events (" + sourceInfoCount + " items) ..."); + LOG.info("Waiting for events (" + sourceInfoCount + " sources) ..."); // synchronous event processing loop. Won't return until all events have // been processed. this.processEvents(); - this.prunePartitions(work, context); + this.prunePartitions(); LOG.info("Ok to proceed."); } @@ -129,25 +147,38 @@ private void clear() { sourceInfoCount = 0; } - public void initialize(MapWork work, JobConf jobConf) throws SerDeException { + private void initialize() throws SerDeException { this.clear(); Map columnMap = new HashMap(); + // sources represent vertex names Set sources = work.getEventSourceTableDescMap().keySet(); sourcesWaitingForEvents.addAll(sources); for (String s : sources) { + // Set to 0 to start with. This will be decremented for all columns for which events + // are generated by this source - which is eventually used to determine number of expected + // events for the source. #colums X #tasks + numExpectedEventsPerSource.put(s, new MutableInt(0)); + numEventsSeenPerSource.put(s, new MutableInt(0)); + // Virtual relation generated by the reduce sync List tables = work.getEventSourceTableDescMap().get(s); + // Real column name - on which the operation is being performed List columnNames = work.getEventSourceColumnNameMap().get(s); + // Expression for the operation. e.g. N^2 > 10 List partKeyExprs = work.getEventSourcePartKeyExprMap().get(s); + // eventSourceTableDesc, eventSourceColumnName, evenSourcePartKeyExpr move in lock-step. + // One entry is added to each at the same time Iterator cit = columnNames.iterator(); Iterator pit = partKeyExprs.iterator(); + // A single source can process multiple columns, and will send an event for each of them. for (TableDesc t : tables) { + numExpectedEventsPerSource.get(s).decrement(); ++sourceInfoCount; String columnName = cit.next(); ExprNodeDesc partKeyExpr = pit.next(); - SourceInfo si = new SourceInfo(t, partKeyExpr, columnName, jobConf); + SourceInfo si = createSourceInfo(t, partKeyExpr, columnName, jobConf); if (!sourceInfoMap.containsKey(s)) { sourceInfoMap.put(s, new ArrayList()); } @@ -157,6 +188,8 @@ public void initialize(MapWork work, JobConf jobConf) throws SerDeException { // We could have multiple sources restrict the same column, need to take // the union of the values in that case. if (columnMap.containsKey(columnName)) { + // All Sources are initialized up front. Events from different sources will end up getting added to the same list. + // Pruning is disabled if either source sends in an event which causes pruning to be skipped si.values = columnMap.get(columnName).values; si.skipPruning = columnMap.get(columnName).skipPruning; } @@ -165,25 +198,27 @@ public void initialize(MapWork work, JobConf jobConf) throws SerDeException { } } - private void prunePartitions(MapWork work, InputInitializerContext context) throws HiveException { + private void prunePartitions() throws HiveException { int expectedEvents = 0; - for (String source : this.sourceInfoMap.keySet()) { - for (SourceInfo si : this.sourceInfoMap.get(source)) { + for (Map.Entry> entry : this.sourceInfoMap.entrySet()) { + String source = entry.getKey(); + for (SourceInfo si : entry.getValue()) { int taskNum = context.getVertexNumTasks(source); - LOG.info("Expecting " + taskNum + " events for vertex " + source); + LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName); expectedEvents += taskNum; - prunePartitionSingleSource(source, si, work); + prunePartitionSingleSource(source, si); } } // sanity check. all tasks must submit events for us to succeed. if (expectedEvents != totalEventCount) { LOG.error("Expecting: " + expectedEvents + ", received: " + totalEventCount); - throw new HiveException("Incorrect event count in dynamic parition pruning"); + throw new HiveException("Incorrect event count in dynamic partition pruning"); } } - private void prunePartitionSingleSource(String source, SourceInfo si, MapWork work) + @VisibleForTesting + protected void prunePartitionSingleSource(String source, SourceInfo si) throws HiveException { if (si.skipPruning.get()) { @@ -223,11 +258,11 @@ private void prunePartitionSingleSource(String source, SourceInfo si, MapWork wo ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey); eval.initialize(soi); - applyFilterToPartitions(work, converter, eval, columnName, values); + applyFilterToPartitions(converter, eval, columnName, values); } @SuppressWarnings("rawtypes") - private void applyFilterToPartitions(MapWork work, Converter converter, ExprNodeEvaluator eval, + private void applyFilterToPartitions(Converter converter, ExprNodeEvaluator eval, String columnName, Set values) throws HiveException { Object[] row = new Object[1]; @@ -238,12 +273,12 @@ private void applyFilterToPartitions(MapWork work, Converter converter, ExprNode PartitionDesc desc = work.getPathToPartitionInfo().get(p); Map spec = desc.getPartSpec(); if (spec == null) { - throw new AssertionException("No partition spec found in dynamic pruning"); + throw new IllegalStateException("No partition spec found in dynamic pruning"); } String partValueString = spec.get(columnName); if (partValueString == null) { - throw new AssertionException("Could not find partition value for column: " + columnName); + throw new IllegalStateException("Could not find partition value for column: " + columnName); } Object partValue = converter.convert(partValueString); @@ -267,17 +302,38 @@ private void applyFilterToPartitions(MapWork work, Converter converter, ExprNode } } + @VisibleForTesting + protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName, + JobConf jobConf) throws + SerDeException { + return new SourceInfo(t, partKeyExpr, columnName, jobConf); + + } + @SuppressWarnings("deprecation") - private static class SourceInfo { + @VisibleForTesting + static class SourceInfo { public final ExprNodeDesc partKey; public final Deserializer deserializer; public final StructObjectInspector soi; public final StructField field; public final ObjectInspector fieldInspector; + /* List of partitions that are required - populated from processing each event */ public Set values = new HashSet(); + /* Whether to skipPruning - depends on the payload from an event which may signal skip - if the event payload is too large */ public AtomicBoolean skipPruning = new AtomicBoolean(); public final String columnName; + @VisibleForTesting // Only used for testing. + SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf, Object forTesting) { + this.partKey = partKey; + this.columnName = columnName; + this.deserializer = null; + this.soi = null; + this.field = null; + this.fieldInspector = null; + } + public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf) throws SerDeException { @@ -328,52 +384,60 @@ private void processEvents() throws SerDeException, IOException, InterruptedExce } @SuppressWarnings("deprecation") - private String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, + @VisibleForTesting + protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException { DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload)); - String columnName = in.readUTF(); - boolean skip = in.readBoolean(); + try { + String columnName = in.readUTF(); - LOG.info("Source of event: " + sourceName); + LOG.info("Source of event: " + sourceName); - List infos = this.sourceInfoMap.get(sourceName); - if (infos == null) { - in.close(); - throw new AssertionException("no source info for event source: " + sourceName); - } - - SourceInfo info = null; - for (SourceInfo si : infos) { - if (columnName.equals(si.columnName)) { - info = si; - break; + List infos = this.sourceInfoMap.get(sourceName); + if (infos == null) { + throw new IllegalStateException("no source info for event source: " + sourceName); } - } - - if (info == null) { - in.close(); - throw new AssertionException("no source info for column: " + columnName); - } - if (skip) { - info.skipPruning.set(true); - } - - while (payload.hasRemaining()) { - writable.readFields(in); - - Object row = info.deserializer.deserialize(writable); + SourceInfo info = null; + for (SourceInfo si : infos) { + if (columnName.equals(si.columnName)) { + info = si; + break; + } + } - Object value = info.soi.getStructFieldData(row, info.field); - value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector); + if (info == null) { + throw new IllegalStateException("no source info for column: " + columnName); + } - if (LOG.isDebugEnabled()) { - LOG.debug("Adding: " + value + " to list of required partitions"); + if (info.skipPruning.get()) { + // Marked as skipped previously. Don't bother processing the rest of the payload. + } else { + boolean skip = in.readBoolean(); + if (skip) { + info.skipPruning.set(true); + } else { + while (payload.hasRemaining()) { + writable.readFields(in); + + Object row = info.deserializer.deserialize(writable); + + Object value = info.soi.getStructFieldData(row, info.field); + value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector); + + if (LOG.isDebugEnabled()) { + LOG.debug("Adding: " + value + " to list of required partitions"); + } + info.values.add(value); + } + } + } + } finally { + if (in != null) { + in.close(); } - info.values.add(value); } - in.close(); return sourceName; } @@ -409,23 +473,47 @@ public void addEvent(InputInitializerEvent event) { synchronized(sourcesWaitingForEvents) { if (sourcesWaitingForEvents.contains(event.getSourceVertexName())) { ++totalEventCount; + numEventsSeenPerSource.get(event.getSourceVertexName()).increment(); queue.offer(event); + checkForSourceCompletion(event.getSourceVertexName()); } } } public void processVertex(String name) { LOG.info("Vertex succeeded: " + name); - synchronized(sourcesWaitingForEvents) { - sourcesWaitingForEvents.remove(name); + // Get a deterministic count of number of tasks for the vertex. + MutableInt prevVal = numExpectedEventsPerSource.get(name); + int prevValInt = prevVal.intValue(); + Preconditions.checkState(prevValInt < 0, + "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt); + prevVal.setValue((-1) * prevValInt * context.getVertexNumTasks(name)); + checkForSourceCompletion(name); + } + } - if (sourcesWaitingForEvents.isEmpty()) { - // we've got what we need; mark the queue - queue.offer(endOfEvents); - } else { - LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " events."); + private void checkForSourceCompletion(String name) { + int expectedEvents = numExpectedEventsPerSource.get(name).getValue(); + if (expectedEvents < 0) { + // Expected events not updated yet - vertex SUCCESS notification not received. + return; + } else { + int processedEvents = numEventsSeenPerSource.get(name).getValue(); + if (processedEvents == expectedEvents) { + sourcesWaitingForEvents.remove(name); + if (sourcesWaitingForEvents.isEmpty()) { + // we've got what we need; mark the queue + queue.offer(endOfEvents); + } else { + LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " sources."); + } + } else if (processedEvents > expectedEvents) { + throw new IllegalStateException( + "Received too many events for " + name + ", Expected=" + expectedEvents + + ", Received=" + processedEvents); } + return; } } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java index a6f4b55..ccaecdc 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java @@ -19,22 +19,17 @@ package org.apache.hadoop.hive.ql.exec.tez; import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; import java.util.List; -import java.util.Map; +import com.google.common.base.Preconditions; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.JavaUtils; import org.apache.hadoop.hive.ql.exec.Utilities; -import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils; import org.apache.hadoop.hive.ql.plan.MapWork; -import org.apache.hadoop.hive.ql.plan.PartitionDesc; +import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.shims.ShimLoader; -import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputFormat; import org.apache.hadoop.mapred.InputSplit; import org.apache.hadoop.mapred.JobConf; @@ -57,7 +52,6 @@ import org.apache.tez.runtime.api.events.InputDataInformationEvent; import org.apache.tez.runtime.api.events.InputInitializerEvent; -import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; @@ -71,43 +65,44 @@ private static final Log LOG = LogFactory.getLog(HiveSplitGenerator.class); - private static final SplitGrouper grouper = new SplitGrouper(); - private final DynamicPartitionPruner pruner = new DynamicPartitionPruner(); - private InputInitializerContext context; - private static Map, Map> cache = - new HashMap, Map>(); + private final DynamicPartitionPruner pruner; + private final Configuration conf; + private final JobConf jobConf; + private final MRInputUserPayloadProto userPayloadProto; + private final SplitGrouper splitGrouper = new SplitGrouper(); - public HiveSplitGenerator(InputInitializerContext initializerContext) { + + public HiveSplitGenerator(InputInitializerContext initializerContext) throws IOException, + SerDeException { super(initializerContext); - } + Preconditions.checkNotNull(initializerContext); + userPayloadProto = + MRInputHelpers.parseMRInputPayload(initializerContext.getInputUserPayload()); - public HiveSplitGenerator() { - this(null); - } + this.conf = + TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes()); - @Override - public List initialize() throws Exception { - InputInitializerContext rootInputContext = getContext(); + this.jobConf = new JobConf(conf); + // Read all credentials into the credentials instance stored in JobConf. + ShimLoader.getHadoopShims().getMergedCredentials(jobConf); - context = rootInputContext; + MapWork work = Utilities.getMapWork(jobConf); - MRInputUserPayloadProto userPayloadProto = - MRInputHelpers.parseMRInputPayload(rootInputContext.getInputUserPayload()); + // Events can start coming in the moment the InputInitializer is created. The pruner + // must be setup and initialized here so that it sets up it's structures to start accepting events. + // Setting it up in initialize leads to a window where events may come in before the pruner is + // initialized, which may cause it to drop events. + pruner = new DynamicPartitionPruner(initializerContext, work, jobConf); - Configuration conf = - TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes()); + } + @Override + public List initialize() throws Exception { boolean sendSerializedEvents = conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true); - // Read all credentials into the credentials instance stored in JobConf. - JobConf jobConf = new JobConf(conf); - ShimLoader.getHadoopShims().getMergedCredentials(jobConf); - - MapWork work = Utilities.getMapWork(jobConf); - // perform dynamic partition pruning - pruner.prune(work, jobConf, context); + pruner.prune(); InputSplitInfoMem inputSplitInfo = null; String realInputFormatName = conf.get("mapred.input.format.class"); @@ -118,8 +113,8 @@ public HiveSplitGenerator() { (InputFormat) ReflectionUtils.newInstance(JavaUtils.loadClass(realInputFormatName), jobConf); - int totalResource = rootInputContext.getTotalAvailableResource().getMemory(); - int taskResource = rootInputContext.getVertexTaskResource().getMemory(); + int totalResource = getContext().getTotalAvailableResource().getMemory(); + int taskResource = getContext().getVertexTaskResource().getMemory(); int availableSlots = totalResource / taskResource; // Create the un-grouped splits @@ -132,12 +127,12 @@ public HiveSplitGenerator() { + " available slots, " + waves + " waves. Input format is: " + realInputFormatName); Multimap groupedSplits = - generateGroupedSplits(jobConf, conf, splits, waves, availableSlots); + splitGrouper.generateGroupedSplits(jobConf, conf, splits, waves, availableSlots); // And finally return them in a flat array InputSplit[] flatSplits = groupedSplits.values().toArray(new InputSplit[0]); LOG.info("Number of grouped splits: " + flatSplits.length); - List locationHints = grouper.createTaskLocationHints(flatSplits); + List locationHints = splitGrouper.createTaskLocationHints(flatSplits); Utilities.clearWork(jobConf); @@ -158,87 +153,7 @@ public HiveSplitGenerator() { } - public Multimap generateGroupedSplits(JobConf jobConf, - Configuration conf, InputSplit[] splits, float waves, int availableSlots) - throws Exception { - return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null, true); - } - public Multimap generateGroupedSplits(JobConf jobConf, - Configuration conf, InputSplit[] splits, float waves, int availableSlots, String inputName, - boolean groupAcrossFiles) throws Exception { - - MapWork work = populateMapWork(jobConf, inputName); - Multimap bucketSplitMultiMap = - ArrayListMultimap. create(); - - int i = 0; - InputSplit prevSplit = null; - for (InputSplit s : splits) { - // this is the bit where we make sure we don't group across partition - // schema boundaries - if (schemaEvolved(s, prevSplit, groupAcrossFiles, work)) { - ++i; - prevSplit = s; - } - bucketSplitMultiMap.put(i, s); - } - LOG.info("# Src groups for split generation: " + (i + 1)); - - // group them into the chunks we want - Multimap groupedSplits = - grouper.group(jobConf, bucketSplitMultiMap, availableSlots, waves); - - return groupedSplits; - } - - private MapWork populateMapWork(JobConf jobConf, String inputName) { - MapWork work = null; - if (inputName != null) { - work = (MapWork) Utilities.getMergeWork(jobConf, inputName); - // work can still be null if there is no merge work for this input - } - if (work == null) { - work = Utilities.getMapWork(jobConf); - } - - return work; - } - - public boolean schemaEvolved(InputSplit s, InputSplit prevSplit, boolean groupAcrossFiles, - MapWork work) throws IOException { - boolean retval = false; - Path path = ((FileSplit) s).getPath(); - PartitionDesc pd = - HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), - path, cache); - String currentDeserializerClass = pd.getDeserializerClassName(); - Class currentInputFormatClass = pd.getInputFileFormatClass(); - - Class previousInputFormatClass = null; - String previousDeserializerClass = null; - if (prevSplit != null) { - Path prevPath = ((FileSplit) prevSplit).getPath(); - if (!groupAcrossFiles) { - return !path.equals(prevPath); - } - PartitionDesc prevPD = - HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), - prevPath, cache); - previousDeserializerClass = prevPD.getDeserializerClassName(); - previousInputFormatClass = prevPD.getInputFileFormatClass(); - } - - if ((currentInputFormatClass != previousInputFormatClass) - || (!currentDeserializerClass.equals(previousDeserializerClass))) { - retval = true; - } - - if (LOG.isDebugEnabled()) { - LOG.debug("Adding split " + path + " to src new group? " + retval); - } - return retval; - } private List createEventList(boolean sendSerializedEvents, InputSplitInfoMem inputSplitInfo) { diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java index f9c80c2..c169677 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.java @@ -26,13 +26,20 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.exec.Utilities; +import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils; import org.apache.hadoop.hive.ql.io.HiveInputFormat; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.PartitionDesc; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputSplit; +import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.split.TezGroupedSplit; import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper; import org.apache.tez.dag.api.TaskLocationHint; @@ -49,8 +56,15 @@ private static final Log LOG = LogFactory.getLog(SplitGrouper.class); + // TODO This needs to be looked at. Map of Map to Map... Made concurrent for now since split generation + // can happen in parallel. + private static final Map, Map> cache = + new ConcurrentHashMap<>(); + private final TezMapredSplitsGrouper tezGrouper = new TezMapredSplitsGrouper(); + + /** * group splits for each bucket separately - while evenly filling all the * available slots with tasks @@ -87,12 +101,83 @@ return bucketGroupedSplitMultimap; } + + /** + * Create task location hints from a set of input splits + * @param splits the actual splits + * @return taskLocationHints - 1 per input split specified + * @throws IOException + */ + public List createTaskLocationHints(InputSplit[] splits) throws IOException { + + List locationHints = Lists.newArrayListWithCapacity(splits.length); + + for (InputSplit split : splits) { + String rack = (split instanceof TezGroupedSplit) ? ((TezGroupedSplit) split).getRack() : null; + if (rack == null) { + if (split.getLocations() != null) { + locationHints.add(TaskLocationHint.createTaskLocationHint(new HashSet(Arrays.asList(split + .getLocations())), null)); + } else { + locationHints.add(TaskLocationHint.createTaskLocationHint(null, null)); + } + } else { + locationHints.add(TaskLocationHint.createTaskLocationHint(null, Collections.singleton(rack))); + } + } + + return locationHints; + } + + /** Generate groups of splits, separated by schema evolution boundaries */ + public Multimap generateGroupedSplits(JobConf jobConf, + Configuration conf, + InputSplit[] splits, + float waves, int availableSlots) + throws Exception { + return generateGroupedSplits(jobConf, conf, splits, waves, availableSlots, null, true); + } + + /** Generate groups of splits, separated by schema evolution boundaries */ + public Multimap generateGroupedSplits(JobConf jobConf, + Configuration conf, + InputSplit[] splits, + float waves, int availableSlots, + String inputName, + boolean groupAcrossFiles) throws + Exception { + + MapWork work = populateMapWork(jobConf, inputName); + Multimap bucketSplitMultiMap = + ArrayListMultimap. create(); + + int i = 0; + InputSplit prevSplit = null; + for (InputSplit s : splits) { + // this is the bit where we make sure we don't group across partition + // schema boundaries + if (schemaEvolved(s, prevSplit, groupAcrossFiles, work)) { + ++i; + prevSplit = s; + } + bucketSplitMultiMap.put(i, s); + } + LOG.info("# Src groups for split generation: " + (i + 1)); + + // group them into the chunks we want + Multimap groupedSplits = + this.group(jobConf, bucketSplitMultiMap, availableSlots, waves); + + return groupedSplits; + } + + /** * get the size estimates for each bucket in tasks. This is used to make sure * we allocate the head room evenly */ private Map estimateBucketSizes(int availableSlots, float waves, - Map> bucketSplitMap) { + Map> bucketSplitMap) { // mapping of bucket id to size of all splits in bucket in bytes Map bucketSizeMap = new HashMap(); @@ -147,24 +232,54 @@ return bucketTaskMap; } - public List createTaskLocationHints(InputSplit[] splits) throws IOException { + private static MapWork populateMapWork(JobConf jobConf, String inputName) { + MapWork work = null; + if (inputName != null) { + work = (MapWork) Utilities.getMergeWork(jobConf, inputName); + // work can still be null if there is no merge work for this input + } + if (work == null) { + work = Utilities.getMapWork(jobConf); + } - List locationHints = Lists.newArrayListWithCapacity(splits.length); + return work; + } - for (InputSplit split : splits) { - String rack = (split instanceof TezGroupedSplit) ? ((TezGroupedSplit) split).getRack() : null; - if (rack == null) { - if (split.getLocations() != null) { - locationHints.add(TaskLocationHint.createTaskLocationHint(new HashSet(Arrays.asList(split - .getLocations())), null)); - } else { - locationHints.add(TaskLocationHint.createTaskLocationHint(null, null)); - } - } else { - locationHints.add(TaskLocationHint.createTaskLocationHint(null, Collections.singleton(rack))); + private boolean schemaEvolved(InputSplit s, InputSplit prevSplit, boolean groupAcrossFiles, + MapWork work) throws IOException { + boolean retval = false; + Path path = ((FileSplit) s).getPath(); + PartitionDesc pd = + HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), + path, cache); + String currentDeserializerClass = pd.getDeserializerClassName(); + Class currentInputFormatClass = pd.getInputFileFormatClass(); + + Class previousInputFormatClass = null; + String previousDeserializerClass = null; + if (prevSplit != null) { + Path prevPath = ((FileSplit) prevSplit).getPath(); + if (!groupAcrossFiles) { + return !path.equals(prevPath); } + PartitionDesc prevPD = + HiveFileFormatUtils.getPartitionDescFromPathRecursively(work.getPathToPartitionInfo(), + prevPath, cache); + previousDeserializerClass = prevPD.getDeserializerClassName(); + previousInputFormatClass = prevPD.getInputFileFormatClass(); } - return locationHints; + if ((currentInputFormatClass != previousInputFormatClass) + || (!currentDeserializerClass.equals(previousDeserializerClass))) { + retval = true; + } + + if (LOG.isDebugEnabled()) { + LOG.debug("Adding split " + path + " to src new group? " + retval); + } + return retval; } + + + } diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java new file mode 100644 index 0000000..1ce2e09 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java @@ -0,0 +1,532 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.tez; + +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.TableDesc; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.mapred.JobConf; +import org.apache.tez.runtime.api.InputInitializerContext; +import org.apache.tez.runtime.api.events.InputInitializerEvent; +import org.junit.Test; + +public class TestDynamicPartitionPruner { + + @Test(timeout = 5000) + public void testNoPruning() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + MapWork mapWork = mock(MapWork.class); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + pruneRunnable.awaitEnd(); + // Return immediately. No entries found for pruning. Verified via the timeout. + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceOrdering1() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceOrdering2() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + pruner.addEvent(event); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceMultipleFiltersOrdering1() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceMultipleFiltersOrdering2() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering1() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.processVertex("v1"); + pruner.processVertex("v2"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering2() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + + pruner.processVertex("v1"); + pruner.processVertex("v2"); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering3() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.processVertex("v1"); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.processVertex("v2"); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000, expected = IllegalStateException.class) + public void testExtraEvents() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 20000) + public void testMissingEvent() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + Thread.sleep(3000l); + // The pruner should not have completed. + assertFalse(pruneRunnable.ended.get()); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + private static class PruneRunnable implements Runnable { + + final DynamicPartitionPruner pruner; + final ReentrantLock lock = new ReentrantLock(); + final Condition endCondition = lock.newCondition(); + final Condition startCondition = lock.newCondition(); + final AtomicBoolean started = new AtomicBoolean(false); + final AtomicBoolean ended = new AtomicBoolean(false); + final AtomicBoolean inError = new AtomicBoolean(false); + + private PruneRunnable(DynamicPartitionPruner pruner) { + this.pruner = pruner; + } + + void start() { + started.set(true); + lock.lock(); + try { + startCondition.signal(); + } finally { + lock.unlock(); + } + } + + void awaitEnd() throws InterruptedException { + lock.lock(); + try { + while (!ended.get()) { + endCondition.await(); + } + } finally { + lock.unlock(); + } + } + + @Override + public void run() { + try { + lock.lock(); + try { + while (!started.get()) { + startCondition.await(); + } + } finally { + lock.unlock(); + } + + pruner.prune(); + lock.lock(); + try { + ended.set(true); + endCondition.signal(); + } finally { + lock.unlock(); + } + } catch (SerDeException | IOException | InterruptedException | HiveException e) { + inError.set(true); + } + } + } + + + private MapWork createMockMapWork(TestSource... testSources) { + MapWork mapWork = mock(MapWork.class); + + Map> tableMap = new HashMap<>(); + Map> columnMap = new HashMap<>(); + Map> exprMap = new HashMap<>(); + + int count = 0; + for (TestSource testSource : testSources) { + + for (int i = 0; i < testSource.numExpressions; i++) { + List tableDescList = tableMap.get(testSource.vertexName); + if (tableDescList == null) { + tableDescList = new LinkedList<>(); + tableMap.put(testSource.vertexName, tableDescList); + } + tableDescList.add(mock(TableDesc.class)); + + List columnList = columnMap.get(testSource.vertexName); + if (columnList == null) { + columnList = new LinkedList<>(); + columnMap.put(testSource.vertexName, columnList); + } + columnList.add(testSource.vertexName + "c_" + count + "_" + i); + + List exprNodeDescList = exprMap.get(testSource.vertexName); + if (exprNodeDescList == null) { + exprNodeDescList = new LinkedList<>(); + exprMap.put(testSource.vertexName, exprNodeDescList); + } + exprNodeDescList.add(mock(ExprNodeDesc.class)); + } + + count++; + } + + doReturn(tableMap).when(mapWork).getEventSourceTableDescMap(); + doReturn(columnMap).when(mapWork).getEventSourceColumnNameMap(); + doReturn(exprMap).when(mapWork).getEventSourcePartKeyExprMap(); + return mapWork; + } + + private static class TestSource { + String vertexName; + int numExpressions; + + public TestSource(String vertexName, int numExpressions) { + this.vertexName = vertexName; + this.numExpressions = numExpressions; + } + } + + private static class DynamicPartitionPrunerForEventTesting extends DynamicPartitionPruner { + + + public DynamicPartitionPrunerForEventTesting( + InputInitializerContext context, MapWork work) throws SerDeException { + super(context, work, new JobConf()); + } + + @Override + protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName, + JobConf jobConf) throws + SerDeException { + return new SourceInfo(t, partKeyExpr, columnName, jobConf, null); + } + + @Override + protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, + IOException { + // No-op: testing events only + return sourceName; + } + + @Override + protected void prunePartitionSingleSource(String source, SourceInfo si) + throws HiveException { + // No-op: testing events only + } + } +}