diff --git a/conf/hbase-site.xml b/conf/hbase-site.xml
index c516ac7..383da86 100644
--- a/conf/hbase-site.xml
+++ b/conf/hbase-site.xml
@@ -21,4 +21,32 @@
*/
-->
+
+ hbase.rootdir
+ file:///home/hbase/hbase-data
+
+
+ hbase.zookeeper.property.dataDir
+ /home/hbase/zookeeper-data
+
+
+ hbase.security.authentication
+ kerberos
+
+
+ hbase.master.kerberos.principal
+ hbase/23a03935850c@EXAMPLE.COM
+
+
+ hbase.regionserver.kerberos.principal
+ hbase/23a03935850c@EXAMPLE.COM
+
+
+ hbase.master.keytab.file
+ /usr/src/hbase/hbase-host.keytab
+
+
+ hbase.regionserver.keytab.file
+ /usr/src/hbase/hbase-host.keytab
+
diff --git a/conf/log4j.properties b/conf/log4j.properties
index 74b13b1..f920377 100644
--- a/conf/log4j.properties
+++ b/conf/log4j.properties
@@ -95,6 +95,8 @@ log4j.appender.asyncconsole.target=System.err
log4j.logger.org.apache.zookeeper=INFO
#log4j.logger.org.apache.hadoop.fs.FSNamesystem=DEBUG
log4j.logger.org.apache.hadoop.hbase=INFO
+log4j.logger.org.apache.hadoop.hbase.security=DEBUG
+log4j.logger.org.apache.hadoop.hbase.ipc=DEBUG
log4j.logger.org.apache.hadoop.hbase.META=INFO
# Make these two classes INFO-level. Make them DEBUG to see more zk debug.
log4j.logger.org.apache.hadoop.hbase.zookeeper.ZKUtil=INFO
diff --git a/conf/regionservers b/conf/regionservers
index 2fbb50c..e69de29 100644
--- a/conf/regionservers
+++ b/conf/regionservers
@@ -1 +0,0 @@
-localhost
diff --git a/hbase-native-client/connection/BUCK b/hbase-native-client/connection/BUCK
index 19536d5..696dab6 100644
--- a/hbase-native-client/connection/BUCK
+++ b/hbase-native-client/connection/BUCK
@@ -22,6 +22,7 @@ cxx_library(
exported_headers=[
"client-dispatcher.h",
"client-handler.h",
+ "sasl-handler.h",
"connection-factory.h",
"connection-pool.h",
"connection-id.h",
@@ -50,8 +51,16 @@ cxx_library(
"//third-party:wangle",
],
compiler_flags=['-Weffc++'],
+ linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'],
+ exported_linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'],
+ #linker_flags = ['-L/usr/local/lib','-lsasl2'],
+ #exported_linker_flags = ['-L/usr/local/lib','-lsasl2'],
visibility=['//core/...',],)
cxx_test(
name="connection-pool-test",
srcs=["connection-pool-test.cc",],
deps=[":connection",],)
+cxx_test(
+ name="sasl-test",
+ srcs=["sasl-test.cc",],
+ deps=[":connection",],)
diff --git a/hbase-native-client/connection/client-handler.cc b/hbase-native-client/connection/client-handler.cc
index af84572..2f2645b 100644
--- a/hbase-native-client/connection/client-handler.cc
+++ b/hbase-native-client/connection/client-handler.cc
@@ -36,9 +36,10 @@ using hbase::pb::ResponseHeader;
using hbase::pb::GetResponse;
using google::protobuf::Message;
-ClientHandler::ClientHandler(std::string user_name, std::shared_ptr codec)
+ClientHandler::ClientHandler(std::string user_name, std::shared_ptr codec,
+ std::shared_ptr conf)
: user_name_(user_name),
- serde_(codec),
+ serde_(codec, conf),
once_flag_(std::make_unique()),
resp_msgs_(
make_unique>>(
@@ -103,10 +104,9 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) {
// We need to send the header once.
// So use call_once to make sure that only one thread wins this.
std::call_once((*once_flag_), [ctx, this]() {
- auto pre = serde_.Preamble();
auto header = serde_.Header(user_name_);
- pre->appendChain(std::move(header));
- ctx->fireWrite(std::move(pre));
+ //LOG(INFO) << "writing " << user_name_ << " "<< header->length();
+ ctx->fireWrite(std::move(header));
});
// Now store the call id to response.
@@ -115,5 +115,6 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) {
VLOG(1) << "Writing RPC Request with call_id:" << r->call_id();
// Send the data down the pipeline.
- return ctx->fireWrite(serde_.Request(r->call_id(), r->method(), r->req_msg().get()));
+ std::unique_ptr iob = serde_.Request(r->call_id(), r->method(), r->req_msg().get());
+ return ctx->fireWrite(std::move(iob));
}
diff --git a/hbase-native-client/connection/client-handler.h b/hbase-native-client/connection/client-handler.h
index afb8e62..f4216f3 100644
--- a/hbase-native-client/connection/client-handler.h
+++ b/hbase-native-client/connection/client-handler.h
@@ -59,7 +59,8 @@ class ClientHandler
* Create the handler
* @param user_name the user name of the user running this process.
*/
- explicit ClientHandler(std::string user_name, std::shared_ptr codec);
+ explicit ClientHandler(std::string user_name, std::shared_ptr codec,
+ std::shared_ptr conf);
/**
* Get bytes from the wire.
diff --git a/hbase-native-client/connection/connection-factory.cc b/hbase-native-client/connection/connection-factory.cc
index 832b00f..044cfde 100644
--- a/hbase-native-client/connection/connection-factory.cc
+++ b/hbase-native-client/connection/connection-factory.cc
@@ -18,21 +18,376 @@
*/
#include "connection/connection-factory.h"
+#include "connection/sasl-handler.h"
#include
#include "connection/client-dispatcher.h"
+#include "connection/client-handler.h"
#include "connection/pipeline.h"
#include "connection/service.h"
+#include "serde/rpc.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
using namespace folly;
-using namespace hbase;
+using namespace wangle;
using std::chrono::milliseconds;
using std::chrono::nanoseconds;
+namespace hbase {
+// the following is needed to avoid compilation error due to invalid casting
+extern "C" {
+ int GetPluginPath(void *context,
+ const char **path) {
+ const char *searchpath = (const char *) context;
+
+ if (!path)
+ return SASL_BADPARAM;
+
+ if (searchpath) {
+ *path = searchpath;
+ }
+
+ return SASL_OK;
+ }
+ int Simple(void *context,
+ int id,
+ const char **result,
+ unsigned *len) {
+ const char *value = (const char *)context;
+
+ if (!result)
+ return SASL_BADPARAM;
+
+ switch (id) {
+ case SASL_CB_USER: {
+ *result = value;
+ if (len) *len = value ? (unsigned) strlen(value) : 0;
+ }
+ break;
+ default:
+ return SASL_BADPARAM;
+ }
+ return SASL_OK;
+ }
+ int SaslLogFn(void *context __attribute__((unused)),
+ int priority,
+ const char *message) {
+ const char *label;
+
+ if (! message)
+ return SASL_BADPARAM;
+
+ switch (priority) {
+ case SASL_LOG_ERR:
+ LOG(ERROR) << ": SASL " << message;
+ break;
+ case SASL_LOG_NOTE:
+ LOG(INFO) << ": SASL " << message;
+ break;
+ default:
+ LOG(INFO) << ": SASL " << message;
+ break;
+ }
+
+ return SASL_OK;
+ }
+};
+
+SaslHandler::SaslHandler(std::string host_name, std::string user_name, std::string service_name)
+ : user_name_(user_name), host_name_(host_name), service_name_(service_name)
+ {
+ }
+
+SaslHandler::~SaslHandler() {
+ if (nullptr != sconn_) {
+ sasl_dispose(&sconn_);
+ }
+ sconn_ = nullptr;
+ }
+
+ std::string SaslHandler::parseServiceName(std::shared_ptr conf) {
+ std::string svrPrincipal = conf->Get("hbase.regionserver.kerberos.principal", "hbase");
+ std::size_t pos = svrPrincipal.find("/");
+ if (pos == std::string::npos && svrPrincipal.find("@") != std::string::npos) {
+ pos = svrPrincipal.find("@");
+ }
+ std::string service_name = "hbase";
+ if (pos != std::string::npos) {
+ service_name = svrPrincipal.substr(0, pos);
+ }
+ return service_name;
+ }
+
+ void SaslHandler::read(Context* ctx, folly::IOBufQueue& buf) {
+ LOG(INFO) << "in read, " << auth_state_;
+ if (auth_state_ == kSucc) ctx->fireRead(buf);
+ else {
+ setupSaslConn(ctx, buf);
+ }
+ }
+
+ folly::Future SaslHandler::write(Context* ctx, std::unique_ptr buf) {
+ LOG(INFO) << "in write, init state: " << auth_state_ << " " << buf->length();
+ if (auth_state_ != kSucc && auth_state_ != kFailure) {
+ if (auth_state_ == kUninit) {
+ saslInit(ctx);
+ }
+ char* buff = (char*)malloc(buf->length());
+ ::memcpy(buff, buf->data(), buf->length());
+ const char* dat = (const char*)buf->data();
+ iobuf_.push_back(std::move(buf));
+ std::unique_ptr empty = std::make_unique();
+ return ctx->fireWrite(std::move(empty));
+ }
+ return ctx->fireWrite(std::move(buf));
+ }
+
+ void SaslHandler::writeSaslOutput(Context* ctx, const char *out, unsigned int outlen) {
+ int bufferSize = outlen + 4;
+ google::protobuf::uint8 *packet = new google::protobuf::uint8[bufferSize];
+ ::memset(packet, '\0', bufferSize);
+
+ google::protobuf::io::ArrayOutputStream aos(packet, bufferSize);
+ google::protobuf::io::CodedOutputStream *coded_output = new google::protobuf::io::CodedOutputStream(&aos);
+ unsigned int uiTotalSize = outlen;
+ SwapByteOrder(uiTotalSize);
+ coded_output->WriteRaw(&uiTotalSize, 4);
+ coded_output->WriteRaw(out, outlen);
+ std::unique_ptr iob = IOBuf::takeOwnership(packet, bufferSize);
+ ctx->fireWrite(std::move(iob)).get();
+ }
+ std::string SaslHandler::toStringBinary(char* temp, unsigned int uiTotalSize) {
+ char const hex_chars[16] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
+ 'A', 'B', 'C', 'D', 'E', 'F' };
+ std::string strResponse;
+ for (unsigned int i = 0; i < uiTotalSize; ++i )
+ {
+ char const byte = temp[i];
+ if (byte >= ' ' && byte <= '~' && byte != '\\') {
+ strResponse += byte;
+ } else {
+ strResponse += "\\x";
+ strResponse += hex_chars[ ( byte & 0xF0 ) >> 4 ];
+ strResponse += hex_chars[ ( byte & 0x0F ) >> 0 ];
+ }
+ }
+ return strResponse;
+ }
+ SaslHandler::InitState SaslHandler::finishAuth(Context* ctx, folly::IOBufQueue& bufQueue) {
+ if (currRC_ == SASL_OK) {
+ std::unique_ptr iob;
+ if (!bufQueue.empty()) {
+ iob = bufQueue.pop_front();
+ LOG(ERROR) << "Error in the final step of handshake " << iob->length();
+ auth_state_ = kFailure;
+ return auth_state_;
+ } else {
+ auth_state_ = kSucc;
+ LOG(INFO) << "auth succeeded " << iobuf_.size();
+ // write what we buffered
+ for (std::vector>::iterator it = iobuf_.begin();
+ it < iobuf_.end(); it++) {
+ iob = std::move(*it);
+ LOG(INFO) << "writing client msg " << iob->length() << " " <<
+ toStringBinary((char *)iob->data(), iob->length());
+ ctx->fireWrite(std::move(iob));
+ }
+ }
+
+ if (auth_state_ == kSucc) {
+ const int *ssfp = nullptr;
+
+ currRC_ = sasl_getprop(sconn_, SASL_SSF, (const void**) &ssfp);
+ if (currRC_ != SASL_OK) {
+ const char *err_msg = nullptr;
+ sasl_errstring(currRC_, NULL, &err_msg);
+
+ LOG(ERROR) << err_msg;
+ }
+ if (ssfp != nullptr && *ssfp > 0) {
+ bool qop_enabled_ = true;
+ }
+ }
+ }
+ return auth_state_;
+ }
+ SaslHandler::InitState SaslHandler::saslInit(Context* ctx) {
+ int rc;
+ const char *mechusing, *mechlist = "GSSAPI";
+ const char *out;
+ unsigned int outlen;
+ /* Create new connection session. */
+ char *searchpath = ::getenv("CYRUS_SASL_PLUGINS_DIR");
+
+ if (NULL == searchpath || 0 == strcmp(searchpath, "")) {
+ searchpath = "/usr/lib/sasl2";
+ }
+
+ std::string sasl_plugin_dir_path = searchpath;
+ LOG(INFO) << "service is " << service_name_ << " " << searchpath;
+
+ sasl_callback_t callbacks[4];
+ sasl_callback_t *callback;
+
+ /* Fill in the callbacks that we're providing... */
+ callback = callbacks;
+
+ /* log */
+ callback->id = SASL_CB_LOG;
+ callback->proc = (sasl_callback_ft) &SaslLogFn;
+ callback->context = NULL;
+ ++callback;
+
+ callback->id = SASL_CB_GETPATH;
+ callback->proc = (sasl_callback_ft) &GetPluginPath;
+ callback->context = (void *)(sasl_plugin_dir_path.c_str());
+ ++callback;
+
+ /* user */
+ callback->id = SASL_CB_USER;
+ callback->proc = (sasl_callback_ft) &Simple;
+ callback->context = const_cast(user_name_.c_str());
+ ++callback;
+
+ /* termination */
+ callback->id = SASL_CB_LIST_END;
+ callback->proc = NULL;
+ callback->context = NULL;
+ ++callback;
+ rc = sasl_client_init(callbacks);
+ if (rc != SASL_OK) {
+ LOG(FATAL) << "Cannot initialize client ("<< rc << ") ";
+ throw std::runtime_error("Cannot initialize client");
+ }
+ rc = sasl_client_new(service_name_.c_str(), /* The service we are using*/
+ host_name_.c_str(),
+ NULL, NULL, /* Local and remote IP address strings
+ (NULL disables mechanisms which require this info)*/
+ NULL, /*connection-specific callbacks*/
+ 0 /*security flags*/, &sconn_);
+ if (rc != SASL_OK) {
+ LOG(FATAL) << "Cannot create client ("<< rc << ") ";
+ throw std::runtime_error("Cannot create client");
+ }
+ sasl_security_properties_t *props = new sasl_security_properties_t();
+ ::memset(props, 0, sizeof(sasl_security_properties_t));
+ props->max_ssf = 2;
+ sasl_setprop(sconn_, SASL_SEC_PROPS, (void *)props);
+
+ std::vector buffer(buffer_size + 1);
+ std::fill(buffer.begin(), buffer.end(), 0);
+ pBuf = (char *) &buffer[0];
+
+ LOG(INFO) << "sasl_client_new called " << sconn_ << " " << host_name_ << " " << user_name_;
+
+ do {
+ currRC_ = sasl_client_start(sconn_, /* the same context from above */
+ mechlist, /* the list of mechanisms from the server */
+ NULL, /* filled in if an interaction is needed */
+ &out, /* filled in on success */
+ &outlen, /* filled in on success */
+ &mechusing);
+ } while (currRC_ == SASL_INTERACT); /* the mechanism may ask us to fill
+ in things many times. result is SASL_CONTINUE on success */
+ if (currRC_ != SASL_CONTINUE) {
+ LOG(FATAL)<< "Cannot start client ("<< currRC_ << ") ";
+ auth_state_ = kFailure;
+ return auth_state_;
+ }
+ LOG(INFO) << "client started " << currRC_ << " " << outlen;
+ writeSaslOutput(ctx, out, outlen);
+ auth_state_ = kRead;
+ return auth_state_;
+ }
+
+ SaslHandler::InitState SaslHandler::setupSaslConn(Context* ctx, folly::IOBufQueue& bufQueue) {
+ const char *out;
+ unsigned int outlen;
+
+ int bytes_sent = 0;
+ int bytes_received = 0;
+
+ if (currRC_ == SASL_OK) {
+ return finishAuth(ctx, bufQueue);
+ } else {
+ std::unique_ptr iob = bufQueue.pop_front();
+ bytes_received = iob->length();
+ pBuf = (char*)malloc(iob->length()+1);
+ ::memcpy(pBuf, iob->data(), iob->length());
+ /* LOG(INFO) << "reading queue code: " << currRC_ << " bytes: " << bytes_received
+ << " queue empty: " << bufQueue.empty(); */
+ if (bytes_received == 0) {
+ auth_state_ = kFailure;
+ return auth_state_;
+ }
+ }
+ unsigned int *uiReader = (unsigned int *)pBuf;
+ unsigned int status = *uiReader;
+
+ char *p = (char *)pBuf + 4;
+
+ uiReader = (unsigned int *)p;
+ unsigned int uiTotalSize = *uiReader;
+ SwapByteOrder(uiTotalSize);
+ //LOG(INFO) << "total size of server response " << uiTotalSize << " status " << status;
+ // p now points to the buffer to be sent to server
+ p = p + 4;
+
+ size_t bs = uiTotalSize;
+
+ bs = uiTotalSize;
+ //LOG(INFO) << "size of bytes " << bs;
+
+ if (auth_state_ == kFailure || status != 0 /*Status 0 is success*/) {
+ //Assumption here is not more than 8 * 1024
+ p[bs] = '\0';
+ LOG(ERROR) << "Exception from server: " << p;
+ auth_state_ = kFailure;
+ return auth_state_;
+ }
+
+ out = NULL;
+ outlen = 0;
+
+ currRC_ = sasl_client_step(sconn_, /* our context */
+ (char*) p, /* the data from the server */
+ bs, /* it's length */
+ NULL, /* this should be unallocated and NULL */
+ &out, /* filled in on success */
+ &outlen); /* filled in on success */
+ delete pBuf;
+ pBuf = NULL;
+ if (currRC_ == SASL_OK || currRC_ == SASL_CONTINUE) {
+ writeSaslOutput(ctx, out, outlen);
+ }
+ LOG(INFO) << "got rc " << currRC_;
+ if (currRC_ == SASL_OK) {
+ return finishAuth(ctx, bufQueue);
+ }
+ return auth_state_;
+ }
+
+ void SaslHandler::SwapByteOrder(unsigned int &ui)
+ {
+ ui = (ui >> 24) |
+ ((ui<<8) & 0x00FF0000) |
+ ((ui>>8) & 0x0000FF00) |
+ (ui << 24);
+ }
+
ConnectionFactory::ConnectionFactory(std::shared_ptr io_pool,
std::shared_ptr codec, nanoseconds connect_timeout)
- : connect_timeout_(connect_timeout),
+ : connect_timeout_(connect_timeout), user_util_(), codec_(codec),
io_pool_(io_pool),
pipeline_factory_(std::make_shared(codec)) {}
@@ -48,15 +403,42 @@ std::shared_ptr> ConnectionFactory::M
}
std::shared_ptr ConnectionFactory::Connect(
std::shared_ptr> client, const std::string &hostname,
- int port) {
+ int port, std::shared_ptr user, std::shared_ptr conf) {
// Yes this will block however it makes dealing with connection pool soooooo
// much nicer.
// TODO see about using shared promise for this.
+ LOG(INFO) << "connecting client " << hostname << " " << port << " " << conf;
auto pipeline = client
- ->connect(SocketAddress(hostname, port, true),
- std::chrono::duration_cast(connect_timeout_))
- .get();
+ ->connect(SocketAddress(hostname, port, true),
+ std::chrono::duration_cast(connect_timeout_))
+ .get();
+ auto serde_ = std::make_unique(codec_, conf);
+ auto pre = serde_->Preamble();
+ // write the Preamble
+ pipeline->getContext()->getTransport()->getEventBase()
+ ->runInEventBaseThreadAndWait([&]() mutable {
+ pipeline->getContext()->write(std::move(pre));
+ });
+ if (user->IsSecurityEnabled(*conf)) {
+ LOG(INFO) << "creating SaslHandler for " << user_util_.user_name();
+ SaslHandler* sasl_handler = new SaslHandler(hostname, user_util_.user_name(),
+ SaslHandler::parseServiceName(conf));
+ pipeline->getContext()->getTransport()->getEventBase()
+ ->runInEventBaseThreadAndWait([&]() mutable {
+ pipeline->remove(pipeline->getHandler());
+ pipeline->remove(pipeline->getHandler());
+ pipeline->addBack(sasl_handler);
+ pipeline->addBack(LengthFieldBasedFrameDecoder{});
+ pipeline->addBack(ClientHandler{user_util_.user_name(), codec_, conf});
+ pipeline->finalize();
+ });
+ }
auto dispatcher = std::make_shared();
- dispatcher->setPipeline(pipeline);
+ pipeline->getContext()->getTransport()->getEventBase()
+ ->runInEventBaseThreadAndWait([&]() mutable {
+ dispatcher->setPipeline(pipeline);
+ });
+ LOG(INFO) << "finalized";
return dispatcher;
}
+}
diff --git a/hbase-native-client/connection/connection-factory.h b/hbase-native-client/connection/connection-factory.h
index 32d0bf7..6d524b2 100644
--- a/hbase-native-client/connection/connection-factory.h
+++ b/hbase-native-client/connection/connection-factory.h
@@ -28,6 +28,8 @@
#include "connection/request.h"
#include "connection/response.h"
#include "connection/service.h"
+#include "security/user.h"
+using hbase::security::User;
using std::chrono::nanoseconds;
@@ -61,10 +63,13 @@ class ConnectionFactory {
*/
virtual std::shared_ptr Connect(
std::shared_ptr> client,
- const std::string &hostname, int port);
+ const std::string &hostname, int port, std::shared_ptr user,
+ std::shared_ptr conf);
private:
nanoseconds connect_timeout_;
+ UserUtil user_util_;
+ std::shared_ptr codec_;
std::shared_ptr io_pool_;
std::shared_ptr pipeline_factory_;
};
diff --git a/hbase-native-client/connection/connection-pool-test.cc b/hbase-native-client/connection/connection-pool-test.cc
index 623ce3c..7c5b57f 100644
--- a/hbase-native-client/connection/connection-pool-test.cc
+++ b/hbase-native-client/connection/connection-pool-test.cc
@@ -61,11 +61,11 @@ class MockService : public MockServiceBase {
};
TEST(TestConnectionPool, TestOnlyCreateOnce) {
- auto hostname = std::string{"hostname"};
+ auto hostname = std::string{"localhost"};
auto mock_boot = std::make_shared();
auto mock_service = std::make_shared();
auto mock_cf = std::make_shared();
- uint32_t port{999};
+ uint32_t port{2181};
EXPECT_CALL((*mock_cf), Connect(_, _, _)).Times(1).WillRepeatedly(Return(mock_service));
EXPECT_CALL((*mock_cf), MakeBootstrap()).Times(1).WillRepeatedly(Return(mock_boot));
diff --git a/hbase-native-client/connection/connection-pool.cc b/hbase-native-client/connection/connection-pool.cc
index 4fe4610..64ef90f 100644
--- a/hbase-native-client/connection/connection-pool.cc
+++ b/hbase-native-client/connection/connection-pool.cc
@@ -35,11 +35,12 @@ using folly::SharedMutexWritePriority;
using folly::SocketAddress;
ConnectionPool::ConnectionPool(std::shared_ptr io_executor,
- std::shared_ptr codec, nanoseconds connect_timeout)
+ std::shared_ptr codec, std::shared_ptr conf,
+ nanoseconds connect_timeout)
: cf_(std::make_shared(io_executor, codec, connect_timeout)),
clients_(),
connections_(),
- map_mutex_() {}
+ map_mutex_(), conf_(conf) {}
ConnectionPool::ConnectionPool(std::shared_ptr cf)
: cf_(cf), clients_(), connections_(), map_mutex_() {}
@@ -88,7 +89,8 @@ std::shared_ptr ConnectionPool::GetNewConnection(
/* create new connection */
auto clientBootstrap = cf_->MakeBootstrap();
- auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port());
+ auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port(),
+ remote_id->user(), conf_);
auto connection = std::make_shared(remote_id, dispatcher);
diff --git a/hbase-native-client/connection/connection-pool.h b/hbase-native-client/connection/connection-pool.h
index 2a8f195..0582d9b 100644
--- a/hbase-native-client/connection/connection-pool.h
+++ b/hbase-native-client/connection/connection-pool.h
@@ -50,7 +50,8 @@ class ConnectionPool {
public:
/** Create connection pool wit default connection factory */
ConnectionPool(std::shared_ptr io_executor,
- std::shared_ptr codec, nanoseconds connect_timeout = nanoseconds(0));
+ std::shared_ptr codec, std::shared_ptr conf,
+ nanoseconds connect_timeout = nanoseconds(0));
/**
* Constructor that allows specifiying the connetion factory.
@@ -93,6 +94,7 @@ class ConnectionPool {
clients_;
folly::SharedMutexWritePriority map_mutex_;
std::shared_ptr cf_;
+ std::shared_ptr conf_;
};
} // namespace hbase
diff --git a/hbase-native-client/connection/pipeline.cc b/hbase-native-client/connection/pipeline.cc
index 00dc05c..71d267a 100644
--- a/hbase-native-client/connection/pipeline.cc
+++ b/hbase-native-client/connection/pipeline.cc
@@ -39,7 +39,7 @@ SerializePipeline::Ptr RpcPipelineFactory::newPipeline(
pipeline->addBack(AsyncSocketHandler{sock});
pipeline->addBack(EventBaseHandler{});
pipeline->addBack(LengthFieldBasedFrameDecoder{});
- pipeline->addBack(ClientHandler{user_util_.user_name(), codec_});
+ pipeline->addBack(ClientHandler{user_util_.user_name(), codec_, nullptr});
pipeline->finalize();
return pipeline;
}
diff --git a/hbase-native-client/connection/rpc-client.cc b/hbase-native-client/connection/rpc-client.cc
index 5fa1138..57df66d 100644
--- a/hbase-native-client/connection/rpc-client.cc
+++ b/hbase-native-client/connection/rpc-client.cc
@@ -30,9 +30,10 @@ using hbase::RpcClient;
namespace hbase {
RpcClient::RpcClient(std::shared_ptr io_executor,
- std::shared_ptr codec, nanoseconds connect_timeout)
- : io_executor_(io_executor) {
- cp_ = std::make_shared(io_executor_, codec, connect_timeout);
+ std::shared_ptr codec, std::shared_ptr conf,
+ nanoseconds connect_timeout)
+ : io_executor_(io_executor), conf_(conf) {
+ cp_ = std::make_shared(io_executor_, codec, conf, connect_timeout);
}
void RpcClient::Close() { io_executor_->stop(); }
diff --git a/hbase-native-client/connection/rpc-client.h b/hbase-native-client/connection/rpc-client.h
index d416ceb..4f02529 100644
--- a/hbase-native-client/connection/rpc-client.h
+++ b/hbase-native-client/connection/rpc-client.h
@@ -46,7 +46,7 @@ namespace hbase {
class RpcClient {
public:
RpcClient(std::shared_ptr io_executor, std::shared_ptr codec,
- nanoseconds connect_timeout = nanoseconds(0));
+ std::shared_ptr conf, nanoseconds connect_timeout = nanoseconds(0));
virtual ~RpcClient() { Close(); }
@@ -78,5 +78,6 @@ class RpcClient {
private:
std::shared_ptr cp_;
std::shared_ptr io_executor_;
+ std::shared_ptr conf_;
};
} // namespace hbase
diff --git a/hbase-native-client/core/async-connection.cc b/hbase-native-client/core/async-connection.cc
index b945e38..41eb48a 100644
--- a/hbase-native-client/core/async-connection.cc
+++ b/hbase-native-client/core/async-connection.cc
@@ -38,7 +38,7 @@ void AsyncConnectionImpl::Init() {
LOG(WARNING) << "Not using RPC Cell Codec";
}
rpc_client_ =
- std::make_shared(io_executor_, codec, connection_conf_->connect_timeout());
+ std::make_shared(io_executor_, codec, conf_, connection_conf_->connect_timeout());
location_cache_ =
std::make_shared(conf_, cpu_executor_, rpc_client_->connection_pool());
caller_factory_ = std::make_shared(shared_from_this());
diff --git a/hbase-native-client/core/client-test.cc b/hbase-native-client/core/client-test.cc
index ff4879a..142d474 100644
--- a/hbase-native-client/core/client-test.cc
+++ b/hbase-native-client/core/client-test.cc
@@ -27,6 +27,7 @@
#include "core/put.h"
#include "core/result.h"
#include "core/table.h"
+#include "security/user.h"
#include "serde/table-name.h"
#include "test-util/test-util.h"
@@ -85,7 +86,8 @@ class ClientTest : public ::testing::Test {
static void SetUpTestCase() {
google::InstallFailureSignalHandler();
test_util = std::make_unique();
- test_util->StartMiniCluster(2);
+ google::InstallFailureSignalHandler();
+ //test_util->StartMiniCluster(2);
}
};
std::unique_ptr ClientTest::test_util = nullptr;
@@ -116,8 +118,11 @@ TEST_F(ClientTest, DefaultConfiguration) {
}
TEST_F(ClientTest, PutGet) {
- // Using TestUtil to populate test data
- ClientTest::test_util->CreateTable("t", "d");
+ // Using TestUtil to populate test data for insecure cluster
+ // table would be created manually for secure cluster
+ if (!User::IsSecurityEnabled(*ClientTest::test_util->conf())) {
+ ClientTest::test_util->CreateTable("t", "d");
+ }
// Create TableName and Row to be fetched from HBase
auto tn = folly::to("t");
diff --git a/hbase-native-client/security/BUCK b/hbase-native-client/security/BUCK
index 7383028..91b547f 100644
--- a/hbase-native-client/security/BUCK
+++ b/hbase-native-client/security/BUCK
@@ -23,4 +23,4 @@ cxx_library(
srcs=[],
deps=["//core:conf"],
compiler_flags=['-Weffc++'],
- visibility=['//core/...', '//connection/...'],)
+ visibility=['PUBLIC',],)
diff --git a/hbase-native-client/security/user.h b/hbase-native-client/security/user.h
index 035af31..af32915 100644
--- a/hbase-native-client/security/user.h
+++ b/hbase-native-client/security/user.h
@@ -21,6 +21,7 @@
#include
#include
#include "core/configuration.h"
+#include
namespace hbase {
namespace security {
@@ -35,6 +36,7 @@ class User {
static std::shared_ptr defaultUser() { return std::make_shared("__drwho"); }
static bool IsSecurityEnabled(const Configuration& conf) {
+ LOG(INFO) << "auth from conf: " << conf.Get("hbase.security.authentication", "");
return conf.Get("hbase.security.authentication", "").compare(kKerberos) == 0;
}
diff --git a/hbase-native-client/serde/BUCK b/hbase-native-client/serde/BUCK
index 38e7b4d..170636d 100644
--- a/hbase-native-client/serde/BUCK
+++ b/hbase-native-client/serde/BUCK
@@ -31,7 +31,8 @@ cxx_library(
"rpc.cc",
"zk.cc",
],
- deps=["//if:if", "//third-party:folly", "//utils:utils"],
+ deps=["//if:if", "//third-party:folly", "//utils:utils",
+ "//security:security"],
tests=[
":client-deserializer-test",
":client-serializer-test",
diff --git a/hbase-native-client/serde/rpc.cc b/hbase-native-client/serde/rpc.cc
index e657a64..1b3a9a5 100644
--- a/hbase-native-client/serde/rpc.cc
+++ b/hbase-native-client/serde/rpc.cc
@@ -30,9 +30,11 @@
#include
#include "if/RPC.pb.h"
+#include "security/user.h"
#include "utils/version.h"
using namespace hbase;
+//using namespace hbase::security;
using folly::IOBuf;
using folly::io::RWPrivateCursor;
@@ -50,6 +52,7 @@ static const std::string PREAMBLE = "HBas";
static const std::string INTERFACE = "ClientService";
static const uint8_t RPC_VERSION = 0;
static const uint8_t DEFAULT_AUTH_TYPE = 80;
+static const uint8_t KERBEROS_AUTH_TYPE = 81;
int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) {
if (buf == nullptr || msg == nullptr) {
@@ -85,7 +88,14 @@ int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) {
return coded_stream.CurrentPosition();
}
-RpcSerde::RpcSerde(std::shared_ptr codec) : auth_type_(DEFAULT_AUTH_TYPE), codec_(codec) {}
+RpcSerde::RpcSerde(std::shared_ptr codec, std::shared_ptr conf) : codec_(codec) {
+ auth_type_ = KERBEROS_AUTH_TYPE;
+ if (conf != nullptr && security::User::IsSecurityEnabled(*conf)) {
+ auth_type_ = KERBEROS_AUTH_TYPE;
+ } else {
+ auth_type_ = DEFAULT_AUTH_TYPE;
+ }
+}
unique_ptr RpcSerde::Preamble() {
auto magic = IOBuf::copyBuffer(PREAMBLE, 0, 2);
diff --git a/hbase-native-client/serde/rpc.h b/hbase-native-client/serde/rpc.h
index abebe94..c7e85eb 100644
--- a/hbase-native-client/serde/rpc.h
+++ b/hbase-native-client/serde/rpc.h
@@ -21,6 +21,7 @@
#include
#include
+#include "core/configuration.h"
#include "if/HBase.pb.h"
#include "serde/cell-scanner.h"
#include "serde/codec.h"
@@ -48,7 +49,7 @@ class RpcSerde {
/**
* Constructor assumes the default auth type.
*/
- RpcSerde(std::shared_ptr codec);
+ RpcSerde(std::shared_ptr codec, std::shared_ptr conf);
/**
* Destructor. This is provided just for testing purposes.
diff --git a/hbase-native-client/test-util/mini-cluster.cc b/hbase-native-client/test-util/mini-cluster.cc
index 34da54c..bd3f3b1 100644
--- a/hbase-native-client/test-util/mini-cluster.cc
+++ b/hbase-native-client/test-util/mini-cluster.cc
@@ -258,6 +258,9 @@ jobject MiniCluster::StartCluster(int num_region_servers) {
}
void MiniCluster::StopCluster() {
+ if (cluster_ == nullptr) {
+ return;
+ }
env();
jmethodID mid = env_->GetMethodID(testing_util_class_, "shutdownMiniCluster", "()V");
env_->CallVoidMethod(htu(), mid);
diff --git a/hbase-native-client/test-util/mini-cluster.h b/hbase-native-client/test-util/mini-cluster.h
index 4119cb5..49028a0 100644
--- a/hbase-native-client/test-util/mini-cluster.h
+++ b/hbase-native-client/test-util/mini-cluster.h
@@ -63,7 +63,7 @@ class MiniCluster {
jmethodID move_mid_;
jmethodID str_ctor_mid_;
jobject htu_;
- jobject cluster_;
+ jobject cluster_ = nullptr;
pthread_mutex_t count_mutex_;
JavaVM *jvm_;
JNIEnv *CreateVM(JavaVM **jvm);
diff --git a/hbase-native-client/test-util/test-util.cc b/hbase-native-client/test-util/test-util.cc
index c4e6ed2..b420e40 100644
--- a/hbase-native-client/test-util/test-util.cc
+++ b/hbase-native-client/test-util/test-util.cc
@@ -44,7 +44,14 @@ std::string TestUtil::RandString(int len) {
return s;
}
-TestUtil::TestUtil() : temp_dir_(TestUtil::RandString()) {}
+TestUtil::TestUtil() : temp_dir_(TestUtil::RandString()) {
+ std::string quorum("localhost:");
+ const std::string port = "2181"; // mini->GetConfValue("hbase.zookeeper.property.clientPort");
+ conf()->Set("hbase.zookeeper.quorum", quorum + port);
+ conf()->Set("hbase.security.authentication", "kerberos");
+ conf()->Set("hbase.regionserver.kerberos.principal", "hbase/23a03935850c@EXAMPLE.COM");
+ conf()->Set(ZKUtil::kHBaseZookeeperClientPort_, port);
+}
TestUtil::~TestUtil() {
if (mini_) StopMiniCluster();
diff --git a/hbase-native-client/utils/BUCK b/hbase-native-client/utils/BUCK
index 04e2b67..a1704c4 100644
--- a/hbase-native-client/utils/BUCK
+++ b/hbase-native-client/utils/BUCK
@@ -28,6 +28,8 @@ cxx_library(
srcs=["bytes-util.cc", "connection-util.cc", "user-util.cc"],
deps=['//third-party:folly',],
tests=[":user-util-test"],
+ linker_flags = ['-L/usr/local/lib','-lkrb5'],
+ exported_linker_flags = ['-L/usr/local/lib','-lkrb5'],
visibility=['PUBLIC',],
compiler_flags=['-Weffc++'],)
cxx_test(
diff --git a/hbase-native-client/utils/user-util.cc b/hbase-native-client/utils/user-util.cc
index 9e170e0..7f37ef7 100644
--- a/hbase-native-client/utils/user-util.cc
+++ b/hbase-native-client/utils/user-util.cc
@@ -23,11 +23,13 @@
#include
#include
#include
+#include
using namespace hbase;
using namespace std;
-UserUtil::UserUtil() : once_flag_{}, user_name_{"drwho"} {}
+UserUtil::UserUtil() : once_flag_{} {}
+ // user_name_{"hbase@EXAMPLE.COM"}
string UserUtil::user_name() {
std::call_once(once_flag_, [this]() { compute_user_name(); });
@@ -45,4 +47,29 @@ void UserUtil::compute_user_name() {
if (passwd && passwd->pw_name) {
user_name_ = string{passwd->pw_name};
}
+ krb5_context ctx;
+ krb5_error_code ret = krb5_init_context(&ctx);
+ if (ret != 0) LOG(INFO) << "cannot init krb ctx " << ret;
+ else {
+ krb5_ccache ccache;
+ ret = krb5_cc_default(ctx, &ccache);
+ if (ret != 0) LOG(INFO) << "cannot get default cache " << ret;
+ else {
+ krb5_principal princ;
+ ret = krb5_cc_get_principal(ctx, ccache, &princ);
+ if (ret != 0) LOG(INFO) << "cannot get default cache " << ret;
+ else {
+ user_name_ = princ->data->data;
+ if (krb5_princ_size(ctx, princ) >= 2) {
+ user_name_ += "/";
+ krb5_data *d = princ->data+1;
+ user_name_ += (char*)d->data;
+ }
+ user_name_ += "@";
+ user_name_ += princ->realm.data;
+ }
+ krb5_free_principal(ctx, princ);
+ }
+ }
+ LOG(INFO) << "computed username " << user_name_;
}
diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java
index 73226aa..fbbdd69 100644
--- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java
+++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java
@@ -1426,6 +1426,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver {
if (LOG.isDebugEnabled()) {
LOG.debug("Will send token of size " + replyToken.length
+ " from saslServer.");
+ LOG.debug(Bytes.toStringBinary(replyToken));
}
doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null,
null);
@@ -1507,6 +1508,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver {
}
int version = versionAndAuthBytes.get(0);
byte authbyte = versionAndAuthBytes.get(1);
+ LOG.debug("preamble auth " + authbyte + " " + isSecurityEnabled + " " + authMethod);
this.authMethod = AuthMethod.valueOf(authbyte);
if (version != CURRENT_VERSION) {
String msg = getFatalConnectionString(version, authbyte);
@@ -1707,6 +1709,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver {
ugi = UserGroupInformation.createProxyUser(protocolUser
.getUserName(), realUser);
// Now the user is a proxy user, set Authentication method Proxy.
+ LOG.debug("protocol user '" + protocolUser.getUserName() + "' ugi '" + ugi.getUserName() + "'" + protocolUser + " " + ugi);
ugi.setAuthenticationMethod(AuthenticationMethod.PROXY);
}
}
diff --git a/hbase-native-client/connection/sasl-handler.h b/hbase-native-client/connection/sasl-handler.h
new file mode 100644
index 0000000..792c024
--- /dev/null
+++ b/hbase-native-client/connection/sasl-handler.h
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+#pragma once
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include "connection/service.h"
+#include "security/user.h"
+using hbase::security::User;
+
+using std::chrono::nanoseconds;
+
+namespace hbase {
+
+/**
+ * Class to perform SASL handshake with server
+ */
+class SaslHandler
+ : public wangle::HandlerAdapter>{
+ public:
+ SaslHandler(std::string host_name, std::string user_name, std::string service_name);
+ ~SaslHandler();
+
+ enum InitState {
+ kUninit, // uninitialized
+ kRead, // waiting to read response from server
+ kSucc, // handshake is successful
+ kFailure // handshake fails
+ };
+
+ // from HandlerAdapter
+ void read(Context* ctx, folly::IOBufQueue& buf) override;
+ folly::Future write(Context* ctx, std::unique_ptr buf) override;
+
+ // parse the service name from given Configuration
+ static std::string parseServiceName(std::shared_ptr conf);
+
+ private:
+ // used by Cyrus
+ sasl_conn_t *sconn_ = NULL;
+ const int buffer_size = 8 * 1024;
+ char *pBuf;
+ int currRC_ = 0;
+ std::string user_name_;
+ std::string host_name_;
+ std::string service_name_;
+ InitState auth_state_ = kUninit;
+ // vector of folly::IOBuf which buffers client writes before handshake is complete
+ std::vector> iobuf_;
+
+ // writes the output returned by sasl_client_XX to server
+ void writeSaslOutput(Context* ctx, const char *out, unsigned int outlen);
+ InitState saslInit(Context* ctx);
+ void SwapByteOrder(unsigned int &ui);
+ SaslHandler::InitState finishAuth(Context* ctx, folly::IOBufQueue& bufQueue);
+ InitState setupSaslConn(Context* ctx, folly::IOBufQueue& buf);
+ InitState getState() {
+ return auth_state_;
+ }
+ static std::string toStringBinary(char* temp, unsigned int uiTotalSize);
+};
+} // namespace hbase