diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java index 26bc56c..bd0515a 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/BufferChain.java @@ -32,11 +32,13 @@ class BufferChain { private final ByteBuffer[] buffers; private int remaining = 0; private int bufferOffset = 0; + private int size; BufferChain(ByteBuffer[] buffers) { for (ByteBuffer b : buffers) { this.remaining += b.remaining(); } + this.size = remaining; this.buffers = buffers; } @@ -108,4 +110,12 @@ class BufferChain { } } } + + int size() { + return size; + } + + ByteBuffer[] getBuffers() { + return this.buffers; + } } diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java new file mode 100644 index 0000000..be55378 --- /dev/null +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcServer.java @@ -0,0 +1,540 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.ipc; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.concurrent.GlobalEventExecutor; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.CellScanner; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.Server; +import org.apache.hadoop.hbase.classification.InterfaceStability; +import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; +import org.apache.hadoop.hbase.nio.ByteBuff; +import org.apache.hadoop.hbase.nio.SingleByteBuff; +import org.apache.hadoop.hbase.security.AccessDeniedException; +import org.apache.hadoop.hbase.security.AuthMethod; +import org.apache.hadoop.hbase.security.HBasePolicyProvider; +import org.apache.hadoop.hbase.security.SaslStatus; +import org.apache.hadoop.hbase.security.SaslUtil; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.Descriptors.MethodDescriptor; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.Message; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; +import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.JVM; +import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; +import org.apache.htrace.TraceInfo; + +/** + * An RPC server with Netty4 implementation. + * + */ +public class NettyRpcServer extends RpcServer { + + public static final Log LOG = LogFactory.getLog(NettyRpcServer.class); + + protected final InetSocketAddress bindAddress; + + private final CountDownLatch closed = new CountDownLatch(1); + private final Channel serverChannel; + private final ChannelGroup allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);; + + public NettyRpcServer(final Server server, final String name, + final List services, + final InetSocketAddress bindAddress, Configuration conf, + RpcScheduler scheduler) throws IOException { + super(server, name, services, bindAddress, conf, scheduler); + this.bindAddress = bindAddress; + boolean useEpoll = useEpoll(conf); + int workerCount = conf.getInt("hbase.netty.rpc.server.worker.count", + Runtime.getRuntime().availableProcessors() / 4); + EventLoopGroup bossGroup = null; + EventLoopGroup workerGroup = null; + if (useEpoll) { + bossGroup = new EpollEventLoopGroup(1); + workerGroup = new EpollEventLoopGroup(workerCount); + } else { + bossGroup = new NioEventLoopGroup(1); + workerGroup = new NioEventLoopGroup(workerCount); + } + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup); + if (useEpoll) { + bootstrap.channel(EpollServerSocketChannel.class); + } else { + bootstrap.channel(NioServerSocketChannel.class); + } + bootstrap.childOption(ChannelOption.TCP_NODELAY, tcpNoDelay); + bootstrap.childOption(ChannelOption.SO_KEEPALIVE, tcpKeepAlive); + bootstrap.childOption(ChannelOption.ALLOCATOR, + PooledByteBufAllocator.DEFAULT); + bootstrap.childHandler(new Initializer(maxRequestSize)); + + try { + serverChannel = bootstrap.bind(this.bindAddress).sync().channel(); + LOG.info("NettyRpcServer bind to address=" + serverChannel.localAddress() + + ", hbase.netty.rpc.server.worker.count=" + workerCount + + ", useEpoll=" + useEpoll); + allChannels.add(serverChannel); + } catch (InterruptedException e) { + throw new InterruptedIOException(e.getMessage()); + } + initReconfigurable(conf); + this.scheduler.init(new RpcSchedulerContext(this)); + } + + private static boolean useEpoll(Configuration conf) { + // Config to enable native transport. + boolean epollEnabled = conf.getBoolean("hbase.rpc.server.nativetransport", + true); + // Use the faster native epoll transport mechanism on linux if enabled + return epollEnabled && JVM.isLinux() && JVM.isAmd64(); + } + + @Override + public synchronized void start() { + if (started) { + return; + } + authTokenSecretMgr = createSecretManager(); + if (authTokenSecretMgr != null) { + setSecretManager(authTokenSecretMgr); + authTokenSecretMgr.start(); + } + this.authManager = new ServiceAuthorizationManager(); + HBasePolicyProvider.init(conf, authManager); + scheduler.start(); + started = true; + } + + @Override + public synchronized void stop() { + if (!running) { + return; + } + LOG.info("Stopping server on " + this.bindAddress.getPort()); + if (authTokenSecretMgr != null) { + authTokenSecretMgr.stop(); + authTokenSecretMgr = null; + } + allChannels.close().awaitUninterruptibly(); + serverChannel.close(); + scheduler.stop(); + closed.countDown(); + running = false; + } + + @Override + public synchronized void join() throws InterruptedException { + closed.await(); + } + + @Override + public synchronized InetSocketAddress getListenerAddress() { + return ((InetSocketAddress) serverChannel.localAddress()); + } + + public class NettyConnection extends RpcServer.Connection { + + protected Channel channel; + + NettyConnection(Channel channel) { + super(); + this.channel = channel; + InetSocketAddress inetSocketAddress = ((InetSocketAddress) channel.remoteAddress()); + this.addr = inetSocketAddress.getAddress(); + if (addr == null) { + this.hostAddress = "*Unknown*"; + } else { + this.hostAddress = inetSocketAddress.getAddress().getHostAddress(); + } + this.remotePort = inetSocketAddress.getPort(); + this.saslCall = new Call(SASL_CALLID, null, null, null, null, null, this, + 0, null, null, 0, null); + this.setConnectionHeaderResponseCall = new Call( + CONNECTION_HEADER_RESPONSE_CALLID, null, null, null, null, null, + this, 0, null, null, 0, null); + this.authFailedCall = new Call(AUTHORIZATION_FAILED_CALLID, null, null, + null, null, null, this, 0, null, null, 0, null); + } + + void readPreamble(ByteBuf buffer) throws IOException { + byte[] rpcHead = + { buffer.readByte(), buffer.readByte(), buffer.readByte(), buffer.readByte() }; + if (!Arrays.equals(HConstants.RPC_HEADER, rpcHead)) { + doBadPreambleHandling("Expected HEADER=" + + Bytes.toStringBinary(HConstants.RPC_HEADER) + " but received HEADER=" + + Bytes.toStringBinary(rpcHead) + " from " + toString()); + return; + } + // Now read the next two bytes, the version and the auth to use. + int version = buffer.readByte(); + byte authbyte = buffer.readByte(); + this.authMethod = AuthMethod.valueOf(authbyte); + if (version != CURRENT_VERSION) { + String msg = getFatalConnectionString(version, authbyte); + doBadPreambleHandling(msg, new WrongVersionException(msg)); + return; + } + if (authMethod == null) { + String msg = getFatalConnectionString(version, authbyte); + doBadPreambleHandling(msg, new BadAuthException(msg)); + return; + } + if (isSecurityEnabled && authMethod == AuthMethod.SIMPLE) { + if (allowFallbackToSimpleAuth) { + metrics.authenticationFallback(); + authenticatedWithFallback = true; + } else { + AccessDeniedException ae = new AccessDeniedException( + "Authentication is required"); + setupResponse(authFailedResponse, authFailedCall, ae, ae.getMessage()); + ((Call) authFailedCall) + .sendResponseIfReady(ChannelFutureListener.CLOSE); + return; + } + } + if (!isSecurityEnabled && authMethod != AuthMethod.SIMPLE) { + doRawSaslReply(SaslStatus.SUCCESS, new IntWritable(SaslUtil.SWITCH_TO_SIMPLE_AUTH), null, + null); + authMethod = AuthMethod.SIMPLE; + // client has already sent the initial Sasl message and we + // should ignore it. Both client and server should fall back + // to simple auth from now on. + skipInitialSaslHandshake = true; + } + if (authMethod != AuthMethod.SIMPLE) { + useSasl = true; + } + connectionPreambleRead = true; + } + + private void doBadPreambleHandling(final String msg) throws IOException { + doBadPreambleHandling(msg, new FatalConnectionException(msg)); + } + + private void doBadPreambleHandling(final String msg, final Exception e) throws IOException { + LOG.warn(msg); + Call fakeCall = new Call(-1, null, null, null, null, null, this, -1, + null, null, 0, null); + setupResponse(null, fakeCall, e, msg); + // closes out the connection. + fakeCall.sendResponseIfReady(ChannelFutureListener.CLOSE); + } + + void process(final ByteBuf buf) throws IOException, InterruptedException { + if (connectionHeaderRead) { + this.callCleanup = new RpcServer.CallCleanup() { + @Override + public void run() { + buf.release(); + } + }; + process(new SingleByteBuff(buf.nioBuffer())); + } else { + byte[] data = new byte[buf.readableBytes()]; + buf.readBytes(data, 0, data.length); + ByteBuffer connectionHeader = ByteBuffer.wrap(data); + buf.release(); + process(connectionHeader); + } + } + + void process(ByteBuffer buf) throws IOException, InterruptedException { + process(new SingleByteBuff(buf)); + } + + void process(ByteBuff buf) throws IOException, InterruptedException { + try { + if (skipInitialSaslHandshake) { + skipInitialSaslHandshake = false; + if (callCleanup != null) { + callCleanup.run(); + } + return; + } + + if (useSasl) { + saslReadAndProcess(buf); + } else { + processOneRpc(buf); + } + } catch (Exception e) { + if (callCleanup != null) { + callCleanup.run(); + } + throw e; + } finally { + this.callCleanup = null; + } + } + + @Override + public synchronized void close() { + disposeSasl(); + channel.close(); + callCleanup = null; + } + + @Override + public boolean isConnectionOpen() { + return channel.isOpen(); + } + + @Override + public RpcServer.Call createCall(int id, final BlockingService service, + final MethodDescriptor md, RequestHeader header, Message param, + CellScanner cellScanner, RpcServer.Connection connection, long size, + TraceInfo tinfo, final InetAddress remoteAddress, int timeout, + CallCleanup reqCleanup) { + return new Call(id, service, md, header, param, cellScanner, connection, + size, tinfo, remoteAddress, timeout, reqCleanup); + } + } + + /** + * Datastructure that holds all necessary to a method invocation and then afterward, carries the + * result. + */ + @InterfaceStability.Evolving + public class Call extends RpcServer.Call { + + Call(int id, final BlockingService service, final MethodDescriptor md, + RequestHeader header, Message param, CellScanner cellScanner, + RpcServer.Connection connection, long size, TraceInfo tinfo, + final InetAddress remoteAddress, int timeout, CallCleanup reqCleanup) { + super(id, service, md, header, param, cellScanner, + connection, size, tinfo, remoteAddress, timeout, reqCleanup); + } + + @Override + public long disconnectSince() { + if (!getConnection().isConnectionOpen()) { + return System.currentTimeMillis() - timestamp; + } else { + return -1L; + } + } + + NettyConnection getConnection() { + return (NettyConnection) this.connection; + } + + /** + * If we have a response, and delay is not set, then respond immediately. Otherwise, do not + * respond to client. This is called by the RPC code in the context of the Handler thread. + */ + @Override + public synchronized void sendResponseIfReady() throws IOException { + getConnection().channel.writeAndFlush(this); + } + + public synchronized void sendResponseIfReady(ChannelFutureListener listener) throws IOException { + getConnection().channel.writeAndFlush(this).addListener(listener); + } + + } + + private class Initializer extends ChannelInitializer { + + final int maxRequestSize; + + Initializer(int maxRequestSize) { + this.maxRequestSize = maxRequestSize; + } + + @Override + protected void initChannel(SocketChannel channel) throws Exception { + ChannelPipeline pipeline = channel.pipeline(); + pipeline.addLast("header", new ConnectionHeaderHandler()); + pipeline.addLast("frameDecoder", new LengthFieldBasedFrameDecoder( + maxRequestSize, 0, 4, 0, 4, true)); + pipeline.addLast("decoder", new MessageDecoder()); + pipeline.addLast("encoder", new MessageEncoder()); + } + + } + + private class ConnectionHeaderHandler extends ByteToMessageDecoder { + private NettyConnection connection; + + ConnectionHeaderHandler() { + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, + List out) throws Exception { + if (byteBuf.readableBytes() < 6) { + return; + } + connection = new NettyConnection(ctx.channel()); + connection.readPreamble(byteBuf); + ((MessageDecoder) ctx.pipeline().get("decoder")) + .setConnection(connection); + ctx.pipeline().remove(this); + } + + } + + private class MessageDecoder extends ChannelInboundHandlerAdapter { + + private NettyConnection connection; + + void setConnection(NettyConnection connection) { + this.connection = connection; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + allChannels.add(ctx.channel()); + if (LOG.isDebugEnabled()) { + LOG.debug("Connection from " + ctx.channel().remoteAddress() + + "; # active connections: " + getNumOpenConnections()); + } + super.channelActive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + throws Exception { + ByteBuf input = (ByteBuf) msg; + // 4 bytes length field + metrics.receivedBytes(input.readableBytes() + 4); + connection.process(input); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + allChannels.remove(ctx.channel()); + if (LOG.isDebugEnabled()) { + LOG.debug("Disconnecting client: " + ctx.channel().remoteAddress() + + ". Number of active connections: " + getNumOpenConnections()); + } + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable e) { + allChannels.remove(ctx.channel()); + if (LOG.isDebugEnabled()) { + LOG.debug("Connection from " + ctx.channel().remoteAddress() + + " catch unexpected exception from downstream.", e.getCause()); + } + ctx.channel().close(); + } + + } + + private class MessageEncoder extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + final Call call = (Call) msg; + ByteBuf response = Unpooled.wrappedBuffer(call.response.getBuffers()); + ctx.write(response, promise).addListener(new CallWriteListener(call)); + } + + } + + private class CallWriteListener implements ChannelFutureListener { + + private Call call; + + CallWriteListener(Call call) { + this.call = call; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + call.done(); + if (future.isSuccess()) { + metrics.sentBytes(call.response.size()); + } + } + + } + + @Override + public void setSocketSendBufSize(int size) { + } + + @Override + public int getNumOpenConnections() { + // allChannels also contains the server channel, so exclude that from the count. + return allChannels.size() - 1; + } + + @Override + public Pair call(BlockingService service, + MethodDescriptor md, Message param, CellScanner cellScanner, + long receiveTime, MonitoredRPCHandler status) throws IOException { + return call(service, md, param, cellScanner, receiveTime, status, + System.currentTimeMillis(), 0); + } + + @Override + public Pair call(BlockingService service, + MethodDescriptor md, Message param, CellScanner cellScanner, + long receiveTime, MonitoredRPCHandler status, long startTime, int timeout) + throws IOException { + Call fakeCall = new Call(-1, service, md, null, param, cellScanner, null, + -1, null, null, timeout, null); + fakeCall.setReceiveTime(receiveTime); + return call(fakeCall, status); + } + +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java index d6a137b..f8d3e09 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/RpcServer.java @@ -20,12 +20,20 @@ package org.apache.hadoop.hbase.ipc; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.DataOutputStream; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -34,6 +42,7 @@ import java.util.Properties; import java.util.concurrent.atomic.LongAdder; import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; import javax.security.sasl.SaslServer; import org.apache.commons.crypto.cipher.CryptoCipherFactory; @@ -50,11 +59,13 @@ import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.Server; import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.hbase.classification.InterfaceStability; +import org.apache.hadoop.hbase.client.VersionInfoUtil; import org.apache.hadoop.hbase.codec.Codec; import org.apache.hadoop.hbase.conf.ConfigurationObserver; import org.apache.hadoop.hbase.exceptions.RegionMovedException; import org.apache.hadoop.hbase.exceptions.RequestTooBigException; import org.apache.hadoop.hbase.io.ByteBufferListOutputStream; +import org.apache.hadoop.hbase.io.ByteBufferOutputStream; import org.apache.hadoop.hbase.io.ByteBufferPool; import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; @@ -66,6 +77,9 @@ import org.apache.hadoop.hbase.regionserver.RSRpcServices; import org.apache.hadoop.hbase.security.AccessDeniedException; import org.apache.hadoop.hbase.security.AuthMethod; import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslDigestCallbackHandler; +import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslGssCallbackHandler; +import org.apache.hadoop.hbase.security.SaslStatus; import org.apache.hadoop.hbase.security.SaslUtil; import org.apache.hadoop.hbase.security.User; import org.apache.hadoop.hbase.security.UserProvider; @@ -73,11 +87,13 @@ import org.apache.hadoop.hbase.security.token.AuthenticationTokenSecretManager; import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.com.google.protobuf.ByteInput; import org.apache.hadoop.hbase.shaded.com.google.protobuf.ByteString; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.CodedInputStream; import org.apache.hadoop.hbase.shaded.com.google.protobuf.CodedOutputStream; import org.apache.hadoop.hbase.shaded.com.google.protobuf.Descriptors.MethodDescriptor; import org.apache.hadoop.hbase.shaded.com.google.protobuf.Message; import org.apache.hadoop.hbase.shaded.com.google.protobuf.ServiceException; import org.apache.hadoop.hbase.shaded.com.google.protobuf.TextFormat; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.UnsafeByteOperations; import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil; import org.apache.hadoop.hbase.shaded.protobuf.generated.ClientProtos; import org.apache.hadoop.hbase.shaded.protobuf.generated.HBaseProtos.VersionInfo; @@ -91,12 +107,18 @@ import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.UserInformati import org.apache.hadoop.hbase.util.ByteBufferUtils; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Pair; +import org.apache.hadoop.io.BytesWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; import org.apache.hadoop.security.authorize.AuthorizationException; import org.apache.hadoop.security.authorize.PolicyProvider; +import org.apache.hadoop.security.authorize.ProxyUsers; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; import org.apache.hadoop.security.token.SecretManager; +import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.security.token.TokenIdentifier; import org.apache.hadoop.util.StringUtils; import org.apache.htrace.TraceInfo; @@ -459,7 +481,7 @@ public abstract class RpcServer implements RpcServerInterface, } } - private void setExceptionResponse(Throwable t, String errorMsg, + protected void setExceptionResponse(Throwable t, String errorMsg, ResponseHeader.Builder headerBuilder) { ExceptionResponse.Builder exceptionBuilder = ExceptionResponse.newBuilder(); exceptionBuilder.setExceptionClassName(t.getClass().getName()); @@ -477,7 +499,7 @@ public abstract class RpcServer implements RpcServerInterface, headerBuilder.setException(exceptionBuilder.build()); } - private ByteBuffer createHeaderAndMessageBytes(Message result, Message header, + protected ByteBuffer createHeaderAndMessageBytes(Message result, Message header, int cellBlockSize, List cellBlock) throws IOException { // Organize the response as a set of bytebuffers rather than collect it all together inside // one big byte array; save on allocations. @@ -550,7 +572,7 @@ public abstract class RpcServer implements RpcServerInterface, return pbBuf; } - private BufferChain wrapWithSasl(BufferChain bc) + protected BufferChain wrapWithSasl(BufferChain bc) throws IOException { if (!this.connection.useSasl) return bc; // Looks like no way around this; saslserver wants a byte array. I have to make it one. @@ -712,7 +734,7 @@ public abstract class RpcServer implements RpcServerInterface, @edu.umd.cs.findbugs.annotations.SuppressWarnings( value="VO_VOLATILE_INCREMENT", justification="False positive according to http://sourceforge.net/p/findbugs/bugs/1032/") - public abstract class Connection { + public abstract class Connection implements Closeable { // If initial preamble with version and magic has been read or not. protected boolean connectionPreambleRead = false; // If the connection header has been read or not. @@ -740,7 +762,9 @@ public abstract class RpcServer implements RpcServerInterface, protected AuthMethod authMethod; protected boolean saslContextEstablished; protected boolean skipInitialSaslHandshake; - + private ByteBuffer unwrappedData; + // When is this set? FindBugs wants to know! Says NP + private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); protected boolean useSasl; protected SaslServer saslServer; protected CryptoAES cryptoAES; @@ -748,14 +772,15 @@ public abstract class RpcServer implements RpcServerInterface, protected boolean useCryptoAesWrap = false; // Fake 'call' for failed authorization response protected static final int AUTHORIZATION_FAILED_CALLID = -1; - + protected Call authFailedCall; protected ByteArrayOutputStream authFailedResponse = new ByteArrayOutputStream(); // Fake 'call' for SASL context setup protected static final int SASL_CALLID = -33; - + protected Call saslCall; // Fake 'call' for connection header response protected static final int CONNECTION_HEADER_RESPONSE_CALLID = -34; + protected Call setConnectionHeaderResponseCall; // was authentication allowed with a fallback to simple auth protected boolean authenticatedWithFallback; @@ -946,8 +971,496 @@ public abstract class RpcServer implements RpcServerInterface, return ugi; } + protected void disposeSasl() { + if (saslServer != null) { + try { + saslServer.dispose(); + saslServer = null; + } catch (SaslException ignored) { + // Ignored. This is being disposed of anyway. + } + } + } + + /** + * No protobuf encoding of raw sasl messages + */ + protected void doRawSaslReply(SaslStatus status, Writable rv, + String errorClass, String error) throws IOException { + ByteBufferOutputStream saslResponse = null; + DataOutputStream out = null; + try { + // In my testing, have noticed that sasl messages are usually + // in the ballpark of 100-200. That's why the initial capacity is 256. + saslResponse = new ByteBufferOutputStream(256); + out = new DataOutputStream(saslResponse); + out.writeInt(status.state); // write status + if (status == SaslStatus.SUCCESS) { + rv.write(out); + } else { + WritableUtils.writeString(out, errorClass); + WritableUtils.writeString(out, error); + } + saslCall.setSaslTokenResponse(saslResponse.getByteBuffer()); + saslCall.sendResponseIfReady(); + } finally { + if (saslResponse != null) { + saslResponse.close(); + } + if (out != null) { + out.close(); + } + } + } + + public void saslReadAndProcess(ByteBuff saslToken) throws IOException, + InterruptedException { + if (saslContextEstablished) { + if (LOG.isTraceEnabled()) + LOG.trace("Have read input token of size " + saslToken.limit() + + " for processing by saslServer.unwrap()"); + + if (!useWrap) { + processOneRpc(saslToken); + } else { + byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes(); + byte [] plaintextData; + if (useCryptoAesWrap) { + // unwrap with CryptoAES + plaintextData = cryptoAES.unwrap(b, 0, b.length); + } else { + plaintextData = saslServer.unwrap(b, 0, b.length); + } + processUnwrappedData(plaintextData); + } + } else { + byte[] replyToken; + try { + if (saslServer == null) { + switch (authMethod) { + case DIGEST: + if (secretManager == null) { + throw new AccessDeniedException( + "Server is not configured to do DIGEST authentication."); + } + saslServer = Sasl.createSaslServer(AuthMethod.DIGEST + .getMechanismName(), null, SaslUtil.SASL_DEFAULT_REALM, + HBaseSaslRpcServer.getSaslProps(), new SaslDigestCallbackHandler( + secretManager, this)); + break; + default: + UserGroupInformation current = UserGroupInformation.getCurrentUser(); + String fullName = current.getUserName(); + if (LOG.isDebugEnabled()) { + LOG.debug("Kerberos principal name is " + fullName); + } + final String names[] = SaslUtil.splitKerberosName(fullName); + if (names.length != 3) { + throw new AccessDeniedException( + "Kerberos principal name does NOT have the expected " + + "hostname part: " + fullName); + } + current.doAs(new PrivilegedExceptionAction() { + @Override + public Object run() throws SaslException { + saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS + .getMechanismName(), names[0], names[1], + HBaseSaslRpcServer.getSaslProps(), new SaslGssCallbackHandler()); + return null; + } + }); + } + if (saslServer == null) + throw new AccessDeniedException( + "Unable to find SASL server implementation for " + + authMethod.getMechanismName()); + if (LOG.isDebugEnabled()) { + LOG.debug("Created SASL server with mechanism = " + authMethod.getMechanismName()); + } + } + if (LOG.isDebugEnabled()) { + LOG.debug("Have read input token of size " + saslToken.limit() + + " for processing by saslServer.evaluateResponse()"); + } + replyToken = saslServer + .evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes()); + } catch (IOException e) { + IOException sendToClient = e; + Throwable cause = e; + while (cause != null) { + if (cause instanceof InvalidToken) { + sendToClient = (InvalidToken) cause; + break; + } + cause = cause.getCause(); + } + doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), + sendToClient.getLocalizedMessage()); + metrics.authenticationFailure(); + String clientIP = this.toString(); + // attempting user could be null + AUDITLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser); + throw e; + } + if (replyToken != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("Will send token of size " + replyToken.length + + " from saslServer."); + } + doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, + null); + } + if (saslServer.isComplete()) { + String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); + useWrap = qop != null && !"auth".equalsIgnoreCase(qop); + ugi = getAuthorizedUgi(saslServer.getAuthorizationID()); + if (LOG.isDebugEnabled()) { + LOG.debug("SASL server context established. Authenticated client: " + + ugi + ". Negotiated QoP is " + + saslServer.getNegotiatedProperty(Sasl.QOP)); + } + metrics.authenticationSuccess(); + AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); + saslContextEstablished = true; + } + } + } + + private void processUnwrappedData(byte[] inBuf) throws IOException, + InterruptedException { + ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); + // Read all RPCs contained in the inBuf, even partial ones + while (true) { + int count; + if (unwrappedDataLengthBuffer.remaining() > 0) { + count = channelRead(ch, unwrappedDataLengthBuffer); + if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) + return; + } + + if (unwrappedData == null) { + unwrappedDataLengthBuffer.flip(); + int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); + + if (unwrappedDataLength == RpcClient.PING_CALL_ID) { + if (LOG.isDebugEnabled()) + LOG.debug("Received ping message"); + unwrappedDataLengthBuffer.clear(); + continue; // ping message + } + unwrappedData = ByteBuffer.allocate(unwrappedDataLength); + } + + count = channelRead(ch, unwrappedData); + if (count <= 0 || unwrappedData.remaining() > 0) + return; + + if (unwrappedData.remaining() == 0) { + unwrappedDataLengthBuffer.clear(); + unwrappedData.flip(); + processOneRpc(new SingleByteBuff(unwrappedData)); + unwrappedData = null; + } + } + } + + public void processOneRpc(ByteBuff buf) throws IOException, + InterruptedException { + if (connectionHeaderRead) { + processRequest(buf); + } else { + processConnectionHeader(buf); + this.connectionHeaderRead = true; + if (!authorizeConnection()) { + // Throw FatalConnectionException wrapping ACE so client does right thing and closes + // down the connection instead of trying to read non-existent retun. + throw new AccessDeniedException("Connection from " + this + " for service " + + connectionHeader.getServiceName() + " is unauthorized for user: " + ugi); + } + this.user = userProvider.create(this.ugi); + } + } + + protected boolean authorizeConnection() throws IOException { + try { + // If auth method is DIGEST, the token was obtained by the + // real user for the effective user, therefore not required to + // authorize real user. doAs is allowed only for simple or kerberos + // authentication + if (ugi != null && ugi.getRealUser() != null + && (authMethod != AuthMethod.DIGEST)) { + ProxyUsers.authorize(ugi, this.getHostAddress(), conf); + } + authorize(ugi, connectionHeader, getHostInetAddress()); + metrics.authorizationSuccess(); + } catch (AuthorizationException ae) { + if (LOG.isDebugEnabled()) { + LOG.debug("Connection authorization failed: " + ae.getMessage(), ae); + } + metrics.authorizationFailure(); + setupResponse(authFailedResponse, authFailedCall, + new AccessDeniedException(ae), ae.getMessage()); + authFailedCall.sendResponseIfReady(); + return false; + } + return true; + } + + // Reads the connection header following version + protected void processConnectionHeader(ByteBuff buf) throws IOException { + if (buf.hasArray()) { + this.connectionHeader = ConnectionHeader.parseFrom(buf.array()); + } else { + CodedInputStream cis = UnsafeByteOperations + .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); + cis.enableAliasing(true); + this.connectionHeader = ConnectionHeader.parseFrom(cis); + } + String serviceName = connectionHeader.getServiceName(); + if (serviceName == null) throw new EmptyServiceNameException(); + this.service = getService(services, serviceName); + if (this.service == null) throw new UnknownServiceException(serviceName); + setupCellBlockCodecs(this.connectionHeader); + RPCProtos.ConnectionHeaderResponse.Builder chrBuilder = + RPCProtos.ConnectionHeaderResponse.newBuilder(); + setupCryptoCipher(this.connectionHeader, chrBuilder); + responseConnectionHeader(chrBuilder); + UserGroupInformation protocolUser = createUser(connectionHeader); + if (!useSasl) { + ugi = protocolUser; + if (ugi != null) { + ugi.setAuthenticationMethod(AuthMethod.SIMPLE.authenticationMethod); + } + // audit logging for SASL authenticated users happens in saslReadAndProcess() + if (authenticatedWithFallback) { + LOG.warn("Allowed fallback to SIMPLE auth for " + ugi + + " connecting from " + getHostAddress()); + } + AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); + } else { + // user is authenticated + ugi.setAuthenticationMethod(authMethod.authenticationMethod); + //Now we check if this is a proxy user case. If the protocol user is + //different from the 'user', it is a proxy user scenario. However, + //this is not allowed if user authenticated with DIGEST. + if ((protocolUser != null) + && (!protocolUser.getUserName().equals(ugi.getUserName()))) { + if (authMethod == AuthMethod.DIGEST) { + // Not allowed to doAs if token authentication is used + throw new AccessDeniedException("Authenticated user (" + ugi + + ") doesn't match what the client claims to be (" + + protocolUser + ")"); + } else { + // Effective user can be different from authenticated user + // for simple auth or kerberos auth + // The user is the real user. Now we create a proxy user + UserGroupInformation realUser = ugi; + ugi = UserGroupInformation.createProxyUser(protocolUser + .getUserName(), realUser); + // Now the user is a proxy user, set Authentication method Proxy. + ugi.setAuthenticationMethod(AuthenticationMethod.PROXY); + } + } + } + if (connectionHeader.hasVersionInfo()) { + // see if this connection will support RetryImmediatelyException + retryImmediatelySupported = VersionInfoUtil.hasMinimumVersion(getVersionInfo(), 1, 2); + + AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort + + " with version info: " + + TextFormat.shortDebugString(connectionHeader.getVersionInfo())); + } else { + AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort + + " with unknown version info"); + } + } + + private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) + throws FatalConnectionException { + // Response the connection header if Crypto AES is enabled + if (!chrBuilder.hasCryptoCipherMeta()) return; + try { + byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray(); + // encrypt the Crypto AES cipher meta data with sasl server, and send to client + byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4]; + Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4); + Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length); + + doConnectionHeaderResponse(saslServer.wrap(unwrapped, 0, unwrapped.length)); + } catch (IOException ex) { + throw new UnsupportedCryptoException(ex.getMessage(), ex); + } + } + + /** + * Send the response for connection header + */ + private void doConnectionHeaderResponse(byte[] wrappedCipherMetaData) + throws IOException { + ByteBufferOutputStream response = null; + DataOutputStream out = null; + try { + response = new ByteBufferOutputStream(wrappedCipherMetaData.length + 4); + out = new DataOutputStream(response); + out.writeInt(wrappedCipherMetaData.length); + out.write(wrappedCipherMetaData); + + setConnectionHeaderResponseCall.setConnectionHeaderResponse(response + .getByteBuffer()); + setConnectionHeaderResponseCall.sendResponseIfReady(); + } finally { + if (out != null) { + out.close(); + } + if (response != null) { + response.close(); + } + } + } + + /** + * @param buf + * Has the request header and the request param and optionally + * encoded data buffer all in this one array. + * @throws IOException + * @throws InterruptedException + */ + protected void processRequest(ByteBuff buf) throws IOException, + InterruptedException { + long totalRequestSize = buf.limit(); + int offset = 0; + // Here we read in the header. We avoid having pb + // do its default 4k allocation for CodedInputStream. We force it to use + // backing array. + CodedInputStream cis; + if (buf.hasArray()) { + cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()) + .newCodedInput(); + } else { + cis = UnsafeByteOperations.unsafeWrap( + new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()) + .newCodedInput(); + } + cis.enableAliasing(true); + int headerSize = cis.readRawVarint32(); + offset = cis.getTotalBytesRead(); + Message.Builder builder = RequestHeader.newBuilder(); + ProtobufUtil.mergeFrom(builder, cis, headerSize); + RequestHeader header = (RequestHeader) builder.build(); + offset += headerSize; + int id = header.getCallId(); + if (LOG.isTraceEnabled()) { + LOG.trace("RequestHeader " + TextFormat.shortDebugString(header) + + " totalRequestSize: " + totalRequestSize + " bytes"); + } + // Enforcing the call queue size, this triggers a retry in the client + // This is a bit late to be doing this check - we have already read in the + // total request. + if ((totalRequestSize + callQueueSizeInBytes.sum()) > maxQueueSizeInBytes) { + final RpcServer.Call callTooBig = createCall(id, this.service, null, + null, null, null, this, totalRequestSize, null, null, 0, + this.callCleanup); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); + setupResponse(responseBuffer, callTooBig, CALL_QUEUE_TOO_BIG_EXCEPTION, + "Call queue is full on " + server.getServerName() + + ", is hbase.ipc.server.max.callqueue.size too small?"); + callTooBig.sendResponseIfReady(); + return; + } + MethodDescriptor md = null; + Message param = null; + CellScanner cellScanner = null; + try { + if (header.hasRequestParam() && header.getRequestParam()) { + md = this.service.getDescriptorForType().findMethodByName( + header.getMethodName()); + if (md == null) + throw new UnsupportedOperationException(header.getMethodName()); + builder = this.service.getRequestPrototype(md).newBuilderForType(); + cis.resetSizeCounter(); + int paramSize = cis.readRawVarint32(); + offset += cis.getTotalBytesRead(); + if (builder != null) { + ProtobufUtil.mergeFrom(builder, cis, paramSize); + param = builder.build(); + } + offset += paramSize; + } else { + // currently header must have request param, so we directly throw + // exception here + String msg = "Invalid request header: " + + TextFormat.shortDebugString(header) + + ", should have param set in it"; + LOG.warn(msg); + throw new DoNotRetryIOException(msg); + } + if (header.hasCellBlockMeta()) { + buf.position(offset); + ByteBuff dup = buf.duplicate(); + dup.limit(offset + header.getCellBlockMeta().getLength()); + cellScanner = cellBlockBuilder.createCellScannerReusingBuffers( + this.codec, this.compressionCodec, dup); + } + } catch (Throwable t) { + InetSocketAddress address = getListenerAddress(); + String msg = (address != null ? address : "(channel closed)") + + " is unable to read call parameter from client " + + getHostAddress(); + LOG.warn(msg, t); + + metrics.exception(t); + + // probably the hbase hadoop version does not match the running hadoop + // version + if (t instanceof LinkageError) { + t = new DoNotRetryIOException(t); + } + // If the method is not present on the server, do not retry. + if (t instanceof UnsupportedOperationException) { + t = new DoNotRetryIOException(t); + } + + final RpcServer.Call readParamsFailedCall = createCall(id, + this.service, null, null, null, null, this, totalRequestSize, null, + null, 0, this.callCleanup); + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + setupResponse(responseBuffer, readParamsFailedCall, t, + msg + "; " + t.getMessage()); + readParamsFailedCall.sendResponseIfReady(); + return; + } + + TraceInfo traceInfo = header.hasTraceInfo() ? new TraceInfo(header + .getTraceInfo().getTraceId(), header.getTraceInfo().getParentId()) + : null; + int timeout = 0; + if (header.hasTimeout() && header.getTimeout() > 0) { + timeout = Math.max(minClientRequestTimeout, header.getTimeout()); + } + RpcServer.Call call = createCall(id, this.service, md, header, param, + cellScanner, this, totalRequestSize, traceInfo, this.addr, timeout, + this.callCleanup); + + if (!scheduler.dispatch(new CallRunner(RpcServer.this, call))) { + callQueueSizeInBytes.add(-1 * call.getSize()); + + ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); + metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); + setupResponse(responseBuffer, call, CALL_QUEUE_TOO_BIG_EXCEPTION, + "Call queue is full on " + server.getServerName() + + ", too many items queued ?"); + call.sendResponseIfReady(); + } + } + public abstract boolean isConnectionOpen(); + public abstract Call createCall(int id, final BlockingService service, + final MethodDescriptor md, RequestHeader header, Message param, + CellScanner cellScanner, Connection connection, long size, + TraceInfo tinfo, final InetAddress remoteAddress, int timeout, + CallCleanup reqCleanup); } /** @@ -1070,6 +1583,20 @@ public abstract class RpcServer implements RpcServerInterface, } } + /** + * Setup response for the RPC Call. + * + * @param response buffer to serialize the response into + * @param call {@link Call} to which we are setting up the response + * @param error error message, if the call failed + * @throws IOException + */ + protected void setupResponse(ByteArrayOutputStream response, Call call, Throwable t, String error) + throws IOException { + if (response != null) response.reset(); + call.setResponse(null, null, t, error); + } + Configuration getConf() { return conf; } @@ -1276,6 +1803,77 @@ public abstract class RpcServer implements RpcServerInterface, } /** + * When the read or write buffer size is larger than this limit, i/o will be + * done in chunks of this size. Most RPC requests and responses would be + * be smaller. + */ + protected static final int NIO_BUFFER_LIMIT = 64 * 1024; //should not be more than 64KB. + + /** + * This is a wrapper around {@link java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer)}. + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of + * ByteBuffer increases. There should not be any performance degredation. + * + * @param channel writable byte channel to write on + * @param buffer buffer to write + * @return number of bytes written + * @throws java.io.IOException e + * @see java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer) + */ + protected int channelRead(ReadableByteChannel channel, + ByteBuffer buffer) throws IOException { + + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? + channel.read(buffer) : channelIO(channel, null, buffer); + if (count > 0) { + metrics.receivedBytes(count); + } + return count; + } + + /** + * Helper for {@link #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer)} + * and {@link #channelWrite(GatheringByteChannel, BufferChain)}. Only + * one of readCh or writeCh should be non-null. + * + * @param readCh read channel + * @param writeCh write channel + * @param buf buffer to read or write into/out of + * @return bytes written + * @throws java.io.IOException e + * @see #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer) + * @see #channelWrite(GatheringByteChannel, BufferChain) + */ + private static int channelIO(ReadableByteChannel readCh, + WritableByteChannel writeCh, + ByteBuffer buf) throws IOException { + + int originalLimit = buf.limit(); + int initialRemaining = buf.remaining(); + int ret = 0; + + while (buf.remaining() > 0) { + try { + int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); + buf.limit(buf.position() + ioSize); + + ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); + + if (ret < ioSize) { + break; + } + + } finally { + buf.limit(originalLimit); + } + } + + int nBytes = initialRemaining - buf.remaining(); + return (nBytes > 0) ? nBytes : ret; + } + + /** * This is extracted to a static method for better unit testing. We try to get buffer(s) from pool * as much as possible. * @@ -1485,4 +2083,4 @@ public abstract class RpcServer implements RpcServerInterface, return this.length; } } -} \ No newline at end of file +} diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java index 01d45cd..35a98c0 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleRpcServer.java @@ -18,9 +18,6 @@ package org.apache.hadoop.hbase.ipc; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.net.BindException; @@ -32,7 +29,6 @@ import java.net.SocketException; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.channels.CancelledKeyException; -import java.nio.channels.Channels; import java.nio.channels.ClosedChannelException; import java.nio.channels.GatheringByteChannel; import java.nio.channels.ReadableByteChannel; @@ -40,8 +36,6 @@ import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; -import java.nio.channels.WritableByteChannel; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -60,9 +54,6 @@ import java.util.concurrent.atomic.LongAdder; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslException; - import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.CellScanner; import org.apache.hadoop.hbase.DoNotRetryIOException; @@ -73,42 +64,26 @@ import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.hbase.classification.InterfaceStability; import org.apache.hadoop.hbase.client.VersionInfoUtil; import org.apache.hadoop.hbase.exceptions.RequestTooBigException; -import org.apache.hadoop.hbase.io.ByteBufferOutputStream; import org.apache.hadoop.hbase.monitoring.MonitoredRPCHandler; import org.apache.hadoop.hbase.nio.ByteBuff; import org.apache.hadoop.hbase.nio.SingleByteBuff; import org.apache.hadoop.hbase.security.AccessDeniedException; import org.apache.hadoop.hbase.security.AuthMethod; import org.apache.hadoop.hbase.security.HBasePolicyProvider; -import org.apache.hadoop.hbase.security.HBaseSaslRpcServer; -import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslDigestCallbackHandler; -import org.apache.hadoop.hbase.security.HBaseSaslRpcServer.SaslGssCallbackHandler; import org.apache.hadoop.hbase.security.SaslStatus; import org.apache.hadoop.hbase.security.SaslUtil; import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.com.google.protobuf.CodedInputStream; import org.apache.hadoop.hbase.shaded.com.google.protobuf.Descriptors.MethodDescriptor; import org.apache.hadoop.hbase.shaded.com.google.protobuf.Message; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.TextFormat; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.UnsafeByteOperations; import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil; -import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; -import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ConnectionHeader; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.util.Bytes; import org.apache.hadoop.hbase.util.Pair; import org.apache.hadoop.hbase.util.Threads; -import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.IOUtils; import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Writable; -import org.apache.hadoop.io.WritableUtils; -import org.apache.hadoop.security.UserGroupInformation; -import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod; -import org.apache.hadoop.security.authorize.AuthorizationException; -import org.apache.hadoop.security.authorize.ProxyUsers; import org.apache.hadoop.security.authorize.ServiceAuthorizationManager; -import org.apache.hadoop.security.token.SecretManager.InvalidToken; import org.apache.hadoop.util.StringUtils; import org.apache.htrace.TraceInfo; @@ -165,8 +140,9 @@ public class SimpleRpcServer extends RpcServer { justification="Can't figure why this complaint is happening... see below") Call(int id, final BlockingService service, final MethodDescriptor md, RequestHeader header, Message param, CellScanner cellScanner, - Connection connection, Responder responder, long size, TraceInfo tinfo, - final InetAddress remoteAddress, int timeout, CallCleanup reqCleanup) { + RpcServer.Connection connection, long size, TraceInfo tinfo, + final InetAddress remoteAddress, int timeout, CallCleanup reqCleanup, + Responder responder) { super(id, service, md, header, param, cellScanner, connection, size, tinfo, remoteAddress, timeout, reqCleanup); this.responder = responder; @@ -178,6 +154,7 @@ public class SimpleRpcServer extends RpcServer { */ @edu.umd.cs.findbugs.annotations.SuppressWarnings(value="IS2_INCONSISTENT_SYNC", justification="Presume the lock on processing request held by caller is protection enough") + @Override void done() { super.done(); this.getConnection().decRpcCount(); // Say that we're done with this call. @@ -192,6 +169,7 @@ public class SimpleRpcServer extends RpcServer { } } + @Override public synchronized void sendResponseIfReady() throws IOException { // set param null to reduce memory pressure this.param = null; @@ -769,19 +747,6 @@ public class SimpleRpcServer extends RpcServer { private long lastContact; protected Socket socket; - private ByteBuffer unwrappedData; - // When is this set? FindBugs wants to know! Says NP - private ByteBuffer unwrappedDataLengthBuffer = ByteBuffer.allocate(4); - - private final Call authFailedCall = new Call(AUTHORIZATION_FAILED_CALLID, null, null, null, - null, null, this, null, 0, null, null, 0, null); - - private final Call saslCall = new Call(SASL_CALLID, null, null, null, null, null, this, null, - 0, null, null, 0, null); - - private final Call setConnectionHeaderResponseCall = new Call(CONNECTION_HEADER_RESPONSE_CALLID, - null, null, null, null, null, this, null, 0, null, null, 0, null); - public Connection(SocketChannel channel, long lastContact) { super(); this.channel = channel; @@ -804,6 +769,13 @@ public class SimpleRpcServer extends RpcServer { socketSendBufferSize); } } + this.saslCall = new Call(SASL_CALLID, null, null, null, null, null, this, + 0, null, null, 0, null, responder); + this.setConnectionHeaderResponseCall = new Call( + CONNECTION_HEADER_RESPONSE_CALLID, null, null, null, null, null, + this, 0, null, null, 0, null, responder); + this.authFailedCall = new Call(AUTHORIZATION_FAILED_CALLID, null, null, + null, null, null, this, 0, null, null, 0, null, responder); } public void setLastContact(long lastContact) { @@ -829,187 +801,6 @@ public class SimpleRpcServer extends RpcServer { rpcCount.increment(); } - private void saslReadAndProcess(ByteBuff saslToken) throws IOException, - InterruptedException { - if (saslContextEstablished) { - if (LOG.isTraceEnabled()) - LOG.trace("Have read input token of size " + saslToken.limit() - + " for processing by saslServer.unwrap()"); - - if (!useWrap) { - processOneRpc(saslToken); - } else { - byte[] b = saslToken.hasArray() ? saslToken.array() : saslToken.toBytes(); - byte [] plaintextData; - if (useCryptoAesWrap) { - // unwrap with CryptoAES - plaintextData = cryptoAES.unwrap(b, 0, b.length); - } else { - plaintextData = saslServer.unwrap(b, 0, b.length); - } - processUnwrappedData(plaintextData); - } - } else { - byte[] replyToken; - try { - if (saslServer == null) { - switch (authMethod) { - case DIGEST: - if (secretManager == null) { - throw new AccessDeniedException( - "Server is not configured to do DIGEST authentication."); - } - saslServer = Sasl.createSaslServer(AuthMethod.DIGEST - .getMechanismName(), null, SaslUtil.SASL_DEFAULT_REALM, - HBaseSaslRpcServer.getSaslProps(), new SaslDigestCallbackHandler( - secretManager, this)); - break; - default: - UserGroupInformation current = UserGroupInformation.getCurrentUser(); - String fullName = current.getUserName(); - if (LOG.isDebugEnabled()) { - LOG.debug("Kerberos principal name is " + fullName); - } - final String names[] = SaslUtil.splitKerberosName(fullName); - if (names.length != 3) { - throw new AccessDeniedException( - "Kerberos principal name does NOT have the expected " - + "hostname part: " + fullName); - } - current.doAs(new PrivilegedExceptionAction() { - @Override - public Object run() throws SaslException { - saslServer = Sasl.createSaslServer(AuthMethod.KERBEROS - .getMechanismName(), names[0], names[1], - HBaseSaslRpcServer.getSaslProps(), new SaslGssCallbackHandler()); - return null; - } - }); - } - if (saslServer == null) - throw new AccessDeniedException( - "Unable to find SASL server implementation for " - + authMethod.getMechanismName()); - if (LOG.isDebugEnabled()) { - LOG.debug("Created SASL server with mechanism = " + authMethod.getMechanismName()); - } - } - if (LOG.isDebugEnabled()) { - LOG.debug("Have read input token of size " + saslToken.limit() - + " for processing by saslServer.evaluateResponse()"); - } - replyToken = saslServer - .evaluateResponse(saslToken.hasArray() ? saslToken.array() : saslToken.toBytes()); - } catch (IOException e) { - IOException sendToClient = e; - Throwable cause = e; - while (cause != null) { - if (cause instanceof InvalidToken) { - sendToClient = (InvalidToken) cause; - break; - } - cause = cause.getCause(); - } - doRawSaslReply(SaslStatus.ERROR, null, sendToClient.getClass().getName(), - sendToClient.getLocalizedMessage()); - metrics.authenticationFailure(); - String clientIP = this.toString(); - // attempting user could be null - AUDITLOG.warn(AUTH_FAILED_FOR + clientIP + ":" + attemptingUser); - throw e; - } - if (replyToken != null) { - if (LOG.isDebugEnabled()) { - LOG.debug("Will send token of size " + replyToken.length - + " from saslServer."); - } - doRawSaslReply(SaslStatus.SUCCESS, new BytesWritable(replyToken), null, - null); - } - if (saslServer.isComplete()) { - String qop = (String) saslServer.getNegotiatedProperty(Sasl.QOP); - useWrap = qop != null && !"auth".equalsIgnoreCase(qop); - ugi = getAuthorizedUgi(saslServer.getAuthorizationID()); - if (LOG.isDebugEnabled()) { - LOG.debug("SASL server context established. Authenticated client: " - + ugi + ". Negotiated QoP is " - + saslServer.getNegotiatedProperty(Sasl.QOP)); - } - metrics.authenticationSuccess(); - AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); - saslContextEstablished = true; - } - } - } - - /** - * No protobuf encoding of raw sasl messages - */ - private void doRawSaslReply(SaslStatus status, Writable rv, - String errorClass, String error) throws IOException { - ByteBufferOutputStream saslResponse = null; - DataOutputStream out = null; - try { - // In my testing, have noticed that sasl messages are usually - // in the ballpark of 100-200. That's why the initial capacity is 256. - saslResponse = new ByteBufferOutputStream(256); - out = new DataOutputStream(saslResponse); - out.writeInt(status.state); // write status - if (status == SaslStatus.SUCCESS) { - rv.write(out); - } else { - WritableUtils.writeString(out, errorClass); - WritableUtils.writeString(out, error); - } - saslCall.setSaslTokenResponse(saslResponse.getByteBuffer()); - saslCall.responder = responder; - saslCall.sendResponseIfReady(); - } finally { - if (saslResponse != null) { - saslResponse.close(); - } - if (out != null) { - out.close(); - } - } - } - - /** - * Send the response for connection header - */ - private void doConnectionHeaderResponse(byte[] wrappedCipherMetaData) throws IOException { - ByteBufferOutputStream response = null; - DataOutputStream out = null; - try { - response = new ByteBufferOutputStream(wrappedCipherMetaData.length + 4); - out = new DataOutputStream(response); - out.writeInt(wrappedCipherMetaData.length); - out.write(wrappedCipherMetaData); - - setConnectionHeaderResponseCall.setConnectionHeaderResponse(response.getByteBuffer()); - setConnectionHeaderResponseCall.responder = responder; - setConnectionHeaderResponseCall.sendResponseIfReady(); - } finally { - if (out != null) { - out.close(); - } - if (response != null) { - response.close(); - } - } - } - - private void disposeSasl() { - if (saslServer != null) { - try { - saslServer.dispose(); - saslServer = null; - } catch (SaslException ignored) { - // Ignored. This is being disposed of anyway. - } - } - } - private int readPreamble() throws IOException { int count; // Check for 'HBas' magic. @@ -1044,7 +835,7 @@ public class SimpleRpcServer extends RpcServer { } else { AccessDeniedException ae = new AccessDeniedException("Authentication is required"); setupResponse(authFailedResponse, authFailedCall, ae, ae.getMessage()); - responder.doRespond(authFailedCall); + authFailedCall.sendResponseIfReady(); throw ae; } } @@ -1150,8 +941,8 @@ public class SimpleRpcServer extends RpcServer { RequestHeader header = (RequestHeader) builder.build(); // Notify the client about the offending request - Call reqTooBig = new Call(header.getCallId(), this.service, null, null, null, - null, this, responder, 0, null, this.addr, 0, null); + Call reqTooBig = new Call(header.getCallId(), this.service, null, + null, null, null, this, 0, null, this.addr, 0, null, responder); metrics.exception(REQUEST_TOO_BIG_EXCEPTION); // Make sure the client recognizes the underlying exception // Otherwise, throw a DoNotRetryIOException. @@ -1252,303 +1043,16 @@ public class SimpleRpcServer extends RpcServer { private int doBadPreambleHandling(final String msg, final Exception e) throws IOException { LOG.warn(msg); - Call fakeCall = new Call(-1, null, null, null, null, null, this, responder, -1, null, null, 0, - null); + Call fakeCall = new Call(-1, null, null, null, null, null, this, -1, + null, null, 0, null, responder); setupResponse(null, fakeCall, e, msg); responder.doRespond(fakeCall); // Returning -1 closes out the connection. return -1; } - // Reads the connection header following version - private void processConnectionHeader(ByteBuff buf) throws IOException { - if (buf.hasArray()) { - this.connectionHeader = ConnectionHeader.parseFrom(buf.array()); - } else { - CodedInputStream cis = UnsafeByteOperations - .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); - cis.enableAliasing(true); - this.connectionHeader = ConnectionHeader.parseFrom(cis); - } - String serviceName = connectionHeader.getServiceName(); - if (serviceName == null) throw new EmptyServiceNameException(); - this.service = getService(services, serviceName); - if (this.service == null) throw new UnknownServiceException(serviceName); - setupCellBlockCodecs(this.connectionHeader); - RPCProtos.ConnectionHeaderResponse.Builder chrBuilder = - RPCProtos.ConnectionHeaderResponse.newBuilder(); - setupCryptoCipher(this.connectionHeader, chrBuilder); - responseConnectionHeader(chrBuilder); - UserGroupInformation protocolUser = createUser(connectionHeader); - if (!useSasl) { - ugi = protocolUser; - if (ugi != null) { - ugi.setAuthenticationMethod(AuthMethod.SIMPLE.authenticationMethod); - } - // audit logging for SASL authenticated users happens in saslReadAndProcess() - if (authenticatedWithFallback) { - LOG.warn("Allowed fallback to SIMPLE auth for " + ugi - + " connecting from " + getHostAddress()); - } - AUDITLOG.info(AUTH_SUCCESSFUL_FOR + ugi); - } else { - // user is authenticated - ugi.setAuthenticationMethod(authMethod.authenticationMethod); - //Now we check if this is a proxy user case. If the protocol user is - //different from the 'user', it is a proxy user scenario. However, - //this is not allowed if user authenticated with DIGEST. - if ((protocolUser != null) - && (!protocolUser.getUserName().equals(ugi.getUserName()))) { - if (authMethod == AuthMethod.DIGEST) { - // Not allowed to doAs if token authentication is used - throw new AccessDeniedException("Authenticated user (" + ugi - + ") doesn't match what the client claims to be (" - + protocolUser + ")"); - } else { - // Effective user can be different from authenticated user - // for simple auth or kerberos auth - // The user is the real user. Now we create a proxy user - UserGroupInformation realUser = ugi; - ugi = UserGroupInformation.createProxyUser(protocolUser - .getUserName(), realUser); - // Now the user is a proxy user, set Authentication method Proxy. - ugi.setAuthenticationMethod(AuthenticationMethod.PROXY); - } - } - } - if (connectionHeader.hasVersionInfo()) { - // see if this connection will support RetryImmediatelyException - retryImmediatelySupported = VersionInfoUtil.hasMinimumVersion(getVersionInfo(), 1, 2); - - AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort - + " with version info: " - + TextFormat.shortDebugString(connectionHeader.getVersionInfo())); - } else { - AUDITLOG.info("Connection from " + this.hostAddress + " port: " + this.remotePort - + " with unknown version info"); - } - } - - private void responseConnectionHeader(RPCProtos.ConnectionHeaderResponse.Builder chrBuilder) - throws FatalConnectionException { - // Response the connection header if Crypto AES is enabled - if (!chrBuilder.hasCryptoCipherMeta()) return; - try { - byte[] connectionHeaderResBytes = chrBuilder.build().toByteArray(); - // encrypt the Crypto AES cipher meta data with sasl server, and send to client - byte[] unwrapped = new byte[connectionHeaderResBytes.length + 4]; - Bytes.putBytes(unwrapped, 0, Bytes.toBytes(connectionHeaderResBytes.length), 0, 4); - Bytes.putBytes(unwrapped, 4, connectionHeaderResBytes, 0, connectionHeaderResBytes.length); - - doConnectionHeaderResponse(saslServer.wrap(unwrapped, 0, unwrapped.length)); - } catch (IOException ex) { - throw new UnsupportedCryptoException(ex.getMessage(), ex); - } - } - - private void processUnwrappedData(byte[] inBuf) throws IOException, - InterruptedException { - ReadableByteChannel ch = Channels.newChannel(new ByteArrayInputStream(inBuf)); - // Read all RPCs contained in the inBuf, even partial ones - while (true) { - int count; - if (unwrappedDataLengthBuffer.remaining() > 0) { - count = channelRead(ch, unwrappedDataLengthBuffer); - if (count <= 0 || unwrappedDataLengthBuffer.remaining() > 0) - return; - } - - if (unwrappedData == null) { - unwrappedDataLengthBuffer.flip(); - int unwrappedDataLength = unwrappedDataLengthBuffer.getInt(); - - if (unwrappedDataLength == RpcClient.PING_CALL_ID) { - if (LOG.isDebugEnabled()) - LOG.debug("Received ping message"); - unwrappedDataLengthBuffer.clear(); - continue; // ping message - } - unwrappedData = ByteBuffer.allocate(unwrappedDataLength); - } - - count = channelRead(ch, unwrappedData); - if (count <= 0 || unwrappedData.remaining() > 0) - return; - - if (unwrappedData.remaining() == 0) { - unwrappedDataLengthBuffer.clear(); - unwrappedData.flip(); - processOneRpc(new SingleByteBuff(unwrappedData)); - unwrappedData = null; - } - } - } - - private void processOneRpc(ByteBuff buf) throws IOException, InterruptedException { - if (connectionHeaderRead) { - processRequest(buf); - } else { - processConnectionHeader(buf); - this.connectionHeaderRead = true; - if (!authorizeConnection()) { - // Throw FatalConnectionException wrapping ACE so client does right thing and closes - // down the connection instead of trying to read non-existent retun. - throw new AccessDeniedException("Connection from " + this + " for service " + - connectionHeader.getServiceName() + " is unauthorized for user: " + ugi); - } - this.user = userProvider.create(this.ugi); - } - } - - /** - * @param buf Has the request header and the request param and optionally encoded data buffer - * all in this one array. - * @throws IOException - * @throws InterruptedException - */ - protected void processRequest(ByteBuff buf) throws IOException, InterruptedException { - long totalRequestSize = buf.limit(); - int offset = 0; - // Here we read in the header. We avoid having pb - // do its default 4k allocation for CodedInputStream. We force it to use backing array. - CodedInputStream cis; - if (buf.hasArray()) { - cis = UnsafeByteOperations.unsafeWrap(buf.array(), 0, buf.limit()).newCodedInput(); - } else { - cis = UnsafeByteOperations - .unsafeWrap(new ByteBuffByteInput(buf, 0, buf.limit()), 0, buf.limit()).newCodedInput(); - } - cis.enableAliasing(true); - int headerSize = cis.readRawVarint32(); - offset = cis.getTotalBytesRead(); - Message.Builder builder = RequestHeader.newBuilder(); - ProtobufUtil.mergeFrom(builder, cis, headerSize); - RequestHeader header = (RequestHeader) builder.build(); - offset += headerSize; - int id = header.getCallId(); - if (LOG.isTraceEnabled()) { - LOG.trace("RequestHeader " + TextFormat.shortDebugString(header) + - " totalRequestSize: " + totalRequestSize + " bytes"); - } - // Enforcing the call queue size, this triggers a retry in the client - // This is a bit late to be doing this check - we have already read in the total request. - if ((totalRequestSize + callQueueSizeInBytes.sum()) > maxQueueSizeInBytes) { - final Call callTooBig = - new Call(id, this.service, null, null, null, null, this, - responder, totalRequestSize, null, null, 0, this.callCleanup); - ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); - metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); - setupResponse(responseBuffer, callTooBig, CALL_QUEUE_TOO_BIG_EXCEPTION, - "Call queue is full on " + server.getServerName() + - ", is hbase.ipc.server.max.callqueue.size too small?"); - responder.doRespond(callTooBig); - return; - } - MethodDescriptor md = null; - Message param = null; - CellScanner cellScanner = null; - try { - if (header.hasRequestParam() && header.getRequestParam()) { - md = this.service.getDescriptorForType().findMethodByName(header.getMethodName()); - if (md == null) throw new UnsupportedOperationException(header.getMethodName()); - builder = this.service.getRequestPrototype(md).newBuilderForType(); - cis.resetSizeCounter(); - int paramSize = cis.readRawVarint32(); - offset += cis.getTotalBytesRead(); - if (builder != null) { - ProtobufUtil.mergeFrom(builder, cis, paramSize); - param = builder.build(); - } - offset += paramSize; - } else { - // currently header must have request param, so we directly throw exception here - String msg = "Invalid request header: " + TextFormat.shortDebugString(header) - + ", should have param set in it"; - LOG.warn(msg); - throw new DoNotRetryIOException(msg); - } - if (header.hasCellBlockMeta()) { - buf.position(offset); - ByteBuff dup = buf.duplicate(); - dup.limit(offset + header.getCellBlockMeta().getLength()); - cellScanner = cellBlockBuilder.createCellScannerReusingBuffers(this.codec, - this.compressionCodec, dup); - } - } catch (Throwable t) { - InetSocketAddress address = getListenerAddress(); - String msg = (address != null ? address : "(channel closed)") + - " is unable to read call parameter from client " + getHostAddress(); - LOG.warn(msg, t); - - metrics.exception(t); - - // probably the hbase hadoop version does not match the running hadoop version - if (t instanceof LinkageError) { - t = new DoNotRetryIOException(t); - } - // If the method is not present on the server, do not retry. - if (t instanceof UnsupportedOperationException) { - t = new DoNotRetryIOException(t); - } - - final Call readParamsFailedCall = - new Call(id, this.service, null, null, null, null, this, - responder, totalRequestSize, null, null, 0, this.callCleanup); - ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); - setupResponse(responseBuffer, readParamsFailedCall, t, - msg + "; " + t.getMessage()); - responder.doRespond(readParamsFailedCall); - return; - } - - TraceInfo traceInfo = header.hasTraceInfo() - ? new TraceInfo(header.getTraceInfo().getTraceId(), header.getTraceInfo().getParentId()) - : null; - int timeout = 0; - if (header.hasTimeout() && header.getTimeout() > 0){ - timeout = Math.max(minClientRequestTimeout, header.getTimeout()); - } - Call call = new Call(id, this.service, md, header, param, cellScanner, this, responder, - totalRequestSize, traceInfo, this.addr, timeout, this.callCleanup); - - if (!scheduler.dispatch(new CallRunner(SimpleRpcServer.this, call))) { - callQueueSizeInBytes.add(-1 * call.getSize()); - - ByteArrayOutputStream responseBuffer = new ByteArrayOutputStream(); - metrics.exception(CALL_QUEUE_TOO_BIG_EXCEPTION); - setupResponse(responseBuffer, call, CALL_QUEUE_TOO_BIG_EXCEPTION, - "Call queue is full on " + server.getServerName() + - ", too many items queued ?"); - responder.doRespond(call); - } - } - - private boolean authorizeConnection() throws IOException { - try { - // If auth method is DIGEST, the token was obtained by the - // real user for the effective user, therefore not required to - // authorize real user. doAs is allowed only for simple or kerberos - // authentication - if (ugi != null && ugi.getRealUser() != null - && (authMethod != AuthMethod.DIGEST)) { - ProxyUsers.authorize(ugi, this.getHostAddress(), conf); - } - authorize(ugi, connectionHeader, getHostInetAddress()); - metrics.authorizationSuccess(); - } catch (AuthorizationException ae) { - if (LOG.isDebugEnabled()) { - LOG.debug("Connection authorization failed: " + ae.getMessage(), ae); - } - metrics.authorizationFailure(); - setupResponse(authFailedResponse, authFailedCall, - new AccessDeniedException(ae), ae.getMessage()); - responder.doRespond(authFailedCall); - return false; - } - return true; - } - - protected synchronized void close() { + @Override + public synchronized void close() { disposeSasl(); data = null; callCleanup = null; @@ -1575,6 +1079,16 @@ public class SimpleRpcServer extends RpcServer { public boolean isConnectionOpen() { return channel.isOpen(); } + + @Override + public RpcServer.Call createCall(int id, final BlockingService service, + final MethodDescriptor md, RequestHeader header, Message param, + CellScanner cellScanner, RpcServer.Connection connection, long size, + TraceInfo tinfo, final InetAddress remoteAddress, int timeout, + CallCleanup reqCleanup) { + return new Call(id, service, md, header, param, cellScanner, connection, + size, tinfo, remoteAddress, timeout, reqCleanup, responder); + } } @@ -1619,20 +1133,6 @@ public class SimpleRpcServer extends RpcServer { return new Connection(channel, time); } - /** - * Setup response for the RPC Call. - * - * @param response buffer to serialize the response into - * @param call {@link Call} to which we are setting up the response - * @param error error message, if the call failed - * @throws IOException - */ - private void setupResponse(ByteArrayOutputStream response, Call call, Throwable t, String error) - throws IOException { - if (response != null) response.reset(); - call.setResponse(null, null, t, error); - } - protected void closeConnection(Connection connection) { connectionManager.close(connection); } @@ -1702,30 +1202,25 @@ public class SimpleRpcServer extends RpcServer { return listener.getAddress(); } + @Override public Pair call(BlockingService service, MethodDescriptor md, Message param, CellScanner cellScanner, long receiveTime, MonitoredRPCHandler status) throws IOException { return call(service, md, param, cellScanner, receiveTime, status, System.currentTimeMillis(),0); } - public Pair call(BlockingService service, MethodDescriptor md, Message param, - CellScanner cellScanner, long receiveTime, MonitoredRPCHandler status, long startTime, - int timeout) + @Override + public Pair call(BlockingService service, + MethodDescriptor md, Message param, CellScanner cellScanner, + long receiveTime, MonitoredRPCHandler status, long startTime, int timeout) throws IOException { - Call fakeCall = new Call(-1, service, md, null, param, cellScanner, null, null, -1, null, null, timeout, - null); + Call fakeCall = new Call(-1, service, md, null, param, cellScanner, null, + -1, null, null, timeout, null, null); fakeCall.setReceiveTime(receiveTime); return call(fakeCall, status); } /** - * When the read or write buffer size is larger than this limit, i/o will be - * done in chunks of this size. Most RPC requests and responses would be - * be smaller. - */ - private static int NIO_BUFFER_LIMIT = 64 * 1024; //should not be more than 64KB. - - /** * This is a wrapper around {@link java.nio.channels.WritableByteChannel#write(java.nio.ByteBuffer)}. * If the amount of data is large, it writes to channel in smaller chunks. * This is to avoid jdk from creating many direct buffers as the size of @@ -1747,70 +1242,6 @@ public class SimpleRpcServer extends RpcServer { } /** - * This is a wrapper around {@link java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of - * ByteBuffer increases. There should not be any performance degredation. - * - * @param channel writable byte channel to write on - * @param buffer buffer to write - * @return number of bytes written - * @throws java.io.IOException e - * @see java.nio.channels.ReadableByteChannel#read(java.nio.ByteBuffer) - */ - protected int channelRead(ReadableByteChannel channel, - ByteBuffer buffer) throws IOException { - - int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? - channel.read(buffer) : channelIO(channel, null, buffer); - if (count > 0) { - metrics.receivedBytes(count); - } - return count; - } - - /** - * Helper for {@link #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer)} - * and {@link #channelWrite(GatheringByteChannel, BufferChain)}. Only - * one of readCh or writeCh should be non-null. - * - * @param readCh read channel - * @param writeCh write channel - * @param buf buffer to read or write into/out of - * @return bytes written - * @throws java.io.IOException e - * @see #channelRead(java.nio.channels.ReadableByteChannel, java.nio.ByteBuffer) - * @see #channelWrite(GatheringByteChannel, BufferChain) - */ - protected static int channelIO(ReadableByteChannel readCh, - WritableByteChannel writeCh, - ByteBuffer buf) throws IOException { - - int originalLimit = buf.limit(); - int initialRemaining = buf.remaining(); - int ret = 0; - - while (buf.remaining() > 0) { - try { - int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); - buf.limit(buf.position() + ioSize); - - ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); - - if (ret < ioSize) { - break; - } - - } finally { - buf.limit(originalLimit); - } - } - - int nBytes = initialRemaining - buf.remaining(); - return (nBytes > 0) ? nBytes : ret; - } - - /** * A convenience method to bind to a given address and report * better exceptions if the address is not a valid host. * @param socket the socket to bind diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/AbstractTestIPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/AbstractTestIPC.java index a1a73c1..581e50e 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/AbstractTestIPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/AbstractTestIPC.java @@ -321,7 +321,7 @@ public abstract class AbstractTestIPC { } @Override - protected void processRequest(ByteBuff buf) throws IOException, InterruptedException { + public void processRequest(ByteBuff buf) throws IOException, InterruptedException { // this will throw exception after the connection header is read, and an RPC is sent // from client throw new DoNotRetryIOException("Failing for test"); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestNettyRpcServer.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestNettyRpcServer.java new file mode 100644 index 0000000..81be74d --- /dev/null +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestNettyRpcServer.java @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.hbase.HBaseTestingUtility; +import org.apache.hadoop.hbase.TableName; +import org.apache.hadoop.hbase.client.Put; +import org.apache.hadoop.hbase.client.Result; +import org.apache.hadoop.hbase.client.ResultScanner; +import org.apache.hadoop.hbase.client.Scan; +import org.apache.hadoop.hbase.client.Table; +import org.apache.hadoop.hbase.testclassification.RPCTests; +import org.apache.hadoop.hbase.testclassification.SmallTests; +import org.apache.hadoop.hbase.util.Bytes; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +@Category({ RPCTests.class, SmallTests.class }) +public class TestNettyRpcServer { + @Rule + public TestName name = new TestName(); + private static HBaseTestingUtility TEST_UTIL; + + private static TableName TABLE; + private static byte[] FAMILY = Bytes.toBytes("f1"); + private static byte[] PRIVATE_COL = Bytes.toBytes("private"); + private static byte[] PUBLIC_COL = Bytes.toBytes("public"); + + @Before + public void setup() { + TABLE = TableName.valueOf(name.getMethodName()); + } + + @BeforeClass + public static void setupBeforeClass() throws Exception { + TEST_UTIL = new HBaseTestingUtility(); + TEST_UTIL.getConfiguration().set( + RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, + NettyRpcServer.class.getName()); + TEST_UTIL.startMiniCluster(); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + TEST_UTIL.shutdownMiniCluster(); + } + + @Test(timeout = 180000) + public void testNettyRpcServer() throws Exception { + final Table table = TEST_UTIL.createTable(TABLE, FAMILY); + try { + // put some test data + List puts = new ArrayList(100); + for (int i = 0; i < 100; i++) { + Put p = new Put(Bytes.toBytes(i)); + p.addColumn(FAMILY, PRIVATE_COL, Bytes.toBytes("secret " + i)); + p.addColumn(FAMILY, PUBLIC_COL, Bytes.toBytes("info " + i)); + puts.add(p); + } + table.put(puts); + + // read to verify it. + Scan scan = new Scan(); + scan.setCaching(16); + ResultScanner rs = table.getScanner(scan); + int rowcnt = 0; + for (Result r : rs) { + rowcnt++; + int rownum = Bytes.toInt(r.getRow()); + assertTrue(r.containsColumn(FAMILY, PRIVATE_COL)); + assertEquals("secret " + rownum, + Bytes.toString(r.getValue(FAMILY, PRIVATE_COL))); + assertTrue(r.containsColumn(FAMILY, PUBLIC_COL)); + assertEquals("info " + rownum, + Bytes.toString(r.getValue(FAMILY, PUBLIC_COL))); + } + assertEquals("Expected 100 rows returned", 100, rowcnt); + } finally { + table.close(); + } + } + +} diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java index b039003..f21359c 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestProtoBufRpc.java @@ -24,6 +24,8 @@ import static org.junit.Assert.fail; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collection; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.HBaseConfiguration; @@ -40,6 +42,10 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.common.collect.Lists; @@ -48,6 +54,7 @@ import com.google.common.collect.Lists; * of types in src/test/protobuf/test.proto and protobuf service definition from * src/test/protobuf/test_rpc_service.proto */ +@RunWith(Parameterized.class) @Category({ RPCTests.class, MediumTests.class }) public class TestProtoBufRpc { public final static String ADDRESS = "localhost"; @@ -56,9 +63,20 @@ public class TestProtoBufRpc { private Configuration conf; private RpcServerInterface server; + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + @Before public void setUp() throws IOException { // Setup server for both protocols this.conf = HBaseConfiguration.create(); + this.conf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, + rpcServerImpl); Logger log = Logger.getLogger("org.apache.hadoop.ipc.HBaseServer"); log.setLevel(Level.DEBUG); log = Logger.getLogger("org.apache.hadoop.ipc.HBaseServer.trace"); diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java index 449899f..c12331f 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcHandlerException.java @@ -23,11 +23,14 @@ import static org.mockito.Mockito.mock; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collection; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.Abortable; import org.apache.hadoop.hbase.HBaseConfiguration; import org.apache.hadoop.hbase.ipc.RpcServer.BlockingServiceAndInterface; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos.EchoRequestProto; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto.BlockingInterface; import org.apache.hadoop.hbase.testclassification.RPCTests; @@ -35,11 +38,14 @@ import org.apache.hadoop.hbase.testclassification.SmallTests; import org.junit.Ignore; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.common.collect.Lists; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; - +@RunWith(Parameterized.class) @Category({ RPCTests.class, SmallTests.class }) public class TestRpcHandlerException { @@ -64,6 +70,15 @@ public class TestRpcHandlerException { } } + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + /* * This is a unit test to make sure to abort region server when the number of Rpc handler thread * caught errors exceeds the threshold. Client will hang when RS aborts. @@ -73,6 +88,7 @@ public class TestRpcHandlerException { public void testRpcScheduler() throws IOException, InterruptedException { PriorityFunction qosFunction = mock(PriorityFunction.class); Abortable abortable = new AbortServer(); + CONF.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl); RpcScheduler scheduler = new SimpleRpcScheduler(CONF, 2, 0, 0, qosFunction, abortable, 0); RpcServer rpcServer = RpcServerFactory.createRpcServer(null, "testRpcServer", Lists.newArrayList(new BlockingServiceAndInterface((BlockingService) SERVICE, null)), diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java index c848250..85a14f2 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/TestSecureIPC.java @@ -35,6 +35,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import javax.security.sasl.SaslException; @@ -46,11 +47,14 @@ import org.apache.hadoop.hbase.HConstants; import org.apache.hadoop.hbase.ipc.BlockingRpcClient; import org.apache.hadoop.hbase.ipc.FifoRpcScheduler; import org.apache.hadoop.hbase.ipc.NettyRpcClient; +import org.apache.hadoop.hbase.ipc.NettyRpcServer; import org.apache.hadoop.hbase.ipc.RpcClient; import org.apache.hadoop.hbase.ipc.RpcClientFactory; import org.apache.hadoop.hbase.ipc.RpcServer; -import org.apache.hadoop.hbase.ipc.RpcServerInterface; import org.apache.hadoop.hbase.ipc.RpcServerFactory; +import org.apache.hadoop.hbase.ipc.RpcServerInterface; +import org.apache.hadoop.hbase.ipc.SimpleRpcServer; +import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos; import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto.BlockingInterface; import org.apache.hadoop.hbase.testclassification.SecurityTests; @@ -72,7 +76,6 @@ import org.junit.runners.Parameterized.Parameters; import org.mockito.Mockito; import com.google.common.collect.Lists; -import org.apache.hadoop.hbase.shaded.com.google.protobuf.BlockingService; @RunWith(Parameterized.class) @Category({ SecurityTests.class, SmallTests.class }) @@ -96,15 +99,27 @@ public class TestSecureIPC { @Rule public ExpectedException exception = ExpectedException.none(); - @Parameters(name = "{index}: rpcClientImpl={0}") + @Parameters(name = "{index}: rpcClientImpl={0}, rpcServerImpl={1}") public static Collection parameters() { - return Arrays.asList(new Object[]{BlockingRpcClient.class.getName()}, - new Object[]{NettyRpcClient.class.getName()}); + List params = new ArrayList<>(); + List rpcClientImpls = Arrays.asList( + BlockingRpcClient.class.getName(), NettyRpcClient.class.getName()); + List rpcServerImpls = Arrays.asList( + SimpleRpcServer.class.getName(), NettyRpcServer.class.getName()); + for (String rpcClientImpl : rpcClientImpls) { + for (String rpcServerImpl : rpcServerImpls) { + params.add(new Object[] { rpcClientImpl, rpcServerImpl }); + } + } + return params; } - @Parameter + @Parameter(0) public String rpcClientImpl; + @Parameter(1) + public String rpcServerImpl; + @BeforeClass public static void setUp() throws Exception { KDC = TEST_UTIL.setupMiniKdc(KEYTAB_FILE); @@ -129,6 +144,8 @@ public class TestSecureIPC { clientConf = getSecuredConfiguration(); clientConf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcClientImpl); serverConf = getSecuredConfiguration(); + serverConf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, + rpcServerImpl); } @Test diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java index 92eaecc..5186235 100644 --- a/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/security/token/TestTokenAuthentication.java @@ -27,6 +27,8 @@ import java.io.IOException; import java.io.InterruptedIOException; import java.net.InetSocketAddress; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; @@ -49,11 +51,13 @@ import org.apache.hadoop.hbase.client.ConnectionFactory; import org.apache.hadoop.hbase.client.Table; import org.apache.hadoop.hbase.coprocessor.RegionCoprocessorEnvironment; import org.apache.hadoop.hbase.ipc.FifoRpcScheduler; +import org.apache.hadoop.hbase.ipc.NettyRpcServer; import org.apache.hadoop.hbase.ipc.RpcServer; import org.apache.hadoop.hbase.ipc.RpcServer.BlockingServiceAndInterface; import org.apache.hadoop.hbase.ipc.RpcServerFactory; import org.apache.hadoop.hbase.ipc.RpcServerInterface; import org.apache.hadoop.hbase.ipc.ServerRpcController; +import org.apache.hadoop.hbase.ipc.SimpleRpcServer; import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos; import org.apache.hadoop.hbase.regionserver.HRegion; import org.apache.hadoop.hbase.regionserver.RegionServerServices; @@ -79,10 +83,14 @@ import org.apache.hadoop.security.authorize.Service; import org.apache.hadoop.security.token.SecretManager; import org.apache.hadoop.security.token.Token; import org.apache.hadoop.security.token.TokenIdentifier; -import org.junit.AfterClass; -import org.junit.BeforeClass; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; import com.google.protobuf.BlockingService; import com.google.protobuf.RpcController; @@ -96,6 +104,7 @@ import com.google.protobuf.ServiceException; // RpcServer is all about shaded protobuf whereas the Token Service is a CPEP which does non-shaded // protobufs. Since hbase-2.0.0, we added convertion from shaded to non-shaded so this test keeps // working. +@RunWith(Parameterized.class) @Category({SecurityTests.class, MediumTests.class}) public class TestTokenAuthentication { static { @@ -115,6 +124,7 @@ public class TestTokenAuthentication { AuthenticationProtos.AuthenticationService.BlockingInterface, Runnable, Server { private static final Log LOG = LogFactory.getLog(TokenServer.class); private Configuration conf; + private HBaseTestingUtility TEST_UTIL; private RpcServerInterface rpcServer; private InetSocketAddress isa; private ZooKeeperWatcher zookeeper; @@ -124,8 +134,10 @@ public class TestTokenAuthentication { private boolean stopped = false; private long startcode; - public TokenServer(Configuration conf) throws IOException { + public TokenServer(Configuration conf, HBaseTestingUtility TEST_UTIL) + throws IOException { this.conf = conf; + this.TEST_UTIL = TEST_UTIL; this.startcode = EnvironmentEdgeManager.currentTime(); // Server to handle client requests. String hostname = @@ -387,14 +399,23 @@ public class TestTokenAuthentication { } } - private static HBaseTestingUtility TEST_UTIL; - private static TokenServer server; - private static Thread serverThread; - private static AuthenticationTokenSecretManager secretManager; - private static ClusterId clusterId = new ClusterId(); + @Parameters(name = "{index}: rpcServerImpl={0}") + public static Collection parameters() { + return Arrays.asList(new Object[] { SimpleRpcServer.class.getName() }, + new Object[] { NettyRpcServer.class.getName() }); + } + + @Parameter(0) + public String rpcServerImpl; + + private HBaseTestingUtility TEST_UTIL; + private TokenServer server; + private Thread serverThread; + private AuthenticationTokenSecretManager secretManager; + private ClusterId clusterId = new ClusterId(); - @BeforeClass - public static void setupBeforeClass() throws Exception { + @Before + public void setUp() throws Exception { TEST_UTIL = new HBaseTestingUtility(); TEST_UTIL.startMiniZKCluster(); // register token type for protocol @@ -406,7 +427,8 @@ public class TestTokenAuthentication { conf.set("hadoop.security.authentication", "kerberos"); conf.set("hbase.security.authentication", "kerberos"); conf.setBoolean(HADOOP_SECURITY_AUTHORIZATION, true); - server = new TokenServer(conf); + conf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl); + server = new TokenServer(conf, TEST_UTIL); serverThread = new Thread(server); Threads.setDaemonThreadRunning(serverThread, "TokenServer:"+server.getServerName().toString()); // wait for startup @@ -428,8 +450,8 @@ public class TestTokenAuthentication { } } - @AfterClass - public static void tearDownAfterClass() throws Exception { + @After + public void tearDown() throws Exception { server.stop("Test complete"); Threads.shutdown(serverThread); TEST_UTIL.shutdownMiniZKCluster();