diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannelImpl.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannelImpl.java index 2b9000a..2e0eda4 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannelImpl.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/AsyncRpcChannelImpl.java @@ -231,6 +231,26 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel { } } + private void startConnectionWithEncryption(Channel ch) { + // for rpc encryption, the order of ChannelInboundHandler should be: + // LengthFieldBasedFrameDecoder->SaslClientHandler->LengthFieldBasedFrameDecoder + // Don't skip the first 4 bytes for length in beforeUnwrapDecoder, + // SaslClientHandler will handler this + ch.pipeline().addFirst("beforeUnwrapDecoder", + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 0)); + ch.pipeline().addLast("afterUnwrapDecoder", + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)); + ch.pipeline().addLast(new AsyncServerResponseHandler(this)); + List callsToWrite; + synchronized (pendingCalls) { + connected = true; + callsToWrite = new ArrayList(pendingCalls.values()); + } + for (AsyncCall call : callsToWrite) { + writeRequest(call); + } + } + /** * Get SASL handler * @param bootstrap to reconnect to @@ -243,6 +263,7 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel { client.fallbackAllowed, client.conf.get("hbase.rpc.protection", SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase()), + getChannelHeaderBytes(authMethod), new SaslClientHandler.SaslExceptionHandler() { @Override public void handle(int retryCount, Random random, Throwable cause) { @@ -261,6 +282,11 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel { public void onSuccess(Channel channel) { startHBaseConnection(channel); } + + @Override + public void onSaslProtectionSucess(Channel channel) { + startConnectionWithEncryption(channel); + } }); } @@ -341,6 +367,25 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel { * @throws java.io.IOException on failure to write */ private ChannelFuture writeChannelHeader(Channel channel) throws IOException { + RPCProtos.ConnectionHeader header = getChannelHeader(authMethod); + int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header); + ByteBuf b = channel.alloc().directBuffer(totalSize); + + b.writeInt(header.getSerializedSize()); + b.writeBytes(header.toByteArray()); + + return channel.writeAndFlush(b); + } + + private byte[] getChannelHeaderBytes(AuthMethod authMethod) { + RPCProtos.ConnectionHeader header = getChannelHeader(authMethod); + ByteBuffer b = ByteBuffer.allocate(header.getSerializedSize() + 4); + b.putInt(header.getSerializedSize()); + b.put(header.toByteArray()); + return b.array(); + } + + private RPCProtos.ConnectionHeader getChannelHeader(AuthMethod authMethod) { RPCProtos.ConnectionHeader.Builder headerBuilder = RPCProtos.ConnectionHeader.newBuilder() .setServiceName(serviceName); @@ -357,16 +402,7 @@ public class AsyncRpcChannelImpl implements AsyncRpcChannel { } headerBuilder.setVersionInfo(ProtobufUtil.getVersionInfo()); - RPCProtos.ConnectionHeader header = headerBuilder.build(); - - int totalSize = IPCUtil.getTotalSizeWhenWrittenDelimited(header); - - ByteBuf b = channel.alloc().directBuffer(totalSize); - - b.writeInt(header.getSerializedSize()); - b.writeBytes(header.toByteArray()); - - return channel.writeAndFlush(b); + return headerBuilder.build(); } /** diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java index bfb625b..8fa7ebe 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/SaslClientHandler.java @@ -39,6 +39,7 @@ import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.security.PrivilegedExceptionAction; import java.util.Map; @@ -63,6 +64,7 @@ public class SaslClientHandler extends ChannelDuplexHandler { private final SaslExceptionHandler exceptionHandler; private final SaslSuccessfulConnectHandler successfulConnectHandler; private byte[] saslToken; + private byte[] connectionHeader; private boolean firstRead = true; private int retryCount = 0; @@ -80,10 +82,11 @@ public class SaslClientHandler extends ChannelDuplexHandler { */ public SaslClientHandler(UserGroupInformation ticket, AuthMethod method, Token token, String serverPrincipal, boolean fallbackAllowed, - String rpcProtection, SaslExceptionHandler exceptionHandler, + String rpcProtection, byte[] connectionHeader, SaslExceptionHandler exceptionHandler, SaslSuccessfulConnectHandler successfulConnectHandler) throws IOException { this.ticket = ticket; this.fallbackAllowed = fallbackAllowed; + this.connectionHeader = connectionHeader; this.exceptionHandler = exceptionHandler; this.successfulConnectHandler = successfulConnectHandler; @@ -225,8 +228,13 @@ public class SaslClientHandler extends ChannelDuplexHandler { if (!useWrap) { ctx.pipeline().remove(this); + successfulConnectHandler.onSuccess(ctx.channel()); + } else { + byte[] wrappedCH = saslClient.wrap(connectionHeader, 0, connectionHeader.length); + // write connection header + writeSaslToken(ctx, wrappedCH); + successfulConnectHandler.onSaslProtectionSucess(ctx.channel()); } - successfulConnectHandler.onSuccess(ctx.channel()); } } // Normal wrapped reading @@ -303,9 +311,19 @@ public class SaslClientHandler extends ChannelDuplexHandler { super.write(ctx, msg, promise); } else { ByteBuf in = (ByteBuf) msg; + int length = in.readInt(); + ByteBuffer b = ByteBuffer.allocate(4); + b.putInt(length); + + byte[] content = new byte[length]; + in.readBytes(content); + + byte[] unwrapped = new byte[length + 4]; + System.arraycopy(b.array(), 0, unwrapped, 0, 4); + System.arraycopy(content, 0, unwrapped, 4, content.length); try { - saslToken = saslClient.wrap(in.array(), in.readerIndex(), in.readableBytes()); + saslToken = saslClient.wrap(unwrapped, 0, unwrapped.length); } catch (SaslException se) { try { saslClient.dispose(); @@ -355,5 +373,12 @@ public class SaslClientHandler extends ChannelDuplexHandler { * @param channel which is successfully authenticated */ public void onSuccess(Channel channel); + + /** + * Runs on success if data protection used in Sasl + * + * @param channel which is successfully authenticated + */ + public void onSaslProtectionSucess(Channel channel); } } diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/AbstractTestSecureIPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/AbstractTestSecureIPC.java index 7e99cc0..385b7b0 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/AbstractTestSecureIPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/AbstractTestSecureIPC.java @@ -36,6 +36,7 @@ import java.util.concurrent.ThreadLocalRandom; import com.google.protobuf.RpcController; import com.google.protobuf.ServiceException; +import org.apache.commons.lang.RandomStringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.CommonConfigurationKeys; import org.apache.hadoop.hbase.Cell; @@ -217,6 +218,12 @@ public abstract class AbstractTestSecureIPC { setRpcProtection("integrity,authentication", "privacy,authentication"); callRpcService(User.create(ugi)); + + setRpcProtection("integrity,authentication", "integrity,authentication"); + callRpcService(User.create(ugi)); + + setRpcProtection("privacy,authentication", "privacy,authentication"); + callRpcService(User.create(ugi)); } @Test @@ -302,18 +309,17 @@ public abstract class AbstractTestSecureIPC { @Override public void run() { - String result; try { - result = stub.echo(null, TestProtos.EchoRequestProto.newBuilder().setMessage(String.valueOf( - ThreadLocalRandom.current().nextInt())).build()).getMessage(); - } catch (ServiceException e) { - throw new RuntimeException(e); - } - if (results != null) { - synchronized (results) { - results.add(result); - } + int[] messageSize = new int[] {100, 1000, 10000}; + for (int i = 0; i < messageSize.length; i++) { + String input = RandomStringUtils.random(messageSize[i]); + String result = stub.echo(null, TestProtos.EchoRequestProto.newBuilder() + .setMessage(input).build()).getMessage(); + assertEquals(input, result); } + } catch (ServiceException e) { + throw new RuntimeException(e); + } } } }