diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
index 44d9a57..3675fc6 100644
--- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
+++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java
@@ -947,6 +947,8 @@
HIVE_SERVER2_SSL_KEYSTORE_PATH("hive.server2.keystore.path", ""),
HIVE_SERVER2_SSL_KEYSTORE_PASSWORD("hive.server2.keystore.password", ""),
+ HIVE_SERVER2_SASL_MESSAGE_LIMIT("hive.server2.sasl.message.limit", -1),
+
HIVE_SECURITY_COMMAND_WHITELIST("hive.security.command.whitelist", "set,reset,dfs,add,delete,compile"),
HIVE_CONF_RESTRICTED_LIST("hive.conf.restricted.list", "hive.security.authenticator.manager,hive.security.authorization.manager"),
diff --git conf/hive-default.xml.template conf/hive-default.xml.template
index e53df4f..79e356d 100644
--- conf/hive-default.xml.template
+++ conf/hive-default.xml.template
@@ -2709,6 +2709,12 @@
+ hive.server2.sasl.message.limit
+ -1
+ If the length of incoming sasl message is greater than this, regard it as invalid and close the transport. Zero or less value disables this.
+
+
+
hive.convert.join.bucket.mapjoin.tez
false
Whether joins can be automatically converted to bucket map
diff --git service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
index 72b3e7e..462fa2a 100644
--- service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
+++ service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java
@@ -24,7 +24,6 @@
import java.util.HashMap;
import java.util.Map;
-import javax.security.auth.login.LoginException;
import javax.security.sasl.Sasl;
import org.apache.hadoop.hive.conf.HiveConf;
@@ -48,28 +47,19 @@
private static final Logger LOG = LoggerFactory.getLogger(HiveAuthFactory.class);
public static enum AuthTypes {
- NOSASL("NOSASL"),
- NONE("NONE"),
- LDAP("LDAP"),
- KERBEROS("KERBEROS"),
- CUSTOM("CUSTOM"),
- PAM("PAM");
-
- private String authType;
-
- AuthTypes(String authType) {
- this.authType = authType;
- }
-
- public String getAuthName() {
- return authType;
- }
+ NOSASL, NONE, LDAP, KERBEROS, CUSTOM, PAM
+ }
- };
+ public static enum TransTypes {
+ HTTP { AuthTypes getDefaultAuthType() { return AuthTypes.NOSASL; } },
+ BINARY { AuthTypes getDefaultAuthType() { return AuthTypes.NONE; } };
+ abstract AuthTypes getDefaultAuthType();
+ }
- private HadoopThriftAuthBridge.Server saslServer = null;
- private String authTypeStr;
- private String transportMode;
+ private final HadoopThriftAuthBridge.Server saslServer;
+ private final AuthTypes authType;
+ private final TransTypes transportType;
+ private final int saslMessageLimit;
private final HiveConf conf;
public static final String HS2_PROXY_USER = "hive.server2.proxy.user";
@@ -77,33 +67,30 @@ public String getAuthName() {
public HiveAuthFactory(HiveConf conf) throws TTransportException {
this.conf = conf;
- transportMode = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE);
- authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION);
-
- // In http mode we use NOSASL as the default auth type
- if (transportMode.equalsIgnoreCase("http")) {
- if (authTypeStr == null) {
- authTypeStr = AuthTypes.NOSASL.getAuthName();
- }
- }
- else {
- if (authTypeStr == null) {
- authTypeStr = AuthTypes.NONE.getAuthName();
- }
- if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())
- && ShimLoader.getHadoopShims().isSecureShimImpl()) {
- saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(
- conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB),
- conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)
- );
- // start delegation token manager
- try {
- saslServer.startDelegationTokenSecretManager(conf, null);
- } catch (IOException e) {
- throw new TTransportException("Failed to start token manager", e);
- }
+ String transTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_TRANSPORT_MODE);
+ String authTypeStr = conf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION);
+
+ transportType = TransTypes.valueOf(transTypeStr.toUpperCase());
+ authType = authTypeStr == null ? transportType.getDefaultAuthType() :
+ AuthTypes.valueOf(authTypeStr.toUpperCase());
+
+ if (transportType == TransTypes.BINARY
+ && authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.name())
+ && ShimLoader.getHadoopShims().isSecureShimImpl()) {
+ saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer(
+ conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB),
+ conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)
+ );
+ // start delegation token manager
+ try {
+ saslServer.startDelegationTokenSecretManager(conf, null);
+ } catch (Exception e) {
+ throw new TTransportException("Failed to start token manager", e);
}
+ } else {
+ saslServer = null;
}
+ saslMessageLimit = conf.getIntVar(ConfVars.HIVE_SERVER2_SASL_MESSAGE_LIMIT);
}
public Map getSaslProperties() {
@@ -115,42 +102,24 @@ public HiveAuthFactory(HiveConf conf) throws TTransportException {
return saslProps;
}
- public TTransportFactory getAuthTransFactory() throws LoginException {
- TTransportFactory transportFactory;
- if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) {
- try {
- transportFactory = saslServer.createTransportFactory(getSaslProperties());
- } catch (TTransportException e) {
- throw new LoginException(e.getMessage());
- }
- } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) {
- transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
- } else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) {
- transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
- } else if (authTypeStr.equalsIgnoreCase(AuthTypes.PAM.getAuthName())) {
- transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
- } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) {
- transportFactory = new TTransportFactory();
- } else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) {
- transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
- } else {
- throw new LoginException("Unsupported authentication type " + authTypeStr);
+ public TTransportFactory getAuthTransFactory() throws Exception {
+ if (authType == AuthTypes.KERBEROS) {
+ return saslServer.createTransportFactory(getSaslProperties());
+ }
+ if (authType == AuthTypes.NOSASL) {
+ return new TTransportFactory();
}
- return transportFactory;
+ return PlainSaslHelper.getPlainTransportFactory(authType.name(), saslMessageLimit);
}
- public TProcessorFactory getAuthProcFactory(ThriftCLIService service)
- throws LoginException {
- if (transportMode.equalsIgnoreCase("http")) {
+ public TProcessorFactory getAuthProcFactory(ThriftCLIService service) {
+ if (transportType == TransTypes.HTTP) {
return HttpAuthUtils.getAuthProcFactory(service);
}
- else {
- if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) {
- return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service);
- } else {
- return PlainSaslHelper.getPlainProcessorFactory(service);
- }
+ if (authType == AuthTypes.KERBEROS) {
+ return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service);
}
+ return PlainSaslHelper.getPlainProcessorFactory(service);
}
public String getRemoteUser() {
diff --git service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
index dd788c6..83511ec 100644
--- service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
+++ service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java
@@ -18,7 +18,12 @@
package org.apache.hive.service.auth;
import java.io.IOException;
+import java.io.UnsupportedEncodingException;
+import java.lang.ref.WeakReference;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.Map;
+import java.util.WeakHashMap;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
@@ -33,15 +38,18 @@
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hive.service.auth.PlainSaslServer.SaslPlainProvider;
import org.apache.hive.service.auth.AuthenticationProviderFactory.AuthMethods;
-import org.apache.hive.service.cli.thrift.TCLIService;
import org.apache.hive.service.cli.thrift.TCLIService.Iface;
import org.apache.hive.service.cli.thrift.ThriftCLIService;
+import org.apache.thrift.EncodingUtils;
import org.apache.thrift.TProcessor;
import org.apache.thrift.TProcessorFactory;
import org.apache.thrift.transport.TSaslClientTransport;
import org.apache.thrift.transport.TSaslServerTransport;
import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class PlainSaslHelper {
@@ -132,16 +140,19 @@ public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService servic
java.security.Security.addProvider(new SaslPlainProvider());
}
- public static TTransportFactory getPlainTransportFactory(String authTypeStr)
- throws LoginException {
- TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory();
- try {
- saslFactory.addServerDefinition("PLAIN",
+ public static TTransportFactory getPlainTransportFactory(String authTypeStr, int saslMessageLimit)
+ throws LoginException, AuthenticationException {
+ if (saslMessageLimit > 0) {
+ PlainSaslHelper.Factory factory = new PlainSaslHelper.Factory(saslMessageLimit);
+ factory.addServerDefinition("PLAIN",
authTypeStr, null, new HashMap(),
new PlainServerCallbackHandler(authTypeStr));
- } catch (AuthenticationException e) {
- throw new LoginException ("Error setting callback handler" + e);
+ return factory;
}
+ TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory();
+ saslFactory.addServerDefinition("PLAIN",
+ authTypeStr, null, new HashMap(),
+ new PlainServerCallbackHandler(authTypeStr));
return saslFactory;
}
@@ -152,4 +163,107 @@ public static TTransport getPlainTransport(String userName, String passwd,
new PlainClientbackHandler(userName, passwd), underlyingTransport);
}
+ public static class Factory extends TTransportFactory {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class);
+
+ private final int saslMessageLimit;
+
+ public Factory(int saslMessageLimit) {
+ this.saslMessageLimit = saslMessageLimit;
+ }
+
+ private static class TSaslServerDefinition {
+ public String mechanism;
+ public String protocol;
+ public String serverName;
+ public Map props;
+ public CallbackHandler cbh;
+
+ public TSaslServerDefinition(String mechanism, String protocol, String serverName,
+ Map props, CallbackHandler cbh) {
+ this.mechanism = mechanism;
+ this.protocol = protocol;
+ this.serverName = serverName;
+ this.props = props;
+ this.cbh = cbh;
+ }
+ }
+
+ private static Map> transportMap =
+ Collections.synchronizedMap(
+ new WeakHashMap>());
+
+ private Map serverDefinitionMap =
+ new HashMap();
+
+ public void addServerDefinition(String mechanism, String protocol, String serverName,
+ Map props, CallbackHandler cbh) {
+ serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName,
+ props, cbh));
+ }
+
+ @Override
+ public TTransport getTransport(TTransport base) {
+ WeakReference ret = transportMap.get(base);
+ TSaslServerTransport transport = ret == null ? null : ret.get();
+ if (transport == null) {
+ LOGGER.debug("transport map does not contain key {}", base);
+ transport = newSaslTransport(base);
+ try {
+ transport.open();
+ } catch (TTransportException e) {
+ LOGGER.debug("failed to open server transport", e);
+ throw new RuntimeException(e);
+ }
+ transportMap.put(base, new WeakReference(transport));
+ } else {
+ LOGGER.debug("transport map does contain key {}", base);
+ }
+ return transport;
+ }
+
+ private TSaslServerTransport newSaslTransport(final TTransport base) {
+ TSaslServerTransport transport = new TSaslServerTransport(base) {
+
+ private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
+
+ @Override
+ protected SaslResponse receiveSaslMessage() throws TTransportException {
+ underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
+
+ byte statusByte = messageHeader[0];
+ int length = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES);
+ if (length > saslMessageLimit) {
+ base.close();
+ throw new TTransportException("Sasl message is too big (" + length + " bytes)");
+ }
+ byte[] payload = new byte[length];
+ underlyingTransport.readAll(payload, 0, payload.length);
+
+ NegotiationStatus status = NegotiationStatus.byValue(statusByte);
+ if (status == null) {
+ sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
+ } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
+ try {
+ String remoteMessage = new String(payload, "UTF-8");
+ throw new TTransportException("Peer indicated failure: " + remoteMessage);
+ } catch (UnsupportedEncodingException e) {
+ throw new TTransportException(e);
+ }
+ }
+ if (LOGGER.isDebugEnabled())
+ LOGGER.debug(getRole() + ": Received message with status {} and payload length {}",
+ status, payload.length);
+ return new SaslResponse(status, payload);
+ }
+ };
+ for (Map.Entry entry : serverDefinitionMap.entrySet()) {
+ TSaslServerDefinition definition = entry.getValue();
+ transport.addServerDefinition(entry.getKey(),
+ definition.protocol, definition.serverName, definition.props, definition.cbh);
+ }
+ return transport;
+ }
+ }
}