diff --git a/hbase-native-client/Dockerfile b/hbase-native-client/Dockerfile index 8b56590..d750808 100644 --- a/hbase-native-client/Dockerfile +++ b/hbase-native-client/Dockerfile @@ -24,6 +24,14 @@ ARG CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -g -fno-omit-frame-pointer -O2 -p ENV JAVA_HOME="/usr/lib/jvm/java-8-openjdk-amd64/" +RUN wget ftp://ftp.cyrusimap.org/cyrus-sasl/cyrus-sasl-2.1.26.tar.gz ; \ + tar zxf cyrus-sasl-2.1.26.tar.gz ; \ + cd cyrus-sasl-2.1.26 ; \ + ./configure ; \ + make ; \ + make install ;\ + cp /usr/local/lib/sasl2/* /usr/lib/sasl2/ + RUN apt-get install -y vim maven inetutils-ping python-pip doxygen graphviz clang-format && \ pip install yapf && \ apt-get -qq clean && \ diff --git a/hbase-native-client/connection/BUCK b/hbase-native-client/connection/BUCK index 19536d5..29befa0 100644 --- a/hbase-native-client/connection/BUCK +++ b/hbase-native-client/connection/BUCK @@ -22,6 +22,7 @@ cxx_library( exported_headers=[ "client-dispatcher.h", "client-handler.h", + "sasl-handler.h", "connection-factory.h", "connection-pool.h", "connection-id.h", @@ -40,6 +41,7 @@ cxx_library( "pipeline.cc", "request.cc", "rpc-client.cc", + "sasl-handler.cc", ], deps=[ "//if:if", @@ -50,6 +52,8 @@ cxx_library( "//third-party:wangle", ], compiler_flags=['-Weffc++'], + linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'], + exported_linker_flags = ['-L/usr/local/lib','-lsasl2', '-lkrb5'], visibility=['//core/...',],) cxx_test( name="connection-pool-test", diff --git a/hbase-native-client/connection/client-handler.cc b/hbase-native-client/connection/client-handler.cc index af84572..3a9785a 100644 --- a/hbase-native-client/connection/client-handler.cc +++ b/hbase-native-client/connection/client-handler.cc @@ -36,9 +36,10 @@ using hbase::pb::ResponseHeader; using hbase::pb::GetResponse; using google::protobuf::Message; -ClientHandler::ClientHandler(std::string user_name, std::shared_ptr codec) +ClientHandler::ClientHandler(std::string user_name, std::shared_ptr codec, + std::shared_ptr conf) : user_name_(user_name), - serde_(codec), + serde_(codec, conf), once_flag_(std::make_unique()), resp_msgs_( make_unique>>( @@ -103,10 +104,9 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) { // We need to send the header once. // So use call_once to make sure that only one thread wins this. std::call_once((*once_flag_), [ctx, this]() { - auto pre = serde_.Preamble(); auto header = serde_.Header(user_name_); - pre->appendChain(std::move(header)); - ctx->fireWrite(std::move(pre)); + //LOG(INFO) << "writing " << user_name_ << " "<< header->length(); + ctx->fireWrite(std::move(header)); }); // Now store the call id to response. @@ -115,5 +115,6 @@ Future ClientHandler::write(Context *ctx, std::unique_ptr r) { VLOG(1) << "Writing RPC Request with call_id:" << r->call_id(); // Send the data down the pipeline. - return ctx->fireWrite(serde_.Request(r->call_id(), r->method(), r->req_msg().get())); + auto iob = serde_.Request(r->call_id(), r->method(), r->req_msg().get()); + return ctx->fireWrite(std::move(iob)); } diff --git a/hbase-native-client/connection/client-handler.h b/hbase-native-client/connection/client-handler.h index afb8e62..f4216f3 100644 --- a/hbase-native-client/connection/client-handler.h +++ b/hbase-native-client/connection/client-handler.h @@ -59,7 +59,8 @@ class ClientHandler * Create the handler * @param user_name the user name of the user running this process. */ - explicit ClientHandler(std::string user_name, std::shared_ptr codec); + explicit ClientHandler(std::string user_name, std::shared_ptr codec, + std::shared_ptr conf); /** * Get bytes from the wire. diff --git a/hbase-native-client/connection/connection-factory.cc b/hbase-native-client/connection/connection-factory.cc index 832b00f..47b6a24 100644 --- a/hbase-native-client/connection/connection-factory.cc +++ b/hbase-native-client/connection/connection-factory.cc @@ -18,23 +18,39 @@ */ #include "connection/connection-factory.h" +#include "connection/sasl-handler.h" #include #include "connection/client-dispatcher.h" +#include "connection/client-handler.h" #include "connection/pipeline.h" #include "connection/service.h" +#include "serde/rpc.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include using namespace folly; -using namespace hbase; +using namespace wangle; using std::chrono::milliseconds; using std::chrono::nanoseconds; +namespace hbase { + ConnectionFactory::ConnectionFactory(std::shared_ptr io_pool, - std::shared_ptr codec, nanoseconds connect_timeout) - : connect_timeout_(connect_timeout), - io_pool_(io_pool), - pipeline_factory_(std::make_shared(codec)) {} + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout) + : connect_timeout_(connect_timeout), user_util_(), codec_(codec), + io_pool_(io_pool), conf_(conf), + pipeline_factory_(std::make_shared(codec, conf)) {} std::shared_ptr> ConnectionFactory::MakeBootstrap() { auto client = std::make_shared>(); @@ -48,15 +64,24 @@ std::shared_ptr> ConnectionFactory::M } std::shared_ptr ConnectionFactory::Connect( std::shared_ptr> client, const std::string &hostname, - int port) { + int port, std::shared_ptr user, std::shared_ptr conf) { // Yes this will block however it makes dealing with connection pool soooooo // much nicer. // TODO see about using shared promise for this. auto pipeline = client - ->connect(SocketAddress(hostname, port, true), - std::chrono::duration_cast(connect_timeout_)) - .get(); + ->connect(SocketAddress(hostname, port, true), + std::chrono::duration_cast(connect_timeout_)) + .get(); + auto serde_ = std::make_unique(codec_, conf); + auto pre = serde_->Preamble(); + // write the Preamble + pipeline->getContext()->getTransport()->getEventBase() + ->runInEventBaseThreadAndWait([&]() mutable { + pipeline->getContext()->write(std::move(pre)); + }); auto dispatcher = std::make_shared(); dispatcher->setPipeline(pipeline); + LOG(INFO) << "finalized"; return dispatcher; } +} diff --git a/hbase-native-client/connection/connection-factory.h b/hbase-native-client/connection/connection-factory.h index 32d0bf7..6f7efa2 100644 --- a/hbase-native-client/connection/connection-factory.h +++ b/hbase-native-client/connection/connection-factory.h @@ -28,6 +28,7 @@ #include "connection/request.h" #include "connection/response.h" #include "connection/service.h" +#include "security/user.h" using std::chrono::nanoseconds; @@ -44,7 +45,8 @@ class ConnectionFactory { * There should only be one ConnectionFactory per client. */ ConnectionFactory(std::shared_ptr io_pool, - std::shared_ptr codec, nanoseconds connect_timeout = nanoseconds(0)); + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout = nanoseconds(0)); /** Default Destructor */ virtual ~ConnectionFactory() = default; @@ -61,10 +63,14 @@ class ConnectionFactory { */ virtual std::shared_ptr Connect( std::shared_ptr> client, - const std::string &hostname, int port); + const std::string &hostname, int port, std::shared_ptr user, + std::shared_ptr conf); private: nanoseconds connect_timeout_; + UserUtil user_util_; + std::shared_ptr codec_; + std::shared_ptr conf_; std::shared_ptr io_pool_; std::shared_ptr pipeline_factory_; }; diff --git a/hbase-native-client/connection/connection-pool-test.cc b/hbase-native-client/connection/connection-pool-test.cc index 623ce3c..f45df5f 100644 --- a/hbase-native-client/connection/connection-pool-test.cc +++ b/hbase-native-client/connection/connection-pool-test.cc @@ -36,11 +36,12 @@ using hbase::ConnectionId; class MockConnectionFactory : public ConnectionFactory { public: - MockConnectionFactory() : ConnectionFactory(nullptr, nullptr) {} + MockConnectionFactory() : ConnectionFactory(nullptr, nullptr, nullptr) {} MOCK_METHOD0(MakeBootstrap, std::shared_ptr>()); - MOCK_METHOD3(Connect, std::shared_ptr( + MOCK_METHOD5(Connect, std::shared_ptr( std::shared_ptr>, - const std::string &hostname, int port)); + const std::string &hostname, int port, std::shared_ptr user, + std::shared_ptr conf)); }; class MockBootstrap : public wangle::ClientBootstrap {}; @@ -67,7 +68,7 @@ TEST(TestConnectionPool, TestOnlyCreateOnce) { auto mock_cf = std::make_shared(); uint32_t port{999}; - EXPECT_CALL((*mock_cf), Connect(_, _, _)).Times(1).WillRepeatedly(Return(mock_service)); + EXPECT_CALL((*mock_cf), Connect(_, _, _, _, _)).Times(1).WillRepeatedly(Return(mock_service)); EXPECT_CALL((*mock_cf), MakeBootstrap()).Times(1).WillRepeatedly(Return(mock_boot)); ConnectionPool cp{mock_cf}; @@ -86,7 +87,7 @@ TEST(TestConnectionPool, TestOnlyCreateMultipleDispose) { auto mock_service = std::make_shared(); auto mock_cf = std::make_shared(); - EXPECT_CALL((*mock_cf), Connect(_, _, _)).Times(2).WillRepeatedly(Return(mock_service)); + EXPECT_CALL((*mock_cf), Connect(_, _, _, _, _)).Times(2).WillRepeatedly(Return(mock_service)); EXPECT_CALL((*mock_cf), MakeBootstrap()).Times(2).WillRepeatedly(Return(mock_boot)); ConnectionPool cp{mock_cf}; diff --git a/hbase-native-client/connection/connection-pool.cc b/hbase-native-client/connection/connection-pool.cc index 4fe4610..faa60e8 100644 --- a/hbase-native-client/connection/connection-pool.cc +++ b/hbase-native-client/connection/connection-pool.cc @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -35,11 +36,12 @@ using folly::SharedMutexWritePriority; using folly::SocketAddress; ConnectionPool::ConnectionPool(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout) - : cf_(std::make_shared(io_executor, codec, connect_timeout)), + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout) + : cf_(std::make_shared(io_executor, codec, conf, connect_timeout)), clients_(), connections_(), - map_mutex_() {} + map_mutex_(), conf_(conf) {} ConnectionPool::ConnectionPool(std::shared_ptr cf) : cf_(cf), clients_(), connections_(), map_mutex_() {} @@ -88,7 +90,8 @@ std::shared_ptr ConnectionPool::GetNewConnection( /* create new connection */ auto clientBootstrap = cf_->MakeBootstrap(); - auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port()); + auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port(), + remote_id->user(), conf_); auto connection = std::make_shared(remote_id, dispatcher); @@ -118,4 +121,8 @@ void ConnectionPool::Close() { } connections_.clear(); clients_.clear(); + auto secure = User::IsSecurityEnabled(*conf_); + if (secure) { + int rc = sasl_client_done(); + } } diff --git a/hbase-native-client/connection/connection-pool.h b/hbase-native-client/connection/connection-pool.h index 2a8f195..0582d9b 100644 --- a/hbase-native-client/connection/connection-pool.h +++ b/hbase-native-client/connection/connection-pool.h @@ -50,7 +50,8 @@ class ConnectionPool { public: /** Create connection pool wit default connection factory */ ConnectionPool(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout = nanoseconds(0)); + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout = nanoseconds(0)); /** * Constructor that allows specifiying the connetion factory. @@ -93,6 +94,7 @@ class ConnectionPool { clients_; folly::SharedMutexWritePriority map_mutex_; std::shared_ptr cf_; + std::shared_ptr conf_; }; } // namespace hbase diff --git a/hbase-native-client/connection/pipeline.cc b/hbase-native-client/connection/pipeline.cc index 00dc05c..af1d7df 100644 --- a/hbase-native-client/connection/pipeline.cc +++ b/hbase-native-client/connection/pipeline.cc @@ -25,21 +25,28 @@ #include #include "connection/client-handler.h" +#include "connection/sasl-handler.h" using namespace folly; using namespace hbase; using namespace wangle; -RpcPipelineFactory::RpcPipelineFactory(std::shared_ptr codec) - : user_util_(), codec_(codec) {} +RpcPipelineFactory::RpcPipelineFactory(std::shared_ptr codec, std::shared_ptr conf) + : user_util_(), codec_(codec), conf_(conf) {} SerializePipeline::Ptr RpcPipelineFactory::newPipeline( std::shared_ptr sock) { auto pipeline = SerializePipeline::create(); pipeline->addBack(AsyncSocketHandler{sock}); pipeline->addBack(EventBaseHandler{}); + auto secure = User::IsSecurityEnabled(*conf_); + if (secure) { + SaslHandler* sasl_handler = new SaslHandler(user_util_.user_name(secure), + SaslHandler::parseServiceName(conf_, kPrincipalKey)); + pipeline->addBack(sasl_handler); + } pipeline->addBack(LengthFieldBasedFrameDecoder{}); - pipeline->addBack(ClientHandler{user_util_.user_name(), codec_}); + pipeline->addBack(ClientHandler{user_util_.user_name(secure), codec_, conf_}); pipeline->finalize(); return pipeline; } diff --git a/hbase-native-client/connection/pipeline.h b/hbase-native-client/connection/pipeline.h index ea40cfd..df83ee4 100644 --- a/hbase-native-client/connection/pipeline.h +++ b/hbase-native-client/connection/pipeline.h @@ -23,6 +23,7 @@ #include +#include "core/configuration.h" #include "connection/request.h" #include "connection/response.h" #include "serde/codec.h" @@ -41,7 +42,7 @@ class RpcPipelineFactory : public wangle::PipelineFactory { /** * Constructor. This will create user util. */ - explicit RpcPipelineFactory(std::shared_ptr codec); + explicit RpcPipelineFactory(std::shared_ptr codec, std::shared_ptr conf); /** * Create a new pipeline. @@ -57,5 +58,7 @@ class RpcPipelineFactory : public wangle::PipelineFactory { private: UserUtil user_util_; std::shared_ptr codec_; + std::shared_ptr conf_; + std::string kPrincipalKey = "hbase.regionserver.kerberos.principal"; }; } // namespace hbase diff --git a/hbase-native-client/connection/rpc-client.cc b/hbase-native-client/connection/rpc-client.cc index 5fa1138..57df66d 100644 --- a/hbase-native-client/connection/rpc-client.cc +++ b/hbase-native-client/connection/rpc-client.cc @@ -30,9 +30,10 @@ using hbase::RpcClient; namespace hbase { RpcClient::RpcClient(std::shared_ptr io_executor, - std::shared_ptr codec, nanoseconds connect_timeout) - : io_executor_(io_executor) { - cp_ = std::make_shared(io_executor_, codec, connect_timeout); + std::shared_ptr codec, std::shared_ptr conf, + nanoseconds connect_timeout) + : io_executor_(io_executor), conf_(conf) { + cp_ = std::make_shared(io_executor_, codec, conf, connect_timeout); } void RpcClient::Close() { io_executor_->stop(); } diff --git a/hbase-native-client/connection/rpc-client.h b/hbase-native-client/connection/rpc-client.h index d416ceb..4f02529 100644 --- a/hbase-native-client/connection/rpc-client.h +++ b/hbase-native-client/connection/rpc-client.h @@ -46,7 +46,7 @@ namespace hbase { class RpcClient { public: RpcClient(std::shared_ptr io_executor, std::shared_ptr codec, - nanoseconds connect_timeout = nanoseconds(0)); + std::shared_ptr conf, nanoseconds connect_timeout = nanoseconds(0)); virtual ~RpcClient() { Close(); } @@ -78,5 +78,6 @@ class RpcClient { private: std::shared_ptr cp_; std::shared_ptr io_executor_; + std::shared_ptr conf_; }; } // namespace hbase diff --git a/hbase-native-client/core/async-connection.cc b/hbase-native-client/core/async-connection.cc index b945e38..41eb48a 100644 --- a/hbase-native-client/core/async-connection.cc +++ b/hbase-native-client/core/async-connection.cc @@ -38,7 +38,7 @@ void AsyncConnectionImpl::Init() { LOG(WARNING) << "Not using RPC Cell Codec"; } rpc_client_ = - std::make_shared(io_executor_, codec, connection_conf_->connect_timeout()); + std::make_shared(io_executor_, codec, conf_, connection_conf_->connect_timeout()); location_cache_ = std::make_shared(conf_, cpu_executor_, rpc_client_->connection_pool()); caller_factory_ = std::make_shared(shared_from_this()); diff --git a/hbase-native-client/core/async-rpc-retrying-test.cc b/hbase-native-client/core/async-rpc-retrying-test.cc index 4956972..ce9463d 100644 --- a/hbase-native-client/core/async-rpc-retrying-test.cc +++ b/hbase-native-client/core/async-rpc-retrying-test.cc @@ -198,7 +198,7 @@ TEST(AsyncRpcRetryTest, TestGetBasic) { auto io_executor_ = std::make_shared(1); auto codec = std::make_shared(); - auto rpc_client = std::make_shared(io_executor_, codec); + auto rpc_client = std::make_shared(io_executor_, codec, test_util->conf()); /* init connection configuration */ auto connection_conf = std::make_shared( diff --git a/hbase-native-client/core/location-cache-test.cc b/hbase-native-client/core/location-cache-test.cc index 8d1ac5f..3253c56 100644 --- a/hbase-native-client/core/location-cache-test.cc +++ b/hbase-native-client/core/location-cache-test.cc @@ -52,7 +52,7 @@ TEST_F(LocationCacheTest, TestGetMetaNodeContents) { auto cpu = std::make_shared(4); auto io = std::make_shared(4); auto codec = std::make_shared(); - auto cp = std::make_shared(io, codec); + auto cp = std::make_shared(io, codec, LocationCacheTest::test_util_->conf()); LocationCache cache{LocationCacheTest::test_util_->conf(), cpu, cp}; auto f = cache.LocateMeta(); auto result = f.get(); @@ -68,7 +68,7 @@ TEST_F(LocationCacheTest, TestGetRegionLocation) { auto cpu = std::make_shared(4); auto io = std::make_shared(4); auto codec = std::make_shared(); - auto cp = std::make_shared(io, codec); + auto cp = std::make_shared(io, codec, LocationCacheTest::test_util_->conf()); LocationCache cache{LocationCacheTest::test_util_->conf(), cpu, cp}; // If there is no table this should throw an exception @@ -87,7 +87,7 @@ TEST_F(LocationCacheTest, TestCaching) { auto cpu = std::make_shared(4); auto io = std::make_shared(4); auto codec = std::make_shared(); - auto cp = std::make_shared(io, codec); + auto cp = std::make_shared(io, codec, LocationCacheTest::test_util_->conf()); LocationCache cache{LocationCacheTest::test_util_->conf(), cpu, cp}; auto tn_1 = folly::to("t1"); diff --git a/hbase-native-client/security/BUCK b/hbase-native-client/security/BUCK index 7383028..91b547f 100644 --- a/hbase-native-client/security/BUCK +++ b/hbase-native-client/security/BUCK @@ -23,4 +23,4 @@ cxx_library( srcs=[], deps=["//core:conf"], compiler_flags=['-Weffc++'], - visibility=['//core/...', '//connection/...'],) + visibility=['PUBLIC',],) diff --git a/hbase-native-client/security/user.h b/hbase-native-client/security/user.h index 035af31..41227df 100644 --- a/hbase-native-client/security/user.h +++ b/hbase-native-client/security/user.h @@ -21,6 +21,7 @@ #include #include #include "core/configuration.h" +#include namespace hbase { namespace security { diff --git a/hbase-native-client/serde/BUCK b/hbase-native-client/serde/BUCK index 38e7b4d..170636d 100644 --- a/hbase-native-client/serde/BUCK +++ b/hbase-native-client/serde/BUCK @@ -31,7 +31,8 @@ cxx_library( "rpc.cc", "zk.cc", ], - deps=["//if:if", "//third-party:folly", "//utils:utils"], + deps=["//if:if", "//third-party:folly", "//utils:utils", + "//security:security"], tests=[ ":client-deserializer-test", ":client-serializer-test", diff --git a/hbase-native-client/serde/client-deserializer-test.cc b/hbase-native-client/serde/client-deserializer-test.cc index 054684d..21ee4ac 100644 --- a/hbase-native-client/serde/client-deserializer-test.cc +++ b/hbase-native-client/serde/client-deserializer-test.cc @@ -30,12 +30,12 @@ using hbase::pb::RegionSpecifier; using hbase::pb::RegionSpecifier_RegionSpecifierType; TEST(TestRpcSerde, TestReturnFalseOnNullPtr) { - RpcSerde deser{nullptr}; + RpcSerde deser{nullptr, nullptr}; ASSERT_LT(deser.ParseDelimited(nullptr, nullptr), 0); } TEST(TestRpcSerde, TestReturnFalseOnBadInput) { - RpcSerde deser{nullptr}; + RpcSerde deser{nullptr, nullptr}; auto buf = IOBuf::copyBuffer("test"); GetRequest gr; @@ -44,8 +44,8 @@ TEST(TestRpcSerde, TestReturnFalseOnBadInput) { TEST(TestRpcSerde, TestGoodGetRequestFullRoundTrip) { GetRequest in; - RpcSerde ser{nullptr}; - RpcSerde deser{nullptr}; + RpcSerde ser{nullptr, nullptr}; + RpcSerde deser{nullptr, nullptr}; // fill up the GetRequest. in.mutable_region()->set_value("test_region_id"); diff --git a/hbase-native-client/serde/client-serializer-test.cc b/hbase-native-client/serde/client-serializer-test.cc index 33c48f3..02dac20 100644 --- a/hbase-native-client/serde/client-serializer-test.cc +++ b/hbase-native-client/serde/client-serializer-test.cc @@ -32,7 +32,7 @@ using namespace folly; using namespace folly::io; TEST(RpcSerdeTest, PreambleIncludesHBas) { - RpcSerde ser{nullptr}; + RpcSerde ser{nullptr, nullptr}; auto buf = ser.Preamble(); const char *p = reinterpret_cast(buf->data()); // Take the first for chars and make sure they are the @@ -43,14 +43,14 @@ TEST(RpcSerdeTest, PreambleIncludesHBas) { } TEST(RpcSerdeTest, PreambleIncludesVersion) { - RpcSerde ser{nullptr}; + RpcSerde ser{nullptr, nullptr}; auto buf = ser.Preamble(); EXPECT_EQ(0, static_cast(buf->data())[4]); EXPECT_EQ(80, static_cast(buf->data())[5]); } TEST(RpcSerdeTest, TestHeaderLengthPrefixed) { - RpcSerde ser{nullptr}; + RpcSerde ser{nullptr, nullptr}; auto header = ser.Header("elliott"); // The header should be prefixed by 4 bytes of length. @@ -65,7 +65,7 @@ TEST(RpcSerdeTest, TestHeaderLengthPrefixed) { } TEST(RpcSerdeTest, TestHeaderDecode) { - RpcSerde ser{nullptr}; + RpcSerde ser{nullptr, nullptr}; auto buf = ser.Header("elliott"); auto header_buf = buf->next(); ConnectionHeader h; diff --git a/hbase-native-client/serde/rpc.cc b/hbase-native-client/serde/rpc.cc index e657a64..cd1b95b 100644 --- a/hbase-native-client/serde/rpc.cc +++ b/hbase-native-client/serde/rpc.cc @@ -30,6 +30,7 @@ #include #include "if/RPC.pb.h" +#include "security/user.h" #include "utils/version.h" using namespace hbase; @@ -50,6 +51,7 @@ static const std::string PREAMBLE = "HBas"; static const std::string INTERFACE = "ClientService"; static const uint8_t RPC_VERSION = 0; static const uint8_t DEFAULT_AUTH_TYPE = 80; +static const uint8_t KERBEROS_AUTH_TYPE = 81; int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) { if (buf == nullptr || msg == nullptr) { @@ -85,7 +87,14 @@ int RpcSerde::ParseDelimited(const IOBuf *buf, Message *msg) { return coded_stream.CurrentPosition(); } -RpcSerde::RpcSerde(std::shared_ptr codec) : auth_type_(DEFAULT_AUTH_TYPE), codec_(codec) {} +RpcSerde::RpcSerde(std::shared_ptr codec, std::shared_ptr conf) : codec_(codec) { + auth_type_ = KERBEROS_AUTH_TYPE; + if (conf != nullptr && security::User::IsSecurityEnabled(*conf)) { + auth_type_ = KERBEROS_AUTH_TYPE; + } else { + auth_type_ = DEFAULT_AUTH_TYPE; + } +} unique_ptr RpcSerde::Preamble() { auto magic = IOBuf::copyBuffer(PREAMBLE, 0, 2); @@ -94,7 +103,6 @@ unique_ptr RpcSerde::Preamble() { c.skip(4); // Version c.write(RPC_VERSION); - // Standard security aka Please don't lie to me. c.write(auth_type_); return magic; } diff --git a/hbase-native-client/serde/rpc.h b/hbase-native-client/serde/rpc.h index abebe94..c7e85eb 100644 --- a/hbase-native-client/serde/rpc.h +++ b/hbase-native-client/serde/rpc.h @@ -21,6 +21,7 @@ #include #include +#include "core/configuration.h" #include "if/HBase.pb.h" #include "serde/cell-scanner.h" #include "serde/codec.h" @@ -48,7 +49,7 @@ class RpcSerde { /** * Constructor assumes the default auth type. */ - RpcSerde(std::shared_ptr codec); + RpcSerde(std::shared_ptr codec, std::shared_ptr conf); /** * Destructor. This is provided just for testing purposes. diff --git a/hbase-native-client/test-util/mini-cluster.cc b/hbase-native-client/test-util/mini-cluster.cc index 34da54c..bd3f3b1 100644 --- a/hbase-native-client/test-util/mini-cluster.cc +++ b/hbase-native-client/test-util/mini-cluster.cc @@ -258,6 +258,9 @@ jobject MiniCluster::StartCluster(int num_region_servers) { } void MiniCluster::StopCluster() { + if (cluster_ == nullptr) { + return; + } env(); jmethodID mid = env_->GetMethodID(testing_util_class_, "shutdownMiniCluster", "()V"); env_->CallVoidMethod(htu(), mid); diff --git a/hbase-native-client/test-util/mini-cluster.h b/hbase-native-client/test-util/mini-cluster.h index 4119cb5..49028a0 100644 --- a/hbase-native-client/test-util/mini-cluster.h +++ b/hbase-native-client/test-util/mini-cluster.h @@ -63,7 +63,7 @@ class MiniCluster { jmethodID move_mid_; jmethodID str_ctor_mid_; jobject htu_; - jobject cluster_; + jobject cluster_ = nullptr; pthread_mutex_t count_mutex_; JavaVM *jvm_; JNIEnv *CreateVM(JavaVM **jvm); diff --git a/hbase-native-client/utils/BUCK b/hbase-native-client/utils/BUCK index 04e2b67..a1704c4 100644 --- a/hbase-native-client/utils/BUCK +++ b/hbase-native-client/utils/BUCK @@ -28,6 +28,8 @@ cxx_library( srcs=["bytes-util.cc", "connection-util.cc", "user-util.cc"], deps=['//third-party:folly',], tests=[":user-util-test"], + linker_flags = ['-L/usr/local/lib','-lkrb5'], + exported_linker_flags = ['-L/usr/local/lib','-lkrb5'], visibility=['PUBLIC',], compiler_flags=['-Weffc++'],) cxx_test( diff --git a/hbase-native-client/utils/user-util-test.cc b/hbase-native-client/utils/user-util-test.cc index 7c11d8c..aa3fa45 100644 --- a/hbase-native-client/utils/user-util-test.cc +++ b/hbase-native-client/utils/user-util-test.cc @@ -28,7 +28,7 @@ using namespace hbase; TEST(TestUserUtil, TestGetSomething) { UserUtil u_util; - string name = u_util.user_name(); + string name = u_util.user_name(false); // TODO shell out to whoami to check this. ASSERT_GT(name.length(), 0); diff --git a/hbase-native-client/utils/user-util.cc b/hbase-native-client/utils/user-util.cc index 9e170e0..8d19a7d 100644 --- a/hbase-native-client/utils/user-util.cc +++ b/hbase-native-client/utils/user-util.cc @@ -23,18 +23,20 @@ #include #include #include +#include using namespace hbase; using namespace std; -UserUtil::UserUtil() : once_flag_{}, user_name_{"drwho"} {} +UserUtil::UserUtil() : + once_flag_{} {} -string UserUtil::user_name() { - std::call_once(once_flag_, [this]() { compute_user_name(); }); +string UserUtil::user_name(bool secure) { + std::call_once(once_flag_, [this, secure]() { compute_user_name(secure); }); return user_name_; } -void UserUtil::compute_user_name() { +void UserUtil::compute_user_name(bool secure) { // According to the man page of getpwuid // this should never be free'd // @@ -45,4 +47,34 @@ void UserUtil::compute_user_name() { if (passwd && passwd->pw_name) { user_name_ = string{passwd->pw_name}; } + if (!secure) return; + krb5_context ctx; + krb5_error_code ret = krb5_init_context(&ctx); + if (ret != 0) { + LOG(INFO) << "cannot init krb ctx " << ret; + return; + } + krb5_ccache ccache; + ret = krb5_cc_default(ctx, &ccache); + if (ret != 0) { + LOG(INFO) << "cannot get default cache " << ret; + return; + } + // Here is sample principal: hbase/23a03935850c@EXAMPLE.COM + // There may be one (user) or two (user/host) components before the @ sign + krb5_principal princ; + ret = krb5_cc_get_principal(ctx, ccache, &princ); + if (ret != 0) { + LOG(INFO) << "cannot get default cache " << ret; + return; + } + user_name_ = princ->data->data; + if (krb5_princ_size(ctx, princ) >= 2) { + user_name_ += "/"; + krb5_data *d = princ->data+1; + user_name_ += (char*)d->data; + } + user_name_ += "@"; + user_name_ += princ->realm.data; + krb5_free_principal(ctx, princ); } diff --git a/hbase-native-client/utils/user-util.h b/hbase-native-client/utils/user-util.h index 6f8fce1..6258c85 100644 --- a/hbase-native-client/utils/user-util.h +++ b/hbase-native-client/utils/user-util.h @@ -41,13 +41,13 @@ class UserUtil { * Get the username of the user owning this process. This is thread safe and * lockless for every invocation other than the first one. */ - std::string user_name(); + std::string user_name(bool secure = false); private: /** * Compute the username. This will block. */ - void compute_user_name(); + void compute_user_name(bool secure); std::once_flag once_flag_; std::string user_name_; }; diff --git a/hbase-native-client/connection/sasl-handler.h b/hbase-native-client/connection/sasl-handler.h new file mode 100644 index 0000000..0044677 --- /dev/null +++ b/hbase-native-client/connection/sasl-handler.h @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include "connection/service.h" +#include "security/user.h" +using hbase::security::User; + +using std::chrono::nanoseconds; + +namespace hbase { + +/** + * Class to perform SASL handshake with server + * It is inserted between EventBaseHandler and LengthFieldBasedFrameDecoder in the pipeline + * SaslHandler would intercept writes to server by buffering the IOBuf's and start the handshake process + * (via sasl_client_XX calls provided by Cyrus) + * After handshake is complete, SaslHandler would send the buffered IOBuf's to server and + * act as pass-thru from then on + */ +class SaslHandler + : public wangle::HandlerAdapter>{ + public: + SaslHandler(// std::string host_name, + std::string user_name, std::string service_name); + ~SaslHandler(); + + enum AuthState { + kUninit, // uninitialized + kRead, // waiting to read response from server + kSucc, // handshake is successful + kFailure // handshake fails + }; + + // from HandlerAdapter + void read(Context* ctx, folly::IOBufQueue& buf) override; + folly::Future write(Context* ctx, std::unique_ptr buf) override; + // parse the service name from given Configuration + static std::string parseServiceName(std::shared_ptr conf, std::string key); + + private: + // used by Cyrus + sasl_conn_t *sconn_ = NULL; + const int buffer_size = 8 * 1024; + char *pBuf; + int currRC_ = 0; + std::string user_name_; + std::string service_name_; + std::string host_name_; + AuthState auth_state_ = kUninit; + // vector of folly::IOBuf which buffers client writes before handshake is complete + std::vector> iobuf_; + + // writes the output returned by sasl_client_XX to server + void writeSaslOutput(Context* ctx, const char *out, unsigned int outlen); + AuthState saslInit(Context* ctx); + AuthState finishAuth(Context* ctx, folly::IOBufQueue& bufQueue); + AuthState continueSaslNegotiation(Context* ctx, folly::IOBufQueue& buf); + AuthState getState() { + return auth_state_; + } +}; +} // namespace hbase diff --git a/hbase-native-client/connection/sasl-handler.cc b/hbase-native-client/connection/sasl-handler.cc new file mode 100644 index 0000000..167d963 --- /dev/null +++ b/hbase-native-client/connection/sasl-handler.cc @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#pragma once +#include "connection/sasl-handler.h" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "connection/service.h" +#include "security/user.h" +using hbase::security::User; + +using std::chrono::nanoseconds; +using namespace folly; +using namespace wangle; + +namespace hbase { + +// the following is needed to avoid compilation error due to invalid casting + int GetPluginPath(void *context, + const char **path) { + const char *searchpath = (const char *) context; + + if (!path) + return SASL_BADPARAM; + + if (searchpath) { + *path = searchpath; + } + + return SASL_OK; + } + int UserCallback(void *context, + int id, + const char **result, + unsigned *len) { + const char *value = (const char *)context; + + if (!result) + return SASL_BADPARAM; + + switch (id) { + case SASL_CB_USER: { + *result = value; + if (len) *len = value ? (unsigned) strlen(value) : 0; + } + break; + default: + return SASL_BADPARAM; + } + return SASL_OK; + } + int SaslLogFn(void *context __attribute__((unused)), + int priority, + const char *message) { + const char *label; + + if (! message) + return SASL_BADPARAM; + + switch (priority) { + case SASL_LOG_ERR: + LOG(ERROR) << ": SASL " << message; + break; + case SASL_LOG_NOTE: + LOG(INFO) << ": SASL " << message; + break; + default: + LOG(INFO) << ": SASL " << message; + break; + } + + return SASL_OK; + } + +SaslHandler::SaslHandler(// std::string host_name, + std::string user_name, std::string service_name) + : user_name_(user_name), service_name_(service_name) + { + } + +SaslHandler::~SaslHandler() { + if (nullptr != sconn_) { + sasl_dispose(&sconn_); + } + sconn_ = nullptr; + } + + std::string SaslHandler::parseServiceName(std::shared_ptr conf, std::string key) { + std::string svrPrincipal = conf->Get(key.c_str(), ""); + // principal is of this form: hbase/23a03935850c@EXAMPLE.COM + // where 23a03935850c is the host (optional) + std::size_t pos = svrPrincipal.find("/"); + if (pos == std::string::npos && svrPrincipal.find("@") != std::string::npos) { + pos = svrPrincipal.find("@"); + } + if (pos == std::string::npos) { + throw "Couldn't retrieve service principal from conf"; + } + std::string service_name = svrPrincipal.substr(0, pos); + return service_name; + } + + void SaslHandler::read(Context* ctx, folly::IOBufQueue& buf) { + if (auth_state_ == kSucc) ctx->fireRead(buf); + else { + continueSaslNegotiation(ctx, buf); + } + } + + folly::Future SaslHandler::write(Context* ctx, std::unique_ptr buf) { + if (host_name_.empty()) { + // assign hostname if it is not assigned + folly::SocketAddress address; + ctx->getPipeline()->getTransport()->getPeerAddress(&address); + host_name_ = address.getHostStr(); + } + if (auth_state_ != kSucc && auth_state_ != kFailure) { + if (auth_state_ == kUninit) { + // perform sasl initialization + saslInit(ctx); + } + // store IOBuf which is to be sent to server after SASL handshake + iobuf_.push_back(std::move(buf)); + std::unique_ptr empty = std::make_unique(); + return ctx->fireWrite(std::move(empty)); + } + return ctx->fireWrite(std::move(buf)); + } + + void SaslHandler::writeSaslOutput(Context* ctx, const char *out, unsigned int outlen) { + int bufferSize = outlen + 4; + auto iob = IOBuf::create(bufferSize); + iob->append(bufferSize); + // Create the array output stream. + google::protobuf::io::ArrayOutputStream aos{iob->writableData(), static_cast(iob->length())}; + google::protobuf::io::CodedOutputStream *coded_output = + new google::protobuf::io::CodedOutputStream(&aos); + unsigned int uiTotalSize = outlen; + uiTotalSize = ntohl(uiTotalSize); + coded_output->WriteRaw(&uiTotalSize, 4); + coded_output->WriteRaw(out, outlen); + ctx->fireWrite(std::move(iob)); + } + SaslHandler::AuthState SaslHandler::finishAuth(Context* ctx, folly::IOBufQueue& bufQueue) { + if (currRC_ == SASL_OK) { + std::unique_ptr iob; + if (!bufQueue.empty()) { + iob = bufQueue.pop_front(); + LOG(ERROR) << "Error in the final step of handshake " << iob->length(); + auth_state_ = kFailure; + return auth_state_; + } else { + auth_state_ = kSucc; + VLOG(1) << "auth succeeded " << iobuf_.size(); + // write what we buffered + for (int i = 0; i < iobuf_.size(); i++) { + iob = std::move(iobuf_.at(i)); + ctx->fireWrite(std::move(iob)); + } + } + } + return auth_state_; + } + /* mutex functions */ + int mutex_cnt_ = 0; + pthread_mutex_t mutex_; + + typedef struct my_mutex_s { + int num; + } mutex_t; + + void *mutex_new(void) + { + mutex_t *ret = (mutex_t *)malloc(sizeof(mutex_t)); + ret->num = mutex_cnt_; + mutex_cnt_++; + + return ret; + } + + int mutex_lock(mutex_t *m) + { + pthread_mutex_lock(&mutex_); + return SASL_OK; + } + int mutex_unlock(mutex_t *m) + { + pthread_mutex_unlock(&mutex_); + return SASL_OK; + } + + void mutex_dispose(mutex_t *m) + { + if (m==NULL) return; + free(m); + } + SaslHandler::AuthState SaslHandler::saslInit(Context* ctx) { + int rc; + const char *mechusing, *mechlist = "GSSAPI"; + const char *out; + unsigned int outlen; + + sasl_set_mutex((sasl_mutex_alloc_t *) &mutex_new, + (sasl_mutex_lock_t *) &mutex_lock, + (sasl_mutex_unlock_t *) &mutex_unlock, + (sasl_mutex_free_t *) &mutex_dispose); + /* Create new connection session. */ + char *searchpath = ::getenv("CYRUS_SASL_PLUGINS_DIR"); + + if (NULL == searchpath || 0 == strcmp(searchpath, "")) { + searchpath = "/usr/lib/sasl2"; + } + + std::string sasl_plugin_dir_path = searchpath; + VLOG(1) << "service is " << service_name_ << " " << searchpath; + + sasl_callback_t callbacks[4]; + sasl_callback_t *callback; + + /* Fill in the callbacks that we're providing... */ + callback = callbacks; + + /* log */ + callback->id = SASL_CB_LOG; + callback->proc = (sasl_callback_ft) &SaslLogFn; + callback->context = NULL; + ++callback; + + callback->id = SASL_CB_GETPATH; + callback->proc = (sasl_callback_ft) &GetPluginPath; + callback->context = (void *)(sasl_plugin_dir_path.c_str()); + ++callback; + + /* user */ + callback->id = SASL_CB_USER; + callback->proc = (sasl_callback_ft) &UserCallback; + callback->context = const_cast(user_name_.c_str()); + ++callback; + + /* termination */ + callback->id = SASL_CB_LIST_END; + callback->proc = NULL; + callback->context = NULL; + ++callback; + rc = sasl_client_init(callbacks); + if (rc != SASL_OK) { + LOG(FATAL) << "Cannot initialize client ("<< rc << ") "; + throw std::runtime_error("Cannot initialize client"); + } + VLOG(1) << "init " << rc << " " << service_name_ << " " << host_name_; + rc = sasl_client_new(service_name_.c_str(), /* The service we are using*/ + host_name_.c_str(), + NULL, NULL, /* Local and remote IP address strings + (NULL disables mechanisms which require this info)*/ + NULL, /*connection-specific callbacks*/ + 0 /*security flags*/, &sconn_); + if (rc != SASL_OK) { + LOG(FATAL) << "Cannot create client ("<< rc << ") "; + throw std::runtime_error("Cannot create client"); + } + sasl_security_properties_t *props = new sasl_security_properties_t(); + ::memset(props, 0, sizeof(sasl_security_properties_t)); + props->max_ssf = 2; + sasl_setprop(sconn_, SASL_SEC_PROPS, (void *)props); + + auto buffer = std::vector(buffer_size+1, 0); + std::fill(buffer.begin(), buffer.end(), 0); + pBuf = (char *) &buffer[0]; + + VLOG(1) << "sasl_client_new called " << " " << host_name_ << " " << user_name_; + + do { + currRC_ = sasl_client_start(sconn_, /* the same context from above */ + mechlist, /* the list of mechanisms from the server */ + NULL, /* filled in if an interaction is needed */ + &out, /* filled in on success */ + &outlen, /* filled in on success */ + &mechusing); + } while (currRC_ == SASL_INTERACT); /* the mechanism may ask us to fill + in things many times. result is SASL_CONTINUE on success */ + if (currRC_ != SASL_CONTINUE) { + LOG(FATAL)<< "Cannot start client ("<< currRC_ << ") "; + auth_state_ = kFailure; + return auth_state_; + } + VLOG(1) << "client started " << currRC_ << " " << outlen; + writeSaslOutput(ctx, out, outlen); + auth_state_ = kRead; + return auth_state_; + } + + SaslHandler::AuthState SaslHandler::continueSaslNegotiation(Context* ctx, folly::IOBufQueue& bufQueue) { + const char *out; + unsigned int outlen; + + int bytes_sent = 0; + int bytes_received = 0; + + std::unique_ptr iob = nullptr; + if (currRC_ == SASL_OK) { + return finishAuth(ctx, bufQueue); + } else { + iob = bufQueue.pop_front(); + bytes_received = iob->length(); + pBuf = (char*)iob->data(); + if (bytes_received == 0) { + auth_state_ = kFailure; + return auth_state_; + } + } + unsigned int *uiReader = (unsigned int *)pBuf; + unsigned int status = *uiReader; + + char *p = (char *)pBuf + 4; + + uiReader = (unsigned int *)p; + unsigned int uiTotalSize = *uiReader; + uiTotalSize = ntohl(uiTotalSize); + // p now points to the buffer to be sent to server + p = p + 4; + + if (auth_state_ == kFailure || status != 0 /*Status 0 is success*/) { + //Assumption here is that the response from server is not more than 8 * 1024 + p[uiTotalSize] = '\0'; + LOG(ERROR) << "Exception from server: " << p; + auth_state_ = kFailure; + return auth_state_; + } + + out = NULL; + outlen = 0; + + currRC_ = sasl_client_step(sconn_, /* our context */ + (char*) p, /* the data from the server */ + uiTotalSize, /* its length */ + NULL, /* this should be unallocated and NULL */ + &out, /* filled in on success */ + &outlen); /* filled in on success */ + pBuf = NULL; + if (currRC_ == SASL_OK || currRC_ == SASL_CONTINUE) { + writeSaslOutput(ctx, out, outlen); + } + if (currRC_ == SASL_OK) { + return finishAuth(ctx, bufQueue); + } + return auth_state_; + } +} // namespace hbase