From 4342719db6ab8c78693101fa9a2b13954212bd7e Mon Sep 17 00:00:00 2001 From: Xiaobing Zhou Date: Wed, 26 Jul 2017 17:04:17 -0700 Subject: [PATCH] HBASE-18078. [C++] Harden RPC by handling various communication abnormalities --- .../connection/connection-factory.cc | 38 ++++-- .../connection/connection-factory.h | 13 +- hbase-native-client/connection/connection-pool.cc | 70 ++++++++-- hbase-native-client/connection/connection-pool.h | 9 ++ hbase-native-client/connection/rpc-client.cc | 16 ++- hbase-native-client/connection/rpc-client.h | 3 +- .../connection/rpc-test-server-handler.cc | 2 +- hbase-native-client/connection/rpc-test-server.cc | 36 ++++- hbase-native-client/connection/rpc-test-server.h | 14 +- hbase-native-client/connection/rpc-test.cc | 146 ++++++++++++++++----- hbase-native-client/core/configuration.cc | 2 +- hbase-native-client/exceptions/exception.h | 14 +- 12 files changed, 297 insertions(+), 66 deletions(-) diff --git a/hbase-native-client/connection/connection-factory.cc b/hbase-native-client/connection/connection-factory.cc index a0c7f96118..d91dd26791 100644 --- a/hbase-native-client/connection/connection-factory.cc +++ b/hbase-native-client/connection/connection-factory.cc @@ -22,11 +22,15 @@ #include +#include +#include +#include #include "connection/client-dispatcher.h" #include "connection/connection-factory.h" #include "connection/pipeline.h" #include "connection/sasl-handler.h" #include "connection/service.h" +#include "exceptions/exception.h" using std::chrono::milliseconds; using std::chrono::nanoseconds; @@ -56,15 +60,29 @@ std::shared_ptr> ConnectionFactory::M std::shared_ptr ConnectionFactory::Connect( std::shared_ptr> client, const std::string &hostname, uint16_t port) { - // Yes this will block however it makes dealing with connection pool soooooo - // much nicer. - // TODO see about using shared promise for this. - auto pipeline = client - ->connect(folly::SocketAddress(hostname, port, true), - std::chrono::duration_cast(connect_timeout_)) - .get(); - auto dispatcher = std::make_shared(); - dispatcher->setPipeline(pipeline); - return dispatcher; + return AsyncConnect(client, hostname, port).get(); +} + +folly::Future> ConnectionFactory::AsyncConnect( + std::shared_ptr> client, const std::string &hostname, + uint16_t port) { + folly::Promise> promise; + auto future = promise.getFuture(); + + try { + /* any connection error (e.g. timeout) will be folly::AsyncSocketException */ + auto pipeline = client + ->connect(folly::SocketAddress(hostname, port, true), + std::chrono::duration_cast(connect_timeout_)) + .get(); + auto dispatcher = std::make_shared(); + dispatcher->setPipeline(pipeline); + promise.setValue(dispatcher); + } catch (const folly::AsyncSocketException &e) { + promise.setException(folly::make_exception_wrapper( + folly::make_exception_wrapper(e))); + } + + return future; } } // namespace hbase diff --git a/hbase-native-client/connection/connection-factory.h b/hbase-native-client/connection/connection-factory.h index c96087d1dc..f94bb16d29 100644 --- a/hbase-native-client/connection/connection-factory.h +++ b/hbase-native-client/connection/connection-factory.h @@ -18,6 +18,7 @@ */ #pragma once +#include #include #include @@ -55,7 +56,7 @@ class ConnectionFactory { virtual std::shared_ptr> MakeBootstrap(); /** - * Connect a ClientBootstrap to a server and return the pipeline. + * Connect a ClientBootstrap to a server and return the wangle::Service. * * This is mostly visible so that mocks can override socket connections. */ @@ -63,6 +64,16 @@ class ConnectionFactory { std::shared_ptr> client, const std::string &hostname, uint16_t port); + /** + * Asynchronously Connect a ClientBootstrap to a server and return the wangle::Service. + * + * This async function makes it easy to propagate exceptions in a controlled way with + * help of folly::Future/Promise. + */ + virtual folly::Future> AsyncConnect( + std::shared_ptr> client, + const std::string &hostname, uint16_t port); + private: std::chrono::nanoseconds connect_timeout_; std::shared_ptr conf_; diff --git a/hbase-native-client/connection/connection-pool.cc b/hbase-native-client/connection/connection-pool.cc index e98759d2fc..cd2efc3ed8 100644 --- a/hbase-native-client/connection/connection-pool.cc +++ b/hbase-native-client/connection/connection-pool.cc @@ -20,13 +20,17 @@ #include "connection/connection-pool.h" #include +#include #include #include #include #include +#include "exceptions/exception.h" using std::chrono::nanoseconds; +using namespace folly; +using namespace hbase; namespace hbase { @@ -45,24 +49,29 @@ ConnectionPool::~ConnectionPool() { Close(); } std::shared_ptr ConnectionPool::GetConnection( std::shared_ptr remote_id) { - // Try and get th cached connection. + /** + * Try and get the cached connection, if there's no connection then create it. + */ auto found_ptr = GetCachedConnection(remote_id); + return found_ptr == nullptr ? GetNewConnection(remote_id) : found_ptr; +} - // If there's no connection then create it. - if (found_ptr == nullptr) { - found_ptr = GetNewConnection(remote_id); - } - return found_ptr; +folly::Future> ConnectionPool::AsyncGetConnection( + std::shared_ptr remote_id) { + /** + * Try and get the cached connection, if there's no connection then create it. + */ + auto found_ptr = GetCachedConnection(remote_id); + return found_ptr == nullptr + ? AsyncGetNewConnection(remote_id) + : folly::makeFuture>(std::move(found_ptr)); } std::shared_ptr ConnectionPool::GetCachedConnection( std::shared_ptr remote_id) { folly::SharedMutexWritePriority::ReadHolder holder(map_mutex_); auto found = connections_.find(remote_id); - if (found == connections_.end()) { - return nullptr; - } - return found->second; + return found == connections_.end() ? nullptr : found->second; } std::shared_ptr ConnectionPool::GetNewConnection( @@ -91,11 +100,50 @@ std::shared_ptr ConnectionPool::GetNewConnection( connections_.insert(std::make_pair(remote_id, connection)); clients_.insert(std::make_pair(remote_id, clientBootstrap)); - return connection; } } +folly::Future> ConnectionPool::AsyncGetNewConnection( + std::shared_ptr remote_id) { + // Grab the upgrade lock. While we are double checking other readers can + // continue on + SharedMutexWritePriority::UpgradeHolder u_holder{map_mutex_}; + + folly::Promise> promise; + auto future = promise.getFuture(); + + // Now check if someone else created the connection before we got the lock + // This is safe since we hold the upgrade lock. + // upgrade lock is more power than the reader lock. + auto found = connections_.find(remote_id); + if (found != connections_.end() && found->second != nullptr) { + promise.setValue(found->second); + } else { + // Yeah it looks a lot like there's no connection + SharedMutexWritePriority::WriteHolder w_holder{std::move(u_holder)}; + + // Make double sure there are not stale connections hanging around. + connections_.erase(remote_id); + + /* create new connection */ + auto clientBootstrap = cf_->MakeBootstrap(); + try { + auto dispatcher = cf_->Connect(clientBootstrap, remote_id->host(), remote_id->port()); + auto connection = std::make_shared(remote_id, dispatcher); + promise.setValue(connection); + + connections_.insert(std::make_pair(remote_id, connection)); + clients_.insert(std::make_pair(remote_id, clientBootstrap)); + } catch (const hbase::ConnectionException &e) { + /* propagating ConnectionException up */ + promise.setException(folly::make_exception_wrapper(e)); + } + } + + return future; +} + void ConnectionPool::Close(std::shared_ptr remote_id) { folly::SharedMutexWritePriority::WriteHolder holder{map_mutex_}; DLOG(INFO) << "Closing RPC Connection to host:" << remote_id->host() diff --git a/hbase-native-client/connection/connection-pool.h b/hbase-native-client/connection/connection-pool.h index c7c4246e2a..0c8da358cf 100644 --- a/hbase-native-client/connection/connection-pool.h +++ b/hbase-native-client/connection/connection-pool.h @@ -19,6 +19,7 @@ #pragma once #include +#include #include #include #include @@ -66,6 +67,12 @@ class ConnectionPool { std::shared_ptr GetConnection(std::shared_ptr remote_id); /** + * Asynchronously get connection by ConnectionId. + */ + folly::Future> AsyncGetConnection( + std::shared_ptr remote_id); + + /** * Close/remove a connection. */ void Close(std::shared_ptr remote_id); @@ -78,6 +85,8 @@ class ConnectionPool { private: std::shared_ptr GetCachedConnection(std::shared_ptr remote_id); std::shared_ptr GetNewConnection(std::shared_ptr remote_id); + folly::Future> AsyncGetNewConnection( + std::shared_ptr remote_id); std::unordered_map, std::shared_ptr, ConnectionIdHash, ConnectionIdEquals> connections_; diff --git a/hbase-native-client/connection/rpc-client.cc b/hbase-native-client/connection/rpc-client.cc index 10faa7a84e..80b161c3b7 100644 --- a/hbase-native-client/connection/rpc-client.cc +++ b/hbase-native-client/connection/rpc-client.cc @@ -22,6 +22,7 @@ #include #include #include +#include "exceptions/exception.h" using hbase::security::User; using std::chrono::nanoseconds; @@ -55,7 +56,7 @@ folly::Future> RpcClient::AsyncCall(const std::string& std::unique_ptr req, std::shared_ptr ticket) { auto remote_id = std::make_shared(host, port, ticket); - return GetConnection(remote_id)->SendRequest(std::move(req)); + return CallForResult(remote_id, std::move(req)); } folly::Future> RpcClient::AsyncCall(const std::string& host, @@ -64,10 +65,17 @@ folly::Future> RpcClient::AsyncCall(const std::string& std::shared_ptr ticket, const std::string& service_name) { auto remote_id = std::make_shared(host, port, ticket, service_name); - return GetConnection(remote_id)->SendRequest(std::move(req)); + return CallForResult(remote_id, std::move(req)); } -std::shared_ptr RpcClient::GetConnection(std::shared_ptr remote_id) { - return cp_->GetConnection(remote_id); +folly::Future> RpcClient::CallForResult( + std::shared_ptr remote_id, std::unique_ptr req) { + try { + auto connection = cp_->AsyncGetConnection(remote_id).get(); + return connection->SendRequest(std::move(req)); + } catch (const hbase::ConnectionException& e) { + return folly::makeFuture>( + folly::make_exception_wrapper(e)); + } } } // namespace hbase diff --git a/hbase-native-client/connection/rpc-client.h b/hbase-native-client/connection/rpc-client.h index 0ecde5b775..37cfc01481 100644 --- a/hbase-native-client/connection/rpc-client.h +++ b/hbase-native-client/connection/rpc-client.h @@ -64,7 +64,8 @@ class RpcClient { std::shared_ptr connection_pool() const { return cp_; } private: - std::shared_ptr GetConnection(std::shared_ptr remote_id); + folly::Future> CallForResult(std::shared_ptr remote_id, + std::unique_ptr req); private: std::shared_ptr cp_; diff --git a/hbase-native-client/connection/rpc-test-server-handler.cc b/hbase-native-client/connection/rpc-test-server-handler.cc index 7d2f407d55..4fc5562fc2 100644 --- a/hbase-native-client/connection/rpc-test-server-handler.cc +++ b/hbase-native-client/connection/rpc-test-server-handler.cc @@ -55,7 +55,7 @@ folly::Future RpcTestServerSerializeHandler::write(Context* ctx, std::unique_ptr RpcTestServerSerializeHandler::CreateReceivedRequest( const std::string& method_name) { std::unique_ptr result = nullptr; - ; + if (method_name == "ping") { result = std::make_unique(std::make_shared(), std::make_shared(), method_name); diff --git a/hbase-native-client/connection/rpc-test-server.cc b/hbase-native-client/connection/rpc-test-server.cc index d3a30b104c..57461f8e0a 100644 --- a/hbase-native-client/connection/rpc-test-server.cc +++ b/hbase-native-client/connection/rpc-test-server.cc @@ -22,6 +22,8 @@ #include #include +#include +#include #include "connection/rpc-test-server-handler.h" #include "connection/rpc-test-server.h" #include "if/test.pb.h" @@ -30,19 +32,35 @@ namespace hbase { RpcTestServerSerializePipeline::Ptr RpcTestServerPipelineFactory::newPipeline( std::shared_ptr sock) { + if (service_ == nullptr) { + initService(sock); + } + CHECK(service_ != nullptr); + auto pipeline = RpcTestServerSerializePipeline::create(); pipeline->addBack(AsyncSocketHandler(sock)); // ensure we can write from any thread pipeline->addBack(EventBaseHandler()); pipeline->addBack(LengthFieldBasedFrameDecoder()); pipeline->addBack(RpcTestServerSerializeHandler()); - pipeline->addBack( - MultiplexServerDispatcher, std::unique_ptr>(&service_)); + pipeline->addBack(MultiplexServerDispatcher, std::unique_ptr>( + service_.get())); pipeline->finalize(); return pipeline; } +void RpcTestServerPipelineFactory::initService(std::shared_ptr sock) { + /* get server address */ + SocketAddress localAddress; + sock->getLocalAddress(&localAddress); + + /* init service with server address */ + service_ = std::make_shared, std::unique_ptr>>( + std::make_shared(1), + std::make_shared(std::make_shared(localAddress))); +} + Future> RpcTestService::operator()(std::unique_ptr request) { /* build Response */ auto response = std::make_unique(); @@ -52,17 +70,29 @@ Future> RpcTestService::operator()(std::unique_ptr(); response->set_resp_msg(pb_resp_msg); + VLOG(1) << "RPC server:" + << " ping called."; + } else if (method_name == "echo") { auto pb_resp_msg = std::make_shared(); + /* get msg from client */ auto pb_req_msg = std::static_pointer_cast(request->req_msg()); pb_resp_msg->set_message(pb_req_msg->message()); response->set_resp_msg(pb_resp_msg); + VLOG(1) << "RPC server:" + << " echo called, " << pb_req_msg->message(); + } else if (method_name == "error") { // TODO: + } else if (method_name == "pause") { // TODO: } else if (method_name == "addr") { - // TODO: + auto pb_resp_msg = std::make_shared(); + pb_resp_msg->set_addr(socket_address_->describe()); + response->set_resp_msg(pb_resp_msg); + VLOG(1) << "RPC server:" + << " addr called, " << socket_address_->describe(); } return folly::makeFuture>(std::move(response)); diff --git a/hbase-native-client/connection/rpc-test-server.h b/hbase-native-client/connection/rpc-test-server.h index c3225ff573..910e96e727 100644 --- a/hbase-native-client/connection/rpc-test-server.h +++ b/hbase-native-client/connection/rpc-test-server.h @@ -17,6 +17,7 @@ * */ #pragma once +#include #include #include #include @@ -33,9 +34,13 @@ using RpcTestServerSerializePipeline = wangle::Pipeline, std::unique_ptr> { public: - RpcTestService() {} + RpcTestService(std::shared_ptr socket_address) + : socket_address_(socket_address) {} virtual ~RpcTestService() = default; Future> operator()(std::unique_ptr request) override; + + private: + std::shared_ptr socket_address_; }; class RpcTestServerPipelineFactory : public PipelineFactory { @@ -44,7 +49,10 @@ class RpcTestServerPipelineFactory : public PipelineFactory sock) override; private: - ExecutorFilter, std::unique_ptr> service_{ - std::make_shared(1), std::make_shared()}; + void initService(std::shared_ptr sock); + + private: + std::shared_ptr, std::unique_ptr>> service_{ + nullptr}; }; } // end of namespace hbase diff --git a/hbase-native-client/connection/rpc-test.cc b/hbase-native-client/connection/rpc-test.cc index d4cd89f0a4..0fc58692e6 100644 --- a/hbase-native-client/connection/rpc-test.cc +++ b/hbase-native-client/connection/rpc-test.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "connection/rpc-client.h" #include "if/test.pb.h" @@ -38,48 +39,133 @@ using namespace wangle; using namespace folly; using namespace hbase; +using namespace std::chrono; DEFINE_int32(port, 0, "test server port"); +typedef ServerBootstrap ServerTestBootstrap; +typedef std::shared_ptr ServerPtr; -TEST(RpcTestServer, echo) { - /* create conf */ +std::shared_ptr CreateConf() { auto conf = std::make_shared(); conf->Set(RpcSerde::HBASE_CLIENT_RPC_TEST_MODE, "true"); + return conf; +} +ServerPtr CreateRpcServer() { /* create rpc test server */ - auto server = std::make_shared>(); + auto server = std::make_shared(); server->childPipeline(std::make_shared()); server->bind(FLAGS_port); - folly::SocketAddress server_addr; - server->getSockets()[0]->getAddress(&server_addr); + return server; +} + +std::shared_ptr GetRpcServerAddress(ServerPtr server) { + auto addr = std::make_shared(); + server->getSockets()[0]->getAddress(addr.get()); + return addr; +} + +std::shared_ptr CreateRpcClient(std::shared_ptr conf) { + auto io_executor = std::make_shared(1); + auto client = std::make_shared(io_executor, nullptr, conf); + return client; +} - /* create RpcClient */ +std::shared_ptr CreateRpcClient(std::shared_ptr conf, + std::chrono::nanoseconds connect_timeout) { auto io_executor = std::make_shared(1); + auto client = std::make_shared(io_executor, nullptr, conf, connect_timeout); + return client; +} + +/** + * test ping + */ +TEST(RpcTestServer, Ping) { + auto conf = CreateConf(); + auto server = CreateRpcServer(); + auto server_addr = GetRpcServerAddress(server); + auto client = CreateRpcClient(conf); + + auto request = std::make_unique(std::make_shared(), + std::make_shared(), "ping"); + + /* sending out request */ + client + ->AsyncCall(server_addr->getAddressStr(), server_addr->getPort(), std::move(request), + hbase::security::User::defaultUser()) + .then([=](std::unique_ptr response) { + auto pb_resp = std::static_pointer_cast(response->resp_msg()); + EXPECT_TRUE(pb_resp != nullptr); + VLOG(1) << "RPC ping returned."; + }) + .onError([](const folly::exception_wrapper& e) { + FAIL() << "Shouldn't get here, no exception is expected for RPC ping."; + }); + + server->stop(); + server->join(); +} + +/** + * test echo + */ +TEST(RpcTestServer, Echo) { + auto conf = CreateConf(); + auto server = CreateRpcServer(); + auto server_addr = GetRpcServerAddress(server); + auto client = CreateRpcClient(conf); + + std::string greetings = "hello, hbase server!"; + auto request = std::make_unique(std::make_shared(), + std::make_shared(), "echo"); + auto pb_msg = std::static_pointer_cast(request->req_msg()); + pb_msg->set_message(greetings); + + /* sending out request */ + client + ->AsyncCall(server_addr->getAddressStr(), server_addr->getPort(), std::move(request), + hbase::security::User::defaultUser()) + .then([=](std::unique_ptr response) { + auto pb_resp = std::static_pointer_cast(response->resp_msg()); + EXPECT_TRUE(pb_resp != nullptr); + VLOG(1) << "RPC echo returned: " + pb_resp->message(); + EXPECT_EQ(greetings, pb_resp->message()); + }) + .onError([](const folly::exception_wrapper& e) { + FAIL() << "Shouldn't get here, no exception is expected for RPC echo."; + }); + + server->stop(); + server->join(); +} + +/** + * test addr + */ +TEST(RpcTestServer, Addr) { + auto conf = CreateConf(); + auto server = CreateRpcServer(); + auto server_addr = GetRpcServerAddress(server); + auto client = CreateRpcClient(conf); - auto rpc_client = std::make_shared(io_executor, nullptr, conf); - - /** - * test echo - */ - try { - std::string greetings = "hello, hbase server!"; - auto request = std::make_unique(std::make_shared(), - std::make_shared(), "echo"); - auto pb_msg = std::static_pointer_cast(request->req_msg()); - pb_msg->set_message(greetings); - - /* sending out request */ - rpc_client - ->AsyncCall(server_addr.getAddressStr(), server_addr.getPort(), std::move(request), - hbase::security::User::defaultUser()) - .then([=](std::unique_ptr response) { - auto pb_resp = std::static_pointer_cast(response->resp_msg()); - VLOG(1) << "message returned: " + pb_resp->message(); - EXPECT_EQ(greetings, pb_resp->message()); - }); - } catch (const std::exception& e) { - throw e; - } + auto request = std::make_unique(std::make_shared(), + std::make_shared(), "addr"); + /* sending out request */ + client + ->AsyncCall(server_addr->getAddressStr(), server_addr->getPort(), std::move(request), + hbase::security::User::defaultUser()) + .then([=](std::unique_ptr response) { + auto pb_resp = std::static_pointer_cast(response->resp_msg()); + EXPECT_TRUE(pb_resp != nullptr); + VLOG(1) << "RPC addr returned: " + pb_resp->addr(); + folly::SocketAddress addr_returned{}; + addr_returned.setFromIpPort(pb_resp->addr()); + EXPECT_EQ(server_addr->getPort(), addr_returned.getPort()); + }) + .onError([](const folly::exception_wrapper& e) { + FAIL() << "Shouldn't get here, no exception is expected for RPC addr."; + }); server->stop(); server->join(); diff --git a/hbase-native-client/core/configuration.cc b/hbase-native-client/core/configuration.cc index f4fc46d3ac..1fd2851559 100644 --- a/hbase-native-client/core/configuration.cc +++ b/hbase-native-client/core/configuration.cc @@ -24,8 +24,8 @@ #include #include -#include #include +#include namespace hbase { diff --git a/hbase-native-client/exceptions/exception.h b/hbase-native-client/exceptions/exception.h index bdedff4068..2e79055625 100644 --- a/hbase-native-client/exceptions/exception.h +++ b/hbase-native-client/exceptions/exception.h @@ -59,7 +59,7 @@ class IOException : public std::logic_error { IOException(const std::string& what, bool do_not_retry) : logic_error(what), do_not_retry_(do_not_retry) {} - IOException(const std::string& what, folly::exception_wrapper cause) + IOException(const std::string& what, const folly::exception_wrapper& cause) : logic_error(what), cause_(cause), do_not_retry_(false) {} IOException(const std::string& what, folly::exception_wrapper cause, bool do_not_retry) @@ -115,6 +115,18 @@ class RetriesExhaustedException : public IOException { int32_t num_retries_; }; +class ConnectionException : public IOException { + public: + ConnectionException() {} + + ConnectionException(const std::string& what) : IOException(what) {} + + ConnectionException(const folly::exception_wrapper& cause) : IOException("", cause) {} + + ConnectionException(const std::string& what, const folly::exception_wrapper& cause) + : IOException(what, cause) {} +}; + class RemoteException : public IOException { public: RemoteException() : IOException(), port_(0) {} -- 2.11.0 (Apple Git-81)