diff --git a/ql/src/java/org/apache/hadoop/hive/ql/processors/ResetProcessor.java b/ql/src/java/org/apache/hadoop/hive/ql/processors/ResetProcessor.java index bbd4501..b40879d 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/processors/ResetProcessor.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/processors/ResetProcessor.java @@ -23,7 +23,11 @@ import java.util.List; import java.util.Map; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; + import org.apache.commons.lang3.StringUtils; + import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveVariableSource; import org.apache.hadoop.hive.conf.SystemVariables; @@ -33,7 +37,6 @@ import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType; import org.apache.hadoop.hive.ql.session.SessionState; -import com.google.common.collect.Lists; public class ResetProcessor implements CommandProcessor { @@ -45,8 +48,11 @@ public void init() { @Override public CommandProcessorResponse run(String command) throws CommandNeedRetryException { - SessionState ss = SessionState.get(); + return run(SessionState.get(), command); + } + @VisibleForTesting + CommandProcessorResponse run(SessionState ss, String command) throws CommandNeedRetryException { CommandProcessorResponse authErrResp = CommandUtil.authorizeCommand(ss, HiveOperationType.RESET, Arrays.asList(command)); if (authErrResp != null) { @@ -88,7 +94,7 @@ public CommandProcessorResponse run(String command) throws CommandNeedRetryExcep ? Lists.newArrayList("Resetting " + message + " to default values") : null); } - private void resetOverridesOnly(SessionState ss) { + private static void resetOverridesOnly(SessionState ss) { if (ss.getOverriddenConfigurations().isEmpty()) return; HiveConf conf = new HiveConf(); for (String key : ss.getOverriddenConfigurations().keySet()) { @@ -97,21 +103,20 @@ private void resetOverridesOnly(SessionState ss) { ss.getOverriddenConfigurations().clear(); } - private void resetOverrideOnly(SessionState ss, String varname) { + private static void resetOverrideOnly(SessionState ss, String varname) { if (!ss.getOverriddenConfigurations().containsKey(varname)) return; setSessionVariableFromConf(ss, varname, new HiveConf()); ss.getOverriddenConfigurations().remove(varname); } - private void setSessionVariableFromConf(SessionState ss, String varname, - HiveConf conf) { + private static void setSessionVariableFromConf(SessionState ss, String varname, HiveConf conf) { String value = conf.get(varname); if (value != null) { - ss.getConf().set(varname, value); + SetProcessor.setConf(ss, varname, varname, value, false); } } - private CommandProcessorResponse resetToDefault(SessionState ss, String varname) { + private static CommandProcessorResponse resetToDefault(SessionState ss, String varname) { varname = varname.trim(); try { String nonErrorMessage = null; diff --git a/ql/src/java/org/apache/hadoop/hive/ql/processors/SetProcessor.java b/ql/src/java/org/apache/hadoop/hive/ql/processors/SetProcessor.java index 0ffa182..1458211 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/processors/SetProcessor.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/processors/SetProcessor.java @@ -209,17 +209,22 @@ public static CommandProcessorResponse setVariable( : new CommandProcessorResponse(0, Lists.newArrayList(nonErrorMessage)); } + static String setConf(String varname, String key, String varvalue, boolean register) + throws IllegalArgumentException { + return setConf(SessionState.get(), varname, key, varvalue, register); + } + /** * @return A console message that is not strong enough to fail the command (e.g. deprecation). */ - static String setConf(String varname, String key, String varvalue, boolean register) + static String setConf(SessionState ss, String varname, String key, String varvalue, boolean register) throws IllegalArgumentException { String result = null; - HiveConf conf = SessionState.get().getConf(); + HiveConf conf = ss.getConf(); String value = new VariableSubstitution(new HiveVariableSource() { @Override public Map getHiveVariable() { - return SessionState.get().getHiveVariables(); + return ss.getHiveVariables(); } }).substitute(conf, varvalue); if (conf.getBoolVar(HiveConf.ConfVars.HIVECONFVALIDATION)) { @@ -246,7 +251,7 @@ static String setConf(String varname, String key, String varvalue, boolean regis conf.verifyAndSet(key, value); if (HiveConf.ConfVars.HIVE_EXECUTION_ENGINE.varname.equals(key)) { if (!"spark".equals(value)) { - SessionState.get().closeSparkSession(); + ss.closeSparkSession(); } if ("mr".equals(value)) { result = HiveConf.generateMrDeprecationWarning(); @@ -254,7 +259,7 @@ static String setConf(String varname, String key, String varvalue, boolean regis } } if (register) { - SessionState.get().getOverriddenConfigurations().put(key, value); + ss.getOverriddenConfigurations().put(key, value); } return result; } diff --git a/ql/src/test/org/apache/hadoop/hive/ql/processors/TestResetProcessor.java b/ql/src/test/org/apache/hadoop/hive/ql/processors/TestResetProcessor.java new file mode 100644 index 0000000..509bf88 --- /dev/null +++ b/ql/src/test/org/apache/hadoop/hive/ql/processors/TestResetProcessor.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.processors; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.ql.CommandNeedRetryException; +import org.apache.hadoop.hive.ql.session.SessionState; + +import org.junit.Test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + + +public class TestResetProcessor { + + @Test + public void testResetClosesSparkSession() throws CommandNeedRetryException { + SessionState mockSessionState = createMockSparkSessionState(); + new ResetProcessor().run(mockSessionState, ""); + verify(mockSessionState).closeSparkSession(); + } + + @Test + public void testResetExecutionEngineClosesSparkSession() throws CommandNeedRetryException { + SessionState mockSessionState = createMockSparkSessionState(); + new ResetProcessor().run(mockSessionState, HiveConf.ConfVars.HIVE_EXECUTION_ENGINE.varname); + verify(mockSessionState).closeSparkSession(); + } + + private static SessionState createMockSparkSessionState() { + SessionState mockSessionState = mock(SessionState.class); + Map overriddenConfigurations = new HashMap<>(); + overriddenConfigurations.put(HiveConf.ConfVars.HIVE_EXECUTION_ENGINE.varname, "spark"); + when(mockSessionState.getOverriddenConfigurations()).thenReturn(overriddenConfigurations); + when(mockSessionState.getConf()).thenReturn(new HiveConf()); + return mockSessionState; + } +}