commit aef87b496f988355b60a3ef9e22665c7620c3f6a Author: Todd Lipcon Date: Fri May 4 18:05:01 2012 -0700 kill buffer copies in HBaseServer more reducing copies fix hinting one more place Improve server serialization diff --git src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java index e138199..464c07e 100644 --- src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java +++ src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java @@ -98,6 +98,7 @@ import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableFactories; import org.apache.hadoop.io.WritableUtils; +import com.google.protobuf.CodedOutputStream; import com.google.protobuf.Message; import com.google.protobuf.RpcController; @@ -423,6 +424,12 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur // one extra class code for writable instance. return r.getWritableSize() + size + Bytes.SIZEOF_BYTE; } + if (instance instanceof Message) { + int pbSize = ((Message)instance).getSerializedSize(); + return CodedOutputStream.computeRawVarint32Size(pbSize) + + pbSize + + 4 + instance.getClass().getName().length(); + } return 0L; // no hint is the default. } /** diff --git src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java new file mode 100644 index 0000000..f2997d6 --- /dev/null +++ src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java @@ -0,0 +1,71 @@ +package org.apache.hadoop.hbase.ipc; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.GatheringByteChannel; + +class BufferChain { + private static final int NIO_BUFFER_LIMIT = 8192; + private ByteBuffer[] buffers; + int remaining; + int bufferOffset = 0; + + BufferChain(ByteBuffer ... buffers) { + this.buffers = buffers; + + remaining = 0; + for (ByteBuffer b : buffers) { + remaining += b.remaining(); + } + } + + public boolean hasRemaining() { + return remaining > 0; + } + + public long writeChunk(GatheringByteChannel channel, int chunkSize) throws IOException { + int chunkRemaining = chunkSize; + ByteBuffer lastBuffer = null; + int bufCount = 0; + int restoreLimit = -1; + + while (chunkRemaining > 0 && bufferOffset + bufCount < buffers.length) { + lastBuffer = buffers[bufferOffset + bufCount]; + if (!lastBuffer.hasRemaining()) { + bufferOffset++; + continue; + } + bufCount++; + + if (lastBuffer.remaining() > chunkRemaining) { + restoreLimit = lastBuffer.limit(); + lastBuffer.limit(lastBuffer.position() + chunkRemaining); + chunkRemaining = 0; + break; + } else { + chunkRemaining -= lastBuffer.remaining(); + } + } + assert lastBuffer != null; + if (chunkRemaining == chunkSize) { + assert !hasRemaining(); + // no data left to write + return 0; + } + + try { + long ret = channel.write(buffers, bufferOffset, bufCount); + if (ret > 0) { + remaining -= ret; + } + return ret; + } finally { + if (restoreLimit >= 0) { + lastBuffer.limit(restoreLimit); + } + } + } + + + +} diff --git src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java index 7829e0a..e7456c4 100644 --- src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java +++ src/main/java/org/apache/hadoop/hbase/ipc/HBaseServer.java @@ -35,6 +35,7 @@ import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; import java.nio.channels.ClosedChannelException; +import java.nio.channels.GatheringByteChannel; import java.nio.channels.ReadableByteChannel; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; @@ -58,7 +59,6 @@ import java.util.concurrent.LinkedBlockingQueue; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hbase.io.DataOutputOutputStream; import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.io.HbaseObjectWritable; import org.apache.hadoop.hbase.io.WritableWithSize; @@ -69,8 +69,6 @@ import org.apache.hadoop.hbase.protobuf.generated.RPCProtos.RpcException; import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; import org.apache.hadoop.hbase.monitoring.TaskMonitor; import org.apache.hadoop.hbase.security.User; -import org.apache.hadoop.hbase.util.ByteBufferOutputStream; -import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.io.DataOutputBuffer; import org.apache.hadoop.io.Writable; import org.apache.hadoop.ipc.RPC.VersionMismatch; @@ -80,6 +78,8 @@ import org.apache.hadoop.util.StringUtils; import com.google.common.base.Function; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.protobuf.ByteString; +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.WireFormat; import org.cliffc.high_scale_lib.Counter; @@ -270,7 +270,7 @@ public abstract class HBaseServer implements RpcServer { protected Connection connection; // connection to client protected long timestamp; // the time received when response is null // the time served when response is not null - protected ByteBuffer response; // the response for this call + protected BufferChain response; // the response for this call protected boolean delayResponse; protected Responder responder; protected boolean delayReturnValue; // if the return value should be @@ -305,7 +305,84 @@ public abstract class HBaseServer implements RpcServer { return; if (errorClass != null) { this.isError = true; + this.response = makeResponseForError(errorClass, error); + } else { + try { + this.response = makeResponseForSuccess(value); + } catch (IOException ioe) { + errorClass = ioe.getClass().getName(); + error = StringUtils.stringifyException(ioe); + this.isError = true; + this.response = makeResponseForError(errorClass, error); + } } + } + + private BufferChain makeResponseForError(String errorClass, String error) { + RpcResponse response = RpcResponse.newBuilder() + .setCallId(this.id) + .setError(true) + .setException(RpcException.newBuilder() + .setExceptionName(errorClass) + .setStackTrace(error)) + .build(); + + ByteBuffer buf = response.toByteString().asReadOnlyByteBuffer(); + return new BufferChain(buf); + } + + private BufferChain makeResponseForSuccess(Object value) throws IOException { + // Success responses are much more common than error responses, so + // we do a bit of manual protobuf serialization to avoid buffer copies. + // Because protobufs use vint length prefixes, this is slightly tricky. + // The end result looks like: + // + // < length prefix > - vint32 for the whole message ("delimiter" in PB terms) + // < RpcResponse.callid tag> + // < RpcResponse.error tag> + // < RpcResponse.response tag> + // < RpcResponse.response length (vint32) + // < Serialized response object (HBaseObjectWritable) + // + // Because we don't always know the length of the response object until we serialize + // it, we can't fill in the length prefix ahead of time. So, we split the above + // into two buffers, and return a BufferChain so that they're written together + // to the wire using writev() + + // Serialize the actual result + ByteBuffer valueBuf = resultToByteBuffer(value); + int valueLen = valueBuf.remaining(); + + int headerLen = + CodedOutputStream.computeInt32Size(RpcResponse.CALLID_FIELD_NUMBER, this.id) + + CodedOutputStream.computeBoolSize(RpcResponse.ERROR_FIELD_NUMBER, false) + + CodedOutputStream.computeTagSize(RpcResponse.RESPONSE_FIELD_NUMBER) + + CodedOutputStream.computeRawVarint32Size(valueLen); + int totalLen = headerLen + valueLen; + int delimiterLen = CodedOutputStream.computeRawVarint32Size(totalLen); + + byte[] header = new byte[delimiterLen + headerLen]; + CodedOutputStream cos = CodedOutputStream.newInstance(header); + + // Write delimiter for whole message + cos.writeRawVarint32(totalLen); + + // Set call ID and error flag + cos.writeInt32(RpcResponse.CALLID_FIELD_NUMBER, this.id); + cos.writeBool(RpcResponse.ERROR_FIELD_NUMBER, false); + + // Write header for the actual response data + cos.writeTag(RpcResponse.RESPONSE_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + cos.writeRawVarint32(valueLen); + cos.flush(); + cos.checkNoSpaceLeft(); + + return new BufferChain( + ByteBuffer.wrap(header), + valueBuf); + } + + private ByteBuffer resultToByteBuffer(Object value) throws IOException { Writable result = null; if (value instanceof Writable) { result = (Writable) value; @@ -316,50 +393,38 @@ public abstract class HBaseServer implements RpcServer { result = new HbaseObjectWritable(value); } } - + + boolean hinted = false; int size = BUFFER_INITIAL_SIZE; if (result instanceof WritableWithSize) { // get the size hint. WritableWithSize ohint = (WritableWithSize) result; - long hint = ohint.getWritableSize() + Bytes.SIZEOF_BYTE + - (2 * Bytes.SIZEOF_INT); + long hint = ohint.getWritableSize(); if (hint > Integer.MAX_VALUE) { - // oops, new problem. - IOException ioe = - new IOException("Result buffer size too large: " + hint); - errorClass = ioe.getClass().getName(); - error = StringUtils.stringifyException(ioe); - } else { + throw new IOException("Result buffer size too large: " + hint); + } else if (hint > 0) { size = (int)hint; + hinted = true; } } - - ByteBufferOutputStream buf = new ByteBufferOutputStream(size); - DataOutputStream out = new DataOutputStream(buf); - try { - RpcResponse.Builder builder = RpcResponse.newBuilder(); - // Call id. - builder.setCallId(this.id); - builder.setError(error != null); - if (error != null) { - RpcException.Builder b = RpcException.newBuilder(); - b.setExceptionName(errorClass); - b.setStackTrace(error); - builder.setException(b.build()); - } else { - DataOutputBuffer d = new DataOutputBuffer(size); - result.write(d); - byte[] response = d.getData(); - builder.setResponse(ByteString.copyFrom(response)); + + DataOutputBuffer buf = new DataOutputBuffer(size); + result.write(buf); + + // Debug logs if the hint was too small (in which case we paid an + // extra copy to expand the buffer) or too big (in which case we + // wasted allocation space) + if (hinted && LOG.isDebugEnabled()) { + if (buf.getLength() > size) { + LOG.debug("Hint for value " + value + " too small: " + + "hint=" + size + " actual=" + buf.getLength()); + } else if (buf.getLength() < size - 8) { + LOG.debug("Hint for value " + value + " was much too big: " + + "hint=" + size + " actual=" + buf.getLength()); } - builder.build().writeDelimitedTo( - DataOutputOutputStream.constructOutputStream(out)); - } catch (IOException e) { - LOG.warn("Exception while creating response " + e); } - ByteBuffer bb = buf.getByteBuffer(); - bb.position(0); - this.response = bb; + + return ByteBuffer.wrap(buf.getData(), 0, buf.getLength()); } @Override @@ -933,7 +998,7 @@ public abstract class HBaseServer implements RpcServer { // // Send as much data as we can in the non-blocking fashion // - int numBytes = channelWrite(channel, call.response); + long numBytes = channelWrite(channel, call.response); if (numBytes < 0) { return true; } @@ -1686,11 +1751,10 @@ public abstract class HBaseServer implements RpcServer { * @throws java.io.IOException e * @see java.nio.channels.WritableByteChannel#write(java.nio.ByteBuffer) */ - protected int channelWrite(WritableByteChannel channel, - ByteBuffer buffer) throws IOException { + protected long channelWrite(GatheringByteChannel channel, + BufferChain buffer) throws IOException { - int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - channel.write(buffer) : channelIO(null, channel, buffer); + long count = buffer.writeChunk(channel, NIO_BUFFER_LIMIT); if (count > 0) { rpcMetrics.sentBytes.inc(count); } diff --git src/test/java/org/apache/hadoop/hbase/ipc/TestBufferChain.java src/test/java/org/apache/hadoop/hbase/ipc/TestBufferChain.java new file mode 100644 index 0000000..7cc91af --- /dev/null +++ src/test/java/org/apache/hadoop/hbase/ipc/TestBufferChain.java @@ -0,0 +1,129 @@ +package org.apache.hadoop.hbase.ipc; + +import static org.junit.Assert.*; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; + + +public class TestBufferChain { + private File tmpFile; + + private static final byte[][] HELLO_WORLD_CHUNKS = new byte[][] { + "hello".getBytes(Charsets.UTF_8), + " ".getBytes(Charsets.UTF_8), + "world".getBytes(Charsets.UTF_8) + }; + + @Before + public void setup() throws IOException { + tmpFile = File.createTempFile("TestBufferChain", "txt"); + } + + @After + public void teardown() { + tmpFile.delete(); + } + + @Test + public void testChainChunkBiggerThanWholeArray() throws IOException { + ByteBuffer[] bufs = wrapArrays(HELLO_WORLD_CHUNKS); + BufferChain chain = new BufferChain(bufs); + writeAndVerify(chain, "hello world", 8192); + assertNoRemaining(bufs); + } + + @Test + public void testChainChunkBiggerThanSomeArrays() throws IOException { + ByteBuffer[] bufs = wrapArrays(HELLO_WORLD_CHUNKS); + BufferChain chain = new BufferChain(bufs); + writeAndVerify(chain, "hello world", 3); + assertNoRemaining(bufs); + } + + @Test + public void testLimitOffset() throws IOException { + ByteBuffer[] bufs = new ByteBuffer[] { + stringBuf("XXXhelloYYY", 3, 5), + stringBuf(" ", 0, 1), + stringBuf("XXXXworldY", 4, 5) }; + BufferChain chain = new BufferChain(bufs); + writeAndVerify(chain , "hello world", 3); + assertNoRemaining(bufs); + } + + @Test + public void testWithSpy() throws IOException { + ByteBuffer[] bufs = new ByteBuffer[] { + stringBuf("XXXhelloYYY", 3, 5), + stringBuf(" ", 0, 1), + stringBuf("XXXXworldY", 4, 5) }; + BufferChain chain = new BufferChain(bufs); + FileOutputStream fos = new FileOutputStream(tmpFile); + FileChannel ch = Mockito.spy(fos.getChannel()); + try { + chain.writeChunk(ch, 2); + assertEquals("he", Files.toString(tmpFile, Charsets.UTF_8)); + chain.writeChunk(ch, 2); + assertEquals("hell", Files.toString(tmpFile, Charsets.UTF_8)); + chain.writeChunk(ch, 3); + assertEquals("hello w", Files.toString(tmpFile, Charsets.UTF_8)); + chain.writeChunk(ch, 8); + assertEquals("hello world", Files.toString(tmpFile, Charsets.UTF_8)); + } finally { + ch.close(); + } + } + + private ByteBuffer stringBuf(String string, int position, int length) { + ByteBuffer buf = ByteBuffer.wrap(string.getBytes(Charsets.UTF_8)); + buf.position(position); + buf.limit(position + length); + assertTrue(buf.hasRemaining()); + return buf; + } + + private void assertNoRemaining(ByteBuffer[] bufs) { + for (ByteBuffer buf : bufs) { + assertFalse(buf.hasRemaining()); + } + } + + private ByteBuffer[] wrapArrays(byte[][] arrays) { + ByteBuffer[] ret = new ByteBuffer[arrays.length]; + for (int i = 0; i < arrays.length; i++) { + ret[i] = ByteBuffer.wrap(arrays[i]); + } + return ret; + } + + private void writeAndVerify(BufferChain chain, String string, int chunkSize) + throws IOException { + FileOutputStream fos = new FileOutputStream(tmpFile); + FileChannel ch = fos.getChannel(); + try { + long remaining = string.length(); + while (chain.hasRemaining()) { + long n = chain.writeChunk(ch, chunkSize); + assertTrue(n == chunkSize || n == remaining); + remaining -= n; + } + assertEquals(0, remaining); + } finally { + fos.close(); + } + assertFalse(chain.hasRemaining()); + assertEquals(string, Files.toString(tmpFile, Charsets.UTF_8)); + } +}