diff --git hbase-native-client/connection/client-dispatcher.cc hbase-native-client/connection/client-dispatcher.cc index 35a1f7d..ff850b7 100644 --- hbase-native-client/connection/client-dispatcher.cc +++ hbase-native-client/connection/client-dispatcher.cc @@ -25,16 +25,11 @@ using std::unique_ptr; namespace hbase { -ClientDispatcher::ClientDispatcher() : requests_(5000), current_call_id_(9) {} +ClientDispatcher::ClientDispatcher() : current_call_id_(9), requests_(5000) {} void ClientDispatcher::read(Context *ctx, unique_ptr in) { auto call_id = in->call_id(); - - auto search = requests_.find(call_id); - CHECK(search != requests_.end()); - auto p = std::move(search->second); - - requests_.erase(call_id); + auto p = requests_.find_and_erase(call_id); if (in->exception()) { p.setException(in->exception()); @@ -46,8 +41,9 @@ void ClientDispatcher::read(Context *ctx, unique_ptr in) { folly::Future> ClientDispatcher::operator()(unique_ptr arg) { auto call_id = current_call_id_++; arg->set_call_id(call_id); - requests_.insert(call_id, folly::Promise>{}); - auto &p = requests_.find(call_id)->second; + + auto &p = requests_[call_id]; + auto f = p.getFuture(); p.setInterruptHandler([call_id, this](const folly::exception_wrapper &e) { LOG(ERROR) << "e = " << call_id; diff --git hbase-native-client/connection/client-dispatcher.h hbase-native-client/connection/client-dispatcher.h index 857042c..1f8e6b3 100644 --- hbase-native-client/connection/client-dispatcher.h +++ hbase-native-client/connection/client-dispatcher.h @@ -19,16 +19,18 @@ #pragma once -#include #include #include #include +#include #include +#include #include "connection/pipeline.h" #include "connection/request.h" #include "connection/response.h" +#include "utils/concurrent-map.h" namespace hbase { /** @@ -51,7 +53,7 @@ class ClientDispatcher folly::Future close() override; private: - folly::AtomicHashMap>> requests_; + concurrent_map>> requests_; // Start at some number way above what could // be there for un-initialized call id counters. // diff --git hbase-native-client/connection/client-handler.cc hbase-native-client/connection/client-handler.cc index 894ecb3..e07b7c3 100644 --- hbase-native-client/connection/client-handler.cc +++ hbase-native-client/connection/client-handler.cc @@ -40,8 +40,9 @@ ClientHandler::ClientHandler(std::string user_name, std::shared_ptr codec serde_(codec), server_(server), once_flag_(std::make_unique()), - resp_msgs_(std::make_unique>>(5000)) { -} + resp_msgs_( + std::make_unique>>( + 5000)) {} void ClientHandler::read(Context *ctx, std::unique_ptr buf) { if (LIKELY(buf != nullptr)) { @@ -53,15 +54,7 @@ void ClientHandler::read(Context *ctx, std::unique_ptr buf) { VLOG(3) << "Read RPC ResponseHeader size=" << used_bytes << " call_id=" << header.call_id() << " has_exception=" << header.has_exception(); - // Get the response protobuf from the map - auto search = resp_msgs_->find(header.call_id()); - // It's an error if it's not there. - CHECK(search != resp_msgs_->end()); - auto resp_msg = search->second; - CHECK(resp_msg != nullptr); - - // Make sure we don't leak the protobuf - resp_msgs_->erase(header.call_id()); + auto resp_msg = resp_msgs_->find_and_erase(header.call_id()); // set the call_id. // This will be used to by the dispatcher to match up @@ -132,7 +125,7 @@ folly::Future ClientHandler::write(Context *ctx, std::unique_ptrDebugString(); // Now store the call id to response. - resp_msgs_->insert(r->call_id(), r->resp_msg()); + resp_msgs_->insert(std::make_pair(r->call_id(), r->resp_msg())); // Send the data down the pipeline. return ctx->fireWrite(serde_.Request(r->call_id(), r->method(), r->req_msg().get())); diff --git hbase-native-client/connection/client-handler.h hbase-native-client/connection/client-handler.h index 4c106e0..8de3a8b 100644 --- hbase-native-client/connection/client-handler.h +++ hbase-native-client/connection/client-handler.h @@ -18,7 +18,6 @@ */ #pragma once -#include #include #include @@ -30,6 +29,7 @@ #include "exceptions/exception.h" #include "serde/codec.h" #include "serde/rpc.h" +#include "utils/concurrent-map.h" // Forward decs. namespace hbase { @@ -81,7 +81,6 @@ class ClientHandler std::string server_; // for logging // in flight requests - std::unique_ptr>> - resp_msgs_; + std::unique_ptr>> resp_msgs_; }; } // namespace hbase diff --git hbase-native-client/utils/BUCK hbase-native-client/utils/BUCK index 788056b..96f02b8 100644 --- hbase-native-client/utils/BUCK +++ hbase-native-client/utils/BUCK @@ -20,6 +20,7 @@ cxx_library( exported_headers=[ "bytes-util.h", "connection-util.h", + "concurrent-map.h", "optional.h", "sys-util.h", "time-util.h", @@ -38,18 +39,26 @@ cxx_library( ], compiler_flags=['-Weffc++'],) cxx_test( - name="user-util-test", + name="bytes-util-test", srcs=[ - "user-util-test.cc", + "bytes-util-test.cc", ], deps=[ ":utils", ],) cxx_test( - name="bytes-util-test", + name="concurrent-map-test", srcs=[ - "bytes-util-test.cc", + "concurrent-map-test.cc", ], deps=[ ":utils", ],) +cxx_test( + name="user-util-test", + srcs=[ + "user-util-test.cc", + ], + deps=[ + ":utils", + ],) \ No newline at end of file diff --git hbase-native-client/utils/concurrent-map-test.cc hbase-native-client/utils/concurrent-map-test.cc new file mode 100644 index 0000000..588bd08 --- /dev/null +++ hbase-native-client/utils/concurrent-map-test.cc @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include + +#include "utils/concurrent-map.h" + +using hbase::concurrent_map; + +TEST(TestConcurrentMap, TestFindAndErase) { + concurrent_map map{500}; + + map.insert(std::make_pair("foo", "bar")); + auto prev = map.find_and_erase("foo"); + ASSERT_EQ("bar", prev); + + ASSERT_EQ(map.end(), map.find("foo")); +} diff --git hbase-native-client/utils/concurrent-map.h hbase-native-client/utils/concurrent-map.h new file mode 100644 index 0000000..e2d6e39 --- /dev/null +++ hbase-native-client/utils/concurrent-map.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace hbase { + +/** + * A concurrent version of std::unordered_map where we acquire a shared or exclusive + * lock for operations. This is NOT a highly-concurrent and scalable implementation + * since there is only one lock object. + * Replace this with tbb::concurrent_unordered_map or similar. + * + * Concurrency here is different than in Java. For example, the iterators returned from + * find() will not copy the key, value pairs. + */ +template +class concurrent_map { + public: + typedef K key_type; + typedef V mapped_type; + typedef std::pair value_type; + typedef typename std::unordered_map::iterator iterator; + typedef typename std::unordered_map::const_iterator const_iterator; + + concurrent_map() : map_(), mutex_() {} + explicit concurrent_map(int32_t n) : map_(n), mutex_() {} + + void insert(const value_type& value) { + std::unique_lock lock(mutex_); + map_.insert(value); + } + + /** + * Return the mapped object for this key. Be careful to not use the return reference + * to do assignment. I think it won't be thread safe + */ + mapped_type& at(const key_type& key) { + std::shared_lock lock(mutex_); + iterator where = map_.find(key); + if (where == end()) { + std::runtime_error("Key not found"); + } + return where->second; + } + + mapped_type& operator[](const key_type& key) { + std::shared_lock lock(mutex_); + iterator where = map_.find(key); + if (where == end()) { + return map_[key]; + } + return where->second; + } + + /** + * Atomically finds the entry and removes it from the map, returning + * the previously associated value. + */ + mapped_type find_and_erase(const K& key) { + std::unique_lock lock(mutex_); + auto search = map_.find(key); + // It's an error if it's not there. + CHECK(search != end()); + auto val = std::move(search->second); + map_.erase(key); + return val; + } + + void erase(const K& key) { + std::unique_lock lock(mutex_); + map_.erase(key); + } + + iterator begin() { return map_.begin(); } + + const_iterator begin() const { return map_.begin(); } + + const_iterator cbegin() const { return map_.begin(); } + + iterator end() { return map_.end(); } + + const_iterator end() const { return map_.end(); } + + const_iterator cend() const { return map_.end(); } + + iterator find(const K& key) { + std::shared_lock lock(mutex_); + return map_.find(key); + } + + // TODO: find(), at() returning const_iterator + + bool empty() const { + std::unique_lock lock(mutex_); + return map_.empty(); + } + + private: + std::shared_timed_mutex mutex_; + std::unordered_map map_; +}; +} /* namespace hbase */