diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/client/HConnectionManager.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/HConnectionManager.java index 7f514a8..464a92a 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/client/HConnectionManager.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/HConnectionManager.java @@ -40,6 +40,7 @@ import java.util.NavigableMap; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -544,10 +545,8 @@ public class HConnectionManager { private RpcClientEngine rpcEngine; // Known region ServerName.toString() -> RegionClient/Admin - private final ConcurrentHashMap> servers = - new ConcurrentHashMap>(); - private final ConcurrentHashMap connectionLock = - new ConcurrentHashMap(); + private final ConcurrentMap servers = + new ConcurrentHashMap(); /** * Map of table to table {@link HRegionLocation}s. The table key is made @@ -1438,7 +1437,7 @@ public class HConnectionManager { @Deprecated public ClientProtocol getClient(final String hostname, final int port) throws IOException { - return (ClientProtocol)getProtocol(hostname, port, clientClass); + return getProtocol(hostname, port, clientClass); } @Override @@ -1447,8 +1446,7 @@ public class HConnectionManager { if (isDeadServer(serverName)){ throw new RegionServerStoppedException("The server " + serverName + " is dead."); } - return (ClientProtocol) - getProtocol(serverName.getHostname(), serverName.getPort(), clientClass); + return getProtocol(serverName.getHostname(), serverName.getPort(), clientClass); } @Override @@ -1456,7 +1454,7 @@ public class HConnectionManager { public AdminProtocol getAdmin(final String hostname, final int port, final boolean master) throws IOException { - return (AdminProtocol)getProtocol(hostname, port, adminClass); + return getProtocol(hostname, port, adminClass); } @Override @@ -1465,8 +1463,7 @@ public class HConnectionManager { if (isDeadServer(serverName)){ throw new RegionServerStoppedException("The server " + serverName + " is dead."); } - return (AdminProtocol)getProtocol( - serverName.getHostname(), serverName.getPort(), adminClass); + return getProtocol(serverName.getHostname(), serverName.getPort(), adminClass); } /** @@ -1478,47 +1475,36 @@ public class HConnectionManager { * @return Proxy. * @throws IOException */ - IpcProtocol getProtocol(final String hostname, - final int port, final Class protocolClass) + T getProtocol(final String hostname, + final int port, final Class protocolClass) throws IOException { String rsName = Addressing.createHostAndPortStr(hostname, port); // See if we already have a connection (common case) - Map protocols = this.servers.get(rsName); + ProtocolMap protocols = this.servers.get(rsName); if (protocols == null) { - protocols = new HashMap(); - Map existingProtocols = - this.servers.putIfAbsent(rsName, protocols); + protocols = new ProtocolMap(); + ProtocolMap existingProtocols = this.servers.putIfAbsent(rsName, protocols); if (existingProtocols != null) { protocols = existingProtocols; } } - String protocol = protocolClass.getName(); - IpcProtocol server = protocols.get(protocol); - if (server == null) { - // create a unique lock for this RS + protocol (if necessary) - String lockKey = protocol + "@" + rsName; - this.connectionLock.putIfAbsent(lockKey, lockKey); - // get the RS lock - synchronized (this.connectionLock.get(lockKey)) { - // do one more lookup in case we were stalled above - server = protocols.get(protocol); - if (server == null) { - try { - // Only create isa when we need to. - InetSocketAddress address = new InetSocketAddress(hostname, port); - // definitely a cache miss. establish an RPC for this RS - server = HBaseClientRPC.waitForProxy(rpcEngine, protocolClass, address, this.conf, - this.maxRPCAttempts, this.rpcTimeout, this.rpcTimeout); - protocols.put(protocol, server); - } catch (RemoteException e) { - LOG.warn("RemoteException connecting to RS", e); - // Throw what the RemoteException was carrying. - throw e.unwrapRemoteException(); - } + + return protocols.get(protocolClass, new ProtocolMap.ProxyProvider() { + @Override + public T get() throws IOException { + try { + // Only create isa when we need to. + InetSocketAddress address = new InetSocketAddress(hostname, port); + // definitely a cache miss. establish an RPC for this RS + return HBaseClientRPC.waitForProxy(rpcEngine, protocolClass, address, conf, + maxRPCAttempts, rpcTimeout, rpcTimeout); + } catch (RemoteException e) { + LOG.warn("RemoteException connecting to RS", e); + // Throw what the RemoteException was carrying. + throw e.unwrapRemoteException(); } } - } - return server; + }); } @Override diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/client/ProtocolMap.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/ProtocolMap.java new file mode 100644 index 0000000..26146db --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/client/ProtocolMap.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.client; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; + +import org.apache.hadoop.hbase.IpcProtocol; + +/** + * A special map from a protocol to a proxy of the protocol + * with providing the same proxy for equal protocols, + * where a protocol is represented by an instance of {@code Class}, + * and a proxy is represented by an instance of {@code IpcProtocol}. + *

+ * Thread safe. + */ +class ProtocolMap { + /** + * Provides a proxy for {@code T}. + * This is invoked by a single thread. + */ + interface ProxyProvider { + T get() throws IOException; + } + + private final ConcurrentMap, ProxyHolder> holderMap = + new ConcurrentHashMap, ProxyHolder>(); + + private static class ProxyHolder { + final CountDownLatch initLatch = new CountDownLatch(1); + T proxy; + + T init(ProxyProvider provider) throws IOException { + try { + return this.proxy = provider.get(); + } finally { + initLatch.countDown(); + } + } + + /** + * @return null if its initialization failed + */ + T get() { + boolean interrupted = false; + try { + while (true) { + try { + initLatch.await(); + return proxy; + } catch (InterruptedException e) { + interrupted = true; + } + } + } finally { + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + } + + /** + * Returns a proxy for the given {@code protocol}. + * You get the same proxy that you once got for the given {@code protocol}, + * or you get a new proxy or an exception from the given {@code provider}. + * + * @param protocol a key to get a proxy + * @param provider a provider of a new proxy for the {@code protocol} + * @return a proxy for the given {@code protocol} + * @throws NullPointerException if {@code protocol} or {@code provider} is null + * @throws IOException thrown by the given {@code provider} + */ + T get( + Class protocol, ProxyProvider provider) throws IOException { + + if (protocol == null) { + throw new NullPointerException("protocol"); + } + if (provider == null) { + throw new NullPointerException("provider"); + } + + ProxyHolder holder = holderMap.get(protocol); + if (holder != null) { + IpcProtocol proxy = holder.get(); + if (proxy != null) { + return protocol.cast(proxy); + } + holderMap.remove(protocol, holder); + } + + ProxyHolder newHolder = new ProxyHolder(); + while (true) { + ProxyHolder existingHoder = holderMap.putIfAbsent(protocol, newHolder); + if (existingHoder == null) { + try { + return newHolder.init(provider); + } catch (IOException e) { + holderMap.remove(protocol, newHolder); + throw e; + } + } + + IpcProtocol existingProxy = existingHoder.get(); + if (existingProxy != null) { + return protocol.cast(existingProxy); + } + holderMap.remove(protocol, existingHoder); + } + } +} diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestProtocolMap.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestProtocolMap.java new file mode 100644 index 0000000..244299e --- /dev/null +++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/client/TestProtocolMap.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.client; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hadoop.hbase.IpcProtocol; +import org.apache.hadoop.hbase.SmallTests; +import org.junit.Assert; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@Category(SmallTests.class) +public class TestProtocolMap { + + private static interface MyProtocol extends IpcProtocol {} + + private static class MyProtocolImpl implements MyProtocol {} + + private static class ProxyProviderWithCallCounter + implements ProtocolMap.ProxyProvider { + + final AtomicInteger callCount = new AtomicInteger(); + + @Override + public MyProtocol get() { + callCount.incrementAndGet(); + return new MyProtocolImpl(); + } + } + + private static class MyIOException extends IOException {} + + private enum ProxyProviderToThrowMyIOException + implements ProtocolMap.ProxyProvider { + INSTANCE; + + @Override + public MyProtocol get() throws IOException { + throw new MyIOException(); + } + } + + @Test + public void testGet() throws Exception { + ProtocolMap map = new ProtocolMap(); + ProxyProviderWithCallCounter provider = new ProxyProviderWithCallCounter(); + + MyProtocol p = map.get(MyProtocol.class, provider); + Assert.assertTrue(p instanceof MyProtocolImpl); + Assert.assertEquals(1, provider.callCount.get()); + + MyProtocol p2 = map.get(MyProtocol.class, provider); + Assert.assertSame(p, p2); + Assert.assertEquals(1, provider.callCount.get()); + } + + @Test(expected=MyIOException.class) + public void testGetThrowsException() throws Exception { + ProtocolMap map = new ProtocolMap(); + map.get(MyProtocol.class, ProxyProviderToThrowMyIOException.INSTANCE); + } + + @Test + public void testContention() throws Exception { + final int concurrentLevel = 100; + final CyclicBarrier readyBarrier = new CyclicBarrier(concurrentLevel + 1); + final ProtocolMap map = new ProtocolMap(); + final ProxyProviderWithCallCounter provider = new ProxyProviderWithCallCounter(); + + ExecutorService service = Executors.newCachedThreadPool(); + List> futures = new ArrayList>(); + + for (int i = 0; i < concurrentLevel; i++) { + futures.add(service.submit(new Callable() { + @Override + public MyProtocol call() throws Exception { + readyBarrier.await(); + return map.get(MyProtocol.class, provider); + } + })); + } + + service.shutdown(); + readyBarrier.await(); + + if (! service.awaitTermination(100, TimeUnit.MILLISECONDS)) { + service.shutdownNow(); + Assert.fail("100msec elapsed before termination"); + } + + Set resultSet = new HashSet(); + for (Future future : futures) { + resultSet.add(future.get()); + } + + Assert.assertEquals(1, resultSet.size()); + Assert.assertEquals(1, provider.callCount.get()); + } +}