diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 44d9a57..3675fc6 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -947,6 +947,8 @@ HIVE_SERVER2_SSL_KEYSTORE_PATH("hive.server2.keystore.path", ""), HIVE_SERVER2_SSL_KEYSTORE_PASSWORD("hive.server2.keystore.password", ""), + HIVE_SERVER2_SASL_MESSAGE_LIMIT("hive.server2.sasl.message.limit", -1), + HIVE_SECURITY_COMMAND_WHITELIST("hive.security.command.whitelist", "set,reset,dfs,add,delete,compile"), HIVE_CONF_RESTRICTED_LIST("hive.conf.restricted.list", "hive.security.authenticator.manager,hive.security.authorization.manager"), diff --git conf/hive-default.xml.template conf/hive-default.xml.template index e53df4f..79e356d 100644 --- conf/hive-default.xml.template +++ conf/hive-default.xml.template @@ -2709,6 +2709,12 @@ + hive.server2.sasl.message.limit + -1 + If the length of incoming sasl message is greater than this, regard it as invalid and close the transport. Zero or less value disables this. + + + hive.convert.join.bucket.mapjoin.tez false Whether joins can be automatically converted to bucket map diff --git service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java index 72b3e7e..462fa2a 100644 --- service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -24,7 +24,6 @@ import java.util.HashMap; import java.util.Map; -import javax.security.auth.login.LoginException; import javax.security.sasl.Sasl; import org.apache.hadoop.hive.conf.HiveConf; @@ -48,28 +47,19 @@ private static final Logger LOG = LoggerFactory.getLogger(HiveAuthFactory.class); public static enum AuthTypes { - NOSASL("NOSASL"), - NONE("NONE"), - LDAP("LDAP"), - KERBEROS("KERBEROS"), - CUSTOM("CUSTOM"), - PAM("PAM"); - - private String authType; - - AuthTypes(String authType) { - this.authType = authType; - } - - public String getAuthName() { - return authType; - } + NOSASL, NONE, LDAP, KERBEROS, CUSTOM, PAM + } - }; + public static enum TransTypes { + HTTP { AuthTypes getDefaultAuthType() { return AuthTypes.NOSASL; } }, + BINARY { AuthTypes getDefaultAuthType() { return AuthTypes.NONE; } }; + abstract AuthTypes getDefaultAuthType(); + } - private HadoopThriftAuthBridge.Server saslServer = null; - private String authTypeStr; - private String transportMode; + private final HadoopThriftAuthBridge.Server saslServer; + private final AuthTypes authType; + private final TransTypes transportType; + private final int saslMessageLimit; private final HiveConf conf; public static final String HS2_PROXY_USER = "hive.server2.proxy.user"; @@ -77,33 +67,30 @@ public String getAuthName() { public HiveAuthFactory(HiveConf conf) throws TTransportException { this.conf = conf; - transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); - authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); - - // In http mode we use NOSASL as the default auth type - if (transportMode.equalsIgnoreCase("http")) { - if (authTypeStr == null) { - authTypeStr = AuthTypes.NOSASL.getAuthName(); - } - } - else { - if (authTypeStr == null) { - authTypeStr = AuthTypes.NONE.getAuthName(); - } - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName()) - && ShimLoader.getHadoopShims().isSecureShimImpl()) { - saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer( - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) - ); - // start delegation token manager - try { - saslServer.startDelegationTokenSecretManager(conf, null); - } catch (IOException e) { - throw new TTransportException("Failed to start token manager", e); - } + String transTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE); + String authTypeStr = conf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION); + + transportType = TransTypes.valueOf(transTypeStr.toUpperCase()); + authType = authTypeStr == null ? transportType.getDefaultAuthType() : + AuthTypes.valueOf(authTypeStr.toUpperCase()); + + if (transportType == TransTypes.BINARY + && authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.name()) + && ShimLoader.getHadoopShims().isSecureShimImpl()) { + saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer( + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + ); + // start delegation token manager + try { + saslServer.startDelegationTokenSecretManager(conf, null); + } catch (Exception e) { + throw new TTransportException("Failed to start token manager", e); } + } else { + saslServer = null; } + saslMessageLimit = conf.getIntVar(ConfVars.HIVE_SERVER2_SASL_MESSAGE_LIMIT); } public Map getSaslProperties() { @@ -115,42 +102,24 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException { return saslProps; } - public TTransportFactory getAuthTransFactory() throws LoginException { - TTransportFactory transportFactory; - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - try { - transportFactory = saslServer.createTransportFactory(getSaslProperties()); - } catch (TTransportException e) { - throw new LoginException(e.getMessage()); - } - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.PAM.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) { - transportFactory = new TTransportFactory(); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else { - throw new LoginException("Unsupported authentication type " + authTypeStr); + public TTransportFactory getAuthTransFactory() throws Exception { + if (authType == AuthTypes.KERBEROS) { + return saslServer.createTransportFactory(getSaslProperties()); + } + if (authType == AuthTypes.NOSASL) { + return new TTransportFactory(); } - return transportFactory; + return PlainSaslHelper.getPlainTransportFactory(authType.name(), saslMessageLimit); } - public TProcessorFactory getAuthProcFactory(ThriftCLIService service) - throws LoginException { - if (transportMode.equalsIgnoreCase("http")) { + public TProcessorFactory getAuthProcFactory(ThriftCLIService service) { + if (transportType == TransTypes.HTTP) { return HttpAuthUtils.getAuthProcFactory(service); } - else { - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service); - } else { - return PlainSaslHelper.getPlainProcessorFactory(service); - } + if (authType == AuthTypes.KERBEROS) { + return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service); } + return PlainSaslHelper.getPlainProcessorFactory(service); } public String getRemoteUser() { diff --git service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java index dd788c6..83511ec 100644 --- service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java +++ service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java @@ -18,7 +18,12 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.lang.ref.WeakReference; +import java.util.Collections; import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; @@ -33,15 +38,18 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hive.service.auth.PlainSaslServer.SaslPlainProvider; import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods; -import org.apache.hive.service.cli.thrift.TCLIService; import org.apache.hive.service.cli.thrift.TCLIService.Iface; import org.apache.hive.service.cli.thrift.ThriftCLIService; +import org.apache.thrift.EncodingUtils; import org.apache.thrift.TProcessor; import org.apache.thrift.TProcessorFactory; import org.apache.thrift.transport.TSaslClientTransport; import org.apache.thrift.transport.TSaslServerTransport; import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class PlainSaslHelper { @@ -132,16 +140,19 @@ public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService servic java.security.Security.addProvider(new SaslPlainProvider()); } - public static TTransportFactory getPlainTransportFactory(String authTypeStr) - throws LoginException { - TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory(); - try { - saslFactory.addServerDefinition("PLAIN", + public static TTransportFactory getPlainTransportFactory(String authTypeStr, int saslMessageLimit) + throws LoginException, AuthenticationException { + if (saslMessageLimit > 0) { + PlainSaslHelper.Factory factory = new PlainSaslHelper.Factory(saslMessageLimit); + factory.addServerDefinition("PLAIN", authTypeStr, null, new HashMap(), new PlainServerCallbackHandler(authTypeStr)); - } catch (AuthenticationException e) { - throw new LoginException ("Error setting callback handler" + e); + return factory; } + TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory(); + saslFactory.addServerDefinition("PLAIN", + authTypeStr, null, new HashMap(), + new PlainServerCallbackHandler(authTypeStr)); return saslFactory; } @@ -152,4 +163,107 @@ public static TTransport getPlainTransport(String userName, String passwd, new PlainClientbackHandler(userName, passwd), underlyingTransport); } + public static class Factory extends TTransportFactory { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class); + + private final int saslMessageLimit; + + public Factory(int saslMessageLimit) { + this.saslMessageLimit = saslMessageLimit; + } + + private static class TSaslServerDefinition { + public String mechanism; + public String protocol; + public String serverName; + public Map props; + public CallbackHandler cbh; + + public TSaslServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + this.mechanism = mechanism; + this.protocol = protocol; + this.serverName = serverName; + this.props = props; + this.cbh = cbh; + } + } + + private static Map> transportMap = + Collections.synchronizedMap( + new WeakHashMap>()); + + private Map serverDefinitionMap = + new HashMap(); + + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + @Override + public TTransport getTransport(TTransport base) { + WeakReference ret = transportMap.get(base); + TSaslServerTransport transport = ret == null ? null : ret.get(); + if (transport == null) { + LOGGER.debug("transport map does not contain key {}", base); + transport = newSaslTransport(base); + try { + transport.open(); + } catch (TTransportException e) { + LOGGER.debug("failed to open server transport", e); + throw new RuntimeException(e); + } + transportMap.put(base, new WeakReference(transport)); + } else { + LOGGER.debug("transport map does contain key {}", base); + } + return transport; + } + + private TSaslServerTransport newSaslTransport(final TTransport base) { + TSaslServerTransport transport = new TSaslServerTransport(base) { + + private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; + + @Override + protected SaslResponse receiveSaslMessage() throws TTransportException { + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + + byte statusByte = messageHeader[0]; + int length = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES); + if (length > saslMessageLimit) { + base.close(); + throw new TTransportException("Sasl message is too big (" + length + " bytes)"); + } + byte[] payload = new byte[length]; + underlyingTransport.readAll(payload, 0, payload.length); + + NegotiationStatus status = NegotiationStatus.byValue(statusByte); + if (status == null) { + sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte); + } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) { + try { + String remoteMessage = new String(payload, "UTF-8"); + throw new TTransportException("Peer indicated failure: " + remoteMessage); + } catch (UnsupportedEncodingException e) { + throw new TTransportException(e); + } + } + if (LOGGER.isDebugEnabled()) + LOGGER.debug(getRole() + ": Received message with status {} and payload length {}", + status, payload.length); + return new SaslResponse(status, payload); + } + }; + for (Map.Entry entry : serverDefinitionMap.entrySet()) { + TSaslServerDefinition definition = entry.getValue(); + transport.addServerDefinition(entry.getKey(), + definition.protocol, definition.serverName, definition.props, definition.cbh); + } + return transport; + } + } }