diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 2e2bf5a..2220af9 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1624,7 +1624,9 @@ "Transport mode of HiveServer2."), HIVE_SERVER2_THRIFT_BIND_HOST("hive.server2.thrift.bind.host", "", "Bind host on which to run the HiveServer2 Thrift service."), - + HIVE_THRIFT_SASL_MESSAGE_LIMIT("hive.thrift.sasl.message.limit", 104857600, + "If the length of incoming sasl message is greater than this, regard it as invalid and close the transport. " + + "Zero or less value disables this. Default is 100MB."), // http (over thrift) transport settings HIVE_SERVER2_THRIFT_HTTP_PORT("hive.server2.thrift.http.port", 10001, "Port number of HiveServer2 Thrift interface when hive.server2.transport.mode is 'http'."), diff --git a/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java b/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java index 3e1ce53..f61526d 100644 --- a/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java +++ b/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoop20SAuthBridge.java @@ -80,8 +80,8 @@ public Server() throws TTransportException { super(); } @Override - public TTransportFactory createTransportFactory(Map saslProps) - throws TTransportException { + public TTransportFactory createTransportFactory(Map saslProps, + int saslMessageLimit) throws TTransportException { TSaslServerTransport.Factory transFactory = new TSaslServerTransport.Factory(); transFactory.addServerDefinition(AuthMethod.DIGEST.getMechanismName(), diff --git a/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java b/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java index 2547da7..c7aa9c6 100644 --- a/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java +++ b/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java @@ -5818,8 +5818,10 @@ public static void startMetaStore(int port, HadoopThriftAuthBridge bridge, conf.getVar(HiveConf.ConfVars.METASTORE_KERBEROS_PRINCIPAL)); // start delegation token manager saslServer.startDelegationTokenSecretManager(conf, baseHandler.getMS(), ServerMode.METASTORE); - transFactory = saslServer.createTransportFactory( - MetaStoreUtils.getMetaStoreSaslProperties(conf)); + int saslMessageLimit = conf.getIntVar(ConfVars.HIVE_THRIFT_SASL_MESSAGE_LIMIT); + transFactory = + saslServer.createTransportFactory(MetaStoreUtils.getMetaStoreSaslProperties(conf), + saslMessageLimit); processor = saslServer.wrapProcessor( new ThriftHiveMetastore.Processor(handler)); LOG.info("Starting DB backed MetaStore Server in Secure Mode"); diff --git a/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java b/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java index 8352951..627fb5e 100644 --- a/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -18,7 +18,6 @@ package org.apache.hive.service.auth; import java.io.IOException; -import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -28,7 +27,6 @@ import java.util.Map; import javax.net.ssl.SSLServerSocket; -import javax.security.auth.login.LoginException; import javax.security.sasl.Sasl; import org.apache.hadoop.hive.conf.HiveConf; @@ -61,28 +59,27 @@ public enum AuthTypes { - NOSASL("NOSASL"), - NONE("NONE"), - LDAP("LDAP"), - KERBEROS("KERBEROS"), - CUSTOM("CUSTOM"), - PAM("PAM"); - - private final String authType; - - AuthTypes(String authType) { - this.authType = authType; - } - - public String getAuthName() { - return authType; - } + NOSASL, NONE, LDAP, KERBEROS, CUSTOM, PAM + } + public static enum TransTypes { + HTTP { + AuthTypes getDefaultAuthType() { + return AuthTypes.NOSASL; + } + }, + BINARY { + AuthTypes getDefaultAuthType() { + return AuthTypes.NONE; + } + }; + abstract AuthTypes getDefaultAuthType(); } - private HadoopThriftAuthBridge.Server saslServer; - private String authTypeStr; - private final String transportMode; + private final HadoopThriftAuthBridge.Server saslServer; + private final AuthTypes authType; + private final TransTypes transportType; + private final int saslMessageLimit; private final HiveConf conf; public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; @@ -90,29 +87,27 @@ public String getAuthName() { public HiveAuthFactory(HiveConf conf) throws TTransportException { this.conf = conf; - transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); - authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); - - // In http mode we use NOSASL as the default auth type - if ("http".equalsIgnoreCase(transportMode)) { - if (authTypeStr == null) { - authTypeStr = AuthTypes.NOSASL.getAuthName(); + saslMessageLimit = conf.getIntVar(ConfVars.HIVE_THRIFT_SASL_MESSAGE_LIMIT); + String transTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); + String authTypeStr = conf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION); + transportType = TransTypes.valueOf(transTypeStr.toUpperCase()); + authType = + authTypeStr == null ? transportType.getDefaultAuthType() : AuthTypes.valueOf(authTypeStr + .toUpperCase()); + if (transportType == TransTypes.BINARY + && authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.name())) { + saslServer = + ShimLoader.getHadoopThriftAuthBridge().createServer( + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); + // start delegation token manager + try { + saslServer.startDelegationTokenSecretManager(conf, null, ServerMode.HIVESERVER2); + } catch (Exception e) { + throw new TTransportException("Failed to start token manager", e); } } else { - if (authTypeStr == null) { - authTypeStr = AuthTypes.NONE.getAuthName(); - } - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - saslServer = ShimLoader.getHadoopThriftAuthBridge() - .createServer(conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)); - // start delegation token manager - try { - saslServer.startDelegationTokenSecretManager(conf, null, ServerMode.HIVESERVER2); - } catch (IOException e) { - throw new TTransportException("Failed to start token manager", e); - } - } + saslServer = null; } } @@ -124,42 +119,28 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { return saslProps; } - public TTransportFactory getAuthTransFactory() throws LoginException { - TTransportFactory transportFactory; - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - try { - transportFactory = saslServer.createTransportFactory(getSaslProperties()); - } catch (TTransportException e) { - throw new LoginException(e.getMessage()); - } - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.PAM.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) { - transportFactory = new TTransportFactory(); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else { - throw new LoginException("Unsupported authentication type " + authTypeStr); + public TTransportFactory getAuthTransFactory() throws Exception { + if (authType == AuthTypes.KERBEROS) { + return saslServer.createTransportFactory(getSaslProperties(), saslMessageLimit); } - return transportFactory; + if (authType == AuthTypes.NOSASL) { + return new TTransportFactory(); + } + return PlainSaslHelper.getPlainTransportFactory(authType.name(), saslMessageLimit); } /** * Returns the thrift processor factory for HiveServer2 running in binary mode + * * @param service * @return * @throws LoginException */ - public TProcessorFactory getAuthProcFactory(ThriftCLIService service) throws LoginException { - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { + public TProcessorFactory getAuthProcFactory(ThriftCLIService service) { + if (authType == AuthTypes.KERBEROS) { return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service); - } else { - return PlainSaslHelper.getPlainProcessorFactory(service); } + return PlainSaslHelper.getPlainProcessorFactory(service); } public String getRemoteUser() { diff --git a/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java b/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java index afc1441..ba987f3 100644 --- a/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java +++ b/service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java @@ -30,6 +30,7 @@ import javax.security.sasl.AuthorizeCallback; import javax.security.sasl.SaslException; +import org.apache.hadoop.hive.thrift.HadoopThriftAuthBridge; import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods; import org.apache.hive.service.auth.PlainSaslServer.SaslPlainProvider; import org.apache.hive.service.cli.thrift.TCLIService.Iface; @@ -42,7 +43,6 @@ import org.apache.thrift.transport.TTransportFactory; public final class PlainSaslHelper { - public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService service) { return new SQLPlainProcessorFactory(service); } @@ -52,16 +52,18 @@ public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService servic Security.addProvider(new SaslPlainProvider()); } - public static TTransportFactory getPlainTransportFactory(String authTypeStr) - throws LoginException { - TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory(); - try { - saslFactory.addServerDefinition("PLAIN", authTypeStr, null, new HashMap(), - new PlainServerCallbackHandler(authTypeStr)); - } catch (AuthenticationException e) { - throw new LoginException("Error setting callback handler" + e); + public static TTransportFactory getPlainTransportFactory(String authTypeStr, int saslMessageLimit) + throws LoginException, AuthenticationException { + TSaslServerTransport.Factory saslTransportFactory; + if (saslMessageLimit > 0) { + saslTransportFactory = + new HadoopThriftAuthBridge.HiveSaslServerTransportFactory(saslMessageLimit); + } else { + saslTransportFactory = new TSaslServerTransport.Factory(); } - return saslFactory; + saslTransportFactory.addServerDefinition("PLAIN", authTypeStr, null, + new HashMap(), new PlainServerCallbackHandler(authTypeStr)); + return saslTransportFactory; } public static TTransport getPlainTransport(String username, String password, diff --git a/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java b/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java index 348c419..036a79d 100644 --- a/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java +++ b/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java @@ -20,12 +20,17 @@ import static org.apache.hadoop.fs.CommonConfigurationKeys.HADOOP_SECURITY_AUTHENTICATION; import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.lang.ref.WeakReference; import java.net.InetAddress; import java.net.Socket; import java.security.PrivilegedAction; import java.security.PrivilegedExceptionAction; +import java.util.Collections; +import java.util.HashMap; import java.util.Locale; import java.util.Map; +import java.util.WeakHashMap; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; @@ -51,6 +56,7 @@ import org.apache.hadoop.security.SaslRpcServer.AuthMethod; import org.apache.hadoop.security.SecurityUtil; import org.apache.hadoop.security.UserGroupInformation; +import org.apache.thrift.EncodingUtils; import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.ProxyUsers; @@ -67,6 +73,8 @@ import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Functions that bridge Thrift's SASL transports to Hadoop's @@ -365,8 +373,8 @@ protected Server(String keytabFile, String principalConf) * @param saslProps Map of SASL properties */ - public TTransportFactory createTransportFactory(Map saslProps) - throws TTransportException { + public TTransportFactory createTransportFactory(Map saslProps, + int saslMessageLimit) throws TTransportException { // Parse out the kerberos principal, host, realm. String kerberosName = realUgi.getUserName(); final String names[] = SaslRpcServer.splitKerberosName(kerberosName); @@ -374,17 +382,19 @@ public TTransportFactory createTransportFactory(Map saslProps) throw new TTransportException("Kerberos principal should have 3 parts: " + kerberosName); } - TSaslServerTransport.Factory transFactory = new TSaslServerTransport.Factory(); - transFactory.addServerDefinition( - AuthMethod.KERBEROS.getMechanismName(), - names[0], names[1], // two parts of kerberos principal - saslProps, - new SaslRpcServer.SaslGssCallbackHandler()); - transFactory.addServerDefinition(AuthMethod.DIGEST.getMechanismName(), - null, SaslRpcServer.SASL_DEFAULT_REALM, - saslProps, new SaslDigestCallbackHandler(secretManager)); - - return new TUGIAssumingTransportFactory(transFactory, realUgi); + TSaslServerTransport.Factory saslTransportFactory; + if (saslMessageLimit > 0) { + saslTransportFactory = new HiveSaslServerTransportFactory(saslMessageLimit); + } else { + saslTransportFactory = new TSaslServerTransport.Factory(); + } + saslTransportFactory.addServerDefinition(AuthMethod.KERBEROS.getMechanismName(), names[0], names[1], + saslProps, new SaslRpcServer.SaslGssCallbackHandler()); + saslTransportFactory + .addServerDefinition(AuthMethod.DIGEST.getMechanismName(), null, + SaslRpcServer.SASL_DEFAULT_REALM, saslProps, new SaslDigestCallbackHandler( + secretManager)); + return new TUGIAssumingTransportFactory(saslTransportFactory, realUgi); } /** @@ -737,4 +747,101 @@ public TTransport run() { } } } + + public static class HiveSaslServerTransportFactory extends TSaslServerTransport.Factory { + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class); + private final int saslMessageLimit; + public HiveSaslServerTransportFactory(int saslMessageLimit) { + this.saslMessageLimit = saslMessageLimit; + } + + private static class TSaslServerDefinition { + public String mechanism; + public String protocol; + public String serverName; + public Map props; + public CallbackHandler cbh; + public TSaslServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + this.mechanism = mechanism; + this.protocol = protocol; + this.serverName = serverName; + this.props = props; + this.cbh = cbh; + } + } + + private static Map> transportMap = Collections + .synchronizedMap(new WeakHashMap>()); + private Map serverDefinitionMap = + new HashMap(); + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + @Override + public TTransport getTransport(TTransport base) { + WeakReference ret = transportMap.get(base); + TSaslServerTransport transport = ret == null ? null : ret.get(); + if (transport == null) { + LOGGER.debug("transport map does not contain key {}", base); + transport = newSaslTransport(base); + try { + transport.open(); + } catch (TTransportException e) { + LOGGER.debug("failed to open server transport", e); + throw new RuntimeException(e); + } + transportMap.put(base, new WeakReference(transport)); + } else { + LOGGER.debug("transport map does contain key {}", base); + } + return transport; + } + + private TSaslServerTransport newSaslTransport(final TTransport base) { + // Anonymous subclass of TSaslServerTransport. TSaslServerTransport#recieveSaslMessage + // is replaced with one that has additional check for the message size. + TSaslServerTransport transport = new TSaslServerTransport(base) { + private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; + @Override + protected SaslResponse receiveSaslMessage() throws TTransportException { + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + byte statusByte = messageHeader[0]; + int length = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES); + if (length > saslMessageLimit) { + base.close(); + throw new TTransportException("Sasl message is too big (" + length + " bytes). " + + "The peer connection is possibly using a protocol other than thrift."); + } + byte[] payload = new byte[length]; + underlyingTransport.readAll(payload, 0, payload.length); + NegotiationStatus status = NegotiationStatus.byValue(statusByte); + if (status == null) { + sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte); + } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) { + try { + String remoteMessage = new String(payload, "UTF-8"); + throw new TTransportException("Peer indicated failure: " + remoteMessage); + } catch (UnsupportedEncodingException e) { + throw new TTransportException(e); + } + } + if (LOGGER.isDebugEnabled()) + LOGGER.debug(getRole() + ": Received message with status {} and payload length {}", + status, payload.length); + return new SaslResponse(status, payload); + } + }; + for (Map.Entry entry : serverDefinitionMap.entrySet()) { + TSaslServerDefinition definition = entry.getValue(); + transport.addServerDefinition(entry.getKey(), definition.protocol, definition.serverName, + definition.props, definition.cbh); + } + return transport; + } + } + }