diff --git a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java index 69fda45..a005357 100644 --- a/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java +++ b/common/src/java/org/apache/hadoop/hive/conf/HiveConf.java @@ -1852,6 +1852,9 @@ public void setSparkConfigUpdated(boolean isSparkConfigUpdated) { new TimeValidator(TimeUnit.MILLISECONDS), "Time that HiveServer2 will wait before responding to asynchronous calls that use long polling"), + HIVE_SESSION_IMPL_CLASSNAME("hive.session.impl.classname", null, "Classname for custom implementation of hive session"), + HIVE_SESSION_IMPL_WITH_UGI_CLASSNAME("hive.session.impl.withugi.classname", null, "Classname for custom implementation of hive session with UGI"), + // HiveServer2 auth configuration HIVE_SERVER2_AUTHENTICATION("hive.server2.authentication", "NONE", new StringSet("NOSASL", "NONE", "LDAP", "KERBEROS", "PAM", "CUSTOM"), diff --git a/service/src/java/org/apache/hive/service/cli/session/SessionManager.java b/service/src/java/org/apache/hive/service/cli/session/SessionManager.java index 36a30b1..01927ea 100644 --- a/service/src/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/service/src/java/org/apache/hive/service/cli/session/SessionManager.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; +import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -67,6 +68,8 @@ private volatile boolean shutdown; // The HiveServer2 instance running this service private final HiveServer2 hiveServer2; + private String sessionImplWithUGIclassName; + private String sessionImplclassName; public SessionManager(HiveServer2 hiveServer2) { super(SessionManager.class.getSimpleName()); @@ -82,9 +85,15 @@ public synchronized void init(HiveConf hiveConf) { } createBackgroundOperationPool(); addService(operationManager); + initSessionImplClassName(); super.init(hiveConf); } + private void initSessionImplClassName() { + this.sessionImplclassName = hiveConf.getVar(ConfVars.HIVE_SESSION_IMPL_CLASSNAME); + this.sessionImplWithUGIclassName = hiveConf.getVar(ConfVars.HIVE_SESSION_IMPL_WITH_UGI_CLASSNAME); + } + private void createBackgroundOperationPool() { int poolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS); LOG.info("HiveServer2: Background operation thread pool size: " + poolSize); @@ -245,12 +254,35 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str // If doAs is set to true for HiveServer2, we will create a proxy object for the session impl. // Within the proxy object, we wrap the method call in a UserGroupInformation#doAs if (withImpersonation) { - HiveSessionImplwithUGI sessionWithUGI = new HiveSessionImplwithUGI(protocol, username, password, - hiveConf, ipAddress, delegationToken); - session = HiveSessionProxy.getProxy(sessionWithUGI, sessionWithUGI.getSessionUgi()); - sessionWithUGI.setProxySession(session); + HiveSessionImplwithUGI hiveSessionUgi; + if (sessionImplWithUGIclassName == null) { + hiveSessionUgi = new HiveSessionImplwithUGI(protocol, username, password, + hiveConf, ipAddress, delegationToken); + } else { + try { + Class clazz = Class.forName(sessionImplWithUGIclassName); + Constructor constructor = clazz.getConstructor(String.class, String.class, Map.class, String.class); + hiveSessionUgi = (HiveSessionImplwithUGI) constructor.newInstance(new Object[] + {username, password, sessionConf, delegationToken}); + } catch (Exception e) { + throw new HiveSQLException("Cannot initilize session class:" + sessionImplWithUGIclassName); + } + } + session = HiveSessionProxy.getProxy(hiveSessionUgi, hiveSessionUgi.getSessionUgi()); + hiveSessionUgi.setProxySession(session); } else { - session = new HiveSessionImpl(protocol, username, password, hiveConf, ipAddress); + if (sessionImplclassName == null) { + session = new HiveSessionImpl(protocol, username, password, hiveConf, ipAddress); + } else { + try { + Class clazz = Class.forName(sessionImplclassName); + Constructor constructor = clazz.getConstructor(String.class, String.class, Map.class); + session = (HiveSession) constructor.newInstance(new Object[] + {username, password, sessionConf}); + } catch (Exception e) { + throw new HiveSQLException("Cannot initilize session class:" + sessionImplclassName); + } + } } session.setSessionManager(this); session.setOperationManager(operationManager); diff --git a/service/src/test/org/apache/hive/service/cli/session/TestPluggableHiveSessionImpl.java b/service/src/test/org/apache/hive/service/cli/session/TestPluggableHiveSessionImpl.java new file mode 100644 index 0000000..b3cace9 --- /dev/null +++ b/service/src/test/org/apache/hive/service/cli/session/TestPluggableHiveSessionImpl.java @@ -0,0 +1,220 @@ +package org.apache.hive.service.cli.session; + +import junit.framework.Assert; +import junit.framework.TestCase; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.IMetaStoreClient; +import org.apache.hadoop.hive.ql.session.SessionState; +import org.apache.hive.service.auth.HiveAuthFactory; +import org.apache.hive.service.cli.*; +import org.apache.hive.service.cli.operation.OperationManager; +import org.apache.hive.service.cli.thrift.EmbeddedThriftBinaryCLIService; +import org.apache.hive.service.cli.thrift.TProtocolVersion; +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.List; +import java.util.Map; + +public class TestPluggableHiveSessionImpl extends TestCase { + + private HiveConf hiveConf; + private ThriftCLIServiceClient client; + private EmbeddedThriftBinaryCLIService service; + + @Override + @Before + public void setUp() { + hiveConf = new HiveConf(); + hiveConf.setVar(HiveConf.ConfVars.HIVE_SESSION_IMPL_CLASSNAME, TestHiveSessionImpl.class.getName()); + service = new EmbeddedThriftBinaryCLIService(); + service.init(hiveConf); + client = new ThriftCLIServiceClient(service); + } + + + @Test + public void testSessionImpl() { + SessionHandle sessionHandle = null; + try { + sessionHandle = client.openSession("tom", "password"); + Assert.assertEquals(TestHiveSessionImpl.class.getName(), + service.getHiveConf().getVar(HiveConf.ConfVars.HIVE_SESSION_IMPL_CLASSNAME)); + client.closeSession(sessionHandle); + } catch (HiveSQLException e) { + e.printStackTrace(); + } + } + + class TestHiveSessionImpl implements HiveSession { + + @Override public void open(Map sessionConfMap) throws Exception { + + } + + @Override public IMetaStoreClient getMetaStoreClient() throws HiveSQLException { + return null; + } + + @Override public GetInfoValue getInfo(GetInfoType getInfoType) throws HiveSQLException { + return null; + } + + @Override public OperationHandle executeStatement(String statement, Map confOverlay) + throws HiveSQLException { + return null; + } + + @Override public OperationHandle executeStatementAsync(String statement, Map confOverlay) + throws HiveSQLException { + return null; + } + + @Override public OperationHandle getTypeInfo() throws HiveSQLException { + return null; + } + + @Override public OperationHandle getCatalogs() throws HiveSQLException { + return null; + } + + @Override public OperationHandle getSchemas(String catalogName, String schemaName) throws HiveSQLException { + return null; + } + + @Override + public OperationHandle getTables(String catalogName, String schemaName, String tableName, List tableTypes) + throws HiveSQLException { + return null; + } + + @Override public OperationHandle getTableTypes() throws HiveSQLException { + return null; + } + + @Override + public OperationHandle getColumns(String catalogName, String schemaName, String tableName, String columnName) + throws HiveSQLException { + return null; + } + + @Override public OperationHandle getFunctions(String catalogName, String schemaName, String functionName) + throws HiveSQLException { + return null; + } + + @Override public void close() throws HiveSQLException { + + } + + @Override public void cancelOperation(OperationHandle opHandle) throws HiveSQLException { + + } + + @Override public void closeOperation(OperationHandle opHandle) throws HiveSQLException { + + } + + @Override public TableSchema getResultSetMetadata(OperationHandle opHandle) throws HiveSQLException { + return null; + } + + @Override + public RowSet fetchResults(OperationHandle opHandle, FetchOrientation orientation, long maxRows, + FetchType fetchType) + throws HiveSQLException { + return null; + } + + @Override public String getDelegationToken(HiveAuthFactory authFactory, String owner, String renewer) + throws HiveSQLException { + return null; + } + + @Override public void cancelDelegationToken(HiveAuthFactory authFactory, String tokenStr) throws HiveSQLException { + + } + + @Override public void renewDelegationToken(HiveAuthFactory authFactory, String tokenStr) throws HiveSQLException { + + } + + @Override public void closeExpiredOperations() { + + } + + @Override public long getNoOperationTime() { + return 0; + } + + @Override public TProtocolVersion getProtocolVersion() { + return null; + } + + @Override public void setSessionManager(SessionManager sessionManager) { + + } + + @Override public SessionManager getSessionManager() { + return null; + } + + @Override public void setOperationManager(OperationManager operationManager) { + + } + + @Override public boolean isOperationLogEnabled() { + return false; + } + + @Override public File getOperationLogSessionDir() { + return null; + } + + @Override public void setOperationLogSessionDir(File operationLogRootDir) { + + } + + @Override public SessionHandle getSessionHandle() { + return null; + } + + @Override public String getUsername() { + return null; + } + + @Override public String getPassword() { + return null; + } + + @Override public HiveConf getHiveConf() { + return null; + } + + @Override public SessionState getSessionState() { + return null; + } + + @Override public String getUserName() { + return null; + } + + @Override public void setUserName(String userName) { + + } + + @Override public String getIpAddress() { + return null; + } + + @Override public void setIpAddress(String ipAddress) { + + } + + @Override public long getLastAccessTime() { + return 0; + } + } +}