diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java index 696874e..e0922cf 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/DynamicPartitionPruner.java @@ -31,10 +31,12 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import javolution.testing.AssertionException; import org.apache.commons.logging.Log; @@ -74,30 +76,41 @@ private static final Log LOG = LogFactory.getLog(DynamicPartitionPruner.class); + private final MapWork work; + private final InputInitializerContext context; + private final Map> sourceInfoMap = new HashMap>(); private final BytesWritable writable = new BytesWritable(); + /* Keeps track of all events that need to be processed - irrespective of the source */ private final BlockingQueue queue = new LinkedBlockingQueue(); + /* Keeps track of vertices from which events are expected */ private final Set sourcesWaitingForEvents = new HashSet(); + private final Map numExpectedEventsPerSource = new HashMap<>(); + private final Map numEventsSeenPerSource = new HashMap<>(); + private int sourceInfoCount = 0; private final Object endOfEvents = new Object(); private int totalEventCount = 0; - public DynamicPartitionPruner() { + public DynamicPartitionPruner(InputInitializerContext context, MapWork work, JobConf jobConf) throws + SerDeException { + this.context = context; + this.work = work; + initialize(work, jobConf); } - public void prune(MapWork work, JobConf jobConf, InputInitializerContext context) + public void prune() throws SerDeException, IOException, InterruptedException, HiveException { synchronized(sourcesWaitingForEvents) { - initialize(work, jobConf); if (sourcesWaitingForEvents.isEmpty()) { return; @@ -112,11 +125,11 @@ public void prune(MapWork work, JobConf jobConf, InputInitializerContext context } } - LOG.info("Waiting for events (" + sourceInfoCount + " items) ..."); + LOG.info("Waiting for events (" + sourceInfoCount + " sources) ..."); // synchronous event processing loop. Won't return until all events have // been processed. this.processEvents(); - this.prunePartitions(work, context); + this.prunePartitions(work); LOG.info("Ok to proceed."); } @@ -132,22 +145,32 @@ private void clear() { public void initialize(MapWork work, JobConf jobConf) throws SerDeException { this.clear(); Map columnMap = new HashMap(); + // sources represent vertex names Set sources = work.getEventSourceTableDescMap().keySet(); sourcesWaitingForEvents.addAll(sources); for (String s : sources) { + numExpectedEventsPerSource.put(s, new AtomicInteger(0)); + numEventsSeenPerSource.put(s, new AtomicInteger(0)); + // Virtual relation generated by the reduce sync List tables = work.getEventSourceTableDescMap().get(s); + // Real column name - on which the operation is being performed List columnNames = work.getEventSourceColumnNameMap().get(s); + // Expression for the operation. e.g. N^2 > 10 List partKeyExprs = work.getEventSourcePartKeyExprMap().get(s); + // eventSourceTableDesc, eventSourceColumnName, evenSourcePartKeyExpr move in lock-step. + // One entry is added to each at the same time Iterator cit = columnNames.iterator(); Iterator pit = partKeyExprs.iterator(); + // A single source can process multiple columns, and will send an event for each of them. for (TableDesc t : tables) { + numExpectedEventsPerSource.get(s).decrementAndGet(); ++sourceInfoCount; String columnName = cit.next(); ExprNodeDesc partKeyExpr = pit.next(); - SourceInfo si = new SourceInfo(t, partKeyExpr, columnName, jobConf); + SourceInfo si = createSourceInfo(t, partKeyExpr, columnName, jobConf); if (!sourceInfoMap.containsKey(s)) { sourceInfoMap.put(s, new ArrayList()); } @@ -157,6 +180,8 @@ public void initialize(MapWork work, JobConf jobConf) throws SerDeException { // We could have multiple sources restrict the same column, need to take // the union of the values in that case. if (columnMap.containsKey(columnName)) { + // All Sources are initialized up front. Events from different sources will end up getting added to the same list. + // Pruning is disabled if either source sends in an event which causes pruning to be skipped si.values = columnMap.get(columnName).values; si.skipPruning = columnMap.get(columnName).skipPruning; } @@ -165,12 +190,13 @@ public void initialize(MapWork work, JobConf jobConf) throws SerDeException { } } - private void prunePartitions(MapWork work, InputInitializerContext context) throws HiveException { + private void prunePartitions(MapWork work) throws HiveException { int expectedEvents = 0; - for (String source : this.sourceInfoMap.keySet()) { - for (SourceInfo si : this.sourceInfoMap.get(source)) { + for (Map.Entry> entry : this.sourceInfoMap.entrySet()) { + String source = entry.getKey(); + for (SourceInfo si : entry.getValue()) { int taskNum = context.getVertexNumTasks(source); - LOG.info("Expecting " + taskNum + " events for vertex " + source); + LOG.info("Expecting " + taskNum + " events for vertex " + source + ", for column " + si.columnName); expectedEvents += taskNum; prunePartitionSingleSource(source, si, work); } @@ -179,11 +205,12 @@ private void prunePartitions(MapWork work, InputInitializerContext context) thro // sanity check. all tasks must submit events for us to succeed. if (expectedEvents != totalEventCount) { LOG.error("Expecting: " + expectedEvents + ", received: " + totalEventCount); - throw new HiveException("Incorrect event count in dynamic parition pruning"); + throw new HiveException("Incorrect event count in dynamic partition pruning"); } } - private void prunePartitionSingleSource(String source, SourceInfo si, MapWork work) + @VisibleForTesting + protected void prunePartitionSingleSource(String source, SourceInfo si, MapWork work) throws HiveException { if (si.skipPruning.get()) { @@ -267,17 +294,38 @@ private void applyFilterToPartitions(MapWork work, Converter converter, ExprNode } } + @VisibleForTesting + protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName, + JobConf jobConf) throws + SerDeException { + return new SourceInfo(t, partKeyExpr, columnName, jobConf); + + } + @SuppressWarnings("deprecation") - private static class SourceInfo { + @VisibleForTesting + static class SourceInfo { public final ExprNodeDesc partKey; public final Deserializer deserializer; public final StructObjectInspector soi; public final StructField field; public final ObjectInspector fieldInspector; + /* List of partitions that are required - populated from processing each event */ public Set values = new HashSet(); + /* Whether to skipPruning - depends on the payload from an event which may signal skip - if the event payload is too large */ public AtomicBoolean skipPruning = new AtomicBoolean(); public final String columnName; + @VisibleForTesting // Only used for testing. + SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf, Object forTesting) { + this.partKey = partKey; + this.columnName = columnName; + this.deserializer = null; + this.soi = null; + this.field = null; + this.fieldInspector = null; + } + public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf) throws SerDeException { @@ -328,12 +376,12 @@ private void processEvents() throws SerDeException, IOException, InterruptedExce } @SuppressWarnings("deprecation") - private String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, + @VisibleForTesting + protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException { DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload)); String columnName = in.readUTF(); - boolean skip = in.readBoolean(); LOG.info("Source of event: " + sourceName); @@ -356,24 +404,31 @@ private String processPayload(ByteBuffer payload, String sourceName) throws SerD throw new AssertionException("no source info for column: " + columnName); } - if (skip) { - info.skipPruning.set(true); + if (info.skipPruning.get()) { + // Marked as skipped previously. Don't bother processing the rest of the payload. + in.close(); } - while (payload.hasRemaining()) { - writable.readFields(in); + boolean skip = in.readBoolean(); + if (skip) { + info.skipPruning.set(true); + in.close(); + } else { + while (payload.hasRemaining()) { + writable.readFields(in); - Object row = info.deserializer.deserialize(writable); + Object row = info.deserializer.deserialize(writable); - Object value = info.soi.getStructFieldData(row, info.field); - value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector); + Object value = info.soi.getStructFieldData(row, info.field); + value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector); - if (LOG.isDebugEnabled()) { - LOG.debug("Adding: " + value + " to list of required partitions"); + if (LOG.isDebugEnabled()) { + LOG.debug("Adding: " + value + " to list of required partitions"); + } + info.values.add(value); } - info.values.add(value); } - in.close(); + return sourceName; } @@ -409,23 +464,47 @@ public void addEvent(InputInitializerEvent event) { synchronized(sourcesWaitingForEvents) { if (sourcesWaitingForEvents.contains(event.getSourceVertexName())) { ++totalEventCount; + numEventsSeenPerSource.get(event.getSourceVertexName()).incrementAndGet(); queue.offer(event); + checkForSourceCompletion(event.getSourceVertexName()); } } } public void processVertex(String name) { LOG.info("Vertex succeeded: " + name); - synchronized(sourcesWaitingForEvents) { - sourcesWaitingForEvents.remove(name); + // Get a deterministic count of number of tasks for the vertex. + AtomicInteger prevVal = numExpectedEventsPerSource.get(name); + int prevValInt = prevVal.get(); + Preconditions.checkState(prevValInt < 0, + "Invalid value for numExpectedEvents for source: " + name + ", oldVal=" + prevValInt); + prevVal.set((-1) * prevValInt * context.getVertexNumTasks(name)); + checkForSourceCompletion(name); + } + } - if (sourcesWaitingForEvents.isEmpty()) { - // we've got what we need; mark the queue - queue.offer(endOfEvents); - } else { - LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " events."); + private void checkForSourceCompletion(String name) { + int expectedEvents = numExpectedEventsPerSource.get(name).get(); + if (expectedEvents < 0) { + // Expected events not updated yet - vertex SUCCESS notification not received. + return; + } else { + int processedEvents = numEventsSeenPerSource.get(name).get(); + if (processedEvents == expectedEvents) { + sourcesWaitingForEvents.remove(name); + if (sourcesWaitingForEvents.isEmpty()) { + // we've got what we need; mark the queue + queue.offer(endOfEvents); + } else { + LOG.info("Waiting for " + sourcesWaitingForEvents.size() + " sources."); + } + } else if (processedEvents > expectedEvents) { + throw new IllegalStateException( + "Received too many events for " + name + ", Expected=" + expectedEvents + + ", Received=" + processedEvents); } + return; } } } diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java index a6f4b55..0e3bdc6 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/tez/HiveSplitGenerator.java @@ -33,6 +33,7 @@ import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils; import org.apache.hadoop.hive.ql.plan.MapWork; import org.apache.hadoop.hive.ql.plan.PartitionDesc; +import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.InputFormat; @@ -72,42 +73,57 @@ private static final Log LOG = LogFactory.getLog(HiveSplitGenerator.class); private static final SplitGrouper grouper = new SplitGrouper(); - private final DynamicPartitionPruner pruner = new DynamicPartitionPruner(); - private InputInitializerContext context; private static Map, Map> cache = new HashMap, Map>(); - public HiveSplitGenerator(InputInitializerContext initializerContext) { + private final DynamicPartitionPruner pruner; + private final Configuration conf; + private final JobConf jobConf; + private final MRInputUserPayloadProto userPayloadProto; + + + public HiveSplitGenerator(InputInitializerContext initializerContext) throws IOException, + SerDeException { super(initializerContext); + + + if (initializerContext != null) { + userPayloadProto = + MRInputHelpers.parseMRInputPayload(initializerContext.getInputUserPayload()); + + this.conf = + TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes()); + + this.jobConf = new JobConf(conf); + // Read all credentials into the credentials instance stored in JobConf. + ShimLoader.getHadoopShims().getMergedCredentials(jobConf); + + MapWork work = Utilities.getMapWork(jobConf); + + // Setup the pruner with all parameters so that it can initialize itself + pruner = new DynamicPartitionPruner(initializerContext, work, jobConf); + } else { + // This could be invoked directly by the CustomPartitionVertex - in which case it'll send + // in required parameters for split grouping. + this.userPayloadProto = null; + this.conf = null; + this.jobConf = null; + this.pruner = null; + } + } - public HiveSplitGenerator() { + public HiveSplitGenerator() throws IOException, SerDeException { this(null); } @Override public List initialize() throws Exception { - InputInitializerContext rootInputContext = getContext(); - - context = rootInputContext; - - MRInputUserPayloadProto userPayloadProto = - MRInputHelpers.parseMRInputPayload(rootInputContext.getInputUserPayload()); - - Configuration conf = - TezUtils.createConfFromByteString(userPayloadProto.getConfigurationBytes()); - boolean sendSerializedEvents = conf.getBoolean("mapreduce.tez.input.initializer.serialize.event.payload", true); - // Read all credentials into the credentials instance stored in JobConf. - JobConf jobConf = new JobConf(conf); - ShimLoader.getHadoopShims().getMergedCredentials(jobConf); - - MapWork work = Utilities.getMapWork(jobConf); - // perform dynamic partition pruning - pruner.prune(work, jobConf, context); + pruner.prune(); InputSplitInfoMem inputSplitInfo = null; String realInputFormatName = conf.get("mapred.input.format.class"); @@ -118,8 +134,8 @@ public HiveSplitGenerator() { (InputFormat) ReflectionUtils.newInstance(JavaUtils.loadClass(realInputFormatName), jobConf); - int totalResource = rootInputContext.getTotalAvailableResource().getMemory(); - int taskResource = rootInputContext.getVertexTaskResource().getMemory(); + int totalResource = getContext().getTotalAvailableResource().getMemory(); + int taskResource = getContext().getVertexTaskResource().getMemory(); int availableSlots = totalResource / taskResource; // Create the un-grouped splits diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java new file mode 100644 index 0000000..4490e38 --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestDynamicPartitionPruner.java @@ -0,0 +1,532 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.tez; + +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; +import org.apache.hadoop.hive.ql.plan.MapWork; +import org.apache.hadoop.hive.ql.plan.TableDesc; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.mapred.JobConf; +import org.apache.tez.runtime.api.InputInitializerContext; +import org.apache.tez.runtime.api.events.InputInitializerEvent; +import org.junit.Test; + +public class TestDynamicPartitionPruner { + + @Test(timeout = 5000) + public void testNoPruning() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + MapWork mapWork = mock(MapWork.class); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + pruneRunnable.awaitEnd(); + // Return immediately. No entries found for pruning. Verified via the timeout. + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceOrdering1() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceOrdering2() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + pruner.addEvent(event); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceMultipleFiltersOrdering1() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testSingleSourceMultipleFiltersOrdering2() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + pruner.addEvent(event); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering1() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.processVertex("v1"); + pruner.processVertex("v2"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering2() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + + pruner.processVertex("v1"); + pruner.processVertex("v2"); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000) + public void testMultipleSourcesOrdering3() throws InterruptedException, SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(2).when(mockInitContext).getVertexNumTasks("v1"); + doReturn(3).when(mockInitContext).getVertexNumTasks("v2"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 2), new TestSource("v2", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent eventV1 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV1.setSourceVertexName("v1"); + + InputInitializerEvent eventV2 = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + eventV2.setSourceVertexName("v2"); + + // 2 X 2 events for V1. 3 X 1 events for V2 + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.processVertex("v1"); + pruner.addEvent(eventV1); + pruner.addEvent(eventV1); + pruner.addEvent(eventV2); + pruner.processVertex("v2"); + pruner.addEvent(eventV2); + pruner.addEvent(eventV2); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 5000, expected = IllegalStateException.class) + public void testExtraEvents() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.addEvent(event); + pruner.addEvent(event); + pruner.processVertex("v1"); + + pruneRunnable.awaitEnd(); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + @Test(timeout = 20000) + public void testMissingEvent() throws InterruptedException, IOException, HiveException, + SerDeException { + InputInitializerContext mockInitContext = mock(InputInitializerContext.class); + doReturn(1).when(mockInitContext).getVertexNumTasks("v1"); + + MapWork mapWork = createMockMapWork(new TestSource("v1", 1)); + DynamicPartitionPruner pruner = + new DynamicPartitionPrunerForEventTesting(mockInitContext, mapWork); + + + PruneRunnable pruneRunnable = new PruneRunnable(pruner); + Thread t = new Thread(pruneRunnable); + t.start(); + try { + pruneRunnable.start(); + + InputInitializerEvent event = + InputInitializerEvent.create("FakeTarget", "TargetInput", ByteBuffer.allocate(0)); + event.setSourceVertexName("v1"); + + pruner.processVertex("v1"); + Thread.sleep(3000l); + // The pruner should not have completed. + assertFalse(pruneRunnable.ended.get()); + assertFalse(pruneRunnable.inError.get()); + } finally { + t.interrupt(); + t.join(); + } + } + + private static class PruneRunnable implements Runnable { + + final DynamicPartitionPruner pruner; + final ReentrantLock lock = new ReentrantLock(); + final Condition endCondition = lock.newCondition(); + final Condition startCondition = lock.newCondition(); + final AtomicBoolean started = new AtomicBoolean(false); + final AtomicBoolean ended = new AtomicBoolean(false); + final AtomicBoolean inError = new AtomicBoolean(false); + + private PruneRunnable(DynamicPartitionPruner pruner) { + this.pruner = pruner; + } + + void start() { + started.set(true); + lock.lock(); + try { + startCondition.signal(); + } finally { + lock.unlock(); + } + } + + void awaitEnd() throws InterruptedException { + lock.lock(); + try { + while (!ended.get()) { + endCondition.await(); + } + } finally { + lock.unlock(); + } + } + + @Override + public void run() { + try { + lock.lock(); + try { + while (!started.get()) { + startCondition.await(); + } + } finally { + lock.unlock(); + } + + pruner.prune(); + lock.lock(); + try { + ended.set(true); + endCondition.signal(); + } finally { + lock.unlock(); + } + } catch (SerDeException | IOException | InterruptedException | HiveException e) { + inError.set(true); + } + } + } + + + private MapWork createMockMapWork(TestSource... testSources) { + MapWork mapWork = mock(MapWork.class); + + Map> tableMap = new HashMap<>(); + Map> columnMap = new HashMap<>(); + Map> exprMap = new HashMap<>(); + + int count = 0; + for (TestSource testSource : testSources) { + + for (int i = 0; i < testSource.numExpressions; i++) { + List tableDescList = tableMap.get(testSource.vertexName); + if (tableDescList == null) { + tableDescList = new LinkedList<>(); + tableMap.put(testSource.vertexName, tableDescList); + } + tableDescList.add(mock(TableDesc.class)); + + List columnList = columnMap.get(testSource.vertexName); + if (columnList == null) { + columnList = new LinkedList<>(); + columnMap.put(testSource.vertexName, columnList); + } + columnList.add(testSource.vertexName + "c_" + count + "_" + i); + + List exprNodeDescList = exprMap.get(testSource.vertexName); + if (exprNodeDescList == null) { + exprNodeDescList = new LinkedList<>(); + exprMap.put(testSource.vertexName, exprNodeDescList); + } + exprNodeDescList.add(mock(ExprNodeDesc.class)); + } + + count++; + } + + doReturn(tableMap).when(mapWork).getEventSourceTableDescMap(); + doReturn(columnMap).when(mapWork).getEventSourceColumnNameMap(); + doReturn(exprMap).when(mapWork).getEventSourcePartKeyExprMap(); + return mapWork; + } + + private static class TestSource { + String vertexName; + int numExpressions; + + public TestSource(String vertexName, int numExpressions) { + this.vertexName = vertexName; + this.numExpressions = numExpressions; + } + } + + private static class DynamicPartitionPrunerForEventTesting extends DynamicPartitionPruner { + + + public DynamicPartitionPrunerForEventTesting( + InputInitializerContext context, MapWork work) throws SerDeException { + super(context, work, new JobConf()); + } + + @Override + protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, String columnName, + JobConf jobConf) throws + SerDeException { + return new SourceInfo(t, partKeyExpr, columnName, jobConf, null); + } + + @Override + protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, + IOException { + // No-op: testing events only + return sourceName; + } + + @Override + protected void prunePartitionSingleSource(String source, SourceInfo si, MapWork work) + throws HiveException { + // No-op: testing events only + } + } +}