Index: src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java =================================================================== --- src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java (revision 0) +++ src/test/java/org/apache/hadoop/hbase/ipc/TestIPC.java (revision 0) @@ -0,0 +1,92 @@ +package org.apache.hadoop.hbase.ipc; + +import java.io.IOException; +import java.net.Socket; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import javax.net.SocketFactory; +import java.lang.reflect.Method; +import java.util.*; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Test; + +import static org.mockito.Mockito.*; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.util.StringUtils; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.hbase.io.HbaseObjectWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; + +import org.apache.commons.logging.*; +import org.apache.log4j.Logger; + +public class TestIPC { + public static final Log LOG = LogFactory.getLog(TestIPC.class); + private static final Random RANDOM = new Random(); + + private static class TestRpcServer extends HBaseServer { + TestRpcServer() throws IOException { + super("0.0.0.0", 0, LongWritable.class, 1, 1, HBaseConfiguration.create(), "TestRpcServer", 0); + } + + @Override + public Writable call(Class protocol, + Writable param, long receiveTime, MonitoredRPCHandler status) throws IOException { + return param; + } + } + + @Test + public void testRTEDuringConnectionSetup() throws Exception { + LOG.info("testRTEDuringConnectionSetup start ..."); + Configuration conf = HBaseConfiguration.create(); + + LOG.info("Mock socket"); + SocketFactory spyFactory = spy(NetUtils.getDefaultSocketFactory(conf)); + Mockito.doAnswer(new Answer() { + @Override + public Socket answer(InvocationOnMock invocation) throws Throwable { + Socket s = spy((Socket)invocation.callRealMethod()); + doThrow(new RuntimeException("Injected fault")).when(s).setSoTimeout(anyInt()); + return s; + } + }).when(spyFactory).createSocket(); + + LOG.info("rpcServer start"); + TestRpcServer rpcServer = new TestRpcServer(); + rpcServer.start(); + + LOG.info("test Client"); + HBaseClient client = new HBaseClient( + LongWritable.class, + conf, + spyFactory); + InetSocketAddress address = rpcServer.getListenerAddress(); + try { + client.call(new LongWritable(RANDOM.nextLong()), address); + fail("1: Expected an exception to have been thrown!"); + } catch (Exception e) { + LOG.info("1: Caught expected exception: " + e.toString()); + assertTrue(StringUtils.stringifyException(e).contains("Injected fault")); + } + + try { + client.call(new LongWritable(RANDOM.nextLong()), address); + fail("2: Expected an exception to have been thrown!"); + } catch (Exception e) { + LOG.info("2: Caught expected exception: " + e.toString()); + assertTrue(StringUtils.stringifyException(e).contains("Injected fault")); + } + } +} Index: src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java =================================================================== --- src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java (revision 42830) +++ src/main/java/org/apache/hadoop/hbase/ipc/HBaseClient.java (working copy) @@ -371,11 +371,13 @@ // start the receiver thread after the socket connection has been set up start(); - } catch (IOException e) { - markClosed(e); - close(); - - throw e; + } catch (Throwable t) { + if (t instanceof IOException) { + markClosed((IOException)t); + } else { + markClosed(new IOException("Coundn't set up IO Streams", t)); + } + close(); } }