diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java index d384ed6db6..fafae31850 100644 --- ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java +++ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/SparkUtilities.java @@ -7,7 +7,7 @@ * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -122,21 +122,25 @@ public static boolean isDedicatedCluster(Configuration conf) { public static SparkSession getSparkSession(HiveConf conf, SparkSessionManager sparkSessionManager) throws HiveException { - SparkSession sparkSession = SessionState.get().getSparkSession(); - HiveConf sessionConf = SessionState.get().getConf(); - // Spark configurations are updated close the existing session - // In case of async queries or confOverlay is not empty, - // sessionConf and conf are different objects - if (sessionConf.getSparkConfigUpdated() || conf.getSparkConfigUpdated()) { - sparkSessionManager.closeSession(sparkSession); - sparkSession = null; - conf.setSparkConfigUpdated(false); - sessionConf.setSparkConfigUpdated(false); + SessionState sessionState = SessionState.get(); + synchronized (sessionState) { + SparkSession sparkSession = sessionState.getSparkSession(); + HiveConf sessionConf = sessionState.getConf(); + + // Spark configurations are updated close the existing session + // In case of async queries or confOverlay is not empty, + // sessionConf and conf are different objects + if (sessionConf.getSparkConfigUpdated() || conf.getSparkConfigUpdated()) { + sparkSessionManager.closeSession(sparkSession); + sparkSession = null; + conf.setSparkConfigUpdated(false); + sessionConf.setSparkConfigUpdated(false); + } + sparkSession = sparkSessionManager.getSession(sparkSession, conf, true); + sessionState.setSparkSession(sparkSession); + return sparkSession; } - sparkSession = sparkSessionManager.getSession(sparkSession, conf, true); - SessionState.get().setSparkSession(sparkSession); - return sparkSession; } /** diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkUtilities.java ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkUtilities.java new file mode 100644 index 0000000000..f797f309dd --- /dev/null +++ ql/src/test/org/apache/hadoop/hive/ql/exec/spark/TestSparkUtilities.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hive.ql.exec.spark; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.exec.spark.session.SparkSession; +import org.apache.hadoop.hive.ql.exec.spark.session.SparkSessionManager; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.session.SessionState; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for the SparkUtilities class. + */ +public class TestSparkUtilities { + + @Test + public void testGetSparkSessionUsingMultipleThreadsWithTheSameSession() throws HiveException, InterruptedException { + + // The only real state required from SessionState + final AtomicReference activeSparkSession = new AtomicReference<>(); + + // Mocks + HiveConf mockConf = mock(HiveConf.class); + + SparkSessionManager mockSessionManager = mock(SparkSessionManager.class); + doAnswer(invocationOnMock -> { + SparkSession sparkSession = invocationOnMock.getArgumentAt(0, SparkSession.class); + if (sparkSession == null) { + return mock(SparkSession.class); + } else { + return sparkSession; + } + }).when(mockSessionManager).getSession(any(SparkSession.class), eq(mockConf), eq(true)); + + SessionState mockSessionState = mock(SessionState.class); + when(mockSessionState.getConf()).thenReturn(mockConf); + doAnswer(invocationOnMock -> { + activeSparkSession.set(invocationOnMock.getArgumentAt(0, SparkSession.class)); + return null; + }).when(mockSessionState).setSparkSession(any(SparkSession.class)); + doAnswer(invocationOnMock -> + activeSparkSession.get() + ).when(mockSessionState).getSparkSession(); + + // When + List> callables = new ArrayList<>(); + callables.add(new GetSparkSessionTester(mockConf, mockSessionManager, mockSessionState)); + callables.add(new GetSparkSessionTester(mockConf, mockSessionManager, mockSessionState)); + callables.add(new GetSparkSessionTester(mockConf, mockSessionManager, mockSessionState)); + + ExecutorService executorService = Executors.newFixedThreadPool(callables.size()); + List> results = executorService.invokeAll(callables); + + // Then + results.stream().map(f -> resolve(f)).forEach(ss -> assertEquals(ss, activeSparkSession.get())); + + } + + private SparkSession resolve(Future future) { + try { + return future.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private class GetSparkSessionTester implements Callable { + + private HiveConf hiveConf; + private SparkSessionManager sparkSessionManager; + private SessionState sessionState; + + GetSparkSessionTester(HiveConf hiveConf, SparkSessionManager sparkSessionManager, + SessionState sessionState) { + this.hiveConf = hiveConf; + this.sparkSessionManager = sparkSessionManager; + this.sessionState = sessionState; + } + + @Override + public SparkSession call() throws Exception { + SessionState.setCurrentSessionState(sessionState); + return SparkUtilities.getSparkSession(hiveConf, sparkSessionManager); + } + } +}