diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/util/Bytes.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/util/Bytes.java index 3d709a5..f40ebbd 100644 --- a/hbase-common/src/main/java/org/apache/hadoop/hbase/util/Bytes.java +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/util/Bytes.java @@ -20,6 +20,7 @@ package org.apache.hadoop.hbase.util; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkPositionIndex; +import com.google.protobuf.ByteString; import java.io.DataInput; import java.io.DataOutput; @@ -64,7 +65,7 @@ import org.apache.hadoop.hbase.util.Bytes.LexicographicalComparerHolder.UnsafeCo @SuppressWarnings("restriction") @InterfaceAudience.Public @InterfaceStability.Stable -public class Bytes { +public class Bytes implements Comparable { //HConstants.UTF8_ENCODING should be updated if this changed /** When we encode strings, we always specify UTF8 encoding */ private static final String UTF8_ENCODING = "UTF-8"; @@ -132,7 +133,7 @@ public class Bytes { // SizeOf which uses java.lang.instrument says 24 bytes. (3 longs?) public static final int ESTIMATED_HEAP_TAX = 16; - + /** * Returns length of the byte array, returning 0 if the array is null. * Useful for calculating sizes. @@ -143,6 +144,190 @@ public class Bytes { return b == null ? 0 : b.length; } + private byte[] bytes; + private int offset; + private int length; + + /** + * Create a zero-size sequence. + */ + public Bytes() { + super(); + } + + /** + * Create a Bytes using the byte array as the initial value. + * @param bytes This array becomes the backing storage for the object. + */ + public Bytes(byte[] bytes) { + this(bytes, 0, bytes.length); + } + + /** + * Set the new Bytes to the contents of the passed + * ibw. + * @param ibw the value to set this Bytes to. + */ + public Bytes(final Bytes ibw) { + this(ibw.get(), ibw.getOffset(), ibw.getLength()); + } + + /** + * Set the value to a given byte range + * @param bytes the new byte range to set to + * @param offset the offset in newData to start at + * @param length the number of bytes in the range + */ + public Bytes(final byte[] bytes, final int offset, + final int length) { + this.bytes = bytes; + this.offset = offset; + this.length = length; + } + + /** + * Copy bytes from ByteString instance. + * @param byteString copy from + */ + public Bytes(final ByteString byteString) { + this(byteString.toByteArray()); + } + + /** + * Get the data from the Bytes. + * @return The data is only valid between offset and offset+length. + */ + public byte [] get() { + if (this.bytes == null) { + throw new IllegalStateException("Uninitialiized. Null constructor " + + "called w/o accompaying readFields invocation"); + } + return this.bytes; + } + + /** + * @param b Use passed bytes as backing array for this instance. + */ + public void set(final byte [] b) { + set(b, 0, b.length); + } + + /** + * @param b Use passed bytes as backing array for this instance. + * @param offset + * @param length + */ + public void set(final byte [] b, final int offset, final int length) { + this.bytes = b; + this.offset = offset; + this.length = length; + } + + /** + * @return the number of valid bytes in the buffer + * @deprecated use {@link #getLength()} instead + */ + @Deprecated + public int getSize() { + if (this.bytes == null) { + throw new IllegalStateException("Uninitialiized. Null constructor " + + "called w/o accompaying readFields invocation"); + } + return this.length; + } + + /** + * @return the number of valid bytes in the buffer + */ + public int getLength() { + if (this.bytes == null) { + throw new IllegalStateException("Uninitialiized. Null constructor " + + "called w/o accompaying readFields invocation"); + } + return this.length; + } + + /** + * @return offset + */ + public int getOffset(){ + return this.offset; + } + + public ByteString toByteString() { + return ByteString.copyFrom(this.bytes, this.offset, this.length); + } + + @Override + public int hashCode() { + return Bytes.hashCode(bytes, offset, length); + } + + /** + * Define the sort order of the Bytes. + * @param that The other bytes writable + * @return Positive if left is bigger than right, 0 if they are equal, and + * negative if left is smaller than right. + */ + public int compareTo(Bytes that) { + return BYTES_RAWCOMPARATOR.compare( + this.bytes, this.offset, this.length, + that.bytes, that.offset, that.length); + } + + /** + * Compares the bytes in this object to the specified byte array + * @param that + * @return Positive if left is bigger than right, 0 if they are equal, and + * negative if left is smaller than right. + */ + public int compareTo(final byte [] that) { + return BYTES_RAWCOMPARATOR.compare( + this.bytes, this.offset, this.length, + that, 0, that.length); + } + + /** + * @see Object#equals(Object) + */ + @Override + public boolean equals(Object right_obj) { + if (right_obj instanceof byte []) { + return compareTo((byte [])right_obj) == 0; + } + if (right_obj instanceof Bytes) { + return compareTo((Bytes)right_obj) == 0; + } + return false; + } + + /** + * @see Object#toString() + */ + @Override + public String toString() { + return Bytes.toString(bytes, offset, length); + } + + /** + * @param array List of byte []. + * @return Array of byte []. + */ + public static byte [][] toArray(final List array) { + // List#toArray doesn't work on lists of byte []. + byte[][] results = new byte[array.size()][]; + for (int i = 0; i < array.size(); i++) { + results[i] = array.get(i); + } + return results; + } + + /** + * Returns a copy of the bytes referred to by this writable + */ + public byte[] copyBytes() { + return Arrays.copyOfRange(bytes, offset, offset+length); + } /** * Byte array comparator class. */ @@ -369,7 +554,7 @@ public class Bytes { final byte [] b2) { return toString(b1, 0, b1.length) + sep + toString(b2, 0, b2.length); } - + /** * This method will convert utf8 encoded bytes into a string. If the given byte array is null, * this method will return null. diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index 51de6af..5d620de 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -20,7 +20,6 @@ package org.apache.hadoop.hbase.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -33,7 +32,6 @@ import java.net.SocketException; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; -import java.nio.channels.Channels; import java.nio.channels.ClosedChannelException; import java.nio.channels.GatheringByteChannel; import java.nio.channels.ReadableByteChannel; @@ -277,6 +275,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { private UserProvider userProvider; private final BoundedByteBufferPool reservoir; + private final BoundedByteBufferPool reqBufPool; private volatile boolean allowFallbackToSimpleAuth; @@ -554,6 +553,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { private ExecutorService readPool; + public Listener(final String name) throws IOException { super(name); backlogLength = conf.getInt("hbase.ipc.server.listen.queue.size", 128); @@ -859,6 +859,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { ". Number of active connections: " + numConnections); } closeConnection(c); + key.cancel(); } } @@ -1196,9 +1197,6 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { private AuthMethod authMethod; private boolean saslContextEstablished; private boolean skipInitialSaslHandshake; - private ByteBuffer unwrappedData; - // When is this set? FindBugs wants to know! Says NP - private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); boolean useSasl; SaslServer saslServer; private boolean useWrap = false; @@ -1311,18 +1309,19 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { return authorizedUgi; } - private void saslReadAndProcess(byte[] saslToken) throws IOException, + private void saslReadAndProcess(ByteBuffer saslToken) throws IOException, InterruptedException { if (saslContextEstablished) { if (LOG.isTraceEnabled()) - LOG.trace("Have read input token of size " + saslToken.length + LOG.trace("Have read input token of size " + saslToken.remaining() + " for processing by saslServer.unwrap()"); if (!useWrap) { processOneRpc(saslToken); } else { - byte [] plaintextData = saslServer.unwrap(saslToken, 0, saslToken.length); - processUnwrappedData(plaintextData); + byte[] plaintextData = + saslServer.unwrap(saslToken.array(), saslToken.position(), saslToken.remaining()); + processUnwrappedData(ByteBuffer.wrap(plaintextData)); } } else { byte[] replyToken; @@ -1370,10 +1369,12 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } } if (LOG.isDebugEnabled()) { - LOG.debug("Have read input token of size " + saslToken.length + LOG.debug("Have read input token of size " + saslToken.remaining() + " for processing by saslServer.evaluateResponse()"); } - replyToken = saslServer.evaluateResponse(saslToken); + + replyToken = saslServer.evaluateResponse( + new Bytes(saslToken.array(), saslToken.position(), saslToken.remaining()).copyBytes()); } catch (IOException e) { IOException sendToClient = e; Throwable cause = e; @@ -1569,8 +1570,12 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { throw new IllegalArgumentException("Unexpected data length " + dataLength + "!! from " + getHostAddress()); } - data = ByteBuffer.allocate(dataLength); - + data = reqBufPool.getBuffer(); + if (data.capacity() < dataLength) { + data = ByteBuffer.allocate(dataLength); + } else { + data.limit(dataLength); + } // Increment the rpc count. This counter will be decreased when we write // the response. If we want the connection to be detected as idle properly, we // need to keep the inc / dec correct. @@ -1598,14 +1603,15 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } if (useSasl) { - saslReadAndProcess(data.array()); + saslReadAndProcess(data); } else { - processOneRpc(data.array()); + processOneRpc(data); } } finally { dataLengthBuffer.clear(); // Clean for the next call - data = null; // For the GC + reqBufPool.putBuffer(data); + data = null; } } @@ -1629,8 +1635,11 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } // Reads the connection header following version - private void processConnectionHeader(byte[] buf) throws IOException { - this.connectionHeader = ConnectionHeader.parseFrom(buf); + private void processConnectionHeader(ByteBuffer buf) throws IOException { + ConnectionHeader.Builder builder = ConnectionHeader.newBuilder(); + ProtobufUtil.mergeFrom(builder, buf.array(), buf.position(), buf.remaining()); + buf.position(buf.position() + buf.remaining()); + this.connectionHeader = builder.build(); String serviceName = connectionHeader.getServiceName(); if (serviceName == null) throw new EmptyServiceNameException(); this.service = getService(services, serviceName); @@ -1712,45 +1721,41 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } } - private void processUnwrappedData(byte[] inBuf) throws IOException, + private void processUnwrappedData(ByteBuffer inBuf) throws IOException, InterruptedException { - ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); + int originalLimit = inBuf.limit(); + IOException exceptions = null; // Read all RPCs contained in the inBuf, even partial ones while (true) { - int count; - if (unwrappedDataLengthBuffer.remaining() > 0) { - count = channelRead(ch, unwrappedDataLengthBuffer); - if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) - return; - } + if (inBuf.remaining() < 4) break; - if (unwrappedData == null) { - unwrappedDataLengthBuffer.flip(); - int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); + int dataLength = inBuf.getInt(); - if (unwrappedDataLength == RpcClient.PING_CALL_ID) { - if (LOG.isDebugEnabled()) - LOG.debug("Received ping message"); - unwrappedDataLengthBuffer.clear(); - continue; // ping message - } - unwrappedData = ByteBuffer.allocate(unwrappedDataLength); + if (dataLength == RpcClient.PING_CALL_ID) { + if (LOG.isDebugEnabled()) + LOG.debug("Received ping message"); + continue; // ping message } - count = channelRead(ch, unwrappedData); - if (count <= 0 || unwrappedData.remaining() > 0) - return; + if (inBuf.remaining() < dataLength) break; - if (unwrappedData.remaining() == 0) { - unwrappedDataLengthBuffer.clear(); - unwrappedData.flip(); - processOneRpc(unwrappedData.array()); - unwrappedData = null; + try{ + inBuf.limit(inBuf.position() + dataLength); + processOneRpc(inBuf); + } catch (Exception e) { + if (exceptions == null) + exceptions = new IOException("cause by following exceptions:"); + exceptions.addSuppressed(e); + inBuf.position(inBuf.limit()); // fix offset + } finally { + inBuf.limit(originalLimit); } } + + if (exceptions != null) throw exceptions; } - private void processOneRpc(byte[] buf) throws IOException, InterruptedException { + private void processOneRpc(ByteBuffer buf) throws IOException, InterruptedException { if (connectionHeaderRead) { processRequest(buf); } else { @@ -1772,18 +1777,20 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { * @throws IOException * @throws InterruptedException */ - protected void processRequest(byte[] buf) throws IOException, InterruptedException { - long totalRequestSize = buf.length; - int offset = 0; + protected void processRequest(ByteBuffer buf) throws IOException, InterruptedException { + int totalRequestSize = buf.remaining(); + int initPos = buf.position(); // Here we read in the header. We avoid having pb // do its default 4k allocation for CodedInputStream. We force it to use backing array. - CodedInputStream cis = CodedInputStream.newInstance(buf, offset, buf.length); + CodedInputStream cis = + CodedInputStream.newInstance(buf.array(), buf.position(), buf.remaining()); int headerSize = cis.readRawVarint32(); - offset = cis.getTotalBytesRead(); + buf.position(initPos + cis.getTotalBytesRead()); Message.Builder builder = RequestHeader.newBuilder(); - ProtobufUtil.mergeFrom(builder, buf, offset, headerSize); + builder.mergeFrom(buf.array(), buf.position(), headerSize); RequestHeader header = (RequestHeader) builder.build(); - offset += headerSize; + cis.skipRawBytes(headerSize); + buf.position(initPos + cis.getTotalBytesRead()); int id = header.getCallId(); if (LOG.isTraceEnabled()) { LOG.trace("RequestHeader " + TextFormat.shortDebugString(header) + @@ -1812,19 +1819,18 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { md = this.service.getDescriptorForType().findMethodByName(header.getMethodName()); if (md == null) throw new UnsupportedOperationException(header.getMethodName()); builder = this.service.getRequestPrototype(md).newBuilderForType(); - // To read the varint, I need an inputstream; might as well be a CIS. - cis = CodedInputStream.newInstance(buf, offset, buf.length); int paramSize = cis.readRawVarint32(); - offset += cis.getTotalBytesRead(); + buf.position(initPos + cis.getTotalBytesRead()); if (builder != null) { - ProtobufUtil.mergeFrom(builder, buf, offset, paramSize); + builder.mergeFrom(buf.array(), buf.position(), paramSize); param = builder.build(); } - offset += paramSize; + buf.position(buf.position() + paramSize); } if (header.hasCellBlockMeta()) { cellScanner = ipcUtil.createCellScanner(this.codec, this.compressionCodec, - buf, offset, buf.length); + buf.array(), buf.position(), buf.remaining()); + buf.position(buf.position() + buf.remaining()); } } catch (Throwable t) { InetSocketAddress address = getListenerAddress(); @@ -1850,9 +1856,9 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { setupResponse(responseBuffer, readParamsFailedCall, t, msg + "; " + t.getMessage()); responder.doRespond(readParamsFailedCall); + buf.position(buf.limit()); return; } - TraceInfo traceInfo = header.hasTraceInfo() ? new TraceInfo(header.getTraceInfo().getTraceId(), header.getTraceInfo().getParentId()) : null; @@ -1998,7 +2004,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { } else { reservoir = null; } - + this.server = server; this.services = services; this.bindAddress = bindAddress; @@ -2007,6 +2013,7 @@ public class RpcServer implements RpcServerInterface, ConfigurationObserver { this.maxQueueSize = this.conf.getInt("hbase.ipc.server.max.callqueue.size", DEFAULT_MAX_CALLQUEUE_SIZE); this.readThreads = conf.getInt("hbase.ipc.server.read.threadpool.size", 10); + this.reqBufPool = new BoundedByteBufferPool(1024*1024, 16*1024, readThreads * 3); this.maxIdleTime = 2 * conf.getInt("hbase.ipc.client.connection.maxidletime", 1000); this.maxConnectionsToNuke = conf.getInt("hbase.ipc.client.kill.max", 10); this.thresholdIdleConnections = conf.getInt("hbase.ipc.client.idlethreshold", 4000); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java index a5700d0..2463a62 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureRPC.java @@ -170,8 +170,20 @@ public class TestSecureRPC { testRpcFallbackToSimpleAuth(AsyncRpcClient.class); } + @Test + public void testWrappedRpc() throws Exception { + Configuration conf = getSecuredConfiguration(); + conf.set("hbase.rpc.protection", SaslUtil.QualityOfProtection.PRIVACY.name().toLowerCase()); + testRpcCallWithEnabledKerberosSaslAuth(RpcClientImpl.class, conf); + } + private void testRpcCallWithEnabledKerberosSaslAuth(Class rpcImplClass) throws Exception { + testRpcCallWithEnabledKerberosSaslAuth(rpcImplClass, getSecuredConfiguration()); + } + + private void testRpcCallWithEnabledKerberosSaslAuth(Class rpcImplClass, Configuration conf) + throws Exception { String krbKeytab = getKeytabFileForTesting(); String krbPrincipal = getPrincipalForTesting();