diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index a3c853a..ec341f5 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -2471,6 +2471,15 @@ private static void populateLlapDaemonVarsSet(Set llapDaemonVarsSetLocal " PERFORMANCE: Execution + Performance logs \n" + " VERBOSE: All logs" ), + // HS2 guard rails + HIVE_SERVER2_TRACK_CONNECTIONS("hive.server2.track.connections", "user-ip-address", new StringSet("user", + "ip-address", "user-ip-address"), "Used in conjunction with hive.server2.limit.connections to track and limit" + + "the number of HS2 connections based on user name, ip-address or both."), + HIVE_SERVER2_LIMIT_CONNECTIONS("hive.server2.limit.connections", 0, "Used in conjunction with" + + "hive.server2.limit.connections to track and limit the number of HS2 connections based on user name, " + + "ip-address or both. Connections exceeding this limit for any user/ip-address/both will be dropped. " + + "Default: 0 means tracking is disabled."), + // Enable metric collection for HiveServer2 HIVE_SERVER2_METRICS_ENABLED("hive.server2.metrics.enabled", false, "Enable metrics on the HiveServer2."), diff --git a/service/src/java/org/apache/hive/service/cli/session/SessionManager.java b/service/src/java/org/apache/hive/service/cli/session/SessionManager.java index 9b2ae57..bcf29d5 100644 --- a/service/src/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/service/src/java/org/apache/hive/service/cli/session/SessionManager.java @@ -63,11 +63,21 @@ */ public class SessionManager extends CompositeService { + enum TrackComponent { + USER, + IP_ADDRESS, + USER_IP_ADDRESS, + DISABLE + } + public static final String HIVERCFILE = ".hiverc"; private static final Logger LOG = LoggerFactory.getLogger(CompositeService.class); private HiveConf hiveConf; private final Map handleToSession = new ConcurrentHashMap(); + private final Map connectionsCount = new ConcurrentHashMap<>(); + private TrackComponent trackComponent; + private int connectionLimit; private final OperationManager operationManager = new OperationManager(); private ThreadPoolExecutor backgroundOperationPool; private boolean isOperationLogEnabled; @@ -103,6 +113,18 @@ public synchronized void init(HiveConf hiveConf) { registerOpenSesssionMetrics(metrics); registerActiveSesssionMetrics(metrics); } + + String trackConnectionsStr = hiveConf.getVar(ConfVars.HIVE_SERVER2_TRACK_CONNECTIONS); + trackConnectionsStr = trackConnectionsStr.replaceAll("-", "_").toUpperCase(); + trackComponent = TrackComponent.valueOf(trackConnectionsStr); + + connectionLimit = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS); + if (connectionLimit <= 0) { + trackComponent = TrackComponent.DISABLE; + LOG.info("Tracking of connections is disabled as limit is set to 0"); + } else { + LOG.info("Tracking of connections enabled for {} with limit {}", trackComponent, connectionLimit); + } super.init(hiveConf); } @@ -368,6 +390,13 @@ public HiveSession createSession(SessionHandle sessionHandle, TProtocolVersion p String delegationToken) throws HiveSQLException { + if (!trackComponent.equals(TrackComponent.DISABLE)) { + String trackConnectionStr = getConnectionTrackingStr(username, ipAddress); + if (trackConnectionStr != null) { + incrementConnections(trackConnectionStr); + } + } + HiveSession session; // If doAs is set to true for HiveServer2, we will create a proxy object for the session impl. // Within the proxy object, we wrap the method call in a UserGroupInformation#doAs @@ -439,6 +468,70 @@ public HiveSession createSession(SessionHandle sessionHandle, TProtocolVersion p return session; } + private void incrementConnections(final String trackConnectionStr) throws HiveSQLException { + synchronized (connectionsCount) { + if (connectionsCount.containsKey(trackConnectionStr)) { + final int connectionCount = connectionsCount.get(trackConnectionStr) + 1; + if (connectionCount > connectionLimit) { + String msg = "Connection limit exceeded for " + trackConnectionStr + + ". Current value: " + connectionCount + " Limit: " + connectionLimit; + LOG.info(msg); + throw new HiveSQLException(msg); + } + connectionsCount.put(trackConnectionStr, connectionCount); + } else { + connectionsCount.put(trackConnectionStr, 1); + } + } + LOG.info("Incremented #connections open for {} to {}. limit: {}", trackConnectionStr, + connectionsCount.get(trackConnectionStr), connectionLimit); + } + + private void decrementConnections(final HiveSession session) { + final String trackConnectionStr = getConnectionTrackingStr(session.getUserName(), + session.getIpAddress()); + if (trackConnectionStr != null) { + synchronized (connectionsCount) { + if (connectionsCount.containsKey(trackConnectionStr)) { + int newConnCount = connectionsCount.get(trackConnectionStr) - 1; + if (newConnCount <= 0) { + connectionsCount.remove(trackConnectionStr); + } else { + connectionsCount.put(trackConnectionStr, newConnCount); + } + } + LOG.info("Decremented #connections open for {} to {}. limit: {}", trackConnectionStr, + connectionsCount.get(trackConnectionStr), connectionLimit); + } + } + } + + private String getConnectionTrackingStr(final String username, final String ipAddress) { + String trackConnectionStr = null; + switch (trackComponent) { + case USER: + trackConnectionStr = username == null ? null : username.trim(); + break; + case IP_ADDRESS: + trackConnectionStr = ipAddress == null ? null : ipAddress.trim(); + break; + case USER_IP_ADDRESS: + if (username == null || ipAddress == null || username.isEmpty() || ipAddress.isEmpty()) { + trackConnectionStr = null; + } else { + trackConnectionStr = username.trim() + "-" + ipAddress.trim(); + } + break; + } + + if (trackConnectionStr == null || trackConnectionStr.isEmpty()) { + LOG.warn("Skipping to track limits for connections at level: " + trackComponent + " as " + + "username ({})/ip-address ({}) is missing", username, ipAddress); + trackConnectionStr = null; + } + return trackConnectionStr; + } + public synchronized void closeSession(SessionHandle sessionHandle) throws HiveSQLException { HiveSession session = handleToSession.remove(sessionHandle); if (session == null) { @@ -448,6 +541,7 @@ public synchronized void closeSession(SessionHandle sessionHandle) throws HiveSQ try { session.close(); } finally { + decrementConnections(session); // Shutdown HiveServer2 if it has been deregistered from ZooKeeper and has no active sessions if (!(hiveServer2 == null) && (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_SUPPORT_DYNAMIC_SERVICE_DISCOVERY)) && (hiveServer2.isDeregisteredWithZooKeeper())) { diff --git a/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 6354c8c..fc9e6b2 100644 --- a/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -320,8 +320,6 @@ public TOpenSessionResp OpenSession(TOpenSessionReq req) throws TException { LOG.info("Client protocol version: " + req.getClient_protocol()); TOpenSessionResp resp = new TOpenSessionResp(); try { - Map openConf = req.getConfiguration(); - SessionHandle sessionHandle = getSessionHandle(req, resp); resp.setSessionHandle(sessionHandle.toTSessionHandle()); Map configurationMap = new HashMap(); diff --git a/service/src/test/org/apache/hive/service/cli/TestCLIServiceConnectionLimits.java b/service/src/test/org/apache/hive/service/cli/TestCLIServiceConnectionLimits.java new file mode 100644 index 0000000..b0438b6 --- /dev/null +++ b/service/src/test/org/apache/hive/service/cli/TestCLIServiceConnectionLimits.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hive.service.cli; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hive.conf.HiveConf; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +public class TestCLIServiceConnectionLimits { + @org.junit.Rule + public ExpectedException thrown = ExpectedException.none(); + + private int limit = 10; + + @Test + public void testNoLimit() throws HiveSQLException { + CLIService service = getService("user", 0); + List sessionHandles = new ArrayList<>(); + try { + String user = "foo"; + String ipaddress = "127.0.0.1"; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testIncrementAndDecrementConnectionsUser() throws HiveSQLException { + CLIService service = getService("user", 10); + List sessionHandles = new ArrayList<>(); + try { + String user = "foo"; + String ipaddress = "127.0.0.1"; + // open 5 connections + for (int i = 0; i < limit / 2; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + + // close them all + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + sessionHandles.clear(); + + // open till limit but not exceed + for (int i = 0; i < limit; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testInvalidUserName() throws HiveSQLException { + CLIService service = getService("user", 10); + List sessionHandles = new ArrayList<>(); + try { + String user = null; + String ipaddress = "127.0.0.1"; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + + user = ""; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testInvalidIpaddress() throws HiveSQLException { + CLIService service = getService("ip-address", 10); + List sessionHandles = new ArrayList<>(); + try { + String user = "foo"; + String ipaddress = null; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + + ipaddress = ""; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testInvalidUserIpaddress() throws HiveSQLException { + CLIService service = getService("user-ip-address", 10); + List sessionHandles = new ArrayList<>(); + try { + String user = " "; + String ipaddress = null; + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testConnectionLimitPerUser() throws HiveSQLException { + String track = "user"; + String user = "foo"; + String ipaddress = "127.0.0.1"; + String errMsg = "Connection limit exceeded for " + user + ". Current value: " + (limit + 1) + " Limit: " + limit; + thrown.expect(HiveSQLException.class); + thrown.expectMessage(errMsg); + + CLIService service = getService(track, 10); + List sessionHandles = new ArrayList<>(); + try { + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testConnectionLimitPerIpAddress() throws HiveSQLException { + String track = "ip-address"; + String user = "foo"; + String ipaddress = "127.0.0.1"; + String errMsg = + "Connection limit exceeded for " + ipaddress + ". Current value: " + (limit + 1) + " Limit: " + limit; + thrown.expect(HiveSQLException.class); + thrown.expectMessage(errMsg); + + CLIService service = getService(track, 10); + List sessionHandles = new ArrayList<>(); + try { + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, user, "bar", ipaddress, null); + sessionHandles.add(session); + } + + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + @Test + public void testConnectionLimitPerUserIpAddress() throws HiveSQLException { + String track = "user-ip-address"; + String user = "foo"; + String ipaddress = "127.0.0.1"; + String errMsg = "Connection limit exceeded for " + (user + "-" + ipaddress) + ". Current value: " + (limit + 1) + + " Limit: " + limit; + thrown.expect(HiveSQLException.class); + thrown.expectMessage(errMsg); + + CLIService service = getService(track, 10); + List sessionHandles = new ArrayList<>(); + try { + for (int i = 0; i < limit + 1; i++) { + SessionHandle session = service.openSession(CLIService.SERVER_VERSION, "foo", "bar", ipaddress, null); + sessionHandles.add(session); + } + + } finally { + for (SessionHandle sessionHandle : sessionHandles) { + service.closeSession(sessionHandle); + } + service.stop(); + } + } + + private CLIService getService(String track, int limit) { + HiveConf conf = new HiveConf(); + conf.setVar(HiveConf.ConfVars.HIVE_AUTHORIZATION_MANAGER, + "org.apache.hadoop.hive.ql.security.authorization.plugin.sqlstd.SQLStdHiveAuthorizerFactory"); + if (track != null) { + conf.setVar(HiveConf.ConfVars.HIVE_SERVER2_TRACK_CONNECTIONS, track); + conf.setIntVar(HiveConf.ConfVars.HIVE_SERVER2_LIMIT_CONNECTIONS, limit); + } + CLIService service = new CLIService(null); + service.init(conf); + service.start(); + return service; + } +}