diff --git a/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java b/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java index 9d46cac..9ca5544 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/plan/SparkWork.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -47,7 +48,8 @@ @Explain(displayName = "Spark", explainLevels = { Level.USER, Level.DEFAULT, Level.EXTENDED }, vectorization = Vectorization.SUMMARY_PATH) public class SparkWork extends AbstractOperatorDesc { - private static int counter; + + private static final AtomicInteger counter = new AtomicInteger(1); private final String name; private final Set roots = new LinkedHashSet(); @@ -65,7 +67,7 @@ private Map cloneToWork; public SparkWork(String name) { - this.name = name + ":" + (++counter); + this.name = name + ":" + counter.getAndIncrement(); cloneToWork = new HashMap(); } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/plan/TestExecutionEngineWorkConcurrency.java b/ql/src/test/org/apache/hadoop/hive/ql/plan/TestExecutionEngineWorkConcurrency.java new file mode 100644 index 0000000..a7fcad0 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/plan/TestExecutionEngineWorkConcurrency.java @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.plan; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.FutureTask; + +import static org.junit.Assert.assertEquals; + + +@RunWith(Parameterized.class) +public final class TestExecutionEngineWorkConcurrency { + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][]{{new TezDagIdProvider()}, {new SparkDagIdProvider()}}); + } + + private final ExecutionEngineDagIdGenerator executionEngineDagIdGenerator; + + public TestExecutionEngineWorkConcurrency(ExecutionEngineDagIdGenerator executionEngineDagIdGenerator) { + this.executionEngineDagIdGenerator = executionEngineDagIdGenerator; + } + + @Test + public void ensureDagIdIsUnique() throws Exception { + final int threadCount = 5; + final CountDownLatch threadReadyToStartSignal = new CountDownLatch(threadCount); + final CountDownLatch startThreadSignal = new CountDownLatch(1); + final int numberOfWorkToCreatePerThread = 100; + + List>> tasks = Lists.newArrayList(); + for (int i = 0; i < threadCount; i++) { + tasks.add(new FutureTask<>(new Callable>() { + @Override + public Set call() throws Exception { + threadReadyToStartSignal.countDown(); + startThreadSignal.await(); + return generateWorkDagIds(numberOfWorkToCreatePerThread); + } + })); + } + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + for (FutureTask> task : tasks) { + executor.execute(task); + } + threadReadyToStartSignal.await(); + startThreadSignal.countDown(); + Set allWorkDagIds = getAllWorkDagIds(tasks); + assertEquals(threadCount * numberOfWorkToCreatePerThread, allWorkDagIds.size()); + } + + private Set generateWorkDagIds(int numberOfNames) { + Set workIds = Sets.newHashSet(); + for (int i = 0; i < numberOfNames; i++) { + workIds.add(executionEngineDagIdGenerator.getDagId()); + } + return workIds; + } + + private static Set getAllWorkDagIds(List>> tasks) + throws ExecutionException, InterruptedException { + Set allWorkDagIds = Sets.newHashSet(); + for (FutureTask> task : tasks) { + allWorkDagIds.addAll(task.get()); + } + return allWorkDagIds; + } + + private interface ExecutionEngineDagIdGenerator { + String getDagId(); + } + + private static final class TezDagIdProvider implements ExecutionEngineDagIdGenerator { + + @Override + public String getDagId() { + return new TezWork("query-id").getDagId(); + } + } + + private static final class SparkDagIdProvider implements ExecutionEngineDagIdGenerator { + + @Override + public String getDagId() { + return new SparkWork("query-id").getName(); + } + } +} diff --git a/ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWorkConcurrency.java b/ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWorkConcurrency.java deleted file mode 100644 index 9af1c1b..0000000 --- a/ql/src/test/org/apache/hadoop/hive/ql/plan/TestTezWorkConcurrency.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.hadoop.hive.ql.plan; - -import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import org.junit.Test; - -import java.util.List; -import java.util.Set; -import java.util.concurrent.Callable; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.FutureTask; - -import static org.junit.Assert.assertEquals; - -public final class TestTezWorkConcurrency { - - @Test - public void ensureDagIdIsUnique() throws Exception { - final int threadCount = 5; - final CountDownLatch threadReadyToStartSignal = new CountDownLatch(threadCount); - final CountDownLatch startThreadSignal = new CountDownLatch(1); - final int numberOfTezWorkToCreatePerThread = 100; - - List>> tasks = Lists.newArrayList(); - for (int i = 0; i < threadCount; i++) { - tasks.add(new FutureTask<>(new Callable>() { - @Override - public Set call() throws Exception { - threadReadyToStartSignal.countDown(); - startThreadSignal.await(); - return generateTezWorkDagIds(numberOfTezWorkToCreatePerThread); - } - })); - } - ExecutorService executor = Executors.newFixedThreadPool(threadCount); - for (FutureTask> task : tasks) { - executor.execute(task); - } - threadReadyToStartSignal.await(); - startThreadSignal.countDown(); - Set allTezWorkDagIds = getAllTezWorkDagIds(tasks); - assertEquals(threadCount * numberOfTezWorkToCreatePerThread, allTezWorkDagIds.size()); - } - - private static Set generateTezWorkDagIds(int numberOfNames) { - Set tezWorkIds = Sets.newHashSet(); - for (int i = 0; i < numberOfNames; i++) { - TezWork work = new TezWork("query-id"); - tezWorkIds.add(work.getDagId()); - } - return tezWorkIds; - } - - private static Set getAllTezWorkDagIds(List>> tasks) - throws ExecutionException, InterruptedException { - Set allTezWorkDagIds = Sets.newHashSet(); - for (FutureTask> task : tasks) { - allTezWorkDagIds.addAll(task.get()); - } - return allTezWorkDagIds; - } -}