diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 803d52b..e7402f2 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1927,6 +1927,10 @@ public void setSparkConfigUpdated(boolean isSparkConfigUpdated) { "HttpOnly attribute of the HS2 generated cookie."), // binary transport settings + HIVE_SERVER2_THRIFT_AUTH_MAX_RETRIES("hive_server2_thrift_auth_max_retries", 1, + "Number of maximum retries to authenticate HS2 server or HMS server against Kerberos service.\n" + + "This is to mitigate some false alarm auth issues, such that concurrent query executions\n" + + "against single HS2 server may fail to authenticate due to 'Request is a replay'."), HIVE_SERVER2_THRIFT_PORT("hive.server2.thrift.port", 10000, "Port number of HiveServer2 Thrift interface when hive.server2.transport.mode is 'binary'."), HIVE_SERVER2_THRIFT_SASL_QOP("hive.server2.thrift.sasl.qop", "auth", diff --git a/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoopAuthBridge23.java b/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoopAuthBridge23.java index ff56f80..4c981dc 100644 --- a/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoopAuthBridge23.java +++ b/itests/hive-unit-hadoop2/src/test/java/org/apache/hadoop/hive/thrift/TestHadoopAuthBridge23.java @@ -19,6 +19,7 @@ import junit.framework.TestCase; + import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.conf.HiveConf; @@ -41,8 +42,13 @@ import org.apache.hadoop.security.token.delegation.DelegationKey; import org.apache.hadoop.util.StringUtils; import org.apache.thrift.transport.TSaslServerTransport; +import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; +import org.junit.Test; +import org.mockito.Mockito; + +import static org.mockito.Mockito.*; import java.io.ByteArrayInputStream; import java.io.DataInputStream; @@ -78,7 +84,7 @@ public Server() throws TTransportException { super(); } @Override - public TTransportFactory createTransportFactory(Map saslProps) + public TTransportFactory createTransportFactory(Map saslProps, int authMaxRetries) throws TTransportException { TSaslServerTransport.Factory transFactory = new TSaslServerTransport.Factory(); @@ -87,7 +93,7 @@ public TTransportFactory createTransportFactory(Map saslProps) saslProps, new SaslDigestCallbackHandler(secretManager)); - return new TUGIAssumingTransportFactory(transFactory, realUgi); + return new TUGIAssumingTransportFactory(transFactory, realUgi, authMaxRetries); } static DelegationTokenStore TOKEN_STORE = new MemoryTokenStore(); @@ -234,6 +240,51 @@ public void testSaslWithHiveMetaStore() throws Exception { obtainTokenAndAddIntoUGI(clientUgi, "tokenForFooTablePartition"); } + /** + * Verifies that the expected result returned after 2 unsuccessful retries + * @throws Exception + */ + @Test + public void testRetryGetTransport() throws Exception { + TTransport inputTransport = Mockito.mock(TTransport.class); + TTransport expectedTransport = Mockito.mock(TTransport.class); + TTransportFactory mockWrapped = Mockito.mock(TTransportFactory.class); + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + when(mockWrapped.getTransport(any(TTransport.class))) + .thenThrow(new RuntimeException(new TTransportException())) + .thenThrow(new RuntimeException(new TTransportException())) + .thenReturn(expectedTransport); + + TTransportFactory factory = new HadoopThriftAuthBridge.Server.TUGIAssumingTransportFactory(mockWrapped, ugi, 3); + TTransport transport = factory.getTransport(inputTransport); + + assertEquals(expectedTransport, transport); + verify(mockWrapped, times(3)).getTransport(any(TTransport.class)); + } + + /** + * Verifies exception is thrown after 3 unsuccessful retries + * @throws Exception + */ + @Test + public void testRetryGetTransport2() throws Exception { + Exception expectedException = new RuntimeException(new TTransportException()); + TTransport inputTransport = Mockito.mock(TTransport.class); + TTransportFactory mockWrapped = Mockito.mock(TTransportFactory.class); + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + when(mockWrapped.getTransport(any(TTransport.class))) + .thenThrow(expectedException); + + try { + TTransportFactory factory = new HadoopThriftAuthBridge.Server.TUGIAssumingTransportFactory(mockWrapped, ugi, 3); + factory.getTransport(inputTransport); + } catch(Exception e) { + assertEquals(expectedException, e); + } finally { + verify(mockWrapped, times(3)).getTransport(any(TTransport.class)); + } + } + public void testMetastoreProxyUser() throws Exception { setup(); diff --git a/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java b/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java index 00602e1..1dd5e21 100644 --- a/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java +++ b/metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java @@ -6003,8 +6003,9 @@ public static void startMetaStore(int port, HadoopThriftAuthBridge bridge, conf.getVar(HiveConf.ConfVars.METASTORE_KERBEROS_PRINCIPAL)); // start delegation token manager saslServer.startDelegationTokenSecretManager(conf, baseHandler, ServerMode.METASTORE); + int authMaxRetries = conf.getIntVar(ConfVars.HIVE_SERVER2_THRIFT_AUTH_MAX_RETRIES); transFactory = saslServer.createTransportFactory( - MetaStoreUtils.getMetaStoreSaslProperties(conf)); + MetaStoreUtils.getMetaStoreSaslProperties(conf), authMaxRetries); processor = saslServer.wrapProcessor( new ThriftHiveMetastore.Processor(handler)); LOG.info("Starting DB backed MetaStore Server in Secure Mode"); diff --git a/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java b/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java index 3471f12..0c7c9ee 100644 --- a/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java +++ b/service/src/java/org/apache/hive/service/auth/HiveAuthFactory.java @@ -140,7 +140,8 @@ public TTransportFactory getAuthTransFactory() throws LoginException { TTransportFactory transportFactory; if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) { try { - transportFactory = saslServer.createTransportFactory(getSaslProperties()); + int authMaxRetries = conf.getIntVar(ConfVars.HIVE_SERVER2_THRIFT_AUTH_MAX_RETRIES); + transportFactory = saslServer.createTransportFactory(getSaslProperties(), authMaxRetries); } catch (TTransportException e) { throw new LoginException(e.getMessage()); } diff --git a/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java b/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java index d2b47be..7d514ab 100644 --- a/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java +++ b/shims/common/src/main/java/org/apache/hadoop/hive/thrift/HadoopThriftAuthBridge.java @@ -26,6 +26,7 @@ import java.security.PrivilegedExceptionAction; import java.util.Locale; import java.util.Map; +import java.util.Random; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; @@ -44,7 +45,6 @@ import org.slf4j.LoggerFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.hive.shims.Utils; import org.apache.hadoop.hive.thrift.client.TUGIAssumingTransport; import org.apache.hadoop.security.SaslRpcServer; @@ -367,7 +367,7 @@ protected Server(String keytabFile, String principalConf) * @param saslProps Map of SASL properties */ - public TTransportFactory createTransportFactory(Map saslProps) + public TTransportFactory createTransportFactory(Map saslProps, int authMaxRetries) throws TTransportException { // Parse out the kerberos principal, host, realm. String kerberosName = realUgi.getUserName(); @@ -386,7 +386,7 @@ public TTransportFactory createTransportFactory(Map saslProps) null, SaslRpcServer.SASL_DEFAULT_REALM, saslProps, new SaslDigestCallbackHandler(secretManager)); - return new TUGIAssumingTransportFactory(transFactory, realUgi); + return new TUGIAssumingTransportFactory(transFactory, realUgi, authMaxRetries); } /** @@ -724,12 +724,14 @@ public Boolean run() { static class TUGIAssumingTransportFactory extends TTransportFactory { private final UserGroupInformation ugi; private final TTransportFactory wrapped; + private final int authMaxRetries; - public TUGIAssumingTransportFactory(TTransportFactory wrapped, UserGroupInformation ugi) { + public TUGIAssumingTransportFactory(TTransportFactory wrapped, UserGroupInformation ugi, int authMaxRetries) { assert wrapped != null; assert ugi != null; this.wrapped = wrapped; this.ugi = ugi; + this.authMaxRetries = authMaxRetries; } @@ -738,7 +740,28 @@ public TTransport getTransport(final TTransport trans) { return ugi.doAs(new PrivilegedAction() { @Override public TTransport run() { - return wrapped.getTransport(trans); + // Retry the authentication after sleeping for random microseconds + short numRetries = 0; + Random rand = new Random(); + + while (true) { + try { + return wrapped.getTransport(trans); + } catch(RuntimeException e) { + if (e.getCause() instanceof TTransportException) { + if (++numRetries < authMaxRetries) { + LOG.warn(e.getMessage()); + try { + Thread.sleep(rand.nextInt(1000)); + } catch (InterruptedException ie) { + } + continue; + } + } + + throw e; + } + } } }); }