diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index 8cead2a..b90c542 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -20,7 +20,6 @@ package org.apache.hadoop.hbase.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -33,7 +32,6 @@ import java.net.SocketException; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; -import java.nio.channels.Channels; import java.nio.channels.ClosedChannelException; import java.nio.channels.GatheringByteChannel; import java.nio.channels.ReadableByteChannel; @@ -274,7 +272,7 @@ public class RpcServer implements RpcServerInterface { private UserProvider userProvider; private final BoundedByteBufferPool reservoir; - + private final BoundedByteBufferPool reqBufPool; /** * Datastructure that holds all necessary to a method invocation and then afterward, carries @@ -587,6 +585,7 @@ public class RpcServer implements RpcServerInterface { private ExecutorService readPool; + public Listener(final String name) throws IOException { super(name); backlogLength = conf.getInt("hbase.ipc.server.listen.queue.size", 128); @@ -890,6 +889,7 @@ public class RpcServer implements RpcServerInterface { ". Number of active connections: " + numConnections); } closeConnection(c); + key.cancel(); } } @@ -1234,9 +1234,6 @@ public class RpcServer implements RpcServerInterface { private AuthMethod authMethod; private boolean saslContextEstablished; private boolean skipInitialSaslHandshake; - private ByteBuffer unwrappedData; - // When is this set? FindBugs wants to know! Says NP - private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); boolean useSasl; SaslServer saslServer; private boolean useWrap = false; @@ -1341,18 +1338,19 @@ public class RpcServer implements RpcServerInterface { } } - private void saslReadAndProcess(byte[] saslToken) throws IOException, + private void saslReadAndProcess(ByteBuffer saslToken) throws IOException, InterruptedException { if (saslContextEstablished) { if (LOG.isTraceEnabled()) - LOG.trace("Have read input token of size " + saslToken.length + LOG.trace("Have read input token of size " + saslToken.remaining() + " for processing by saslServer.unwrap()"); if (!useWrap) { processOneRpc(saslToken); } else { - byte [] plaintextData = saslServer.unwrap(saslToken, 0, saslToken.length); - processUnwrappedData(plaintextData); + byte[] plaintextData = + saslServer.unwrap(saslToken.array(), saslToken.position(), saslToken.remaining()); + processUnwrappedData(ByteBuffer.wrap(plaintextData)); } } else { byte[] replyToken; @@ -1400,10 +1398,12 @@ public class RpcServer implements RpcServerInterface { } } if (LOG.isDebugEnabled()) { - LOG.debug("Have read input token of size " + saslToken.length + LOG.debug("Have read input token of size " + saslToken.remaining() + " for processing by saslServer.evaluateResponse()"); } - replyToken = saslServer.evaluateResponse(saslToken); + + replyToken = saslServer.evaluateResponse( + new Bytes(saslToken.array(), saslToken.position(), saslToken.remaining()).copyBytes()); } catch (IOException e) { IOException sendToClient = e; Throwable cause = e; @@ -1594,8 +1594,12 @@ public class RpcServer implements RpcServerInterface { throw new IllegalArgumentException("Unexpected data length " + dataLength + "!! from " + getHostAddress()); } - data = ByteBuffer.allocate(dataLength); - + data = reqBufPool.getBuffer(); + if (data.capacity() < dataLength) { + data = ByteBuffer.allocate(dataLength); + } else { + data.limit(dataLength); + } // Increment the rpc count. This counter will be decreased when we write // the response. If we want the connection to be detected as idle properly, we // need to keep the inc / dec correct. @@ -1623,14 +1627,15 @@ public class RpcServer implements RpcServerInterface { } if (useSasl) { - saslReadAndProcess(data.array()); + saslReadAndProcess(data); } else { - processOneRpc(data.array()); + processOneRpc(data); } } finally { dataLengthBuffer.clear(); // Clean for the next call - data = null; // For the GC + reqBufPool.putBuffer(data); + data = null; } } @@ -1654,8 +1659,11 @@ public class RpcServer implements RpcServerInterface { } // Reads the connection header following version - private void processConnectionHeader(byte[] buf) throws IOException { - this.connectionHeader = ConnectionHeader.parseFrom(buf); + private void processConnectionHeader(ByteBuffer buf) throws IOException { + ConnectionHeader.Builder builder = ConnectionHeader.newBuilder(); + ProtobufUtil.mergeFrom(builder, buf.array(), buf.position(), buf.remaining()); + buf.position(buf.position() + buf.remaining()); + this.connectionHeader = builder.build(); String serviceName = connectionHeader.getServiceName(); if (serviceName == null) throw new EmptyServiceNameException(); this.service = getService(services, serviceName); @@ -1726,45 +1734,41 @@ public class RpcServer implements RpcServerInterface { } } - private void processUnwrappedData(byte[] inBuf) throws IOException, + private void processUnwrappedData(ByteBuffer inBuf) throws IOException, InterruptedException { - ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); + int originalLimit = inBuf.limit(); + IOException exceptions = null; // Read all RPCs contained in the inBuf, even partial ones while (true) { - int count; - if (unwrappedDataLengthBuffer.remaining() > 0) { - count = channelRead(ch, unwrappedDataLengthBuffer); - if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) - return; - } + if (inBuf.remaining() < 4) break; - if (unwrappedData == null) { - unwrappedDataLengthBuffer.flip(); - int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); + int dataLength = inBuf.getInt(); - if (unwrappedDataLength == RpcClient.PING_CALL_ID) { - if (LOG.isDebugEnabled()) - LOG.debug("Received ping message"); - unwrappedDataLengthBuffer.clear(); - continue; // ping message - } - unwrappedData = ByteBuffer.allocate(unwrappedDataLength); + if (dataLength == RpcClient.PING_CALL_ID) { + if (LOG.isDebugEnabled()) + LOG.debug("Received ping message"); + continue; // ping message } - count = channelRead(ch, unwrappedData); - if (count <= 0 || unwrappedData.remaining() > 0) - return; + if (inBuf.remaining() < dataLength) break; - if (unwrappedData.remaining() == 0) { - unwrappedDataLengthBuffer.clear(); - unwrappedData.flip(); - processOneRpc(unwrappedData.array()); - unwrappedData = null; + try{ + inBuf.limit(inBuf.position() + dataLength); + processOneRpc(inBuf); + } catch (Exception e) { + if (exceptions == null) + exceptions = new IOException("cause by following exceptions:"); + exceptions.addSuppressed(e); + inBuf.position(inBuf.limit()); // fix offset + } finally { + inBuf.limit(originalLimit); } } + + if (exceptions != null) throw exceptions; } - private void processOneRpc(byte[] buf) throws IOException, InterruptedException { + private void processOneRpc(ByteBuffer buf) throws IOException, InterruptedException { if (connectionHeaderRead) { processRequest(buf); } else { @@ -1785,18 +1789,20 @@ public class RpcServer implements RpcServerInterface { * @throws IOException * @throws InterruptedException */ - protected void processRequest(byte[] buf) throws IOException, InterruptedException { - long totalRequestSize = buf.length; - int offset = 0; + protected void processRequest(ByteBuffer buf) throws IOException, InterruptedException { + int totalRequestSize = buf.remaining(); + int initPos = buf.position(); // Here we read in the header. We avoid having pb // do its default 4k allocation for CodedInputStream. We force it to use backing array. - CodedInputStream cis = CodedInputStream.newInstance(buf, offset, buf.length); + CodedInputStream cis = + CodedInputStream.newInstance(buf.array(), buf.position(), buf.remaining()); int headerSize = cis.readRawVarint32(); - offset = cis.getTotalBytesRead(); + buf.position(initPos + cis.getTotalBytesRead()); Message.Builder builder = RequestHeader.newBuilder(); - ProtobufUtil.mergeFrom(builder, buf, offset, headerSize); + builder.mergeFrom(buf.array(), buf.position(), headerSize); RequestHeader header = (RequestHeader) builder.build(); - offset += headerSize; + cis.skipRawBytes(headerSize); + buf.position(initPos + cis.getTotalBytesRead()); int id = header.getCallId(); if (LOG.isTraceEnabled()) { LOG.trace("RequestHeader " + TextFormat.shortDebugString(header) + @@ -1824,19 +1830,18 @@ public class RpcServer implements RpcServerInterface { md = this.service.getDescriptorForType().findMethodByName(header.getMethodName()); if (md == null) throw new UnsupportedOperationException(header.getMethodName()); builder = this.service.getRequestPrototype(md).newBuilderForType(); - // To read the varint, I need an inputstream; might as well be a CIS. - cis = CodedInputStream.newInstance(buf, offset, buf.length); int paramSize = cis.readRawVarint32(); - offset += cis.getTotalBytesRead(); + buf.position(initPos + cis.getTotalBytesRead()); if (builder != null) { - ProtobufUtil.mergeFrom(builder, buf, offset, paramSize); + builder.mergeFrom(buf.array(), buf.position(), paramSize); param = builder.build(); } - offset += paramSize; + buf.position(buf.position() + paramSize); } if (header.hasCellBlockMeta()) { cellScanner = ipcUtil.createCellScanner(this.codec, this.compressionCodec, - buf, offset, buf.length); + buf.array(), buf.position(), buf.remaining()); + buf.position(buf.position() + buf.remaining()); } } catch (Throwable t) { String msg = getListenerAddress() + " is unable to read call parameter from client " + @@ -1861,9 +1866,9 @@ public class RpcServer implements RpcServerInterface { setupResponse(responseBuffer, readParamsFailedCall, t, msg + "; " + t.getMessage()); responder.doRespond(readParamsFailedCall); + buf.position(buf.limit()); return; } - TraceInfo traceInfo = header.hasTraceInfo() ? new TraceInfo(header.getTraceInfo().getTraceId(), header.getTraceInfo().getParentId()) : null; @@ -1989,6 +1994,7 @@ public class RpcServer implements RpcServerInterface { this.maxQueueSize = this.conf.getInt("hbase.ipc.server.max.callqueue.size", DEFAULT_MAX_CALLQUEUE_SIZE); this.readThreads = conf.getInt("hbase.ipc.server.read.threadpool.size", 10); + this.reqBufPool = new BoundedByteBufferPool(1024*1024, 16*1024, readThreads * 3); this.maxIdleTime = 2 * conf.getInt("hbase.ipc.client.connection.maxidletime", 1000); this.maxConnectionsToNuke = conf.getInt("hbase.ipc.client.kill.max", 10); this.thresholdIdleConnections = conf.getInt("hbase.ipc.client.idlethreshold", 4000); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java index 8eff063..1e01aa2 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java @@ -105,8 +105,20 @@ public class TestSecureRPC { testRpcCallWithEnabledKerberosSaslAuth(AsyncRpcClient.class); } + @Test + public void testWrappedRpc() throws Exception { + Configuration conf = getSecuredConfiguration(); + conf.set("hbase.rpc.protection", SaslUtil.QualityOfProtection.PRIVACY.name().toLowerCase()); + testRpcCallWithEnabledKerberosSaslAuth(RpcClientImpl.class, conf); + } + private void testRpcCallWithEnabledKerberosSaslAuth(Class rpcImplClass) throws Exception { + testRpcCallWithEnabledKerberosSaslAuth(rpcImplClass, getSecuredConfiguration()); + } + + private void testRpcCallWithEnabledKerberosSaslAuth(Class rpcImplClass, Configuration conf) + throws Exception { String krbKeytab = getKeytabFileForTesting(); String krbPrincipal = getPrincipalForTesting(); @@ -122,7 +134,6 @@ public class TestSecureRPC { assertEquals(AuthenticationMethod.KERBEROS, ugi.getAuthenticationMethod()); assertEquals(krbPrincipal, ugi.getUserName()); - Configuration conf = getSecuredConfiguration(); conf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcImplClass.getName()); SecurityInfo securityInfoMock = Mockito.mock(SecurityInfo.class); Mockito.when(securityInfoMock.getServerPrincipal())