diff --git common/src/java/org/apache/hadoop/hive/conf/HiveConf.java common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 5dbd1a5..871e2bc 100644 --- common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -885,6 +885,8 @@ HIVE_SERVER2_SSL_KEYSTORE_PATH("hive.server2.keystore.path", ""), HIVE_SERVER2_SSL_KEYSTORE_PASSWORD("hive.server2.keystore.password", ""), + HIVE_SERVER2_SASL_MESSAGE_LIMIT("hive.server2.sasl.message.limit", -1), + HIVE_SECURITY_COMMAND_WHITELIST("hive.security.command.whitelist", "set,reset,dfs,add,delete,compile"), HIVE_CONF_RESTRICTED_LIST("hive.conf.restricted.list", "hive.security.authenticator.manager,hive.security.authorization.manager"), diff --git service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java index d8ba3aa..5da62cf 100644 --- service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -65,24 +65,44 @@ public String getAuthName() { }; - private HadoopThriftAuthBridge.Server saslServer = null; - private String authTypeStr; - HiveConf conf; + private final HadoopThriftAuthBridge.Server saslServer; + private final String authTypeStr; + private final AuthTypes authType; + private final int saslMessageLimit; + private final HiveConf conf; public HiveAuthFactory() throws TTransportException { conf = new HiveConf(); - authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); - if (authTypeStr == null) { - authTypeStr = AuthTypes.NONE.getAuthName(); + String authConf = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION); + if (authConf == null) { + authType = AuthTypes.NONE; + authTypeStr = authType.getAuthName(); + } else { + authType = AuthTypes.valueOf(authConf.toUpperCase());; + authTypeStr = authConf; } - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName()) + if (authType == AuthTypes.KERBEROS && ShimLoader.getHadoopShims().isSecureShimImpl()) { saslServer = ShimLoader.getHadoopThriftAuthBridge().createServer( - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), - conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) - ); + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB), + conf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL) + ); + } else { + saslServer = null; } + saslMessageLimit = conf.getIntVar(ConfVars.HIVE_SERVER2_SASL_MESSAGE_LIMIT); + } + + public TTransportFactory getAuthTransFactory() throws LoginException { + if (authType == AuthTypes.KERBEROS) { + try { + return saslServer.createTransportFactory(getSaslProperties()); + } catch (TTransportException e) { + throw new LoginException(e.getMessage()); + } + } + return PlainSaslHelper.getPlainTransportFactory(authTypeStr, saslMessageLimit); } public Map getSaslProperties() { @@ -104,33 +124,9 @@ public HiveAuthFactory() throws TTransportException { return saslProps; } - public TTransportFactory getAuthTransFactory() throws LoginException { - - TTransportFactory transportFactory; - - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { - try { - transportFactory = saslServer.createTransportFactory(getSaslProperties()); - } catch (TTransportException e) { - throw new LoginException(e.getMessage()); - } - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) { - transportFactory = new TTransportFactory(); - } else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) { - transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr); - } else { - throw new LoginException("Unsupported authentication type " + authTypeStr); - } - return transportFactory; - } - public TProcessorFactory getAuthProcFactory(ThriftCLIService service) throws LoginException { - if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { + if (authType == AuthTypes.KERBEROS) { return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service); } else { return PlainSaslHelper.getPlainProcessorFactory(service); diff --git service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java index 15b1675..3c9ae2c 100644 --- service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java +++ service/src/java/org/apache/hive/service/auth/PlainSaslHelper.java @@ -18,7 +18,12 @@ package org.apache.hive.service.auth; import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.lang.ref.WeakReference; +import java.util.Collections; import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; @@ -34,12 +39,16 @@ import org.apache.hive.service.cli.thrift.TCLIService; import org.apache.hive.service.cli.thrift.TCLIService.Iface; import org.apache.hive.service.cli.thrift.ThriftCLIService; +import org.apache.thrift.EncodingUtils; import org.apache.thrift.TProcessor; import org.apache.thrift.TProcessorFactory; import org.apache.thrift.transport.TSaslClientTransport; import org.apache.thrift.transport.TSaslServerTransport; import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class PlainSaslHelper { @@ -123,7 +132,14 @@ public static TProcessorFactory getPlainProcessorFactory(ThriftCLIService servic java.security.Security.addProvider(new SaslPlainProvider()); } - public static TTransportFactory getPlainTransportFactory(String authTypeStr) { + public static TTransportFactory getPlainTransportFactory(String authTypeStr, int saslMessageLimit) { + if (saslMessageLimit > 0) { + PlainSaslHelper.Factory saslFactory = new PlainSaslHelper.Factory(saslMessageLimit); + saslFactory.addServerDefinition("PLAIN", + authTypeStr, null, new HashMap(), + new PlainServerCallbackHandler()); + return saslFactory; + } TSaslServerTransport.Factory saslFactory = new TSaslServerTransport.Factory(); saslFactory.addServerDefinition("PLAIN", authTypeStr, null, new HashMap(), @@ -138,4 +154,107 @@ public static TTransport getPlainTransport(String userName, String passwd, new PlainClientbackHandler(userName, passwd), underlyingTransport); } + public static class Factory extends TTransportFactory { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class); + + private final int saslMessageLimit; + + public Factory(int saslMessageLimit) { + this.saslMessageLimit = saslMessageLimit; + } + + private static class TSaslServerDefinition { + public String mechanism; + public String protocol; + public String serverName; + public Map props; + public CallbackHandler cbh; + + public TSaslServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + this.mechanism = mechanism; + this.protocol = protocol; + this.serverName = serverName; + this.props = props; + this.cbh = cbh; + } + } + + private static Map> transportMap = + Collections.synchronizedMap( + new WeakHashMap>()); + + private Map serverDefinitionMap = + new HashMap(); + + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + @Override + public TTransport getTransport(TTransport base) { + WeakReference ret = transportMap.get(base); + TSaslServerTransport transport = ret == null ? null : ret.get(); + if (transport == null) { + LOGGER.debug("transport map does not contain key {}", base); + transport = newSaslTransport(base); + try { + transport.open(); + } catch (TTransportException e) { + LOGGER.debug("failed to open server transport", e); + throw new RuntimeException(e); + } + transportMap.put(base, new WeakReference(transport)); + } else { + LOGGER.debug("transport map does contain key {}", base); + } + return transport; + } + + private TSaslServerTransport newSaslTransport(final TTransport base) { + TSaslServerTransport transport = new TSaslServerTransport(base) { + + private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; + + @Override + protected SaslResponse receiveSaslMessage() throws TTransportException { + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + + byte statusByte = messageHeader[0]; + int length = EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES); + if (length > saslMessageLimit) { + base.close(); + throw new TTransportException("Sasl message is too big (" + length + " bytes)"); + } + byte[] payload = new byte[length]; + underlyingTransport.readAll(payload, 0, payload.length); + + NegotiationStatus status = NegotiationStatus.byValue(statusByte); + if (status == null) { + sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte); + } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) { + try { + String remoteMessage = new String(payload, "UTF-8"); + throw new TTransportException("Peer indicated failure: " + remoteMessage); + } catch (UnsupportedEncodingException e) { + throw new TTransportException(e); + } + } + if (LOGGER.isDebugEnabled()) + LOGGER.debug(getRole() + ": Received message with status {} and payload length {}", + status, payload.length); + return new SaslResponse(status, payload); + } + }; + for (Map.Entry entry : serverDefinitionMap.entrySet()) { + TSaslServerDefinition definition = entry.getValue(); + transport.addServerDefinition(entry.getKey(), + definition.protocol, definition.serverName, definition.props, definition.cbh); + } + return transport; + } + } }