diff --git a/conf/hbase-site.xml b/conf/hbase-site.xml index c516ac7..383da86 100644 --- a/conf/hbase-site.xml +++ b/conf/hbase-site.xml @@ -21,4 +21,32 @@ */ --> + + hbase.rootdir + file:///home/hbase/hbase-data + + + hbase.zookeeper.property.dataDir + /home/hbase/zookeeper-data + + + hbase.security.authentication + kerberos + + + hbase.master.kerberos.principal + hbase/23a03935850c@EXAMPLE.COM + + + hbase.regionserver.kerberos.principal + hbase/23a03935850c@EXAMPLE.COM + + + hbase.master.keytab.file + /usr/src/hbase/hbase-host.keytab + + + hbase.regionserver.keytab.file + /usr/src/hbase/hbase-host.keytab + diff --git a/conf/log4j.properties b/conf/log4j.properties index 74b13b1..51fb1e4 100644 --- a/conf/log4j.properties +++ b/conf/log4j.properties @@ -16,7 +16,7 @@ # Define some default values that can be overridden by system properties hbase.root.logger=INFO,console -hbase.security.logger=INFO,console +hbase.security.logger=DEBUG,console hbase.log.dir=. hbase.log.file=hbase.log @@ -95,6 +95,8 @@ log4j.appender.asyncconsole.target=System.err log4j.logger.org.apache.zookeeper=INFO #log4j.logger.org.apache.hadoop.fs.FSNamesystem=DEBUG log4j.logger.org.apache.hadoop.hbase=INFO +log4j.logger.org.apache.hadoop.hbase.security=DEBUG +log4j.logger.org.apache.hadoop.hbase.ipc=DEBUG log4j.logger.org.apache.hadoop.hbase.META=INFO # Make these two classes INFO-level. Make them DEBUG to see more zk debug. log4j.logger.org.apache.hadoop.hbase.zookeeper.ZKUtil=INFO diff --git a/conf/regionservers b/conf/regionservers index 2fbb50c..e69de29 100644 --- a/conf/regionservers +++ b/conf/regionservers @@ -1 +0,0 @@ -localhost diff --git a/hbase-native-client/connection/BUCK b/hbase-native-client/connection/BUCK index 19536d5..696dab6 100644 --- a/hbase-native-client/connection/BUCK +++ b/hbase-native-client/connection/BUCK @@ -22,6 +22,7 @@ cxx_library( exported_headers=[ "client-dispatcher.h", "client-handler.h", + "sasl-handler.h", "connection-factory.h", "connection-pool.h", "connection-id.h", @@ -50,8 +51,16 @@ cxx_library( "//third-party:wangle", ], compiler_flags=['-Weffc++'], + linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'], + exported_linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'], + #linker_flags = ['-L/usr/local/lib','-lsasl2'], + #exported_linker_flags = ['-L/usr/local/lib','-lsasl2'], visibility=['//core/...',],) cxx_test( name="connection-pool-test", srcs=["connection-pool-test.cc",], deps=[":connection",],) +cxx_test( + name="sasl-test", + srcs=["sasl-test.cc",], + deps=[":connection",],) diff --git a/hbase-native-client/connection/client-handler.cc b/hbase-native-client/connection/client-handler.cc index af84572..c138229 100644 --- a/hbase-native-client/connection/client-handler.cc +++ b/hbase-native-client/connection/client-handler.cc @@ -103,10 +103,9 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) { // We need to send the header once. // So use call_once to make sure that only one thread wins this. std::call_once((*once_flag_), [ctx, this]() { - auto pre = serde_.Preamble(); auto header = serde_.Header(user_name_); - pre->appendChain(std::move(header)); - ctx->fireWrite(std::move(pre)); + //LOG(INFO) << "writing " << user_name_ << " "<< header->length(); + ctx->fireWrite(std::move(header)); }); // Now store the call id to response. @@ -115,5 +114,6 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) { VLOG(1) << "Writing RPC Request with call_id:" << r->call_id(); // Send the data down the pipeline. - return ctx->fireWrite(serde_.Request(r->call_id(), r->method(), r->req_msg().get())); + std::unique_ptr iob = serde_.Request(r->call_id(), r->method(), r->req_msg().get()); + return ctx->fireWrite(std::move(iob)); } diff --git a/hbase-native-client/connection/connection-factory.cc b/hbase-native-client/connection/connection-factory.cc index 832b00f..3381198 100644 --- a/hbase-native-client/connection/connection-factory.cc +++ b/hbase-native-client/connection/connection-factory.cc @@ -18,21 +18,381 @@ */ #include "connection/connection-factory.h" +#include "connection/sasl-handler.h" #include #include "connection/client-dispatcher.h" +#include "connection/client-handler.h" #include "connection/pipeline.h" #include "connection/service.h" +#include "serde/rpc.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include using namespace folly; -using namespace hbase; +using namespace wangle; using std::chrono::milliseconds; using std::chrono::nanoseconds; +namespace hbase { +// the following is needed to avoid compilation error due to invalid casting +extern "C" { + int GetPluginPath(void *context, + const char **path) { + const char *searchpath = (const char *) context; + + if (!path) + return SASL_BADPARAM; + + if (searchpath) { + *path = searchpath; + } + + return SASL_OK; + } + int Simple(void *context, + int id, + const char **result, + unsigned *len) { + const char *value = (const char *)context; + + if (!result) + return SASL_BADPARAM; + + switch (id) { + case SASL_CB_USER: { + *result = value; + if (len) *len = value ? (unsigned) strlen(value) : 0; + } + break; + default: + return SASL_BADPARAM; + } + return SASL_OK; + } + int SaslLogFn(void *context __attribute__((unused)), + int priority, + const char *message) { + const char *label; + + if (! message) + return SASL_BADPARAM; + + switch (priority) { + case SASL_LOG_ERR: + LOG(ERROR) << ": SASL " << message; + break; + case SASL_LOG_NOTE: + LOG(INFO) << ": SASL " << message; + break; + default: + LOG(INFO) << ": SASL " << message; + break; + } + + return SASL_OK; + } +}; + +SaslHandler::SaslHandler(std::string host_name, std::string user_name, std::string service_name) + : user_name_(user_name), host_name_(host_name) + { + int rc; + /* Create new connection session. */ + char *searchpath = ::getenv("CYRUS_SASL_PLUGINS_DIR"); + + if (NULL == searchpath || 0 == strcmp(searchpath, "")) { + searchpath = "/usr/lib/sasl2"; + } + + std::string sasl_plugin_dir_path = searchpath; + LOG(INFO) << "service is " << service_name << " " << searchpath; + + sasl_callback_t callbacks[4]; + sasl_callback_t *callback; + + /* Fill in the callbacks that we're providing... */ + callback = callbacks; + + /* log */ + callback->id = SASL_CB_LOG; + callback->proc = (sasl_callback_ft) &SaslLogFn; + callback->context = NULL; + ++callback; + + callback->id = SASL_CB_GETPATH; + callback->proc = (sasl_callback_ft) &GetPluginPath; + callback->context = (void *)(sasl_plugin_dir_path.c_str()); + ++callback; + + /* user */ + callback->id = SASL_CB_USER; + callback->proc = (sasl_callback_ft) &Simple; + callback->context = const_cast(user_name_.c_str()); + ++callback; + + /* termination */ + callback->id = SASL_CB_LIST_END; + callback->proc = NULL; + callback->context = NULL; + ++callback; + rc = sasl_client_init(callbacks); + if (rc != SASL_OK) { + LOG(FATAL) << "Cannot initialize client ("<< rc << ") "; + throw std::runtime_error("Cannot initialize client"); + } + rc = sasl_client_new(service_name.c_str(), /* The service we are using*/ + host_name_.c_str(), + NULL, NULL, /* Local and remote IP address strings + (NULL disables mechanisms which require this info)*/ + NULL, /*connection-specific callbacks*/ + 0 /*security flags*/, &sconn_); + if (rc != SASL_OK) { + LOG(FATAL) << "Cannot create client ("<< rc << ") "; + throw std::runtime_error("Cannot create client"); + } + sasl_security_properties_t *props = new sasl_security_properties_t(); + ::memset(props, 0, sizeof(sasl_security_properties_t)); + //props->max_ssf = 2; + sasl_setprop(sconn_, SASL_SEC_PROPS, (void *)props); + + std::vector buffer(buffer_size + 1); + std::fill(buffer.begin(), buffer.end(), 0); + pBuf = (char *) &buffer[0]; + + LOG(INFO) << "sasl_client_new called " << sconn_ << " " << host_name_ << " " << user_name_; + } + +SaslHandler::~SaslHandler() { + if (nullptr != sconn_) { + sasl_dispose(&sconn_); + } + sconn_ = nullptr; + } + std::string SaslHandler::parseServiceName(std::shared_ptr conf) { + std::string svrPrincipal = conf->Get("hbase.regionserver.kerberos.principal", "hbase"); + std::size_t pos = svrPrincipal.find("/"); + if (pos == std::string::npos && svrPrincipal.find("@") != std::string::npos) { + pos = svrPrincipal.find("@"); + } + std::string service_name = "hbase"; + if (pos != std::string::npos) { + service_name = svrPrincipal.substr(0, pos); + } + return service_name; + } + void SaslHandler::performHandshake(Context* ctx) { + if (auth_state_ != kSucc) { + int rc = setupSaslConn(ctx, dummyQ); + if (rc != SASL_CONTINUE) { + LOG(FATAL)<< "Cannot start client ("<< rc << ") "; + return; + } + } + } + void SaslHandler::read(Context* ctx, folly::IOBufQueue& buf) { + LOG(INFO) << "in read, " << auth_state_; + if (auth_state_ == kSucc) ctx->fireRead(buf); + else { + setupSaslConn(ctx, buf); + } + } + folly::Future SaslHandler::write(Context* ctx, std::unique_ptr buf) { + LOG(INFO) << "in write, init state: " << auth_state_ << " " << buf->length(); + if (auth_state_ != kSucc) { + if (auth_state_ == kUninit) { + setupSaslConn(ctx, dummyQ); + } + char* buff = (char*)malloc(buf->length()); + ::memcpy(buff, buf->data(), buf->length()); + const char* dat = (const char*)buf->data(); + //LOG(INFO) << "initial " << toStringBinary(buff, buf->length()); + iobuf_.push_back(std::move(buf)); + std::unique_ptr empty = std::make_unique(); + return ctx->fireWrite(std::move(empty)); + } + return ctx->fireWrite(std::move(buf)); + } + void SaslHandler::writeSaslOutput(Context* ctx, const char *out, unsigned int outlen) { + int bufferSize = outlen + 4; + google::protobuf::uint8 *packet = new google::protobuf::uint8[bufferSize]; + ::memset(packet, '\0', bufferSize); + + google::protobuf::io::ArrayOutputStream aos(packet, bufferSize); + google::protobuf::io::CodedOutputStream *coded_output = new google::protobuf::io::CodedOutputStream(&aos); + unsigned int uiTotalSize = outlen; + SwapByteOrder(uiTotalSize); + coded_output->WriteRaw(&uiTotalSize, 4); + coded_output->WriteRaw(out, outlen); + std::unique_ptr iob = IOBuf::takeOwnership(packet, bufferSize); + ctx->fireWrite(std::move(iob)).get(); + //LOG(INFO) << "written iob " << outlen; + } + std::string SaslHandler::toStringBinary(char* temp, unsigned int uiTotalSize) { + char const hex_chars[16] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F' }; + std::string strResponse; + for (unsigned int i = 0; i < uiTotalSize; ++i ) + { + char const byte = temp[i]; + if (byte >= ' ' && byte <= '~' && byte != '\\') { + strResponse += byte; + } else { + strResponse += "\\x"; + strResponse += hex_chars[ ( byte & 0xF0 ) >> 4 ]; + strResponse += hex_chars[ ( byte & 0x0F ) >> 0 ]; + } + } + return strResponse; + } + SaslHandler::InitState SaslHandler::finishAuth(Context* ctx, folly::IOBufQueue& bufQueue) { + if (currRC_ == SASL_OK) { + std::unique_ptr iob; + if (!bufQueue.empty()) { + iob = bufQueue.pop_front(); + LOG(ERROR) << "Error in the final step of handshake " << iob->length(); + auth_state_ = kFailure; + return auth_state_; + } else { + auth_state_ = kSucc; + LOG(INFO) << "auth succeeded " << iobuf_.size(); + // write what we buffered + for (std::vector>::iterator it = iobuf_.begin(); + it < iobuf_.end(); it++) { + iob = std::move(*it); + LOG(INFO) << "writing client msg " << iob->length() << " " << + toStringBinary((char *)iob->data(), iob->length()); + ctx->fireWrite(std::move(iob)); + } + } + + if (auth_state_ == kSucc) { + const int *ssfp = nullptr; + + currRC_ = sasl_getprop(sconn_, SASL_SSF, (const void**) &ssfp); + if (currRC_ != SASL_OK) { + const char *err_msg = nullptr; + sasl_errstring(currRC_, NULL, &err_msg); + + LOG(ERROR) << err_msg; + } + if (ssfp != nullptr && *ssfp > 0) { + bool qop_enabled_ = true; + } + } + } + return auth_state_; + } + SaslHandler::InitState SaslHandler::setupSaslConn(Context* ctx, folly::IOBufQueue& bufQueue) { + const char *mechusing, *mechlist = "GSSAPI"; + const char *out; + unsigned int outlen; + + if (auth_state_ == kUninit) { + do { + currRC_ = sasl_client_start(sconn_, /* the same context from above */ + mechlist, /* the list of mechanisms + from the server */ + NULL, /* filled in if an + interaction is needed */ + &out, /* filled in on success */ + &outlen, /* filled in on success */ + &mechusing); + } while (currRC_ == SASL_INTERACT); /* the mechanism may ask us to fill + in things many times. result is SASL_CONTINUE on success */ + if (currRC_ != SASL_CONTINUE) { + LOG(FATAL)<< "Cannot start client ("<< currRC_ << ") "; + auth_state_ = kFailure; + return auth_state_; + } + LOG(INFO) << "client started " << currRC_ << " " << outlen; + writeSaslOutput(ctx, out, outlen); + auth_state_ = kRead; + return auth_state_; + } + int bytes_sent = 0; + int bytes_received = 0; + + if (currRC_ == SASL_OK) { + return finishAuth(ctx, bufQueue); + } else { + std::unique_ptr iob = bufQueue.pop_front(); + bytes_received = iob->length(); + pBuf = (char*)malloc(iob->length()+1); + ::memcpy(pBuf, iob->data(), iob->length()); + /* LOG(INFO) << "reading queue code: " << currRC_ << " bytes: " << bytes_received + << " queue empty: " << bufQueue.empty(); */ + if (bytes_received == 0) { + auth_state_ = kFailure; + return auth_state_; + } + } + unsigned int *uiReader = (unsigned int *)pBuf; + unsigned int status = *uiReader; + + char *p = (char *)pBuf + 4; + + uiReader = (unsigned int *)p; + unsigned int uiTotalSize = *uiReader; + SwapByteOrder(uiTotalSize); + //LOG(INFO) << "total size of server response " << uiTotalSize << " status " << status; + // p now points to the buffer to be sent to server + p = p + 4; + + size_t bs = uiTotalSize; + + bs = uiTotalSize; + //LOG(INFO) << "size of bytes " << bs; + + if (auth_state_ == kFailure || status != 0 /*Status 0 is success*/) { + //Assumption here is not more than 8 * 1024 + p[bs] = '\0'; + LOG(ERROR) << "Exception from server: " << p; + auth_state_ = kFailure; + return auth_state_; + } + + out = NULL; + outlen = 0; + + currRC_ = sasl_client_step(sconn_, /* our context */ + (char*) p, /* the data from the server */ + bs, /* it's length */ + NULL, /* this should be unallocated and NULL */ + &out, /* filled in on success */ + &outlen); /* filled in on success */ + delete pBuf; + pBuf = NULL; + if (currRC_ == SASL_OK || currRC_ == SASL_CONTINUE) { + writeSaslOutput(ctx, out, outlen); + } + LOG(INFO) << "got rc " << currRC_; + if (currRC_ == SASL_OK) { + return finishAuth(ctx, bufQueue); + } + return auth_state_; + } + + void SaslHandler::SwapByteOrder(unsigned int &ui) + { + ui = (ui >> 24) | + ((ui<<8) & 0x00FF0000) | + ((ui>>8) & 0x0000FF00) | + (ui << 24); + } + ConnectionFactory::ConnectionFactory(std::shared_ptr io_pool, std::shared_ptr codec, nanoseconds connect_timeout) - : connect_timeout_(connect_timeout), + : connect_timeout_(connect_timeout), user_util_(), codec_(codec), io_pool_(io_pool), pipeline_factory_(std::make_shared(codec)) {} @@ -48,15 +408,42 @@ std::shared_ptr> ConnectionFactory::M } std::shared_ptr ConnectionFactory::Connect( std::shared_ptr> client, const std::string &hostname, - int port) { + int port, std::shared_ptr user, std::shared_ptr conf) { // Yes this will block however it makes dealing with connection pool soooooo // much nicer. // TODO see about using shared promise for this. + LOG(INFO) << "connecting client " << hostname << " " << port << " " << conf; auto pipeline = client - ->connect(SocketAddress(hostname, port, true), - std::chrono::duration_cast(connect_timeout_)) - .get(); + ->connect(SocketAddress(hostname, port, true), + std::chrono::duration_cast(connect_timeout_)) + .get(); + auto serde_ = std::make_unique(codec_); + auto pre = serde_->Preamble(); + // write the Preamble + pipeline->getContext()->getTransport()->getEventBase() + ->runInEventBaseThreadAndWait([&]() mutable { + pipeline->getContext()->write(std::move(pre)); + }); + SaslHandler* sasl_handler = new SaslHandler(hostname, user_util_.user_name(), + SaslHandler::parseServiceName(conf)); + if (user->IsSecurityEnabled(*conf)) { + LOG(INFO) << "creating SaslHandler for " << user_util_.user_name(); + pipeline->getContext()->getTransport()->getEventBase() + ->runInEventBaseThreadAndWait([&]() mutable { + pipeline->remove(pipeline->getHandler()); + pipeline->remove(pipeline->getHandler()); + pipeline->addBack(sasl_handler); + pipeline->addBack(LengthFieldBasedFrameDecoder{}); + pipeline->addBack(ClientHandler{user_util_.user_name(), codec_}); + pipeline->finalize(); + }); + } auto dispatcher = std::make_shared(); - dispatcher->setPipeline(pipeline); + pipeline->getContext()->getTransport()->getEventBase() + ->runInEventBaseThreadAndWait([&]() mutable { + dispatcher->setPipeline(pipeline); + }); + LOG(INFO) << "finalized"; return dispatcher; } +} diff --git a/hbase-native-client/connection/connection-factory.h b/hbase-native-client/connection/connection-factory.h index 32d0bf7..6d524b2 100644 --- a/hbase-native-client/connection/connection-factory.h +++ b/hbase-native-client/connection/connection-factory.h @@ -28,6 +28,8 @@ #include "connection/request.h" #include "connection/response.h" #include "connection/service.h" +#include "security/user.h" +using hbase::security::User; using std::chrono::nanoseconds; @@ -61,10 +63,13 @@ class ConnectionFactory { */ virtual std::shared_ptr Connect( std::shared_ptr> client, - const std::string &hostname, int port); + const std::string &hostname, int port, std::shared_ptr user, + std::shared_ptr conf); private: nanoseconds connect_timeout_; + UserUtil user_util_; + std::shared_ptr codec_; std::shared_ptr io_pool_; std::shared_ptr pipeline_factory_; }; diff --git a/hbase-native-client/connection/connection-pool-test.cc b/hbase-native-client/connection/connection-pool-test.cc index 623ce3c..7c5b57f 100644 --- a/hbase-native-client/connection/connection-pool-test.cc +++ b/hbase-native-client/connection/connection-pool-test.cc @@ -61,11 +61,11 @@ class MockService : public MockServiceBase { }; TEST(TestConnectionPool, TestOnlyCreateOnce) { - auto hostname = std::string{"hostname"}; + auto hostname = std::string{"localhost"}; auto mock_boot = std::make_shared(); auto mock_service = std::make_shared(); auto mock_cf = std::make_shared(); - uint32_t port{999}; + uint32_t port{2181}; EXPECT_CALL((*mock_cf), Connect(_, _, _)).Times(1).WillRepeatedly(Return(mock_service)); EXPECT_CALL((*mock_cf), MakeBootstrap()).Times(1).WillRepeatedly(Return(mock_boot)); diff --git a/hbase-native-client/connection/connection-pool.cc b/hbase-native-client/connection/connection-pool.cc index 4fe4610..64ef90f 100644 --- a/hbase-native-client/connection/connection-pool.cc +++ b/hbase-native-client/connection/connection-pool.cc @@ -35,11 +35,12 @@ using folly::SharedMutexWritePriority; using folly::SocketAddress; ConnectionPool::ConnectionPool(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout) + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout) : cf_(std::make_shared(io_executor, codec, connect_timeout)), clients_(), connections_(), - map_mutex_() {} + map_mutex_(), conf_(conf) {} ConnectionPool::ConnectionPool(std::shared_ptr cf) : cf_(cf), clients_(), connections_(), map_mutex_() {} @@ -88,7 +89,8 @@ std::shared_ptr ConnectionPool::GetNewConnection( /* create new connection */ auto clientBootstrap = cf_->MakeBootstrap(); - auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port()); + auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port(), + remote_id->user(), conf_); auto connection = std::make_shared(remote_id, dispatcher); diff --git a/hbase-native-client/connection/connection-pool.h b/hbase-native-client/connection/connection-pool.h index 2a8f195..0582d9b 100644 --- a/hbase-native-client/connection/connection-pool.h +++ b/hbase-native-client/connection/connection-pool.h @@ -50,7 +50,8 @@ class ConnectionPool { public: /** Create connection pool wit default connection factory */ ConnectionPool(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout = nanoseconds(0)); + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout = nanoseconds(0)); /** * Constructor that allows specifiying the connetion factory. @@ -93,6 +94,7 @@ class ConnectionPool { clients_; folly::SharedMutexWritePriority map_mutex_; std::shared_ptr cf_; + std::shared_ptr conf_; }; } // namespace hbase diff --git a/hbase-native-client/connection/rpc-client.cc b/hbase-native-client/connection/rpc-client.cc index 5fa1138..57df66d 100644 --- a/hbase-native-client/connection/rpc-client.cc +++ b/hbase-native-client/connection/rpc-client.cc @@ -30,9 +30,10 @@ using hbase::RpcClient; namespace hbase { RpcClient::RpcClient(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout) - : io_executor_(io_executor) { - cp_ = std::make_shared(io_executor_, codec, connect_timeout); + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout) + : io_executor_(io_executor), conf_(conf) { + cp_ = std::make_shared(io_executor_, codec, conf, connect_timeout); } void RpcClient::Close() { io_executor_->stop(); } diff --git a/hbase-native-client/connection/rpc-client.h b/hbase-native-client/connection/rpc-client.h index d416ceb..4f02529 100644 --- a/hbase-native-client/connection/rpc-client.h +++ b/hbase-native-client/connection/rpc-client.h @@ -46,7 +46,7 @@ namespace hbase { class RpcClient { public: RpcClient(std::shared_ptr io_executor, std::shared_ptr codec, - nanoseconds connect_timeout = nanoseconds(0)); + std::shared_ptr conf, nanoseconds connect_timeout = nanoseconds(0)); virtual ~RpcClient() { Close(); } @@ -78,5 +78,6 @@ class RpcClient { private: std::shared_ptr cp_; std::shared_ptr io_executor_; + std::shared_ptr conf_; }; } // namespace hbase diff --git a/hbase-native-client/core/async-connection.cc b/hbase-native-client/core/async-connection.cc index b945e38..41eb48a 100644 --- a/hbase-native-client/core/async-connection.cc +++ b/hbase-native-client/core/async-connection.cc @@ -38,7 +38,7 @@ void AsyncConnectionImpl::Init() { LOG(WARNING) << "Not using RPC Cell Codec"; } rpc_client_ = - std::make_shared(io_executor_, codec, connection_conf_->connect_timeout()); + std::make_shared(io_executor_, codec, conf_, connection_conf_->connect_timeout()); location_cache_ = std::make_shared(conf_, cpu_executor_, rpc_client_->connection_pool()); caller_factory_ = std::make_shared(shared_from_this()); diff --git a/hbase-native-client/core/client-test.cc b/hbase-native-client/core/client-test.cc index ff4879a..142d474 100644 --- a/hbase-native-client/core/client-test.cc +++ b/hbase-native-client/core/client-test.cc @@ -27,6 +27,7 @@ #include "core/put.h" #include "core/result.h" #include "core/table.h" +#include "security/user.h" #include "serde/table-name.h" #include "test-util/test-util.h" @@ -85,7 +86,8 @@ class ClientTest : public ::testing::Test { static void SetUpTestCase() { google::InstallFailureSignalHandler(); test_util = std::make_unique(); - test_util->StartMiniCluster(2); + google::InstallFailureSignalHandler(); + //test_util->StartMiniCluster(2); } }; std::unique_ptr ClientTest::test_util = nullptr; @@ -116,8 +118,11 @@ TEST_F(ClientTest, DefaultConfiguration) { } TEST_F(ClientTest, PutGet) { - // Using TestUtil to populate test data - ClientTest::test_util->CreateTable("t", "d"); + // Using TestUtil to populate test data for insecure cluster + // table would be created manually for secure cluster + if (!User::IsSecurityEnabled(*ClientTest::test_util->conf())) { + ClientTest::test_util->CreateTable("t", "d"); + } // Create TableName and Row to be fetched from HBase auto tn = folly::to("t"); diff --git a/hbase-native-client/security/BUCK b/hbase-native-client/security/BUCK index 7383028..91b547f 100644 --- a/hbase-native-client/security/BUCK +++ b/hbase-native-client/security/BUCK @@ -23,4 +23,4 @@ cxx_library( srcs=[], deps=["//core:conf"], compiler_flags=['-Weffc++'], - visibility=['//core/...', '//connection/...'],) + visibility=['PUBLIC',],) diff --git a/hbase-native-client/security/user.h b/hbase-native-client/security/user.h index 035af31..af32915 100644 --- a/hbase-native-client/security/user.h +++ b/hbase-native-client/security/user.h @@ -21,6 +21,7 @@ #include #include #include "core/configuration.h" +#include namespace hbase { namespace security { @@ -35,6 +36,7 @@ class User { static std::shared_ptr defaultUser() { return std::make_shared("__drwho"); } static bool IsSecurityEnabled(const Configuration& conf) { + LOG(INFO) << "auth from conf: " << conf.Get("hbase.security.authentication", ""); return conf.Get("hbase.security.authentication", "").compare(kKerberos) == 0; } diff --git a/hbase-native-client/serde/rpc.cc b/hbase-native-client/serde/rpc.cc index e657a64..1fc18a2 100644 --- a/hbase-native-client/serde/rpc.cc +++ b/hbase-native-client/serde/rpc.cc @@ -31,8 +31,10 @@ #include "if/RPC.pb.h" #include "utils/version.h" +//#include "security/user.h" using namespace hbase; +//using namespace hbase::security; using folly::IOBuf; using folly::io::RWPrivateCursor; @@ -50,6 +52,7 @@ static const std::string PREAMBLE = "HBas"; static const std::string INTERFACE = "ClientService"; static const uint8_t RPC_VERSION = 0; static const uint8_t DEFAULT_AUTH_TYPE = 80; +static const uint8_t KERBEROS_AUTH_TYPE = 81; int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) { if (buf == nullptr || msg == nullptr) { @@ -85,7 +88,14 @@ int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) { return coded_stream.CurrentPosition(); } -RpcSerde::RpcSerde(std::shared_ptr codec) : auth_type_(DEFAULT_AUTH_TYPE), codec_(codec) {} +RpcSerde::RpcSerde(std::shared_ptr codec) : codec_(codec) { + auth_type_ = KERBEROS_AUTH_TYPE; +// if (User::isSecurityEnabled()) { +// auth_type_ = KERBEROS_AUTH_TYPE; +// } else { +// auth_type_ = DEFAULT_AUTH_TYPE; +// } +} unique_ptr RpcSerde::Preamble() { auto magic = IOBuf::copyBuffer(PREAMBLE, 0, 2); diff --git a/hbase-native-client/test-util/mini-cluster.cc b/hbase-native-client/test-util/mini-cluster.cc index 34da54c..bd3f3b1 100644 --- a/hbase-native-client/test-util/mini-cluster.cc +++ b/hbase-native-client/test-util/mini-cluster.cc @@ -258,6 +258,9 @@ jobject MiniCluster::StartCluster(int num_region_servers) { } void MiniCluster::StopCluster() { + if (cluster_ == nullptr) { + return; + } env(); jmethodID mid = env_->GetMethodID(testing_util_class_, "shutdownMiniCluster", "()V"); env_->CallVoidMethod(htu(), mid); diff --git a/hbase-native-client/test-util/mini-cluster.h b/hbase-native-client/test-util/mini-cluster.h index 4119cb5..49028a0 100644 --- a/hbase-native-client/test-util/mini-cluster.h +++ b/hbase-native-client/test-util/mini-cluster.h @@ -63,7 +63,7 @@ class MiniCluster { jmethodID move_mid_; jmethodID str_ctor_mid_; jobject htu_; - jobject cluster_; + jobject cluster_ = nullptr; pthread_mutex_t count_mutex_; JavaVM *jvm_; JNIEnv *CreateVM(JavaVM **jvm); diff --git a/hbase-native-client/test-util/test-util.cc b/hbase-native-client/test-util/test-util.cc index c4e6ed2..b420e40 100644 --- a/hbase-native-client/test-util/test-util.cc +++ b/hbase-native-client/test-util/test-util.cc @@ -44,7 +44,14 @@ std::string TestUtil::RandString(int len) { return s; } -TestUtil::TestUtil() : temp_dir_(TestUtil::RandString()) {} +TestUtil::TestUtil() : temp_dir_(TestUtil::RandString()) { + std::string quorum("localhost:"); + const std::string port = "2181"; // mini->GetConfValue("hbase.zookeeper.property.clientPort"); + conf()->Set("hbase.zookeeper.quorum", quorum + port); + conf()->Set("hbase.security.authentication", "kerberos"); + conf()->Set("hbase.regionserver.kerberos.principal", "hbase/23a03935850c@EXAMPLE.COM"); + conf()->Set(ZKUtil::kHBaseZookeeperClientPort_, port); +} TestUtil::~TestUtil() { if (mini_) StopMiniCluster(); diff --git a/hbase-native-client/utils/BUCK b/hbase-native-client/utils/BUCK index 04e2b67..a1704c4 100644 --- a/hbase-native-client/utils/BUCK +++ b/hbase-native-client/utils/BUCK @@ -28,6 +28,8 @@ cxx_library( srcs=["bytes-util.cc", "connection-util.cc", "user-util.cc"], deps=['//third-party:folly',], tests=[":user-util-test"], + linker_flags = ['-L/usr/local/lib','-lkrb5'], + exported_linker_flags = ['-L/usr/local/lib','-lkrb5'], visibility=['PUBLIC',], compiler_flags=['-Weffc++'],) cxx_test( diff --git a/hbase-native-client/utils/user-util.cc b/hbase-native-client/utils/user-util.cc index 9e170e0..7f37ef7 100644 --- a/hbase-native-client/utils/user-util.cc +++ b/hbase-native-client/utils/user-util.cc @@ -23,11 +23,13 @@ #include #include #include +#include using namespace hbase; using namespace std; -UserUtil::UserUtil() : once_flag_{}, user_name_{"drwho"} {} +UserUtil::UserUtil() : once_flag_{} {} + // user_name_{"hbase@EXAMPLE.COM"} string UserUtil::user_name() { std::call_once(once_flag_, [this]() { compute_user_name(); }); @@ -45,4 +47,29 @@ void UserUtil::compute_user_name() { if (passwd && passwd->pw_name) { user_name_ = string{passwd->pw_name}; } + krb5_context ctx; + krb5_error_code ret = krb5_init_context(&ctx); + if (ret != 0) LOG(INFO) << "cannot init krb ctx " << ret; + else { + krb5_ccache ccache; + ret = krb5_cc_default(ctx, &ccache); + if (ret != 0) LOG(INFO) << "cannot get default cache " << ret; + else { + krb5_principal princ; + ret = krb5_cc_get_principal(ctx, ccache, &princ); + if (ret != 0) LOG(INFO) << "cannot get default cache " << ret; + else { + user_name_ = princ->data->data; + if (krb5_princ_size(ctx, princ) >= 2) { + user_name_ += "/"; + krb5_data *d = princ->data+1; + user_name_ += (char*)d->data; + } + user_name_ += "@"; + user_name_ += princ->realm.data; + } + krb5_free_principal(ctx, princ); + } + } + LOG(INFO) << "computed username " << user_name_; } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index 73226aa..fbbdd69 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -1426,6 +1426,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { if (LOG.isDebugEnabled()) { LOG.debug("Will send token of size " + replyToken.length + " from saslServer."); + LOG.debug(Bytes.toStringBinary(replyToken)); } doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, null); @@ -1507,6 +1508,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } int version = versionAndAuthBytes.get(0); byte authbyte = versionAndAuthBytes.get(1); + LOG.debug("preamble auth " + authbyte + " " + isSecurityEnabled + " " + authMethod); this.authMethod = AuthMethod.valueOf(authbyte); if (version != CURRENT_VERSION) { String msg = getFatalConnectionString(version, authbyte); @@ -1707,6 +1709,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { ugi = UserGroupInformation.createProxyUser(protocolUser .getUserName(), realUser); // Now the user is a proxy user, set Authentication method Proxy. + LOG.debug("protocol user '" + protocolUser.getUserName() + "' ugi '" + ugi.getUserName() + "'" + protocolUser + " " + ugi); ugi.setAuthenticationMethod(AuthenticationMethod.PROXY); } } diff --git a/hbase-native-client/connection/sasl-handler.h b/hbase-native-client/connection/sasl-handler.h new file mode 100644 index 0000000..a23a5e5 --- /dev/null +++ b/hbase-native-client/connection/sasl-handler.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "connection/service.h" +#include "security/user.h" +using hbase::security::User; + +using std::chrono::nanoseconds; + +namespace hbase { + +/** + * Class to perform SASL handshake with server + */ +class SaslHandler + : public wangle::HandlerAdapter>{ + public: + SaslHandler(std::string host_name, std::string user_name, std::string service_name); + ~SaslHandler(); + + enum InitState { + kUninit, // uninitialized + kRead, // waiting to read response from server + kSucc, // handshake is successful + kFailure // handshake fails + }; + + // from HandlerAdapter + void read(Context* ctx, folly::IOBufQueue& buf) override; + folly::Future write(Context* ctx, std::unique_ptr buf) override; + + // parse the service name from given Configuration + static std::string parseServiceName(std::shared_ptr conf); + + private: + // empty IOBufQueue which is used to answer client write() calls + folly::IOBufQueue dummyQ; + // used by Cyrus + sasl_conn_t *sconn_ = NULL; + const int buffer_size = 8 * 1024; + char *pBuf; + int currRC_ = 0; + std::string user_name_; + std::string host_name_; + InitState auth_state_ = kUninit; + // vector of folly::IOBuf which buffers client writes before handshake is complete + std::vector> iobuf_; + + // writes the output returned by sasl_client_XX to server + void writeSaslOutput(Context* ctx, const char *out, unsigned int outlen); + void performHandshake(Context* ctx); + void SwapByteOrder(unsigned int &ui); + SaslHandler::InitState finishAuth(Context* ctx, folly::IOBufQueue& bufQueue); + InitState setupSaslConn(Context* ctx, folly::IOBufQueue& buf); + InitState getState() { + return auth_state_; + } + static std::string toStringBinary(char* temp, unsigned int uiTotalSize); +}; +} // namespace hbase